Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: resolved #41 by switching to Z3 and making assumptions about typ… #46

Merged
merged 1 commit into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 83 additions & 13 deletions src/predi/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ def _sympy_operator(self, op):
'<=': 'Le'
}[op]


def sympy_to_z3(self, expr):
if isinstance(expr, sp.Symbol):
# Convert SymPy symbols to Z3 Real
return z3.Real(str(expr))
elif isinstance(expr, sp.Number):
return z3.RealVal(float(expr))
Expand All @@ -125,11 +127,35 @@ def sympy_to_z3(self, expr):
elif isinstance(expr, sp.Le):
return self.sympy_to_z3(expr.lhs) <= self.sympy_to_z3(expr.rhs)
elif isinstance(expr, sp.And):
return z3.And(*[self.sympy_to_z3(arg) for arg in expr.args])
return And(*[self.sympy_to_z3(arg) for arg in expr.args])
elif isinstance(expr, sp.Or):
return z3.Or(*[self.sympy_to_z3(arg) for arg in expr.args])
return Or(*[self.sympy_to_z3(arg) for arg in expr.args])
elif isinstance(expr, sp.Not):
return z3.Not(self.sympy_to_z3(expr.args[0]))
return Not(self.sympy_to_z3(expr.args[0]))
elif isinstance(expr, sp.Ne):
return self.sympy_to_z3(expr.lhs) != self.sympy_to_z3(expr.rhs)
elif isinstance(expr, sp.Add):
return sum(self.sympy_to_z3(arg) for arg in expr.args)
elif isinstance(expr, sp.Mul):
result = self.sympy_to_z3(expr.args[0])
for arg in expr.args[1:]:
result *= self.sympy_to_z3(arg)
return result
elif isinstance(expr, sp.Pow):
base = self.sympy_to_z3(expr.args[0])
exponent = self.sympy_to_z3(expr.args[1])
return base ** exponent
# elif isinstance(expr, sp.Function):
# # Handle function calls by converting SymPy functions to Z3 function applications
# func_name = str(expr.func)
# z3_func = z3.Function(func_name, *([z3.Real] * len(expr.args)), z3.Real) # Assuming it takes and returns reals
# z3_args = [self.sympy_to_z3(arg) for arg in expr.args]
# return z3_func(*z3_args)
elif isinstance(expr, sp.Function):
# Handle complex paths or function calls as single Real symbols
# Convert the full function call into a symbol-like string
func_name = str(expr).replace('[', '_').replace(']', '').replace('.', '_')
return z3.Real(func_name)
else:
raise ValueError(f"Unsupported expression type: {expr}")

Expand Down Expand Up @@ -260,7 +286,8 @@ def _implies(self, expr1, expr2, level=0):
printer(f'In relational base cases; expr1: {expr1}, expr2: {expr2}', level)
# Check for Eq vs non-Eq comparisons; we don't handle this well, let's return False
if (isinstance(expr1, sp.Eq) and not isinstance(expr2, sp.Eq)) or (not isinstance(expr1, sp.Eq) and isinstance(expr2, sp.Eq)):
printer(f'One of the expressions is equality and the other is not; expr1: {expr1}, expr2: {expr2}', level)
printer(f'One of the expressions is equality and the other is not; expr1: {expr1}, expr2: {expr2}', level)

# switch to z3
printer(f'Switching to Z3 ..... ')
z3_expr1 = self.sympy_to_z3(expr1)
Expand Down Expand Up @@ -292,13 +319,56 @@ def _implies(self, expr1, expr2, level=0):
return False
else:
printer(f'Not all arguments are numbers, floats, or symbols in expr1 and expr2, however, we still try to use the same sympy satisfiability check', level)
try:
negation = sp.And(expr1, Not(expr2))
printer(f"Negation of the implication {expr1} -> {expr2}: {satisfiable(negation)}; type of {type(satisfiable(negation))}", level)
result = not satisfiable(negation, use_lra_theory=True)
printer(f"Implication {expr1} -> {expr2} using satisfiable: {result}", level)
return result
except Exception as e:
printer(f"Error (satisfiability error): {e}", level)
return False

# print type of all lhs and rhs's of both expressions
printer(f'type of expr1.lhs: {type(expr1.lhs)}')
printer(f'type of expr1.rhs: {type(expr1.rhs)}')
printer(f'type of expr2.lhs: {type(expr2.lhs)}')
printer(f'type of expr2.rhs: {type(expr2.rhs)}')



# even if one of the above lhs and rhs's is sympy.core.mul.Mul and then one of its args is a number or a float bigger or lower than 1, we should switch to z3; we are not handling "1" case since it is working with sympy already, don't want to break a working prototype
if any(isinstance(arg, sp.Mul) and any(isinstance(a, (sp.Number, sp.Float)) and (a > 1 or a < 1) for a in arg.args) for arg in [expr1.lhs, expr1.rhs, expr2.lhs, expr2.rhs]):
printer(f'One of the arguments is a Mul, switching to z3 ...', level)


z3_expr1 = self.sympy_to_z3(expr1)
z3_expr2 = self.sympy_to_z3(expr2)

variables = {str(sym) for sym in expr1.free_symbols.union(expr2.free_symbols)}
z3_vars = {var: z3.Real(var) for var in variables} # Convert to Z3 Reals



solver = z3.Solver()

# Add constraints to ensure all variables are greater than 0
for var in z3_vars.values():
solver.add(var > 0)

solver.add(z3_expr1, z3.Not(z3_expr2))




# Check satisfiability
if solver.check() == z3.sat:
# If satisfiable, implication does not hold
printer(f"Implies {expr1} to {expr2}: False", level=0)
return False
else:
# If unsatisfiable, implication holds
printer(f"Implies {expr1} to {expr2}: True", level=0)
return True
else:
try:
negation = sp.And(expr1, Not(expr2))
printer(f"Negation of the implication {expr1} -> {expr2}: {satisfiable(negation)}; type of {type(satisfiable(negation))}", level)
result = not satisfiable(negation, use_lra_theory=True)
printer(f"Implication {expr1} -> {expr2} using satisfiable: {result}", level)
return result
except Exception as e:
printer(f"Error (satisfiability error): {e}", level)
return False
return False
5 changes: 4 additions & 1 deletion tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
("a > b", "a >= b"),
("msg.sender == msg.origin && a >= b", "msg.sender == msg.origin"),
("msg.sender == msg.origin", "msg.sender == msg.origin || a < b"),
("a == 1", "a >= 1")
("a == 1", "a >= 1"),
("a > b * 2", "a > b * 1")
],
'The second predicate is stronger.': [
("msg.sender == msg.origin || a < b", "a < b"),
("a > 12", "a > 13"),
("a + 1 <= b", "a + 1 < b"),
("a > b * 1/2", "a > b * 1"),
("coinMap[_coin].coinContract.balanceOf(msg.sender)>=_amount", "coinMap[_coin].coinContract.balanceOf(msg.sender)>=_amount*1e18")
],
'The predicates are equivalent.': [
("msg.sender == msg.origin", "msg.origin == msg.sender"),
Expand Down
Loading