Skip to content

Commit

Permalink
booleans/builtin: Refactor to not use NodeManager::currentNM() (cvc5#…
Browse files Browse the repository at this point in the history
…11522)

This PR introduces some calls to NodeManager::currentNM(), which will be
removed in subsequent PRs.
  • Loading branch information
daniel-larraz authored Jan 14, 2025
1 parent ea7a520 commit 8dca52e
Show file tree
Hide file tree
Showing 19 changed files with 112 additions and 77 deletions.
3 changes: 2 additions & 1 deletion src/smt/proof_post_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,8 @@ Node ProofPostprocessCallback::expandMacros(ProofRule id,
{
// update to TRUST_THEORY_REWRITE with idr
Assert(args.size() >= 1);
Node tid = builtin::BuiltinProofRuleChecker::mkTheoryIdNode(theoryId);
Node tid = builtin::BuiltinProofRuleChecker::mkTheoryIdNode(
nodeManager(), theoryId);
cdp->addStep(
eq, ProofRule::TRUST_THEORY_REWRITE, {}, {eq, tid, args[1]});
}
Expand Down
3 changes: 2 additions & 1 deletion src/theory/arith/pp_rewrite_eq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ TrustNode PreprocessRewriteEq::ppRewriteEq(TNode atom)
// don't need to rewrite terms since rewritten is not a non-standard op
if (d_env.isTheoryProofProducing())
{
Node t = builtin::BuiltinProofRuleChecker::mkTheoryIdNode(THEORY_ARITH);
Node t = builtin::BuiltinProofRuleChecker::mkTheoryIdNode(nodeManager(),
THEORY_ARITH);
Node eq = atom.eqNode(rewritten);
return d_ppPfGen.mkTrustedRewrite(
atom,
Expand Down
18 changes: 12 additions & 6 deletions src/theory/booleans/circuit_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void CircuitPropagator::assertTrue(TNode assertion)
else if (assertion.getKind() == Kind::AND)
{
ProofCircuitPropagatorBackward prover{
d_env.getProofNodeManager(), assertion, true};
d_env.getNodeManager(), d_env.getProofNodeManager(), assertion, true};
if (isProofEnabled())
{
addProof(assertion, prover.assume(assertion));
Expand Down Expand Up @@ -167,7 +167,8 @@ void CircuitPropagator::makeConflict(Node n)
{
return;
}
ProofCircuitPropagator pcp(d_env.getProofNodeManager());
ProofCircuitPropagator pcp(d_env.getNodeManager(),
d_env.getProofNodeManager());
if (n == bfalse)
{
d_epg->setProofFor(bfalse, pcp.assume(bfalse));
Expand Down Expand Up @@ -239,8 +240,10 @@ void CircuitPropagator::propagateBackward(TNode parent, bool parentAssignment)
{
Trace("circuit-prop") << "CircuitPropagator::propagateBackward(" << parent
<< ", " << parentAssignment << ")" << endl;
ProofCircuitPropagatorBackward prover{
d_env.getProofNodeManager(), parent, parentAssignment};
ProofCircuitPropagatorBackward prover{d_env.getNodeManager(),
d_env.getProofNodeManager(),
parent,
parentAssignment};

// backward rules
switch (parent.getKind())
Expand Down Expand Up @@ -452,8 +455,11 @@ void CircuitPropagator::propagateForward(TNode child, bool childAssignment)
Trace("circuit-prop") << "Parent: " << parent << endl;
Assert(expr::hasSubterm(parent, child));

ProofCircuitPropagatorForward prover{
d_env.getProofNodeManager(), child, childAssignment, parent};
ProofCircuitPropagatorForward prover{d_env.getNodeManager(),
d_env.getProofNodeManager(),
child,
childAssignment,
parent};

// Forward rules
switch (parent.getKind())
Expand Down
66 changes: 36 additions & 30 deletions src/theory/booleans/proof_circuit_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ namespace {

/** Shorthand to create a Node from a constant number */
template <typename T>
Node mkInt(T val)
Node mkInt(NodeManager* nm, T val)
{
return NodeManager::currentNM()->mkConstInt(Rational(val));
return nm->mkConstInt(Rational(val));
}

/**
Expand All @@ -58,8 +58,9 @@ inline std::vector<Node> collectButHoldout(TNode parent,

} // namespace

ProofCircuitPropagator::ProofCircuitPropagator(ProofNodeManager* pnm)
: d_pnm(pnm)
ProofCircuitPropagator::ProofCircuitPropagator(NodeManager* nm,
ProofNodeManager* pnm)
: d_nm(nm), d_pnm(pnm)
{
}

Expand Down Expand Up @@ -270,7 +271,6 @@ std::shared_ptr<ProofNode> ProofCircuitPropagator::mkCResolution(
const std::vector<Node>& lits,
const std::vector<bool>& polarity)
{
auto* nm = NodeManager::currentNM();
std::vector<std::shared_ptr<ProofNode>> children = {clause};
std::vector<Node> cpols;
std::vector<Node> clits;
Expand All @@ -296,12 +296,12 @@ std::shared_ptr<ProofNode> ProofCircuitPropagator::mkCResolution(
{
children.emplace_back(assume(lit));
}
cpols.emplace_back(nm->mkConst(pol));
cpols.emplace_back(d_nm->mkConst(pol));
clits.emplace_back(lit);
}
std::vector<Node> args;
args.push_back(nm->mkNode(Kind::SEXPR, cpols));
args.push_back(nm->mkNode(Kind::SEXPR, clits));
args.push_back(d_nm->mkNode(Kind::SEXPR, cpols));
args.push_back(d_nm->mkNode(Kind::SEXPR, clits));
return mkProof(ProofRule::CHAIN_RESOLUTION, children, args);
}

Expand All @@ -316,21 +316,21 @@ std::shared_ptr<ProofNode> ProofCircuitPropagator::mkCResolution(
std::shared_ptr<ProofNode> ProofCircuitPropagator::mkResolution(
const std::shared_ptr<ProofNode>& clause, const Node& lit, bool polarity)
{
auto* nm = NodeManager::currentNM();
if (polarity)
{
if (lit.getKind() == Kind::NOT)
{
return mkProof(ProofRule::RESOLUTION,
{clause, assume(lit[0])},
{nm->mkConst(false), lit[0]});
{d_nm->mkConst(false), lit[0]});
}
return mkProof(ProofRule::RESOLUTION,
{clause, assume(lit.notNode())},
{nm->mkConst(true), lit});
{d_nm->mkConst(true), lit});
}
return mkProof(
ProofRule::RESOLUTION, {clause, assume(lit)}, {nm->mkConst(false), lit});
return mkProof(ProofRule::RESOLUTION,
{clause, assume(lit)},
{d_nm->mkConst(false), lit});
}

std::shared_ptr<ProofNode> ProofCircuitPropagator::mkNot(
Expand All @@ -345,8 +345,8 @@ std::shared_ptr<ProofNode> ProofCircuitPropagator::mkNot(
}

ProofCircuitPropagatorBackward::ProofCircuitPropagatorBackward(
ProofNodeManager* pnm, TNode parent, bool parentAssignment)
: ProofCircuitPropagator(pnm),
NodeManager* nm, ProofNodeManager* pnm, TNode parent, bool parentAssignment)
: ProofCircuitPropagator(nm, pnm),
d_parent(parent),
d_parentAssignment(parentAssignment)
{
Expand All @@ -359,8 +359,9 @@ std::shared_ptr<ProofNode> ProofCircuitPropagatorBackward::andTrue(
{
return nullptr;
}
return mkProof(
ProofRule::AND_ELIM, {assume(d_parent)}, {mkInt(i - d_parent.begin())});
return mkProof(ProofRule::AND_ELIM,
{assume(d_parent)},
{mkInt(d_nm, i - d_parent.begin())});
}

std::shared_ptr<ProofNode> ProofCircuitPropagatorBackward::orFalse(
Expand All @@ -372,7 +373,7 @@ std::shared_ptr<ProofNode> ProofCircuitPropagatorBackward::orFalse(
}
return mkNot(mkProof(ProofRule::NOT_OR_ELIM,
{assume(d_parent.notNode())},
{mkInt(i - d_parent.begin())}));
{mkInt(d_nm, i - d_parent.begin())}));
}

std::shared_ptr<ProofNode> ProofCircuitPropagatorBackward::iteC(bool c)
Expand Down Expand Up @@ -437,8 +438,12 @@ std::shared_ptr<ProofNode> ProofCircuitPropagatorBackward::impliesNegY()
}

ProofCircuitPropagatorForward::ProofCircuitPropagatorForward(
ProofNodeManager* pnm, Node child, bool childAssignment, Node parent)
: ProofCircuitPropagator{pnm},
NodeManager* nm,
ProofNodeManager* pnm,
Node child,
bool childAssignment,
Node parent)
: ProofCircuitPropagator{nm, pnm},
d_child(child),
d_childAssignment(childAssignment),
d_parent(parent)
Expand Down Expand Up @@ -466,11 +471,11 @@ std::shared_ptr<ProofNode> ProofCircuitPropagatorForward::andOneFalse()
return nullptr;
}
auto it = std::find(d_parent.begin(), d_parent.end(), d_child);
return mkResolution(
mkProof(
ProofRule::CNF_AND_POS, {}, {d_parent, mkInt(it - d_parent.begin())}),
d_child,
true);
return mkResolution(mkProof(ProofRule::CNF_AND_POS,
{},
{d_parent, mkInt(d_nm, it - d_parent.begin())}),
d_child,
true);
}

std::shared_ptr<ProofNode> ProofCircuitPropagatorForward::orOneTrue()
Expand All @@ -480,11 +485,12 @@ std::shared_ptr<ProofNode> ProofCircuitPropagatorForward::orOneTrue()
return nullptr;
}
auto it = std::find(d_parent.begin(), d_parent.end(), d_child);
return mkNot(mkResolution(
mkProof(
ProofRule::CNF_OR_NEG, {}, {d_parent, mkInt(it - d_parent.begin())}),
d_child,
false));
return mkNot(
mkResolution(mkProof(ProofRule::CNF_OR_NEG,
{},
{d_parent, mkInt(d_nm, it - d_parent.begin())}),
d_child,
false));
}

std::shared_ptr<ProofNode> ProofCircuitPropagatorForward::orFalse()
Expand Down
10 changes: 7 additions & 3 deletions src/theory/booleans/proof_circuit_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace booleans {
class ProofCircuitPropagator
{
public:
ProofCircuitPropagator(ProofNodeManager* pnm);
ProofCircuitPropagator(NodeManager* nm, ProofNodeManager* pnm);

/** Assuming the given node */
std::shared_ptr<ProofNode> assume(Node n);
Expand Down Expand Up @@ -117,6 +117,8 @@ class ProofCircuitPropagator
/** Apply NOT_NOT_ELIM rule if n.getResult() is a nested negation */
std::shared_ptr<ProofNode> mkNot(const std::shared_ptr<ProofNode>& n);

/** The associated node manager */
NodeManager* d_nm;
/** The proof node manager */
ProofNodeManager* d_pnm;
};
Expand All @@ -128,7 +130,8 @@ class ProofCircuitPropagator
class ProofCircuitPropagatorBackward : public ProofCircuitPropagator
{
public:
ProofCircuitPropagatorBackward(ProofNodeManager* pnm,
ProofCircuitPropagatorBackward(NodeManager* nm,
ProofNodeManager* pnm,
TNode parent,
bool parentAssignment);

Expand Down Expand Up @@ -172,7 +175,8 @@ class ProofCircuitPropagatorBackward : public ProofCircuitPropagator
class ProofCircuitPropagatorForward : public ProofCircuitPropagator
{
public:
ProofCircuitPropagatorForward(ProofNodeManager* pnm,
ProofCircuitPropagatorForward(NodeManager* nm,
ProofNodeManager* pnm,
Node child,
bool childAssignment,
Node parent);
Expand Down
10 changes: 6 additions & 4 deletions src/theory/builtin/generic_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,12 @@ bool convertToNumeralList(const std::vector<Node>& indices,
return true;
}

Node GenericOp::getOperatorForIndices(Kind k, const std::vector<Node>& indices)
Node GenericOp::getOperatorForIndices(NodeManager* nm,
Kind k,
const std::vector<Node>& indices)
{
// all indices should be constant!
Assert(isIndexedOperatorKind(k));
NodeManager* nm = NodeManager::currentNM();
if (isNumeralIndexedOperatorKind(k))
{
std::vector<uint32_t> numerals;
Expand Down Expand Up @@ -400,7 +401,8 @@ Node GenericOp::getConcreteApp(const Node& app)
// usually one, but we handle cases where it is >1.
size_t nargs = metakind::getMinArityForKind(okind);
std::vector<Node> indices(app.begin(), app.end() - nargs);
Node op = getOperatorForIndices(okind, indices);
NodeManager* nm = NodeManager::currentNM();
Node op = getOperatorForIndices(nm, okind, indices);
// could have a bad index, in which case we don't rewrite
if (op.isNull())
{
Expand All @@ -409,7 +411,7 @@ Node GenericOp::getConcreteApp(const Node& app)
std::vector<Node> args;
args.push_back(op);
args.insert(args.end(), app.end() - nargs, app.end());
Node ret = NodeManager::currentNM()->mkNode(okind, args);
Node ret = nm->mkNode(okind, args);
// could be ill typed, in which case we don't rewrite
if (ret.getTypeOrNull(true).isNull())
{
Expand Down
4 changes: 3 additions & 1 deletion src/theory/builtin/generic_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class GenericOp
* Return the operator of kind k whose indices are the constants in the
* given vector.
*/
static Node getOperatorForIndices(Kind k, const std::vector<Node>& indices);
static Node getOperatorForIndices(NodeManager* nm,
Kind k,
const std::vector<Node>& indices);
/**
* Get the concrete term corresponding to the application of
* APPLY_INDEXED_SYMBOLIC. Requires all indices to be constant.
Expand Down
5 changes: 2 additions & 3 deletions src/theory/builtin/proof_checker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,9 @@ bool BuiltinProofRuleChecker::getTheoryId(TNode n, TheoryId& tid)
return true;
}

Node BuiltinProofRuleChecker::mkTheoryIdNode(TheoryId tid)
Node BuiltinProofRuleChecker::mkTheoryIdNode(NodeManager* nm, TheoryId tid)
{
return NodeManager::currentNM()->mkConstInt(
Rational(static_cast<uint32_t>(tid)));
return nm->mkConstInt(Rational(static_cast<uint32_t>(tid)));
}

} // namespace builtin
Expand Down
2 changes: 1 addition & 1 deletion src/theory/builtin/proof_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class BuiltinProofRuleChecker : public ProofRuleChecker
/** get a TheoryId from a node, return false if we fail */
static bool getTheoryId(TNode n, TheoryId& tid);
/** Make a TheoryId into a node */
static Node mkTheoryIdNode(TheoryId tid);
static Node mkTheoryIdNode(NodeManager* nm, TheoryId tid);
/**
* @param nm The node manager.
* @param n The term to rewrite via ENCODE_EQ_INTRO.
Expand Down
5 changes: 4 additions & 1 deletion src/theory/ff/cocoa_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ CoCoA::symbol cocoaSym(const std::string& varName, std::optional<size_t> index)
return index.has_value() ? CoCoA::symbol(s, *index) : CoCoA::symbol(s);
}

CocoaEncoder::CocoaEncoder(const FfSize& size) : FieldObj(size) {}
CocoaEncoder::CocoaEncoder(NodeManager* nm, const FfSize& size)
: FieldObj(nm, size)
{
}

CoCoA::symbol CocoaEncoder::freshSym(const std::string& varName,
std::optional<size_t> index)
Expand Down
2 changes: 1 addition & 1 deletion src/theory/ff/cocoa_encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class CocoaEncoder : public FieldObj
{
public:
/** Create a new encoder, for this field. */
CocoaEncoder(const FfSize& size);
CocoaEncoder(NodeManager* nm, const FfSize& size);
/** Add a fact (one must call this twice per fact, once per stage). */
void addFact(const Node& fact);
/** Start Stage::Encode. */
Expand Down
2 changes: 1 addition & 1 deletion src/theory/ff/split_gb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ std::optional<std::unordered_map<Node, FiniteFieldValue>> split(
const std::vector<Node>& facts, const FfSize& size, const Env& env)
{
std::unordered_set<Node> bits{};
CocoaEncoder enc(size);
CocoaEncoder enc(env.getNodeManager(), size);
for (const auto& fact : facts)
{
enc.addFact(fact);
Expand Down
7 changes: 5 additions & 2 deletions src/theory/ff/sub_theory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ namespace theory {
namespace ff {

SubTheory::SubTheory(Env& env, FfStatistics* stats, Integer modulus)
: EnvObj(env), FieldObj(modulus), d_facts(context()), d_stats(stats)
: EnvObj(env),
FieldObj(nodeManager(), modulus),
d_facts(context()),
d_stats(stats)
{
AlwaysAssert(modulus.isProbablePrime()) << "non-prime fields are unsupported";
// must be initialized before using CoCoA.
Expand Down Expand Up @@ -85,7 +88,7 @@ Result SubTheory::postCheck(Theory::Effort e)
}
else if (options().ff.ffSolver == options::FfSolver::GB)
{
CocoaEncoder enc(size());
CocoaEncoder enc(nodeManager(), size());
// collect leaves
for (const Node& node : d_facts)
{
Expand Down
4 changes: 2 additions & 2 deletions src/theory/ff/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace cvc5::internal {
namespace theory {
namespace ff {

FieldObj::FieldObj(const FfSize& size)
FieldObj::FieldObj(NodeManager* nm, const FfSize& size)
: d_size(size),
d_nm(NodeManager::currentNM()),
d_nm(nm),
d_zero(d_nm->mkConst(FiniteFieldValue(0, d_size))),
d_one(d_nm->mkConst(FiniteFieldValue(1, d_size)))
#ifdef CVC5_USE_COCOA
Expand Down
2 changes: 1 addition & 1 deletion src/theory/ff/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ using FfModel = std::unordered_map<Node, FiniteFieldValue>;
class FieldObj
{
public:
FieldObj(const FfSize& size);
FieldObj(NodeManager* nm, const FfSize& size);
/** create a sum (with as few as 0 elements); accepts Nodes or TNodes */
template <bool ref_count>
Node mkAdd(const std::vector<NodeTemplate<ref_count>>& summands);
Expand Down
Loading

0 comments on commit 8dca52e

Please sign in to comment.