Skip to content

Commit

Permalink
feat: added test and approximation in main
Browse files Browse the repository at this point in the history
  • Loading branch information
charliekush committed Oct 31, 2024
1 parent bc549a8 commit 7ff4f16
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 71 deletions.
44 changes: 14 additions & 30 deletions src/approx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,33 @@ Approx::Approx(std::string raw_input, std::string diffVar, double value)
}
std::pair<double,double> Approx::approximate()
{
std::cout << "original: " << TextConverter::convertToText(this->root)
<< "\n";

double originalApprox = approximate(this->root);
std::cout << "\n\nDerivative: " <<
TextConverter::convertToText(this->derivative) << "\n";
double derivativeApprox = approximate(this->derivative);

double originalApprox = approximate(this->root, this->diffVar, this->value);

double derivativeApprox = approximate(this->derivative, this->diffVar,
this->value);
return std::make_pair(originalApprox, derivativeApprox);
}
std::shared_ptr<Token> Approx::replaceToken(nodePtr node)
std::shared_ptr<Token> Approx::replaceToken(nodePtr node,
std::shared_ptr<Variable> wrt,
double value)
{
if (!node)
{
return nullptr;
}
//std::cout << "\nToken: " << node->getStr() << "\n";
//std::cout << "Type: " << Lookup::getTokenType(node->getType()) << "\n";
std::shared_ptr<Token> outToken = node->getToken();
if (node->getType() == TokenType::VARIABLE)
{
std::shared_ptr<Number> num;
if (diffVar->equals(node->getToken()))
if (wrt->equals(node->getToken()))
{
//std::cout << "Token matches diffVar: " << node->getStr() << "\n";
//std::cout << "Updating token to: " << this->value << "\n";

outToken = std::make_shared<Number>(std::to_string(this->value),
value);
outToken = std::make_shared<Number>(std::to_string(value), value);
}
else
{
//std::cout << "Token does not match diffVar: " << node->getStr() << "\n";
outToken = std::make_shared<Number>("1.0",1.0);
}

Expand All @@ -69,7 +64,7 @@ std::shared_ptr<Token> Approx::replaceToken(nodePtr node)
auto subExpr = func->getSubExprTree();
auto funcIter = Lookup::functionLookup.find(node->getStr());

auto subApprox = approximate(subExpr);
auto subApprox = approximate(subExpr, wrt, value);
if (subApprox == DBL_MAX)
{
func->getSubExprTree()->printTree();
Expand All @@ -84,40 +79,29 @@ std::shared_ptr<Token> Approx::replaceToken(nodePtr node)
}
return outToken;
}
double Approx::approximate(nodePtr root)
double Approx::approximate(nodePtr root, std::shared_ptr<Variable> wrt,
double value)
{
auto rootCopy = root->copyTree();
auto leaves = ExpressionNode::getLeaves(rootCopy);
//std::cout << "Before replace: " << TextConverter::convertToText(rootCopy) << "\n";
std::stack<nodePtr> stack;
nodePtr current = rootCopy;
while (current != nullptr || !stack.empty()) {

// Reach the left most Node of the
// curr Node
while (current != nullptr) {

// Place pointer to a tree node on
// the stack before traversing
// the node's left subtree
stack.push(current);
current = current->getLeft();
}

// Current must be NULL at this point
current = stack.top();
stack.pop();
auto replacement = this->replaceToken(current);
auto replacement = replaceToken(current, wrt, value);
current->setToken(replacement);
//std::cout << "token: " << current->getStr() << "\n";
// we have visited the node and its
// left subtree. Now, it's right
// subtree's turn
current = current->getRight();

}

//std::cout << "After replace: " << TextConverter::convertToText(rootCopy) << "\n";
bool tempBool = Arithmetic::floatSimplification;
Arithmetic::floatSimplification = true;
TreeFixer::simplify(rootCopy);
Expand Down
10 changes: 7 additions & 3 deletions src/approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ class Approx
nodePtr derivative;
double value;
std::shared_ptr<Variable> diffVar;
std::shared_ptr<Token> replaceToken(nodePtr node);
static std::shared_ptr<Token> replaceToken(nodePtr node,
std::shared_ptr<Variable> wrt,
double value);
public:
Approx(std::string raw_input, std::string diffVar, double value);

std::pair<double,double> approximate();
double approximate(nodePtr root);
static double approximate(nodePtr node,
std::shared_ptr<Variable> wrt,
double value);

};


#endif //
#endif
15 changes: 5 additions & 10 deletions src/arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ void Arithmetic::simplifyExponent(nodePtr& operatorNode)
auto value = Arithmetic::power(operatorNode, leftNum, rightNum);
if (value)
{
std::cout << leftNum->getStr() << "^" << rightNum->getStr()
<< " = " << value->getStr() << "\n";
//std::cout << leftNum->getStr() << "^" << rightNum->getStr() << " = " << value->getStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
Expand Down Expand Up @@ -254,8 +253,7 @@ void Arithmetic::simplifyMultiplication(nodePtr& operatorNode)
auto value = Arithmetic::multiply(operatorNode, leftNum, rightNum);
if (value)
{
std::cout << leftNum->getStr() << "*" << rightNum->getStr()
<< " = " << value->getStr() << "\n";
//std::cout << leftNum->getStr() << "*" << rightNum->getStr() << " = " << value->getStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
Expand Down Expand Up @@ -299,8 +297,7 @@ void Arithmetic::simplifyDivision(nodePtr& operatorNode)
auto value = Arithmetic::divide(operatorNode, leftNum, rightNum);
if (value)
{
std::cout << leftNum->getStr() << "/" << rightNum->getStr()
<< " = " << value->getStr() << "\n";
//std::cout << leftNum->getStr() << "/" << rightNum->getStr() << " = " << value->getStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
Expand Down Expand Up @@ -339,8 +336,7 @@ void Arithmetic::simplifyAddition(nodePtr& operatorNode)
auto value = Arithmetic::add(operatorNode, leftNum, rightNum);
if (value)
{
std::cout << leftNum->getStr() << "+" << rightNum->getStr()
<< " = " << value->getStr() << "\n";
//std::cout << leftNum->getStr() << "+" << rightNum->getStr() << " = " << value->getStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
Expand Down Expand Up @@ -375,8 +371,7 @@ void Arithmetic::simplifySubtraction(nodePtr& operatorNode)
auto value = Arithmetic::subtract(operatorNode, leftNum, rightNum);
if (value)
{
std::cout << leftNum->getStr() << "-" << rightNum->getStr()
<< " = " << value->getStr() << "\n";
//std::cout << leftNum->getStr() << "-" << rightNum->getStr() << " = " << value->getStr() << "\n";
operatorNode->removeLeftChild();
operatorNode->removeRightChild();
operatorNode->setToken(value);
Expand Down
45 changes: 44 additions & 1 deletion src/log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,14 @@ void Logger::addBrace(std::string c, bool endComma)

}


void Logger::logTest(std::string testStr, bool pass)
{
tests.emplace_back(std::make_pair(testStr,pass));
}
void Logger::logApprox(double sub, double out)
{
approximations.emplace_back(std::make_pair(sub, out));
}

std::string Logger::out()
{
Expand Down Expand Up @@ -214,6 +221,42 @@ std::string Logger::out()

this->addPair("input", this->input);
this->addPair("output", this->output);
if (this->tests.size() > 0 )
{
this->addLine("equality tests",false);
this->addBrace("[");
for (int i = 0; i < tests.size(); i++)
{
this->outStr += this->indent() + str(tests[i].first) + ": " +
(tests[i].second ? "true" : "false");
if ((i + 1) != tests.size())
{
this->outStr += ",";
}
this->outStr += "\n";

}
this->addBrace("]",true);
}
if (this->approximations.size() > 0 )
{
this->addLine("approximations",false);
this->addBrace("[");
for (int i = 0; i < tests.size(); i++)
{
this->outStr += this->indent() +
std::to_string(approximations[i].first) + ": " +
std::to_string(approximations[i].second) ;
if ((i + 1) != approximations.size())
{
this->outStr += ",";
}
this->outStr += "\n";

}
this->addBrace("]",true);
}

this->addPair("mode", this->mode,false);
this->addBrace("}");

Expand Down
7 changes: 6 additions & 1 deletion src/log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include <vector>
#include <string>
#include <utility>

class Logger
{
Expand All @@ -24,7 +25,9 @@ class Logger
bool endComma=true);
std::string outStr;
void addBrace(std::string c, bool endComma=false);
std::string cleanBraces(const std::string& input);
std::vector<std::pair<std::string,bool>> tests;
std::vector<std::pair<double,double>> approximations;

public:
Logger(bool useLaTeX);
void setInput(std::string input);
Expand All @@ -36,6 +39,8 @@ class Logger
void logPowerRule(nodePtr node);
void logAddition(nodePtr node);
void logSubtraction(nodePtr node);
void logTest(std::string testStr, bool pass);
void logApprox(double sub, double out);
std::string out();

};
Expand Down
Loading

0 comments on commit 7ff4f16

Please sign in to comment.