From dfd0cf65dc40ca74a04329ef1dce21d386bc13a3 Mon Sep 17 00:00:00 2001 From: Sonia Date: Thu, 11 Nov 2021 23:35:24 +0100 Subject: [PATCH 1/2] Z3 SMT interface --- include/bill/smt/z3.hpp | 368 ++++++++++++++++++++++++++++++++++++++++ test/smt/z3.cpp | 141 +++++++++++++++ 2 files changed, 509 insertions(+) create mode 100644 include/bill/smt/z3.hpp create mode 100644 test/smt/z3.cpp diff --git a/include/bill/smt/z3.hpp b/include/bill/smt/z3.hpp new file mode 100644 index 0000000..a0aee4f --- /dev/null +++ b/include/bill/smt/z3.hpp @@ -0,0 +1,368 @@ +/*------------------------------------------------------------------------------------------------- +| This file is distributed under the MIT License. +| See accompanying file /LICENSE for details. +*------------------------------------------------------------------------------------------------*/ +#pragma once + +#if defined(BILL_HAS_Z3) + +#include +#include +#include +#include +#include + +namespace bill { + +template +class z3_smt_solver { +public: + using var_t = uint32_t; + using lp_expr_t = std::vector>; + + enum class states : uint8_t { + satisfiable, + unsatisfiable, + undefined, + }; + + enum class var_types : uint8_t { + boolean, + integer, + real + }; + + enum class lp_types : uint8_t { + geq, + leq, + eq, + greater, + less, + }; + +#pragma region Constructors + z3_smt_solver() + : solver_(ctx_) + , variable_counter_(0u) + {} + + ~z3_smt_solver() + {} + + /* disallow copying */ + z3_smt_solver(z3_smt_solver const&) = delete; + z3_smt_solver& operator=(const z3_smt_solver&) = delete; +#pragma endregion + +#pragma region Modifiers + void restart() + { + solver_.reset(); + vars_.clear(); + variable_counter_ = 0u; + state_ = states::undefined; + } + + //void set_logic(std::string const& logic) + //{ + // solver_.set("logic", logic.c_str()); + //} + + var_t add_variable(var_types type) + { + switch (type) { + case var_types::boolean: + vars_.push_back(ctx_.bool_const(fmt::format("bool{}", variable_counter_).c_str())); + break; + case var_types::integer: + vars_.push_back(ctx_.int_const(fmt::format("int{}", variable_counter_).c_str())); + break; + case var_types::real: + vars_.push_back(ctx_.real_const(fmt::format("real{}", variable_counter_).c_str())); + break; + default: + assert(false && "Error: unknown variable type\n"); + return std::numeric_limits::max(); + } + + return variable_counter_++; + } + + void add_variables(var_types type, uint32_t num_variables = 1) + { + for (auto i = 0u; i < num_variables; ++i) { + add_variable(type); + } + } + + /* create an integer-type variable and set it to count how many boolean-type variables in `var_set` are true */ + var_t add_integer_cardinality(std::vector const& var_set) + { + var_t counter = add_variable(var_types::integer); + solver_.add(vars_[counter] == make_integer_sum(var_set)); + return counter; + } + + /* create an real-type variable and set it to count how many boolean-type variables in `var_set` are true */ + var_t add_real_cardinality(std::vector const& var_set) + { + var_t counter = add_variable(var_types::real); + solver_.add(vars_[counter] == make_real_sum(var_set)); + return counter; + } + + /* create a boolean-type variable and set it to if the LP condition holds */ + var_t add_lp_condition(lp_expr_t const& lhs, int32_t rhs, lp_types type) + { + assert( is_real_lp_expr(lhs) ); + z3::expr expr = make_lp_expr(lhs); + + var_t cond_var = add_variable(var_types::boolean); + z3::expr cond = vars_[cond_var]; + + switch (type) + { + case lp_types::geq: + solver_.add(z3::implies(cond, expr >= ctx_.real_val(rhs, 1))); + return cond_var; + case lp_types::leq: + solver_.add(z3::implies(cond, expr <= ctx_.real_val(rhs, 1))); + return cond_var; + case lp_types::eq: + solver_.add(z3::implies(cond, expr == ctx_.real_val(rhs, 1))); + return cond_var; + case lp_types::greater: + solver_.add(z3::implies(cond, expr > ctx_.real_val(rhs, 1))); + return cond_var; + case lp_types::less: + solver_.add(z3::implies(cond, expr < ctx_.real_val(rhs, 1))); + return cond_var; + default: + assert(false && "unknown LP constraint type"); + return cond_var; + } + } + + var_t add_ilp_condition(lp_expr_t const& lhs, int32_t rhs, lp_types type) + { + assert( is_integer_lp_expr(lhs) ); + z3::expr expr = make_lp_expr(lhs); + + var_t cond_var = add_variable(var_types::boolean); + z3::expr cond = vars_[cond_var]; + + switch (type) + { + case lp_types::geq: + solver_.add(z3::implies(cond, expr >= rhs)); + return cond_var; + case lp_types::leq: + solver_.add(z3::implies(cond, expr <= rhs)); + return cond_var; + case lp_types::eq: + solver_.add(z3::implies(cond, expr == rhs)); + return cond_var; + case lp_types::greater: + solver_.add(z3::implies(cond, expr > rhs)); + return cond_var; + case lp_types::less: + solver_.add(z3::implies(cond, expr < rhs)); + return cond_var; + default: + assert(false && "unknown LP constraint type"); + return cond_var; + } + } + + /* assert a LP constraint */ + void add_lp_constraint(std::vector> const& lhs, int32_t rhs, lp_types type) + { + var_t cond_var = add_lp_condition(lhs, rhs, type); + solver_.add(vars_[cond_var]); + } + + void add_ilp_constraint(std::vector> const& lhs, int32_t rhs, lp_types type) + { + var_t cond_var = add_ilp_condition(lhs, rhs, type); + solver_.add(vars_[cond_var]); + } + + void assert_true(var_t const& v) + { + assert(is_boolean_type(v)); + solver_.add(vars_[v]); + } + + void assert_false(var_t const& v) + { + assert(is_boolean_type(v)); + solver_.add(!vars_[v]); + } + + template> + void maximize(var_t const& var) + { + solver_.maximize(vars_[var]); + } + + template> + void maximize(lp_expr_t const& objective) + { + z3::expr expr = make_lp_expr(objective); + solver_.maximize(expr); + } + + template> + void minimize(var_t const& var) + { + solver_.minimize(vars_[var]); + } + + template> + void minimize(lp_expr_t const& objective) + { + z3::expr expr = make_lp_expr(objective); + solver_.minimize(expr); + } + + states solve() + { + z3::expr_vector vec(ctx_); + switch (solver_.check(vec)) { + case z3::sat: + state_ = states::satisfiable; + break; + case z3::unsat: + state_ = states::unsatisfiable; + break; + case z3::unknown: + default: + state_ = states::undefined; + break; + }; + z3::reset_params(); + return state_; + } +#pragma endregion + +#pragma region Properties + uint32_t num_variables() const + { + return variable_counter_; + } +#pragma endregion + +#pragma region Get Model + bool get_boolean_variable_value(var_t var) + { + assert(is_boolean_type(var)); + assert(state_ == states::satisfiable); + + return solver_.get_model().eval(vars_[var]).is_true(); + } + + int64_t get_numeral_variable_value_as_integer(var_t var) + { + assert(is_integer_type(var) || is_real_type(var)); + assert(state_ == states::satisfiable); + return solver_.get_model().eval(vars_[var]).get_numeral_int64(); + } +#pragma endregion + + template> + void print() + { + std::cout << solver_.to_smt2() << "\n"; + } + +private: + z3::expr make_lp_expr(lp_expr_t const& expr) + { + assert( expr.size() > 0 ); + z3::expr e = expr[0].first == 1 ? vars_[expr[0].second] : expr[0].first * vars_[expr[0].second]; + for ( auto i = 1u; i < expr.size(); ++i ) + { + e = e + ( expr[i].first == 1 ? vars_[expr[i].second] : expr[i].first * vars_[expr[i].second] ); + } + return e; + } + + z3::expr make_integer_sum(std::vector const& var_set) + { + assert(is_all_boolean(var_set)); + z3::expr_vector vec(ctx_); + for ( auto const& v : var_set ) + vec.push_back(z3::ite(vars_[v], ctx_.int_val(1), ctx_.int_val(0))); + return z3::sum(vec); + } + + z3::expr make_real_sum(std::vector const& var_set) + { + assert(is_all_boolean(var_set)); + z3::expr_vector vec(ctx_); + for ( auto const& v : var_set ) + vec.push_back(z3::ite(vars_[v], ctx_.real_val(1), ctx_.real_val(0))); + return z3::sum(vec); + } + +#pragma region Check Variable Type + bool is_boolean_type(var_t var) + { + return vars_[var].is_bool(); + } + + bool is_integer_type(var_t var) + { + return vars_[var].is_int(); + } + + bool is_real_type(var_t var) + { + return vars_[var].is_real(); + } + + bool is_all_boolean(std::vector const& var_set) + { + bool res = true; + for ( auto const& var : var_set ) + res &= is_boolean_type(var); + return res; + } + + bool is_integer_lp_expr(lp_expr_t const& expr) + { + bool res = true; + for ( auto const& term : expr ) + res &= is_integer_type(term.second); + return res; + } + + bool is_real_lp_expr(lp_expr_t const& expr) + { + bool res = true; + for ( auto const& term : expr ) + res &= is_real_type(term.second); + return res; + } +#pragma endregion + +private: + /*! \brief Backend solver context object */ + z3::context ctx_; + + /*! \brief Backend solver */ + std::conditional_t solver_; + + /*! \brief Current state of the solver */ + states state_ = states::undefined; + + /*! \brief Variables */ + std::vector vars_; + + /*! \brief Stacked counter for number of variables */ + uint32_t variable_counter_; +}; + +} // namespace bill + +#endif diff --git a/test/smt/z3.cpp b/test/smt/z3.cpp new file mode 100644 index 0000000..84c0000 --- /dev/null +++ b/test/smt/z3.cpp @@ -0,0 +1,141 @@ +/*------------------------------------------------------------------------------------------------- +| This file is distributed under the MIT License. +| See accompanying file /LICENSE for details. +*------------------------------------------------------------------------------------------------*/ +#if defined(BILL_HAS_Z3) + +#include "../catch2.hpp" + +#include +#include +#include + +using namespace bill; + +TEST_CASE("Simple SAT LP", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::real ); + solver.add_lp_constraint( {{1,v1}}, 3, solver_t::lp_types::geq ); // v1 >= 3 + solver.add_lp_constraint( {{2,v1}}, 8, solver_t::lp_types::leq ); // 2v1 <= 8 --> v1 <= 4 + + auto const res = solver.solve(); + CHECK( res == solver_t::states::satisfiable ); + auto const sol_v1 = solver.get_numeral_variable_value_as_integer( v1 ); + CHECK( sol_v1 >= 3 ); + CHECK( sol_v1 <= 4 ); +} + +TEST_CASE("Simple SAT ILP", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::integer ); + solver.add_ilp_constraint( {{1,v1}}, 3, solver_t::lp_types::geq ); // v1 >= 3 + solver.add_ilp_constraint( {{2,v1}}, 8, solver_t::lp_types::leq ); // 2v1 <= 8 --> v1 <= 4 + + auto const res = solver.solve(); + CHECK( res == solver_t::states::satisfiable ); + auto const sol_v1 = solver.get_numeral_variable_value_as_integer( v1 ); + CHECK( sol_v1 >= 3 ); + CHECK( sol_v1 <= 4 ); +} + +TEST_CASE("Simple UNSAT LP", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::real ); + solver.add_lp_constraint( {{1,v1}}, 3, solver_t::lp_types::geq ); // v1 >= 3 + solver.add_lp_constraint( {{2,v1}}, 2, solver_t::lp_types::leq ); // 2v1 <= 2 --> v1 <= 1 + + auto const res = solver.solve(); + CHECK( res == solver_t::states::unsatisfiable ); +} + +TEST_CASE("Simple UNSAT ILP", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::integer ); + solver.add_ilp_constraint( {{1,v1}}, 3, solver_t::lp_types::geq ); // v1 >= 3 + solver.add_ilp_constraint( {{1,v1}}, 3, solver_t::lp_types::less ); // v1 < 3 + + auto const res = solver.solve(); + CHECK( res == solver_t::states::unsatisfiable ); +} + +TEST_CASE("Simple optimization LP", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::real ); + solver.add_lp_constraint( {{2,v1}}, 8, solver_t::lp_types::leq ); // 2v1 <= 8 --> v1 <= 4 + solver.maximize( v1 ); + + auto const res = solver.solve(); + CHECK( res == solver_t::states::satisfiable ); + auto const sol_v1 = solver.get_numeral_variable_value_as_integer( v1 ); + CHECK( sol_v1 == 4 ); +} + +TEST_CASE("Simple optimization ILP", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::integer ); + solver.add_ilp_constraint( {{2,v1}}, 8, solver_t::lp_types::less ); // 2v1 < 8 --> v1 < 4 + solver.maximize( v1 ); + + auto const res = solver.solve(); + CHECK( res == solver_t::states::satisfiable ); + auto const sol_v1 = solver.get_numeral_variable_value_as_integer( v1 ); + CHECK( sol_v1 == 3 ); +} + +TEST_CASE("LP with cardinality constraints", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::real ); + auto const b1 = solver.add_lp_condition( {{1,v1}}, 3, solver_t::lp_types::geq ); // v1 >= 3 + auto const b2 = solver.add_lp_condition( {{2,v1}}, 8, solver_t::lp_types::leq ); // 2v1 <= 8 --> v1 <= 4 + auto const b3 = solver.add_lp_condition( {{1,v1}}, 5, solver_t::lp_types::greater ); // v1 > 5 + auto const sum = solver.add_real_cardinality( {b1, b2, b3} ); + solver.add_lp_constraint( {{1,sum}}, 3, solver_t::lp_types::eq ); + + auto const res = solver.solve(); + CHECK( res == solver_t::states::unsatisfiable ); +} + +TEST_CASE("ILP with cardinality constraints and objective", "[smt/z3]") +{ + using solver_t = z3_smt_solver; + + solver_t solver; + auto const v1 = solver.add_variable( solver_t::var_types::integer ); + auto const b1 = solver.add_ilp_condition( {{1,v1}}, -2, solver_t::lp_types::geq ); // v1 >= -2 + auto const b2 = solver.add_ilp_condition( {{1,v1}}, -1, solver_t::lp_types::leq ); // v1 <= -1 + auto const b3 = solver.add_ilp_condition( {{1,v1}}, 0, solver_t::lp_types::eq ); // v1 == 0 + auto const b4 = solver.add_ilp_condition( {{1,v1}}, 10, solver_t::lp_types::less ); // v1 < 10 + auto const b5 = solver.add_ilp_condition( {{1,v1}}, 0, solver_t::lp_types::geq ); // v1 >= 0 + auto const sum = solver.add_integer_cardinality( {b1, b2, b3, b4, b5} ); + solver.maximize( sum ); + + auto const res = solver.solve(); + CHECK( res == solver_t::states::satisfiable ); + auto const sol_v1 = solver.get_numeral_variable_value_as_integer( v1 ); + CHECK( sol_v1 == 0 ); + auto const sol_sum = solver.get_numeral_variable_value_as_integer( sum ); + CHECK( sol_sum == 4 ); +} + +#endif From c0a4622ec042f90f2b6579900c79632c9dac9f8d Mon Sep 17 00:00:00 2001 From: Sonia Date: Thu, 9 Dec 2021 10:49:05 -0800 Subject: [PATCH 2/2] Update z3.hpp --- include/bill/smt/z3.hpp | 47 ++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/include/bill/smt/z3.hpp b/include/bill/smt/z3.hpp index a0aee4f..eefd29d 100644 --- a/include/bill/smt/z3.hpp +++ b/include/bill/smt/z3.hpp @@ -123,19 +123,19 @@ class z3_smt_solver { switch (type) { case lp_types::geq: - solver_.add(z3::implies(cond, expr >= ctx_.real_val(rhs, 1))); + solver_.add(cond == (expr >= ctx_.real_val(rhs, 1))); return cond_var; case lp_types::leq: - solver_.add(z3::implies(cond, expr <= ctx_.real_val(rhs, 1))); + solver_.add(cond == (expr <= ctx_.real_val(rhs, 1))); return cond_var; case lp_types::eq: - solver_.add(z3::implies(cond, expr == ctx_.real_val(rhs, 1))); + solver_.add(cond == (expr == ctx_.real_val(rhs, 1))); return cond_var; case lp_types::greater: - solver_.add(z3::implies(cond, expr > ctx_.real_val(rhs, 1))); + solver_.add(cond == (expr > ctx_.real_val(rhs, 1))); return cond_var; case lp_types::less: - solver_.add(z3::implies(cond, expr < ctx_.real_val(rhs, 1))); + solver_.add(cond == (expr < ctx_.real_val(rhs, 1))); return cond_var; default: assert(false && "unknown LP constraint type"); @@ -154,19 +154,19 @@ class z3_smt_solver { switch (type) { case lp_types::geq: - solver_.add(z3::implies(cond, expr >= rhs)); + solver_.add( cond == (expr >= rhs) ); return cond_var; case lp_types::leq: - solver_.add(z3::implies(cond, expr <= rhs)); + solver_.add( cond == (expr <= rhs) ); return cond_var; case lp_types::eq: - solver_.add(z3::implies(cond, expr == rhs)); + solver_.add( cond == (expr == rhs) ); return cond_var; case lp_types::greater: - solver_.add(z3::implies(cond, expr > rhs)); + solver_.add( cond == (expr > rhs) ); return cond_var; case lp_types::less: - solver_.add(z3::implies(cond, expr < rhs)); + solver_.add( cond == (expr < rhs) ); return cond_var; default: assert(false && "unknown LP constraint type"); @@ -199,6 +199,33 @@ class z3_smt_solver { solver_.add(!vars_[v]); } + void add_pseudo_boolean_constraint(std::vector const& var_set, int32_t rhs, lp_types type) + { + assert( is_all_boolean(var_set) ); + z3::expr_vector vec(ctx_); + + for ( auto i = 0u; i < var_set.size(); ++i ) + { + vec.push_back(vars_[var_set[i]]); + } + + switch (type) + { + case lp_types::geq: + solver_.add(z3::atleast(vec, rhs)); + break; + case lp_types::leq: + solver_.add(z3::atmost(vec, rhs)); + break; + case lp_types::eq: + solver_.add(z3::atmost(vec, rhs)); + solver_.add(z3::atleast(vec, rhs)); + break; + default: + assert(false && "unknown PB constraint type"); + } + } + template> void maximize(var_t const& var) {