Skip to content

Commit

Permalink
feat: [grind_eq] attribute for the grind tactic (#6539)
Browse files Browse the repository at this point in the history
This PR introduces the `[grind_eq]` attribute, designed to annotate
equational theorems and functions for heuristic instantiations in the
`grind` tactic.
When applied to an equational theorem, the `[grind_eq]` attribute
instructs the `grind` tactic to automatically use the annotated theorem
to instantiate patterns during proof search. If applied to a function,
it marks all equational theorems associated with that function.

```lean
@[grind_eq]
theorem foo_idempotent : foo (foo x) = foo x := ...

@[grind_eq] def f (a : Nat) :=
  match a with
  | 0 => 10
  | x+1 => g (f x)
```

In the example above, the `grind` tactic will add instances of the
theorem `foo_idempotent` to the local context whenever it encounters the
pattern `foo (foo x)`. Similarly, functions annotated with `[grind_eq]`
will propagate this annotation to their associated equational theorems.
  • Loading branch information
leodemoura authored Jan 5, 2025
1 parent fd091d1 commit 675244d
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Ctor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ import Lean.Meta.Tactic.Grind.Types
namespace Lean.Meta.Grind

private partial def propagateInjEqs (eqs : Expr) (proof : Expr) : GoalM Unit := do
-- Remark: we must use `shareCommon` before using `pushEq` and `pushHEq`.
-- This is needed because the result type of the injection theorem may allocate
match_expr eqs with
| And left right =>
propagateInjEqs left (.proj ``And 0 proof)
propagateInjEqs right (.proj ``And 1 proof)
| Eq _ lhs rhs => pushEq lhs rhs proof
| HEq _ lhs _ rhs => pushHEq lhs rhs proof
| Eq _ lhs rhs =>
pushEq (← shareCommon lhs) (← shareCommon rhs) proof
| HEq _ lhs _ rhs =>
pushHEq (← shareCommon lhs) (← shareCommon rhs) proof
| _ =>
trace[grind.issues] "unexpected injectivity theorem result type{indentExpr eqs}"
return ()
Expand Down
33 changes: 31 additions & 2 deletions src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Lean.Util.FoldConsts
import Lean.Util.CollectFVars
import Lean.Meta.Basic
import Lean.Meta.InferType
import Lean.Meta.Eqns
import Lean.Meta.Tactic.Grind.Util

namespace Lean.Meta.Grind
Expand Down Expand Up @@ -241,7 +242,7 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
else
return mkOffsetPattern e k
let some f := getPatternFn? pattern
| throwError "invalid pattern, (non-forbidden) application expected"
| throwError "invalid pattern, (non-forbidden) application expected{indentExpr pattern}"
assert! f.isConst || f.isFVar
saveSymbol f.toHeadIndex
let mut args := pattern.getAppArgs
Expand Down Expand Up @@ -432,7 +433,11 @@ pattern.
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) : MetaM EMatchTheorem := do
let info ← getConstInfo declName
let (numParams, patterns) ← forallTelescopeReducing info.type fun xs type => do
let_expr Eq _ lhs _ := type | throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
let lhs ← match_expr type with
| Eq _ lhs _ => pure lhs
| Iff lhs _ => pure lhs
| HEq _ lhs _ _ => pure lhs
| _ => throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
let lhs ← preprocessPattern lhs normalizePattern
return (xs.size, [lhs.abstract xs])
mkEMatchTheorem declName numParams patterns
Expand All @@ -455,4 +460,28 @@ def addEMatchEqTheorem (declName : Name) : MetaM Unit := do
def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState (← getEnv)

private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
if (← getConstInfo declName).isTheorem then
ematchTheoremsExt.add (← mkEMatchEqTheorem declName) attrKind
else if let some eqns ← getEqnsFor? declName then
for eqn in eqns do
ematchTheoremsExt.add (← mkEMatchEqTheorem eqn) attrKind
else
throwError "`[grind_eq]` attribute can only be applied to equational theorems or function definitions"

