Skip to content

Commit

Permalink
fix: grind canonicalizer state management (#6649)
Browse files Browse the repository at this point in the history
This PR fixes a bug in the term canonicalizer used in the `grind`
tactic.
  • Loading branch information
leodemoura authored Jan 15, 2025
1 parent 0f7f80a commit a955708
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 25 deletions.
2 changes: 2 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Canon.lean
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def canonElemCore (f : Expr) (i : Nat) (e : Expr) (kind : CanonElemKind) : State
-- We used to check `c.fvarsSubset e` because it is not
-- in general safe to replace `e` with `c` if `c` has more free variables than `e`.
-- However, we don't revert previously canonicalized elements in the `grind` tactic.
-- Moreover, we store the canonicalizer state in the `Goal` because we case-split
-- and different locals are added in different branches.
modify fun s => { s with canon := s.canon.insert e c }
trace[grind.debug.canon] "found {e} ===> {c}"
return c
Expand Down
27 changes: 15 additions & 12 deletions src/Lean/Meta/Tactic/Grind/Intro.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ private inductive IntroResult where
private def introNext (goal : Goal) (generation : Nat) : GrindM IntroResult := do
let target ← goal.mvarId.getType
if target.isArrow then
goal.mvarId.withContext do
let (r, _) ← GoalM.run goal do
let mvarId := (← get).mvarId
let p := target.bindingDomain!
if !(← isProp p) then
let (fvarId, mvarId) ← goal.mvarId.intro1P
return .newLocal fvarId { goal with mvarId }
let (fvarId, mvarId) ← mvarId.intro1P
return .newLocal fvarId { (← get) with mvarId }
else
let tag ← goal.mvarId.getTag
let tag ← mvarId.getTag
let q := target.bindingBody!
-- TODO: keep applying simp/eraseIrrelevantMData/canon/shareCommon until no progress
let r ← simp p
Expand All @@ -44,12 +45,13 @@ private def introNext (goal : Goal) (generation : Nat) : GrindM IntroResult := d
match r.proof? with
| some he =>
let hNew := mkAppN (mkConst ``Lean.Grind.intro_with_eq) #[p, r.expr, q, he, h]
goal.mvarId.assign hNew
return .newHyp fvarId { goal with mvarId := mvarIdNew }
mvarId.assign hNew
return .newHyp fvarId { (← get) with mvarId := mvarIdNew }
| none =>
-- `p` and `p'` are definitionally equal
goal.mvarId.assign h
return .newHyp fvarId { goal with mvarId := mvarIdNew }
mvarId.assign h
return .newHyp fvarId { (← get) with mvarId := mvarIdNew }
return r
else if target.isLet || target.isForall || target.isLetFun then
let (fvarId, mvarId) ← goal.mvarId.intro1P
mvarId.withContext do
Expand All @@ -61,10 +63,11 @@ private def introNext (goal : Goal) (generation : Nat) : GrindM IntroResult := d
else
let goal := { goal with mvarId }
if target.isLet || target.isLetFun then
let v := (← fvarId.getDecl).value
let r ← simp v
let x ← shareCommon (mkFVar fvarId)
let goal ← GoalM.run' goal <| addNewEq x r.expr (← r.getProof) generation
let goal ← GoalM.run' goal do
let v := (← fvarId.getDecl).value
let r ← simp v
let x ← shareCommon (mkFVar fvarId)
addNewEq x r.expr (← r.getProof) generation
return .newLocal fvarId goal
else
return .newLocal fvarId goal
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Grind/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def simpCore (e : Expr) : GrindM Simp.Result := do
Simplifies `e` using `grind` normalization theorems and simprocs,
and then applies several other preprocessing steps.
-/
def simp (e : Expr) : GrindM Simp.Result := do
def simp (e : Expr) : GoalM Simp.Result := do
let e ← instantiateMVars e
let r ← simpCore e
let e' := r.expr
Expand Down
22 changes: 10 additions & 12 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ instance : Hashable CongrTheoremCacheKey where

/-- State for the `GrindM` monad. -/
structure State where
canon : Canon.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _
/-- Next index for creating auxiliary theorems. -/
Expand Down Expand Up @@ -133,18 +132,9 @@ Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have
been hash-consed. We perform this step before we internalize expressions.
-/
def shareCommon (e : Expr) : GrindM Expr := do
modifyGet fun { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, natZExpr, simpStats, lastTag } =>
modifyGet fun { scState, nextThmIdx, congrThms, trueExpr, falseExpr, natZExpr, simpStats, lastTag } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, natZExpr, simpStats, lastTag })

/--
Canonicalizes nested types, type formers, and instances in `e`.
-/
def canon (e : Expr) : GrindM Expr := do
let canonS ← modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) ← Canon.canon e |>.run canonS
modify fun s => { s with canon := canonS }
return e
(e, { scState, nextThmIdx, congrThms, trueExpr, falseExpr, natZExpr, simpStats, lastTag })

/-- Returns `true` if `e` is the internalized `True` expression. -/
def isTrueExpr (e : Expr) : GrindM Bool :=
Expand Down Expand Up @@ -345,6 +335,7 @@ structure NewFact where

structure Goal where
mvarId : MVarId
canon : Canon.State := {}
enodes : ENodeMap := {}
parents : ParentMap := {}
congrTable : CongrTable enodes := {}
Expand Down Expand Up @@ -406,6 +397,13 @@ abbrev GoalM := StateRefT Goal GrindM
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal

/-- Canonicalizes nested types, type formers, and instances in `e`. -/
def canon (e : Expr) : GoalM Expr := do
let canonS ← modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) ← Canon.canon e |>.run canonS
modify fun s => { s with canon := canonS }
return e

def updateLastTag : GoalM Unit := do
if (← isTracingEnabledFor `grind) then
let currTag ← (← get).mvarId.getTag
Expand Down

0 comments on commit a955708

Please sign in to comment.