Skip to content

Commit

Permalink
fix: modified negative apporximations
Browse files Browse the repository at this point in the history
  • Loading branch information
charliekush committed Nov 4, 2024
1 parent 5359adc commit 9049992
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 51 deletions.
23 changes: 13 additions & 10 deletions src/approx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#include <stack>


Approx::Approx(std::string raw_input, std::string diffVar, double value)
Approx::Approx(std::string raw_input, std::string diffVar)
{
this->value = value;

this->diffVar = std::make_shared<Variable>(diffVar);
Tokenizer parser(raw_input);
auto parsed = parser.tokenize();
Expand All @@ -25,14 +25,14 @@ Approx::Approx(std::string raw_input, std::string diffVar, double value)

this->derivative = Derivative(raw_input,"x").solve();
}
std::pair<double,double> Approx::approximate()
std::pair<double,double> Approx::approximate(double value)
{


double originalApprox = approximate(this->root, this->diffVar, this->value);
double originalApprox = approximate(this->root, this->diffVar, value);

double derivativeApprox = approximate(this->derivative, this->diffVar,
this->value);
double derivativeApprox = approximate(this->derivative,
this->diffVar, value);
return std::make_pair(originalApprox, derivativeApprox);
}
std::shared_ptr<Token> Approx::replaceToken(nodePtr node,
Expand Down Expand Up @@ -61,19 +61,20 @@ std::shared_ptr<Token> Approx::replaceToken(nodePtr node,
else if (node->getType() == TokenType::FUNCTION)
{
auto func = std::dynamic_pointer_cast<Function>(node->getToken());
auto subExpr = func->getSubExprTree();
auto subExpr = func->getSubExprTree()->copyTree();
auto funcIter = Lookup::functionLookup.find(node->getStr());

auto subApprox = approximate(subExpr, wrt, value);
double subApprox = approximate(subExpr, wrt, value);
if (subApprox == DBL_MAX)
{
func->getSubExprTree()->printTree();
throw std::runtime_error("Function subexpr not a number");
}
auto newSubRoot = std::make_shared<Number>(std::to_string(subApprox),
subApprox);
func->setSubExprTree(std::make_shared<ExpressionNode>(newSubRoot));
outToken = func;
auto newFunc = std::make_shared<Function>(func->getStr());
newFunc->setSubExprTree(std::make_shared<ExpressionNode>(newSubRoot));
outToken = newFunc;


}
Expand All @@ -83,6 +84,7 @@ double Approx::approximate(nodePtr root, std::shared_ptr<Variable> wrt,
double value)
{
auto rootCopy = root->copyTree();
std::cout << "Approximating: " << TextConverter::convertToText(rootCopy) << "\n";
std::stack<nodePtr> stack;
nodePtr current = rootCopy;
while (current != nullptr || !stack.empty()) {
Expand All @@ -103,6 +105,7 @@ double Approx::approximate(nodePtr root, std::shared_ptr<Variable> wrt,

bool tempBool = Arithmetic::floatSimplification;
Arithmetic::floatSimplification = true;
std::cout << "To simplify: " << TextConverter::convertToText(rootCopy) << "\n";
try
{
TreeFixer::simplify(rootCopy);
Expand Down
8 changes: 3 additions & 5 deletions src/approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@ class Approx
typedef std::shared_ptr<Number> numPtr;
nodePtr root;
nodePtr derivative;
double value;
std::shared_ptr<Variable> diffVar;
static std::shared_ptr<Token> replaceToken(nodePtr node,
std::shared_ptr<Variable> wrt,
double value);
public:
Approx(std::string raw_input, std::string diffVar, double value);
Approx(std::string raw_input, std::string diffVar);

std::pair<double,double> approximate();
std::pair<double,double> approximate(double value);
static double approximate(nodePtr node,
std::shared_ptr<Variable> wrt,
double value);
std::shared_ptr<Variable> wrt, double value);

};

Expand Down
77 changes: 52 additions & 25 deletions src/arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

#include <cmath>
#include <iostream>
#include "tree_fixer.hpp"


bool Arithmetic::floatSimplification = true;

std::shared_ptr<Number> Arithmetic::performOperation(const operation& op,
numPtr left, numPtr right, bool isDivision = false)
{

// Handle division separately
if (isDivision)
{
Expand All @@ -20,7 +22,7 @@ std::shared_ptr<Number> Arithmetic::performOperation(const operation& op,
// If evenly divisible
if (left->getInt() % right->getInt() == 0)
{
int result = left->getInt() / right->getInt();
int result = (left->getInt()) / (right->getInt());
// return int
auto out = std::make_shared<Number>(std::to_string(result),
result);
Expand All @@ -29,14 +31,14 @@ std::shared_ptr<Number> Arithmetic::performOperation(const operation& op,
}
else
{

double result = static_cast<double>(left->getInt()) /
right->getInt();
( right->getInt());
// Return a float
auto out = std::make_shared<Number>(std::to_string(result),
result);
out->setNegative(result < 0);
if (std::fmod(result, 1) != 0 &&
if (std::fmod(result, 1) != 0 &&
!Arithmetic::floatSimplification)
{
return nullptr;
Expand All @@ -50,11 +52,12 @@ std::shared_ptr<Number> Arithmetic::performOperation(const operation& op,
// Handle cases where both operands are integers for non-division operations
if (left->isInt() && right->isInt())
{
int result = static_cast<int>(op(left->getInt(), right->getInt()));
int result = static_cast<int>(op((left->getInt()),
right->getInt()));
// Return an int
auto out = std::make_shared<Number>(std::to_string(result), result);
out->setNegative(result < 0);

return out;


Expand All @@ -64,23 +67,24 @@ std::shared_ptr<Number> Arithmetic::performOperation(const operation& op,
double result;
if (left->isInt() && right->isDouble())
{
result = op(left->getInt(), right->getDouble());
result = op( left->getInt(), right->getDouble());
}
else if (left->isDouble() && right->isInt())
{
result = op(left->getDouble(), right->getInt());
result = op( left->getDouble(), right->getInt());
}
else if (left->isDouble() && right->isDouble())
{
result = op(left->getDouble(), right->getDouble());
result = op( left->getDouble(),
right->getDouble());
}
else
{
return nullptr;
}

double floatPart = std::modf(result, &floatPart);

if (floatPart == 0.0)
{
auto out = std::make_shared<Number>(std::to_string((int)result),
Expand Down Expand Up @@ -168,7 +172,7 @@ void Arithmetic::setNodeToOne(nodePtr& operatorNode) {
std::make_shared<Number>("0", 0)));
}

void Arithmetic::simplify(nodePtr node, numPtr left, numPtr right)
void Arithmetic::simplify(nodePtr node)
{
if (node->getStr() == "^")
{
Expand Down Expand Up @@ -201,13 +205,13 @@ void Arithmetic::simplifyExponent(nodePtr& operatorNode)
auto value = Arithmetic::power(operatorNode, leftNum, rightNum);
if (value)
{
//std::cout << leftNum->getStr() << "^" << rightNum->getStr() << " = " << value->getStr() << "\n";
std::cout << leftNum->getFullStr() << "^" << rightNum->getFullStr() << " = " << value->getFullStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
operatorNode->setDerivative(std::make_shared<ExpressionNode>(
std::make_shared<Number>("0", 0)));
std::make_shared<Number>("0", 0)));

return;
}
}
Expand Down Expand Up @@ -247,18 +251,18 @@ void Arithmetic::simplifyMultiplication(nodePtr& operatorNode)
{
auto leftNum = getNumberToken(operatorNode->getLeft());
auto rightNum = getNumberToken(operatorNode->getRight());

if (leftNum && rightNum)
{
auto value = Arithmetic::multiply(operatorNode, leftNum, rightNum);
if (value)
{
//std::cout << leftNum->getStr() << "*" << rightNum->getStr() << " = " << value->getStr() << "\n";
std::cout << leftNum->getFullStr() << "*" << rightNum->getFullStr() << " = " << value->getFullStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
operatorNode->setDerivative(std::make_shared<ExpressionNode>(
std::make_shared<Number>("0", 0)));
std::make_shared<Number>("0", 0)));
return;
}
}
Expand All @@ -270,21 +274,44 @@ void Arithmetic::simplifyMultiplication(nodePtr& operatorNode)
}
else if (leftNum->equals(1))
{

TreeModifier::replaceWithRightChild(operatorNode);
}
}
else if (rightNum)
{
if (rightNum->equals(0))
{
{
setNodeToZero(operatorNode);
}
else if (rightNum->equals(1))
{
TreeModifier::replaceWithLeftChild(operatorNode);
}
}

if (TreeFixer::treesEqual(operatorNode->getLeft(), operatorNode->getRight()))
{
operatorNode->setToken(std::make_shared<Operator>("^"));
operatorNode->setLeft(std::make_shared<ExpressionNode>(
std::make_shared<Number>("2", 2)));
}
if (operatorNode->getLeft()->getStr() == "^" &&
TreeFixer::treesEqual(operatorNode->getLeft()->getLeft(),
operatorNode->getRight()))
{
operatorNode->swapChildren();
operatorNode->setRight(operatorNode->getRight()->getLeft());
Arithmetic::simplify(operatorNode->getRight());
}
if (operatorNode->getRight()->getStr() == "^" &&
TreeFixer::treesEqual(operatorNode->getRight()->getLeft(),
operatorNode->getLeft()))
{

operatorNode->setRight(operatorNode->getRight()->getRight());
Arithmetic::simplify(operatorNode->getRight());
}
}


Expand All @@ -297,12 +324,12 @@ void Arithmetic::simplifyDivision(nodePtr& operatorNode)
auto value = Arithmetic::divide(operatorNode, leftNum, rightNum);
if (value)
{
//std::cout << leftNum->getStr() << "/" << rightNum->getStr() << " = " << value->getStr() << "\n";
std::cout << leftNum->getFullStr() << "/" << rightNum->getFullStr() << " = " << value->getFullStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
operatorNode->setDerivative(std::make_shared<ExpressionNode>(
std::make_shared<Number>("0", 0)));
std::make_shared<Number>("0", 0)));
return;
}
}
Expand Down Expand Up @@ -336,12 +363,12 @@ void Arithmetic::simplifyAddition(nodePtr& operatorNode)
auto value = Arithmetic::add(operatorNode, leftNum, rightNum);
if (value)
{
//std::cout << leftNum->getStr() << "+" << rightNum->getStr() << " = " << value->getStr() << "\n";
std::cout << leftNum->getFullStr() << "+" << rightNum->getFullStr() << " = " << value->getFullStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
operatorNode->setDerivative(std::make_shared<ExpressionNode>(
std::make_shared<Number>("0", 0)));
std::make_shared<Number>("0", 0)));
return;
}
}
Expand Down Expand Up @@ -371,12 +398,12 @@ void Arithmetic::simplifySubtraction(nodePtr& operatorNode)
auto value = Arithmetic::subtract(operatorNode, leftNum, rightNum);
if (value)
{
//std::cout << leftNum->getStr() << "-" << rightNum->getStr() << " = " << value->getStr() << "\n";
std::cout << leftNum->getFullStr() << "-" << rightNum->getFullStr() << " = " << value->getFullStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
operatorNode->setDerivative(std::make_shared<ExpressionNode>(
std::make_shared<Number>("0", 0)));
std::make_shared<Number>("0", 0)));
return;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/arithmetic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Arithmetic
static numPtr divide(nodePtr operatorNode, numPtr left, numPtr right);
static numPtr add(nodePtr operatorNode, numPtr left, numPtr right);
static numPtr subtract(nodePtr operatorNode, numPtr left, numPtr right);
static void simplify(nodePtr operatorNode, numPtr left, numPtr right);
static void simplify(nodePtr operatorNode);

static void simplifyExponent(nodePtr& operatorNode);
static void simplifyMultiplication(nodePtr& operatorNode);
Expand Down
5 changes: 2 additions & 3 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ int main(int argc, char const* argv[])
input = "x(ln(x)-1)";
wrt = "x";
test_expr = "";
values = {"1.0","-1.0"};
values = {1.0,-2.0,3.0};
#else
std::vector<std::string> args(argv, argv + argc);
Options options;
Expand All @@ -173,7 +173,7 @@ int main(int argc, char const* argv[])
Logger log(false);
auto derivative = getDerivative(log, input, wrt);

Approx approximator(input, wrt, value);
Approx approximator(input, wrt);

for (const double& v : values)
{
Expand All @@ -200,7 +200,6 @@ int main(int argc, char const* argv[])
if (expected != actual)
{
same = false;

}
}

Expand Down
8 changes: 5 additions & 3 deletions src/token.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "token_queue.hpp"

#include <stdexcept>
#include <cstdlib>

/**
* @brief Constructs a Token with specified type and string.
Expand Down Expand Up @@ -291,7 +292,7 @@ bool Number::isDouble() const
*/
int Number::getInt() const
{
int out = std::get<int>(value);
int out = std::abs(std::get<int>(value));
if (this->isNegative())
{
out *= -1;
Expand All @@ -306,7 +307,7 @@ int Number::getInt() const
*/
double Number::getDouble() const
{
double out = std::get<double>(value);
double out = std::abs(std::get<double>(value));
if (this->isNegative())
{
out *= -1.0;
Expand Down Expand Up @@ -369,7 +370,8 @@ bool Number::equals(int other)

if (this->isInt())
{
return other == this->getInt();
return (other == this->getInt()) &&
((other < 0) == this->isNegative());
}
else
{
Expand Down
Loading

0 comments on commit 9049992

Please sign in to comment.