builtin_initialize
registerBuiltinAttribute {
name := `grind_eq
descr :=
"The `[grind_eq]` attribute is used to annotate equational theorems and functions.\
When applied to an equational theorem, it marks the theorem for use in heuristic instantiations by the `grind` tactic.\
When applied to a function, it automatically annotates the equational theorems associated with that function.\
The `grind` tactic utilizes annotated theorems to add instances of matching patterns into the local context during proof search.\
For example, if a theorem `@[grind_eq] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\
`grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`."
applicationTime := .afterCompilation
add := fun declName _ attrKind =>
addGrindEqAttr declName attrKind |>.run' {}
}

end Lean.Meta.Grind
76 changes: 76 additions & 0 deletions tests/lean/run/grind_eq.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
opaque g : Nat → Nat

@[grind_eq] def f (a : Nat) :=
match a with
| 0 => 10
| x+1 => g (f x)

set_option grind.debug true
set_option grind.debug.proofs true

set_option trace.grind.ematch.instance true
set_option trace.grind.assert true

/--
info: [grind.assert] f (y + 1) = a
[grind.assert] ¬a = g (f y)
[grind.ematch.instance] f.eq_2: f y.succ = g (f y)
[grind.assert] f (y + 1) = g (f y)
-/
#guard_msgs (info) in
example : f (y + 1) = a → a = g (f y):= by
grind

@[grind_eq] def app (xs ys : List α) :=
match xs with
| [] => ys
| x::xs => x :: app xs ys

/--
info: [grind.assert] app [1, 2] ys = xs
[grind.assert] ¬xs = 1 :: 2 :: ys
[grind.ematch.instance] app.eq_2: app [1, 2] ys = 1 :: app [2] ys
[grind.assert] app [1, 2] ys = 1 :: app [2] ys
[grind.ematch.instance] app.eq_2: app [2] ys = 2 :: app [] ys
[grind.assert] app [2] ys = 2 :: app [] ys
[grind.ematch.instance] app.eq_1: app [] ys = ys
[grind.assert] app [] ys = ys
-/
#guard_msgs (info) in
example : app [1, 2] ys = xs → xs = 1::2::ys := by
grind

opaque p : Nat → Nat → Prop
opaque q : Nat → Prop

@[grind_eq] theorem pq : p x x ↔ q x := by sorry

/--
info: [grind.assert] p a a
[grind.assert] ¬q a
[grind.ematch.instance] pq: p a a ↔ q a
[grind.assert] p a a = q a
-/
#guard_msgs (info) in
example : p a a → q a := by
grind

opaque appV (xs : Vector α n) (ys : Vector α m) : Vector α (n + m) :=
Vector.append xs ys

@[grind_eq]
theorem appV_assoc (a : Vector α n) (b : Vector α m) (c : Vector α n') :
HEq (appV a (appV b c)) (appV (appV a b) c) := sorry

/--
info: [grind.assert] x1 = appV a b
[grind.assert] x2 = appV x1 c
[grind.assert] x3 = appV b c
[grind.assert] x4 = appV a x3
[grind.assert] ¬HEq x2 x4
[grind.ematch.instance] appV_assoc: HEq (appV a (appV b c)) (appV (appV a b) c)
[grind.assert] HEq (appV a (appV b c)) (appV (appV a b) c)
-/
#guard_msgs (info) in
example : x1 = appV a b → x2 = appV x1 c → x3 = appV b c → x4 = appV a x3 → HEq x2 x4 := by
grind
7 changes: 7 additions & 0 deletions tests/lean/run/grind_pattern1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,10 @@ grind_pattern hThm1 => plus a c
/-- info: [grind.ematch.pattern] hThm1: [plus #2 #1, plus #2 #3] -/
#guard_msgs in
grind_pattern hThm1 => plus a c, plus a b

/--
error: invalid pattern, (non-forbidden) application expected
#4 ∧ #3
-/
#guard_msgs in
grind_pattern And.imp_left => a ∧ b

0 comments on commit 675244d

Please sign in to comment.