diff --git a/src/rewriter/mkrewrites.py b/src/rewriter/mkrewrites.py index bb429105624..1b0f7a2f1d8 100644 --- a/src/rewriter/mkrewrites.py +++ b/src/rewriter/mkrewrites.py @@ -163,10 +163,6 @@ def validate_rule(rule): used_vars.add(curr) to_visit.extend(curr.children) - unused_vars = set(rule.bvars) - used_vars - if unused_vars: - die(f'Variables {unused_vars} are not matched in {rule.name}') - # Check that list variables are always used within the same operators var_to_op = dict() to_visit = [rule.cond, rule.lhs, rule.rhs] diff --git a/src/rewriter/rewrite_db_proof_cons.cpp b/src/rewriter/rewrite_db_proof_cons.cpp index 0cf40f5269c..66d9dc139e8 100644 --- a/src/rewriter/rewrite_db_proof_cons.cpp +++ b/src/rewriter/rewrite_db_proof_cons.cpp @@ -406,6 +406,9 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id, : toString(id)) << std::endl; std::vector vcs; + // the implied substitution if we have a rule with free variables on RHS + std::vector impliedVs; + std::vector impliedSs; Node transEq; ProvenInfo pic; if (id == RewriteProofStatus::CONG) @@ -577,6 +580,25 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id, Trace("rpc-debug2") << "...fail (no construct conclusion)" << std::endl; return false; } + if (expr::hasBoundVar(stgt)) + { + rpr.getConditionalDefinitions(vars, subs, impliedVs, impliedSs); + Trace("rpc-debug2") << " Implied definitions: " << impliedVs << " -> " + << impliedSs << std::endl; + if (!impliedVs.empty()) + { + // evaluate them + for (Node& s : impliedSs) + { + s = evaluate(s, {}, {}); + } + stgt = expr::narySubstitute(stgt, impliedVs, impliedSs); + Trace("rpc-debug2") << " Implied definitions (post-eval): " << impliedVs + << " -> " << impliedSs << std::endl; + Trace("rpc-debug2") + << "Substituted RHS (post-eval): " << stgt << std::endl; + } + } // inflection substitution, used if conclusion does not exactly match std::unordered_map> isubs; if (stgt != target[1]) @@ -603,7 +625,20 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id, // do its conditions hold? // Get the conditions, substituted { vars -> subs } and with side conditions // evaluated. - if (!rpr.getObligations(vars, subs, vcs)) + if (!impliedVs.empty()) + { + std::vector vsall = vars; + std::vector subsall = subs; + vsall.insert(vsall.end(), impliedVs.begin(), impliedVs.end()); + subsall.insert(subsall.end(), impliedSs.begin(), impliedSs.end()); + if (!rpr.getObligations(vsall, subsall, vcs)) + { + // cannot get conditions, likely due to failed side condition + Trace("rpc-debug2") << "...fail (obligations)" << std::endl; + return false; + } + } + else if (!rpr.getObligations(vars, subs, vcs)) { // cannot get conditions, likely due to failed side condition Trace("rpc-debug2") << "...fail (obligations)" << std::endl; @@ -686,6 +721,11 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id, { pi->d_vars = vars; pi->d_subs = subs; + if (!impliedVs.empty()) + { + pi->d_vars.insert(pi->d_vars.end(), impliedVs.begin(), impliedVs.end()); + pi->d_subs.insert(pi->d_subs.end(), impliedSs.begin(), impliedSs.end()); + } } Trace("rpc-debug2") << "...target proved by " << d_pcache[target].d_id << std::endl; diff --git a/src/rewriter/rewrite_proof_rule.cpp b/src/rewriter/rewrite_proof_rule.cpp index ccaa2013b31..6469e39cf35 100644 --- a/src/rewriter/rewrite_proof_rule.cpp +++ b/src/rewriter/rewrite_proof_rule.cpp @@ -39,6 +39,7 @@ void RewriteProofRule::init(ProofRewriteRule id, Assert(d_cond.empty() && d_obGen.empty() && d_fvs.empty()); d_id = id; d_userFvs = userFvs; + std::map condDef; for (const Node& c : cond) { if (!expr::getListVarContext(c, d_listVarCtx)) @@ -48,6 +49,10 @@ void RewriteProofRule::init(ProofRewriteRule id, } d_cond.push_back(c); d_obGen.push_back(c); + if (c.getKind() == Kind::EQUAL && c[0].getKind() == Kind::BOUND_VARIABLE) + { + condDef[c[0]] = c[1]; + } } d_conc = conc; d_context = context; @@ -57,13 +62,53 @@ void RewriteProofRule::init(ProofRewriteRule id, << id; } - d_numFv = fvs.size(); - std::unordered_set fvsCond; for (const Node& c : d_cond) { expr::getFreeVariables(c, fvsCond); } + + // ensure free variables in conditions and right hand side are either matched + // or are in defined conditions. + std::unordered_set fvsLhs; + expr::getFreeVariables(d_conc[0], fvsLhs); + std::unordered_set fvsUnmatched; + expr::getFreeVariables(d_conc[1], fvsUnmatched); + fvsUnmatched.insert(fvsCond.begin(), fvsCond.end()); + std::map::iterator itc; + for (const Node& v : fvsUnmatched) + { + if (fvsLhs.find(v) != fvsLhs.end()) + { + // variable on left hand side + continue; + } + itc = condDef.find(v); + if (itc == condDef.end()) + { + Unhandled() + << "Free variable " << v << " in rule " << id + << " is not on the left hand side, nor is defined in a condition"; + } + // variable defined in the condition + d_condDefinedVars[v] = itc->second; + // ensure the defining term does not itself contain free variables + std::unordered_set fvst; + expr::getFreeVariables(itc->second, fvst); + for (const Node& vt : fvst) + { + if (fvsLhs.find(vt) == fvsLhs.end()) + { + Unhandled() << "Free variable " << vt << " in rule " << id + << " is not on the left hand side of the rule, and it is " + "used to give a definition to the free variable " + << v; + } + } + } + + d_numFv = fvs.size(); + for (const Node& v : fvs) { d_fvs.push_back(v); @@ -201,5 +246,19 @@ bool RewriteProofRule::isFixedPoint() const { return d_context != Node::null(); } + +void RewriteProofRule::getConditionalDefinitions(const std::vector& vs, + const std::vector& ss, + std::vector& dvs, + std::vector& dss) const +{ + for (const std::pair& cv : d_condDefinedVars) + { + dvs.push_back(cv.first); + Node cvs = expr::narySubstitute(cv.second, vs, ss); + dss.push_back(cvs); + } +} + } // namespace rewriter } // namespace cvc5::internal diff --git a/src/rewriter/rewrite_proof_rule.h b/src/rewriter/rewrite_proof_rule.h index 98bf1629490..cd882a42689 100644 --- a/src/rewriter/rewrite_proof_rule.h +++ b/src/rewriter/rewrite_proof_rule.h @@ -152,8 +152,19 @@ class RewriteProofRule Kind getListContext(Node v) const; /** Was this rule marked as being applied to fixed point? */ bool isFixedPoint() const; - /** Is this rule in flat form? */ - bool isFlatForm() const; + /** + * Get condition definitions given an application vs -> ss of this rule. + * This is used to handle variables that do not occur in the left hand side + * of rewrite rules and are defined in conditions of this rule. + * @param vs The matched variables of this rule. + * @param ss The terms to substitute in this rule for each vs. + * @param dvs The variables for which a definition can now be inferred. + * @param dss The terms that each dvs are defined as, for each dvs. + */ + void getConditionalDefinitions(const std::vector& vs, + const std::vector& ss, + std::vector& dvs, + std::vector& dss) const; private: /** The id of the rule */ @@ -179,6 +190,8 @@ class RewriteProofRule * "holes" in a proof. */ std::unordered_set d_noOccVars; + /** Maps variables to the term they are defined to be */ + std::map d_condDefinedVars; /** The context for list variables (see expr::getListVarContext). */ std::map d_listVarCtx; /** The match trie (for fixed point matching) */ diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index e22fce68e71..90644c665d5 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -1003,9 +1003,15 @@ RewriteResponse ArithRewriter::postRewriteIntsLog2(TNode t) { // pow2 is only supported for integers Assert(t[0].getType().isInteger()); - Integer i = t[0].getConst().getNumerator(); + const Rational& r = t[0].getConst(); + if (r.sgn() < 0) + { + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0))); + } + Integer i = r.getNumerator(); size_t const length = i.length(); - return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(length))); + return RewriteResponse(REWRITE_DONE, + rewriter::mkConst(Integer(length - 1))); } return RewriteResponse(REWRITE_DONE, t); } diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp index 4326471c13d..0bdae858f62 100644 --- a/src/theory/evaluator.cpp +++ b/src/theory/evaluator.cpp @@ -630,8 +630,15 @@ EvalResult Evaluator::evalInternal( case Kind::INTS_LOG2: { const Rational& x = results[currNode[0]].d_rat; - results[currNode] = - EvalResult(Rational(x.getNumerator().length() - 1)); + if (x.sgn() < 0) + { + results[currNode] = EvalResult(Rational(0)); + } + else + { + results[currNode] = + EvalResult(Rational(x.getNumerator().length() - 1)); + } break; } case Kind::CONST_STRING: