From 06437680bf899ced6490de93064b73399b659a02 Mon Sep 17 00:00:00 2001 From: Mojtaba Eshghie Date: Sat, 14 Sep 2024 19:07:20 +0200 Subject: [PATCH] Fix: resolved #41 by switching to Z3 and making assumptions about type and value of symbols (a > 0) --- src/predi/comparator.py | 96 ++++++++++++++++++++++++++++++++++------ tests/test_comparator.py | 5 ++- 2 files changed, 87 insertions(+), 14 deletions(-) diff --git a/src/predi/comparator.py b/src/predi/comparator.py index fbd70e2..beace8e 100644 --- a/src/predi/comparator.py +++ b/src/predi/comparator.py @@ -109,8 +109,10 @@ def _sympy_operator(self, op): '<=': 'Le' }[op] + def sympy_to_z3(self, expr): if isinstance(expr, sp.Symbol): + # Convert SymPy symbols to Z3 Real return z3.Real(str(expr)) elif isinstance(expr, sp.Number): return z3.RealVal(float(expr)) @@ -125,11 +127,35 @@ def sympy_to_z3(self, expr): elif isinstance(expr, sp.Le): return self.sympy_to_z3(expr.lhs) <= self.sympy_to_z3(expr.rhs) elif isinstance(expr, sp.And): - return z3.And(*[self.sympy_to_z3(arg) for arg in expr.args]) + return And(*[self.sympy_to_z3(arg) for arg in expr.args]) elif isinstance(expr, sp.Or): - return z3.Or(*[self.sympy_to_z3(arg) for arg in expr.args]) + return Or(*[self.sympy_to_z3(arg) for arg in expr.args]) elif isinstance(expr, sp.Not): - return z3.Not(self.sympy_to_z3(expr.args[0])) + return Not(self.sympy_to_z3(expr.args[0])) + elif isinstance(expr, sp.Ne): + return self.sympy_to_z3(expr.lhs) != self.sympy_to_z3(expr.rhs) + elif isinstance(expr, sp.Add): + return sum(self.sympy_to_z3(arg) for arg in expr.args) + elif isinstance(expr, sp.Mul): + result = self.sympy_to_z3(expr.args[0]) + for arg in expr.args[1:]: + result *= self.sympy_to_z3(arg) + return result + elif isinstance(expr, sp.Pow): + base = self.sympy_to_z3(expr.args[0]) + exponent = self.sympy_to_z3(expr.args[1]) + return base ** exponent + # elif isinstance(expr, sp.Function): + # # Handle function calls by converting SymPy functions to Z3 function applications + # func_name = str(expr.func) + # z3_func = z3.Function(func_name, *([z3.Real] * len(expr.args)), z3.Real) # Assuming it takes and returns reals + # z3_args = [self.sympy_to_z3(arg) for arg in expr.args] + # return z3_func(*z3_args) + elif isinstance(expr, sp.Function): + # Handle complex paths or function calls as single Real symbols + # Convert the full function call into a symbol-like string + func_name = str(expr).replace('[', '_').replace(']', '').replace('.', '_') + return z3.Real(func_name) else: raise ValueError(f"Unsupported expression type: {expr}") @@ -260,7 +286,8 @@ def _implies(self, expr1, expr2, level=0): printer(f'In relational base cases; expr1: {expr1}, expr2: {expr2}', level) # Check for Eq vs non-Eq comparisons; we don't handle this well, let's return False if (isinstance(expr1, sp.Eq) and not isinstance(expr2, sp.Eq)) or (not isinstance(expr1, sp.Eq) and isinstance(expr2, sp.Eq)): - printer(f'One of the expressions is equality and the other is not; expr1: {expr1}, expr2: {expr2}', level) + printer(f'One of the expressions is equality and the other is not; expr1: {expr1}, expr2: {expr2}', level) + # switch to z3 printer(f'Switching to Z3 ..... ') z3_expr1 = self.sympy_to_z3(expr1) @@ -292,13 +319,56 @@ def _implies(self, expr1, expr2, level=0): return False else: printer(f'Not all arguments are numbers, floats, or symbols in expr1 and expr2, however, we still try to use the same sympy satisfiability check', level) - try: - negation = sp.And(expr1, Not(expr2)) - printer(f"Negation of the implication {expr1} -> {expr2}: {satisfiable(negation)}; type of {type(satisfiable(negation))}", level) - result = not satisfiable(negation, use_lra_theory=True) - printer(f"Implication {expr1} -> {expr2} using satisfiable: {result}", level) - return result - except Exception as e: - printer(f"Error (satisfiability error): {e}", level) - return False + + # print type of all lhs and rhs's of both expressions + printer(f'type of expr1.lhs: {type(expr1.lhs)}') + printer(f'type of expr1.rhs: {type(expr1.rhs)}') + printer(f'type of expr2.lhs: {type(expr2.lhs)}') + printer(f'type of expr2.rhs: {type(expr2.rhs)}') + + + + # even if one of the above lhs and rhs's is sympy.core.mul.Mul and then one of its args is a number or a float bigger or lower than 1, we should switch to z3; we are not handling "1" case since it is working with sympy already, don't want to break a working prototype + if any(isinstance(arg, sp.Mul) and any(isinstance(a, (sp.Number, sp.Float)) and (a > 1 or a < 1) for a in arg.args) for arg in [expr1.lhs, expr1.rhs, expr2.lhs, expr2.rhs]): + printer(f'One of the arguments is a Mul, switching to z3 ...', level) + + + z3_expr1 = self.sympy_to_z3(expr1) + z3_expr2 = self.sympy_to_z3(expr2) + + variables = {str(sym) for sym in expr1.free_symbols.union(expr2.free_symbols)} + z3_vars = {var: z3.Real(var) for var in variables} # Convert to Z3 Reals + + + + solver = z3.Solver() + + # Add constraints to ensure all variables are greater than 0 + for var in z3_vars.values(): + solver.add(var > 0) + + solver.add(z3_expr1, z3.Not(z3_expr2)) + + + + + # Check satisfiability + if solver.check() == z3.sat: + # If satisfiable, implication does not hold + printer(f"Implies {expr1} to {expr2}: False", level=0) + return False + else: + # If unsatisfiable, implication holds + printer(f"Implies {expr1} to {expr2}: True", level=0) + return True + else: + try: + negation = sp.And(expr1, Not(expr2)) + printer(f"Negation of the implication {expr1} -> {expr2}: {satisfiable(negation)}; type of {type(satisfiable(negation))}", level) + result = not satisfiable(negation, use_lra_theory=True) + printer(f"Implication {expr1} -> {expr2} using satisfiable: {result}", level) + return result + except Exception as e: + printer(f"Error (satisfiability error): {e}", level) + return False return False \ No newline at end of file diff --git a/tests/test_comparator.py b/tests/test_comparator.py index e4bb21b..881fc49 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -8,12 +8,15 @@ ("a > b", "a >= b"), ("msg.sender == msg.origin && a >= b", "msg.sender == msg.origin"), ("msg.sender == msg.origin", "msg.sender == msg.origin || a < b"), - ("a == 1", "a >= 1") + ("a == 1", "a >= 1"), + ("a > b * 2", "a > b * 1") ], 'The second predicate is stronger.': [ ("msg.sender == msg.origin || a < b", "a < b"), ("a > 12", "a > 13"), ("a + 1 <= b", "a + 1 < b"), + ("a > b * 1/2", "a > b * 1"), + ("coinMap[_coin].coinContract.balanceOf(msg.sender)>=_amount", "coinMap[_coin].coinContract.balanceOf(msg.sender)>=_amount*1e18") ], 'The predicates are equivalent.': [ ("msg.sender == msg.origin", "msg.origin == msg.sender"),