Skip to content

Commit

Permalink
feat: add user-defined fallback procedure for the grind tactic (#6512)
Browse files Browse the repository at this point in the history
This PR introduces support for user-defined fallback code in the `grind`
tactic. The fallback code can be utilized to inspect the state of
failing `grind` subgoals and/or invoke user-defined automation. Users
can now write `grind on_failure <code>`, where `<code>` should have the
type `GoalM Unit`. See the modified tests in this PR for examples.
  • Loading branch information
leodemoura authored Jan 2, 2025
1 parent 9d62227 commit 3e2f1fa
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 191 deletions.
2 changes: 1 addition & 1 deletion src/Init/Grind/Tactics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ namespace Lean.Parser.Tactic
-/

-- TODO: parameters
syntax (name := grind) "grind" optConfig : tactic
syntax (name := grind) "grind" optConfig ("on_failure " term)? : tactic

end Lean.Parser.Tactic
26 changes: 22 additions & 4 deletions src/Lean/Elab/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,36 @@ def elabGrindPattern : CommandElab := fun stx => do
Grind.addEMatchTheorem declName xs.size patterns.toList
| _ => throwUnsupportedSyntax

def grind (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) : MetaM Unit := do
let mvarIds ← Grind.main mvarId config mainDeclName
def grind (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallback : Grind.Fallback) : MetaM Unit := do
let mvarIds ← Grind.main mvarId config mainDeclName fallback
unless mvarIds.isEmpty do
throwError "`grind` failed\n{goalsToMessageData mvarIds}"

private def elabFallback (fallback? : Option Term) : TermElabM (Grind.GoalM Unit) := do
let some fallback := fallback? | return (pure ())
let type := mkApp (mkConst ``Grind.GoalM) (mkConst ``Unit)
let value ← withLCtx {} {} do Term.elabTermAndSynthesize fallback type
let auxDeclName ← if let .const declName _ := value then
pure declName
else
let auxDeclName ← Term.mkAuxName `_grind_fallback
let decl := Declaration.defnDecl {
name := auxDeclName
levelParams := []
type, value, hints := .opaque, safety := .safe
}
addAndCompile decl
pure auxDeclName
unsafe evalConst (Grind.GoalM Unit) auxDeclName

@[builtin_tactic Lean.Parser.Tactic.grind] def evalApplyRfl : Tactic := fun stx => do
match stx with
| `(tactic| grind $config:optConfig) =>
| `(tactic| grind $config:optConfig $[on_failure $fallback?]?) =>
let fallback ← elabFallback fallback?
logWarningAt stx "The `grind` tactic is experimental and still under development. Avoid using it in production projects"
let declName := (← Term.getDeclName?).getD `_grind
let config ← elabGrindConfig config
withMainContext do liftMetaFinishingTactic (grind · config declName)
withMainContext do liftMetaFinishingTactic (grind · config declName fallback)
| _ => throwUnsupportedSyntax

end Lean.Elab.Tactic
28 changes: 12 additions & 16 deletions src/Lean/Meta/Tactic/Grind/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ import Lean.Meta.Tactic.Grind.EMatch

namespace Lean.Meta.Grind

def mkMethods : CoreM Methods := do
def mkMethods (fallback : Fallback) : CoreM Methods := do
let builtinPropagators ← builtinPropagatorsRef.get
return {
fallback
propagateUp := fun e => do
propagateForallProp e
let .const declName _ := e.getAppFn | return ()
Expand All @@ -32,7 +33,7 @@ def mkMethods : CoreM Methods := do
prop e
}

def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) : MetaM α := do
def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fallback : Fallback) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
Expand All @@ -42,7 +43,7 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) : M
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := (← getSimpCongrTheorems))
x (← mkMethods).toMethodsRef { mainDeclName, config, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }
x (← mkMethods fallback).toMethodsRef { mainDeclName, config, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }

private def mkGoal (mvarId : MVarId) : GrindM Goal := do
let trueExpr ← getTrueExpr
Expand Down Expand Up @@ -71,23 +72,18 @@ def all (goals : List Goal) (f : Goal → GrindM (List Goal)) : GrindM (List Goa
private def simple (goals : List Goal) : GrindM (List Goal) := do
all goals ematchStar

def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) : MetaM (List MVarId) := do
def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallback : Fallback) : MetaM (List MVarId) := do
let go : GrindM (List MVarId) := do
let goals ← initCore mvarId
let goals ← simple goals
let goals ← goals.filterMapM fun goal => do
if goal.inconsistent then return none
let goal ← GoalM.run' goal fallback
if goal.inconsistent then return none
if (← goal.mvarId.isAssigned) then return none
return some goal
trace[grind.debug.final] "{← ppGoals goals}"
return goals.map (·.mvarId)
go.run mainDeclName config

/-- Helper function for debugging purposes -/
def preprocessAndProbe (mvarId : MVarId) (mainDeclName : Name) (p : GoalM Unit) : MetaM Unit :=
let go : GrindM Unit := do
let goals ← initCore mvarId
trace[grind.debug.final] "{← ppGoals goals}"
goals.forM fun goal =>
discard <| GoalM.run' goal p
return ()
withoutModifyingMCtx do
go.run mainDeclName {}
go.run mainDeclName config fallback

