Skip to content

Commit

Permalink
Update RARE reconstruction algorithm to allow free variables defined …
Browse files Browse the repository at this point in the history
…in conditions (cvc5#10914)

Allows RARE rules to contain free variables x not in the left hand side
of a rule, provided they occur defined in a condition e.g. an equality x
= t, where t contains only variables that are in the left hand side of
the rule.

This updates the README for RARE.

It also fixes an inconsistency between the evaluator and the rewriter
for `ints.log2` which was discovered while testing this change.
  • Loading branch information
ajreynol authored Aug 12, 2024
1 parent 0a48958 commit 989b077
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 13 deletions.
4 changes: 0 additions & 4 deletions src/rewriter/mkrewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,6 @@ def validate_rule(rule):
used_vars.add(curr)
to_visit.extend(curr.children)

unused_vars = set(rule.bvars) - used_vars
if unused_vars:
die(f'Variables {unused_vars} are not matched in {rule.name}')

# Check that list variables are always used within the same operators
var_to_op = dict()
to_visit = [rule.cond, rule.lhs, rule.rhs]
Expand Down
42 changes: 41 additions & 1 deletion src/rewriter/rewrite_db_proof_cons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id,
: toString(id))
<< std::endl;
std::vector<Node> vcs;
// the implied substitution if we have a rule with free variables on RHS
std::vector<Node> impliedVs;
std::vector<Node> impliedSs;
Node transEq;
ProvenInfo pic;
if (id == RewriteProofStatus::CONG)
Expand Down Expand Up @@ -577,6 +580,25 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id,
Trace("rpc-debug2") << "...fail (no construct conclusion)" << std::endl;
return false;
}
if (expr::hasBoundVar(stgt))
{
rpr.getConditionalDefinitions(vars, subs, impliedVs, impliedSs);
Trace("rpc-debug2") << " Implied definitions: " << impliedVs << " -> "
<< impliedSs << std::endl;
if (!impliedVs.empty())
{
// evaluate them
for (Node& s : impliedSs)
{
s = evaluate(s, {}, {});
}
stgt = expr::narySubstitute(stgt, impliedVs, impliedSs);
Trace("rpc-debug2") << " Implied definitions (post-eval): " << impliedVs
<< " -> " << impliedSs << std::endl;
Trace("rpc-debug2")
<< "Substituted RHS (post-eval): " << stgt << std::endl;
}
}
// inflection substitution, used if conclusion does not exactly match
std::unordered_map<Node, std::pair<Node, Node>> isubs;
if (stgt != target[1])
Expand All @@ -603,7 +625,20 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id,
// do its conditions hold?
// Get the conditions, substituted { vars -> subs } and with side conditions
// evaluated.
if (!rpr.getObligations(vars, subs, vcs))
if (!impliedVs.empty())
{
std::vector<Node> vsall = vars;
std::vector<Node> subsall = subs;
vsall.insert(vsall.end(), impliedVs.begin(), impliedVs.end());
subsall.insert(subsall.end(), impliedSs.begin(), impliedSs.end());
if (!rpr.getObligations(vsall, subsall, vcs))
{
// cannot get conditions, likely due to failed side condition
Trace("rpc-debug2") << "...fail (obligations)" << std::endl;
return false;
}
}
else if (!rpr.getObligations(vars, subs, vcs))
{
// cannot get conditions, likely due to failed side condition
Trace("rpc-debug2") << "...fail (obligations)" << std::endl;
Expand Down Expand Up @@ -686,6 +721,11 @@ bool RewriteDbProofCons::proveWithRule(RewriteProofStatus id,
{
pi->d_vars = vars;
pi->d_subs = subs;
if (!impliedVs.empty())
{
pi->d_vars.insert(pi->d_vars.end(), impliedVs.begin(), impliedVs.end());
pi->d_subs.insert(pi->d_subs.end(), impliedSs.begin(), impliedSs.end());
}
}
Trace("rpc-debug2") << "...target proved by " << d_pcache[target].d_id
<< std::endl;
Expand Down
63 changes: 61 additions & 2 deletions src/rewriter/rewrite_proof_rule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void RewriteProofRule::init(ProofRewriteRule id,
Assert(d_cond.empty() && d_obGen.empty() && d_fvs.empty());
d_id = id;
d_userFvs = userFvs;
std::map<Node, Node> condDef;
for (const Node& c : cond)
{
if (!expr::getListVarContext(c, d_listVarCtx))
Expand All @@ -48,6 +49,10 @@ void RewriteProofRule::init(ProofRewriteRule id,
}
d_cond.push_back(c);
d_obGen.push_back(c);
if (c.getKind() == Kind::EQUAL && c[0].getKind() == Kind::BOUND_VARIABLE)
{
condDef[c[0]] = c[1];
}
}
d_conc = conc;
d_context = context;
Expand All @@ -57,13 +62,53 @@ void RewriteProofRule::init(ProofRewriteRule id,
<< id;
}

d_numFv = fvs.size();

std::unordered_set<Node> fvsCond;
for (const Node& c : d_cond)
{
expr::getFreeVariables(c, fvsCond);
}

// ensure free variables in conditions and right hand side are either matched
// or are in defined conditions.
std::unordered_set<Node> fvsLhs;
expr::getFreeVariables(d_conc[0], fvsLhs);
std::unordered_set<Node> fvsUnmatched;
expr::getFreeVariables(d_conc[1], fvsUnmatched);
fvsUnmatched.insert(fvsCond.begin(), fvsCond.end());
std::map<Node, Node>::iterator itc;
for (const Node& v : fvsUnmatched)
{
if (fvsLhs.find(v) != fvsLhs.end())
{
// variable on left hand side
continue;
}
itc = condDef.find(v);
if (itc == condDef.end())
{
Unhandled()
<< "Free variable " << v << " in rule " << id
<< " is not on the left hand side, nor is defined in a condition";
}
// variable defined in the condition
d_condDefinedVars[v] = itc->second;
// ensure the defining term does not itself contain free variables
std::unordered_set<Node> fvst;
expr::getFreeVariables(itc->second, fvst);
for (const Node& vt : fvst)
{
if (fvsLhs.find(vt) == fvsLhs.end())
{
Unhandled() << "Free variable " << vt << " in rule " << id
<< " is not on the left hand side of the rule, and it is "
"used to give a definition to the free variable "
<< v;
}
}
}

d_numFv = fvs.size();

for (const Node& v : fvs)
{
d_fvs.push_back(v);
Expand Down Expand Up @@ -201,5 +246,19 @@ bool RewriteProofRule::isFixedPoint() const
{
return d_context != Node::null();
}

void RewriteProofRule::getConditionalDefinitions(const std::vector<Node>& vs,
const std::vector<Node>& ss,
std::vector<Node>& dvs,
std::vector<Node>& dss) const
{
for (const std::pair<const Node, Node>& cv : d_condDefinedVars)
{
dvs.push_back(cv.first);
Node cvs = expr::narySubstitute(cv.second, vs, ss);
dss.push_back(cvs);
}
}

} // namespace rewriter
} // namespace cvc5::internal
17 changes: 15 additions & 2 deletions src/rewriter/rewrite_proof_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,19 @@ class RewriteProofRule
Kind getListContext(Node v) const;
/** Was this rule marked as being applied to fixed point? */
bool isFixedPoint() const;
/** Is this rule in flat form? */
bool isFlatForm() const;
/**
* Get condition definitions given an application vs -> ss of this rule.
* This is used to handle variables that do not occur in the left hand side
* of rewrite rules and are defined in conditions of this rule.
* @param vs The matched variables of this rule.
* @param ss The terms to substitute in this rule for each vs.
* @param dvs The variables for which a definition can now be inferred.
* @param dss The terms that each dvs are defined as, for each dvs.
*/
void getConditionalDefinitions(const std::vector<Node>& vs,
const std::vector<Node>& ss,
std::vector<Node>& dvs,
std::vector<Node>& dss) const;

private:
/** The id of the rule */
Expand All @@ -179,6 +190,8 @@ class RewriteProofRule
* "holes" in a proof.
*/
std::unordered_set<Node> d_noOccVars;
/** Maps variables to the term they are defined to be */
std::map<Node, Node> d_condDefinedVars;
/** The context for list variables (see expr::getListVarContext). */
std::map<Node, Node> d_listVarCtx;
/** The match trie (for fixed point matching) */
Expand Down
10 changes: 8 additions & 2 deletions src/theory/arith/arith_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,9 +1003,15 @@ RewriteResponse ArithRewriter::postRewriteIntsLog2(TNode t)
{
// pow2 is only supported for integers
Assert(t[0].getType().isInteger());
Integer i = t[0].getConst<Rational>().getNumerator();
const Rational& r = t[0].getConst<Rational>();
if (r.sgn() < 0)
{
return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
}
Integer i = r.getNumerator();
size_t const length = i.length();
return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(length)));
return RewriteResponse(REWRITE_DONE,
rewriter::mkConst(Integer(length - 1)));
}
return RewriteResponse(REWRITE_DONE, t);
}
Expand Down
11 changes: 9 additions & 2 deletions src/theory/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,15 @@ EvalResult Evaluator::evalInternal(
case Kind::INTS_LOG2:
{
const Rational& x = results[currNode[0]].d_rat;
results[currNode] =
EvalResult(Rational(x.getNumerator().length() - 1));
if (x.sgn() < 0)
{
results[currNode] = EvalResult(Rational(0));
}
else
{
results[currNode] =
EvalResult(Rational(x.getNumerator().length() - 1));
}
break;
}
case Kind::CONST_STRING:
Expand Down

0 comments on commit 989b077

Please sign in to comment.