Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anutosh491 committed Oct 31, 2023
1 parent e2a273b commit 5455dd0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
4 changes: 4 additions & 0 deletions integration_tests/symbolics_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def test_symbolic_operations():
else:
assert False
assert(z.func == Add)
assert(z.args[0] == x or z.args[0] == y)
assert(z.args[1] == y or z.args[1] == x)
print(z)

# Subtraction
Expand All @@ -43,6 +45,8 @@ def test_symbolic_operations():
else:
assert False
assert(u.func == Mul)
assert(u.args[0] == x)
assert(u.args[1] == y)
print(u)

# Division
Expand Down
7 changes: 7 additions & 0 deletions integration_tests/symbolics_05.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,12 @@ def test_operations():
assert((sin(x) + cos(x)).diff(x) == S(-1)*c + d)
assert((sin(x) + cos(x) + exp(x) + pi).diff(x).expand().diff(x) == exp(x) + S(-1)*c + S(-1)*d)

# test args
assert(a.args[0] == x + y)
assert(a.args[1] == S(2))
assert(b.args[0] == x + y + z)
assert(b.args[1] == S(3))
assert(c.args[0] == x)
assert(d.args[0] == x)

test_operations()
24 changes: 21 additions & 3 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1707,9 +1707,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
pass_result.push_back(al, assert_stmt);
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
SymbolTable* module_scope = current_scope->parent;
ASR::expr_t* left_tmp = nullptr;
ASR::expr_t* right_tmp = nullptr;

ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym);
Expand All @@ -1726,6 +1723,27 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
}
} else if (ASR::is_a<ASR::LogicalBinOp_t>(*x.m_test)) {
ASR::LogicalBinOp_t* binop = ASR::down_cast<ASR::LogicalBinOp_t>(x.m_test);
if (ASR::is_a<ASR::SymbolicCompare_t>(*binop->m_left) && ASR::is_a<ASR::SymbolicCompare_t>(*binop->m_right)) {
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
ASR::SymbolicCompare_t *s1 = ASR::down_cast<ASR::SymbolicCompare_t>(binop->m_left);
left_tmp = process_with_basic_str(al, x.base.base.loc, s1->m_left, basic_str_sym);
right_tmp = process_with_basic_str(al, x.base.base.loc, s1->m_right, basic_str_sym);
ASR::expr_t* test1 = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
s1->m_op, right_tmp, s1->m_type, s1->m_value));

ASR::SymbolicCompare_t *s2 = ASR::down_cast<ASR::SymbolicCompare_t>(binop->m_right);
left_tmp = process_with_basic_str(al, x.base.base.loc, s2->m_left, basic_str_sym);
right_tmp = process_with_basic_str(al, x.base.base.loc, s2->m_right, basic_str_sym);
ASR::expr_t* test2 = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
s2->m_op, right_tmp, s2->m_type, s2->m_value));

ASR::expr_t *cond = ASRUtils::EXPR(ASR::make_LogicalBinOp_t(al, x.base.base.loc,
test1, ASR::logicalbinopType::Or, test2, binop->m_type, binop->m_value));
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, cond, x.m_msg));
pass_result.push_back(al, assert_stmt);
}
}
}
};
Expand Down

0 comments on commit 5455dd0

Please sign in to comment.