end Lean.Meta.Grind
10 changes: 8 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,6 @@ abbrev GoalM := StateRefT Goal GrindM
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal

abbrev Propagator := Expr → GoalM Unit

/--
A helper function used to mark a theorem instance found by the E-matching module.
It returns `true` if it is a new instance and `false` otherwise.
Expand Down Expand Up @@ -677,9 +675,13 @@ def forEachEqc (f : ENode → GoalM Unit) : GoalM Unit := do
if isSameExpr n.self n.root then
f n

abbrev Propagator := Expr → GoalM Unit
abbrev Fallback := GoalM Unit

structure Methods where
propagateUp : Propagator := fun _ => return ()
propagateDown : Propagator := fun _ => return ()
fallback : Fallback := pure ()
deriving Inhabited

def Methods.toMethodsRef (m : Methods) : MethodsRef :=
Expand All @@ -697,6 +699,10 @@ def propagateUp (e : Expr) : GoalM Unit := do
def propagateDown (e : Expr) : GoalM Unit := do
(← getMethods).propagateDown e

def applyFallback : GoalM Unit := do
let fallback : GoalM Unit := (← getMethods).fallback
fallback

/-- Returns expressions in the given expression equivalence class. -/
partial def getEqc (e : Expr) : GoalM (List Expr) :=
go e e []
Expand Down
29 changes: 11 additions & 18 deletions tests/lean/run/grind_canon_insts.lean
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
import Lean

open Lean Meta Elab Tactic Grind in
elab "grind_test" : tactic => withMainContext do
let declName := (← Term.getDeclName?).getD `_main
Meta.Grind.preprocessAndProbe (← getMainGoal) declName do
let nodes ← filterENodes fun e => return e.self.isAppOf ``HMul.hMul
logInfo (nodes.toList.map (·.self))
import Lean.Meta.Tactic.Grind

set_option grind.debug true

Expand Down Expand Up @@ -57,26 +50,27 @@ instance : CommMonoid Nat where
theorem left_comm [CommMonoid α] (a b c : α) : a * (b * c) = b * (a * c) := by
rw [← Semigroup.mul_assoc, CommMonoid.mul_comm a b, Semigroup.mul_assoc]

open Lean Meta Elab Tactic Grind in
def fallback : Fallback := do
let nodes ← filterENodes fun e => return e.self.isAppOf ``HMul.hMul
logInfo (nodes.toList.map (·.self))
(← get).mvarId.admit

/--
info: [b * c, a * (b * c), d * (b * c)]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
example (a b c d : Nat) : b * (a * c) = d * (b * c) → False := by
rw [left_comm] -- Introduces a new (non-canonical) instance for `Mul Nat`
grind_test -- State should have only 3 `*`-applications
sorry
grind on_failure fallback -- State should have only 3 `*`-applications


set_option pp.notation false in
set_option pp.explicit true in
/--
info: [@HMul.hMul Nat Nat Nat (@instHMul Nat instMulNat) b a, @HMul.hMul Nat Nat Nat (@instHMul Nat instMulNat) b d]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
example (a b c d : Nat) : b * a = d * b → False := by
rw [CommMonoid.mul_comm d b] -- Introduces a new (non-canonical) instance for `Mul Nat`
-- See target here
Expand All @@ -85,5 +79,4 @@ example (a b c d : Nat) : b * a = d * b → False := by
=
@HMul.hMul Nat Nat Nat (@instHMul Nat (@Semigroup.toMul Nat (@Monoid.toSemigroup Nat (@CommMonoid.toMonoid Nat instCommMonoidNat)))) b d
→ False
grind_test -- State should have only 2 `*`-applications, and they use the same instance
sorry
grind on_failure fallback -- State should have only 2 `*`-applications, and they use the same instance
21 changes: 8 additions & 13 deletions tests/lean/run/grind_canon_types.lean
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
import Lean
import Lean.Meta.Tactic.Grind

def g (s : Type) := s
def f (a : α) := a

open Lean Meta Elab Tactic Grind in
elab "grind_test" : tactic => withMainContext do
let declName := (← Term.getDeclName?).getD `_main
Meta.Grind.preprocessAndProbe (← getMainGoal) declName do
let nodes ← filterENodes fun e => return e.self.isAppOf ``f
logInfo (nodes.toList.map (·.self))

open Lean Meta Grind in
def fallback : Fallback := do
let nodes ← filterENodes fun e => return e.self.isAppOf ``f
logInfo (nodes.toList.map (·.self))
(← get).mvarId.admit

