diff --git a/src/smt/abduction_solver.cpp b/src/smt/abduction_solver.cpp index 6fb422596fd..066221ce65c 100644 --- a/src/smt/abduction_solver.cpp +++ b/src/smt/abduction_solver.cpp @@ -50,9 +50,15 @@ bool AbductionSolver::getAbduct(const std::vector& axioms, Trace("sygus-abduct") << "Axioms: " << axioms << std::endl; Trace("sygus-abduct") << "SolverEngine::getAbduct: goal " << goal << std::endl; - std::vector asserts(axioms.begin(), axioms.end()); + SubstitutionMap& tls = d_env.getTopLevelSubstitutions().get(); + std::vector axiomsn; + for (const Node& ax : axioms) + { + axiomsn.emplace_back(tls.apply(ax)); + } + std::vector asserts(axiomsn.begin(), axiomsn.end()); // must expand definitions - Node conjn = d_env.getTopLevelSubstitutions().apply(goal); + Node conjn = tls.apply(goal); conjn = rewrite(conjn); // now negate conjn = conjn.negate(); @@ -60,7 +66,7 @@ bool AbductionSolver::getAbduct(const std::vector& axioms, asserts.push_back(conjn); std::string name("__internal_abduct"); Node aconj = quantifiers::SygusAbduct::mkAbductionConjecture( - name, asserts, axioms, grammarType); + name, asserts, axiomsn, grammarType); // should be a quantified conjecture with one function-to-synthesize Assert(aconj.getKind() == Kind::FORALL && aconj[0].getNumChildren() == 1); // remember the abduct-to-synthesize diff --git a/src/smt/interpolation_solver.cpp b/src/smt/interpolation_solver.cpp index 8101de430c8..96c588d38ad 100644 --- a/src/smt/interpolation_solver.cpp +++ b/src/smt/interpolation_solver.cpp @@ -18,6 +18,7 @@ #include #include "base/modal_exception.h" +#include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "options/smt_options.h" #include "smt/env.h" @@ -48,17 +49,78 @@ bool InterpolationSolver::getInterpolant(const std::vector& axioms, "Cannot get interpolation when produce-interpolants options is off."; throw ModalException(msg); } + // apply top-level substitutions Trace("sygus-interpol") << "SolverEngine::getInterpol: conjecture " << conj << std::endl; - // must expand definitions - Node conjn = d_env.getTopLevelSubstitutions().apply(conj); + // We can apply top-level substitutions x -> t that are implied by the + // assertions but only if all symbols in (= x t) are also contained in the + // goal (to satisfy the shared symbol requirement of get-interpolant). + // We construct a subset of the top-level substitutions (tlShared) here that + // can legally be applied, and conjoin these with our final solution when + // applicable below. + SubstitutionMap& tls = d_env.getTopLevelSubstitutions().get(); + SubstitutionMap tlsShared; + std::unordered_map subs = tls.getSubstitutions(); + std::unordered_set conjSyms; + expr::getSymbols(conj, conjSyms); + std::vector axiomsn; + for (const std::pair& s : subs) + { + // Furthermore note that if we have a target grammar, we cannot conjoin + // substitutions since this would violate the grammar from the user. + if (grammarType.isNull()) + { + bool isShared = true; + // legal substitution if all variables in (= x t) also appear in the goal + if (conjSyms.find(s.first) == conjSyms.end()) + { + // solved variable is not shared + isShared = false; + } + else + { + std::unordered_set ssyms; + expr::getSymbols(s.second, ssyms); + for (const Node& sym : ssyms) + { + if (conjSyms.find(sym) == conjSyms.end()) + { + // variable in right hand side is not shared + isShared = false; + break; + } + } + } + if (isShared) + { + // can apply as a substitution + tlsShared.addSubstitution(s.first, s.second); + continue; + } + } + // must treat the substitution as an assertion + axiomsn.emplace_back(s.first.eqNode(s.second)); + } + for (const Node& ax : axioms) + { + axiomsn.emplace_back(rewrite(tlsShared.apply(ax))); + } + Node conjn = tlsShared.apply(conj); conjn = rewrite(conjn); std::string name("__internal_interpol"); + d_tlsConj = Node::null(); d_subsolver = std::make_unique(d_env); if (d_subsolver->solveInterpolation( - name, axioms, conjn, grammarType, interpol)) + name, axiomsn, conjn, grammarType, interpol)) { + if (!tlsShared.empty()) + { + // must conjoin equalities from shared top-level substitutions + NodeManager* nm = nodeManager(); + d_tlsConj = tlsShared.toFormula(nm); + interpol = nm->mkNode(Kind::AND, d_tlsConj, interpol); + } if (options().smt.checkInterpolants) { checkInterpol(interpol, axioms, conj); @@ -73,7 +135,17 @@ bool InterpolationSolver::getInterpolantNext(Node& interpol) // should already have initialized a subsolver, since we are immediately // preceeded by a successful call to get-interpolant(-next). Assert(d_subsolver != nullptr); - return d_subsolver->solveInterpolationNext(interpol); + if (!d_subsolver->solveInterpolationNext(interpol)) + { + return false; + } + // conjoin the top-level substitutions, as computed in getInterpolant + if (!d_tlsConj.isNull()) + { + NodeManager* nm = nodeManager(); + interpol = nm->mkNode(Kind::AND, d_tlsConj, interpol); + } + return true; } void InterpolationSolver::checkInterpol(Node interpol, diff --git a/src/smt/interpolation_solver.h b/src/smt/interpolation_solver.h index a74a488580f..a59ecc59bd4 100644 --- a/src/smt/interpolation_solver.h +++ b/src/smt/interpolation_solver.h @@ -93,6 +93,12 @@ class InterpolationSolver : protected EnvObj /** The subsolver */ std::unique_ptr d_subsolver; + /** + * The conjunction of equalities corresponding to top-level substitutions that + * were applied to the goal, whose left hand sides are symbols that appeared + * in the goal. + */ + Node d_tlsConj; }; } // namespace smt diff --git a/src/smt/solver_engine.cpp b/src/smt/solver_engine.cpp index 989077189a1..c5c129be710 100644 --- a/src/smt/solver_engine.cpp +++ b/src/smt/solver_engine.cpp @@ -1942,12 +1942,10 @@ Node SolverEngine::getInterpolant(const Node& conj, const TypeNode& grammarType) beginCall(true); // Analogous to getAbduct, ensure that assertions are current. d_smtDriver->refreshAssertions(); - std::vector axioms = getSubstitutedAssertions(); - // expand definitions in the conjecture as well - Node conje = d_smtSolver->getPreprocessor()->applySubstitutions(conj); + std::vector axioms = getAssertions(); Node interpol; bool success = - d_interpolSolver->getInterpolant(axioms, conje, grammarType, interpol); + d_interpolSolver->getInterpolant(axioms, conj, grammarType, interpol); // notify the state of whether the get-interpolant call was successfuly, which // impacts the SMT mode. d_state->notifyGetInterpol(success); @@ -1980,11 +1978,10 @@ Node SolverEngine::getAbduct(const Node& conj, const TypeNode& grammarType) beginCall(true); // ensure that assertions are current d_smtDriver->refreshAssertions(); - std::vector axioms = getSubstitutedAssertions(); + std::vector axioms = getAssertions(); // expand definitions in the conjecture as well - Node conje = d_smtSolver->getPreprocessor()->applySubstitutions(conj); Node abd; - bool success = d_abductSolver->getAbduct(axioms, conje, grammarType, abd); + bool success = d_abductSolver->getAbduct(axioms, conj, grammarType, abd); // notify the state of whether the get-abduct call was successful, which // impacts the SMT mode. d_state->notifyGetAbduct(success); diff --git a/src/theory/quantifiers/sygus/sygus_interpol.cpp b/src/theory/quantifiers/sygus/sygus_interpol.cpp index 3a78555bddc..b0f14c36a96 100644 --- a/src/theory/quantifiers/sygus/sygus_interpol.cpp +++ b/src/theory/quantifiers/sygus/sygus_interpol.cpp @@ -374,7 +374,7 @@ bool SygusInterpol::solveInterpolation(const std::string& name, d_subSolver->declareSynthFun(d_itp, grammarType, false, vars_empty); Trace("sygus-interpol") << "SygusInterpol::solveInterpolation: made conjecture : " << d_sygusConj - << ", solving for " << d_sygusConj[0][0] << std::endl; + << std::endl; d_subSolver->assertSygusConstraint(d_sygusConj); Trace("sygus-interpol") diff --git a/src/theory/substitutions.cpp b/src/theory/substitutions.cpp index 6854cf65958..cf8a460678d 100644 --- a/src/theory/substitutions.cpp +++ b/src/theory/substitutions.cpp @@ -42,6 +42,16 @@ std::unordered_map SubstitutionMap::getSubstitutions() const return subs; } +Node SubstitutionMap::toFormula(NodeManager* nm) const +{ + std::vector conj; + for (const auto& sub : d_substitutions) + { + conj.emplace_back(nm->mkNode(Kind::EQUAL, sub.first, sub.second)); + } + return nm->mkAnd(conj); +} + struct substitution_stack_element { TNode d_node; bool d_children_added; diff --git a/src/theory/substitutions.h b/src/theory/substitutions.h index cf8349ff1e1..50d8ee6c37c 100644 --- a/src/theory/substitutions.h +++ b/src/theory/substitutions.h @@ -113,6 +113,11 @@ class SubstitutionMap /** Get substitutions in this object as a raw map */ std::unordered_map getSubstitutions() const; + /** + * Return a formula that is equivalent to this substitution, e.g. for + * [x -> t, y -> s], we return (and (= x t) (= y s)). + */ + Node toFormula(NodeManager* nm) const; /** * Adds a substitution from x to t. */ diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index 528957ac0f9..fbf83b7a98b 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -2483,6 +2483,7 @@ set(regress_1_tests regress1/ho/soundness_fmf_SYO362^5-delta.smt2 regress1/ho/store-ax-min.smt2 regress1/hole6.cvc.smt2 + regress1/interpolant-subs.smt2 regress1/interpolant-unk-570.smt2 regress1/ite5.smt2 regress1/issue10750-zll-repeat.smt2 diff --git a/test/regress/cli/regress1/interpolant-subs.smt2 b/test/regress/cli/regress1/interpolant-subs.smt2 new file mode 100644 index 00000000000..958e47b163c --- /dev/null +++ b/test/regress/cli/regress1/interpolant-subs.smt2 @@ -0,0 +1,7 @@ +; SCRUBBER: grep -v -E '(\(define-fun)' +; EXIT: 0 +(set-logic ALL) +(set-option :produce-interpolants true) +(declare-fun x () Int) +(assert (= 0 x)) +(get-interpolant A (= 0 x))