set_option pp.explicit true
/--
info: [@f Nat a, @f Nat b]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
example (a b c d : Nat) : @f Nat a = b → @f (g Nat) a = c → @f (g Nat) b = d → a = b → False := by
-- State should have only two `f`-applications: `@f Nat a`, `@f Nat b`
-- Note that `@f (g Nat) b` has been canonicalized to `@f Nat b`.
-- Thus, if `a` and `b` equivalence classes are merged, `grind` can still detect that
-- `@f Nat a` and `@f Nat b` are equal too.
grind_test
sorry
grind on_failure fallback
41 changes: 14 additions & 27 deletions tests/lean/run/grind_congr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,40 @@ def f (a : Nat) := a + a + a
def g (a : Nat) := a + a

-- Prints the equivalence class containing a `f` application
open Lean Meta Elab Tactic Grind in
elab "grind_test" : tactic => withMainContext do
let declName := (← Term.getDeclName?).getD `_main
Meta.Grind.preprocessAndProbe (← getMainGoal) declName do
let #[n, _] ← filterENodes fun e => return e.self.isAppOf ``f | unreachable!
let eqc ← getEqc n.self
logInfo eqc
open Lean Meta Grind in
def fallback : Fallback := do
let #[n, _] ← filterENodes fun e => return e.self.isAppOf ``f | unreachable!
let eqc ← getEqc n.self
logInfo eqc
(← get).mvarId.admit

set_option grind.debug true
set_option grind.debug.proofs true

/--
info: [d, f b, c, f a]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
example (a b c d : Nat) : a = b → f a = c → f b = d → False := by
grind_test
sorry
grind on_failure fallback

/--
info: [d, f b, c, f a]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
example (a b c d : Nat) : f a = c → f b = d → a = b → False := by
grind_test
sorry
grind on_failure fallback

/--
info: [d, f (g b), c, f (g a)]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
example (a b c d e : Nat) : f (g a) = c → f (g b) = d → a = e → b = e → False := by
grind_test
sorry
grind on_failure fallback

/--
info: [d, f (g b), c, f v]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
example (a b c d e v : Nat) : f v = c → f (g b) = d → a = e → b = e → v = g a → False := by
grind_test
sorry
grind on_failure fallback
25 changes: 25 additions & 0 deletions tests/lean/run/grind_ematch2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,28 @@ example (as bs cs : Array α) (v₁ v₂ : α)
(h₆ : j < as.size)
: cs[j] = as[j] := by
grind

example (as bs cs ds : Array α) (v₁ v₂ v₃ : α)
(i₁ i₂ i₃ j : Nat)
(h₁ : i₁ < as.size)
(h₂ : as.set i₁ v₁ = bs)
(h₃ : i₂ < bs.size)
(h₃ : bs.set i₂ v₂ = cs)
(h₄ : i₃ < cs.size)
(h₅ : ds = cs.set i₃ v₃)
(h₆ : j ≠ i₁ ∧ j ≠ i₂ ∧ i₃ ≠ j)
(h₇ : j < ds.size)
(h₈ : j < as.size)
: ds[j] = as[j] := by
grind

opaque f (a b : α) : α := a
theorem fx : f x (f x x) = x := sorry
grind_pattern fx => f x (f x x)

/--
info: [grind.ematch.instance] fx: f a (f a a) = a
-/
#guard_msgs (info) in
example : a = b₁ → c = f b₁ b₂ → f a c ≠ a → a = b₂ → False := by
grind
30 changes: 13 additions & 17 deletions tests/lean/run/grind_many_eqs.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Lean
import Lean.Meta.Tactic.Grind

def f (a : Nat) := a + a + a
def g (a : Nat) := a + a
Expand All @@ -8,27 +8,23 @@ def h (n : Nat) : Prop :=
| n+1 => f (n+1) = f n ∧ g (2*n + 1) = g (2*n) ∧ h n

-- Prints the equivalence class containing a `f` application
open Lean Meta Elab Tactic Grind in
elab "grind_test" n:num : tactic => withMainContext do
let n := n.getNat
let declName := (← Term.getDeclName?).getD `_main
Meta.Grind.preprocessAndProbe (← getMainGoal) declName do
let f0 ← Grind.shareCommon (mkApp (mkConst ``f) (mkNatLit 0))
-- The `f 0` equivalence class contains `n+1` elements
assert! (← getEqc f0).length == n + 1
forEachENode fun node => do
if node.self.isAppOf ``g then
-- Any equivalence class containing a `g`-application contains 2 elements
assert! (← getEqc (← getRoot node.self)).length == 2
open Lean Meta Grind in
def fallback (n : Nat) : Fallback := do
let f0 ← Grind.shareCommon (mkApp (mkConst ``f) (mkNatLit 0))
-- The `f 0` equivalence class contains `n+1` elements
assert! (← getEqc f0).length == n + 1
forEachENode fun node => do
if node.self.isAppOf ``g then
-- Any equivalence class containing a `g`-application contains 2 elements
assert! (← getEqc (← getRoot node.self)).length == 2
(← get).mvarId.admit

set_option grind.debug true in
example : h 5 → False := by
simp [h]
grind_test 5
sorry
grind on_failure fallback 5

set_option maxRecDepth 2048
example : h 100 → False := by
simp [h]
grind_test 100
sorry
grind on_failure fallback 100
Loading

0 comments on commit 3e2f1fa

Please sign in to comment.