Skip to content

Commit

Permalink
feat: add support for match-expressions to grind (#6521)
Browse files Browse the repository at this point in the history
This PR adds support for activating relevant `match`-equations as
E-matching theorems. It uses the `match`-equation lhs as the pattern.
  • Loading branch information
leodemoura authored Jan 4, 2025
1 parent 28a7098 commit ad593b3
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 28 deletions.
3 changes: 2 additions & 1 deletion src/Init/Grind/Norm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.SimpLemmas
import Init.PropLemmas
import Init.Classical
import Init.ByCases

Expand Down Expand Up @@ -64,7 +65,7 @@ attribute [grind_norm] forall_and

-- Exists
@[grind_norm↓] theorem not_exists (p : α → Prop) : (¬∃ x, p x) = ∀ x, ¬p x := by simp
attribute [grind_norm] exists_const exists_or
attribute [grind_norm] exists_const exists_or exists_prop exists_and_left exists_and_right

-- Bool cond
@[grind_norm] theorem cond_eq_ite (c : Bool) (a b : α) : cond c a b = ite c a b := by
Expand Down
2 changes: 2 additions & 0 deletions src/Init/Grind/Tactics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ structure Config where
gen : Nat := 5
/-- Maximum number of theorem instances generated using E-matching in a proof search tree branch. -/
instances : Nat := 1000
/-- If `matchEqs` is `true`, `grind` uses `match`-equations as E-matching theorems. -/
matchEqs : Bool := true
deriving Inhabited, BEq

end Lean.Grind
Expand Down
7 changes: 7 additions & 0 deletions src/Init/Grind/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ namespace Lean.Grind
/-- A helper gadget for annotating nested proofs in goals. -/
def nestedProof (p : Prop) (h : p) : p := h

/--
Gadget for marking terms that should not be normalized by `grind`s simplifier.
`grind` uses a simproc to implement this feature.
We use it when adding instances of `match`-equations to prevent them from being simplified to true.
-/
def doNotSimp {α : Sort u} (a : α) : α := a

set_option pp.proofs true

theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (nestedProof p hp) (nestedProof q hq) := by
Expand Down
35 changes: 35 additions & 0 deletions src/Lean/Meta/Tactic/Grind/DoNotSimp.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Grind.Util
import Init.Simproc
import Lean.Meta.Tactic.Simp.Simproc

namespace Lean.Meta.Grind

/--
Returns `Grind.doNotSimp e`.
Recall that `Grind.doNotSimp` is an identity function, but the following simproc is used to prevent the term `e` from being simplified.
-/
def markAsDoNotSimp (e : Expr) : MetaM Expr :=
mkAppM ``Grind.doNotSimp #[e]

builtin_dsimproc_decl reduceDoNotSimp (Grind.doNotSimp _) := fun e => do
let_expr Grind.doNotSimp _ _ ← e | return .continue
return .done e

/-- Adds `reduceDoNotSimp` to `s` -/
def addDoNotSimp (s : Simprocs) : CoreM Simprocs := do
s.add ``reduceDoNotSimp (post := false)

/-- Erases `Grind.doNotSimp` annotations. -/
def eraseDoNotSimp (e : Expr) : CoreM Expr := do
let pre (e : Expr) := do
let_expr Grind.doNotSimp _ a := e | return .continue e
return .continue a
Core.transform e (pre := pre)

end Lean.Meta.Grind
18 changes: 15 additions & 3 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.DoNotSimp

namespace Lean.Meta.Grind
namespace EMatch
Expand Down Expand Up @@ -146,6 +147,15 @@ private def processContinue (c : Choice) (p : Expr) : M Unit := do
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }

/-- Helper function for marking parts of `match`-equation theorem as "do-not-simplify" -/
private partial def annotateMatchEqnType (prop : Expr) : M Expr := do
if let .forallE n d b bi := prop then
withLocalDecl n bi (← markAsDoNotSimp d) fun x => do
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x))
else
let_expr f@Eq α lhs rhs := prop | return prop
return mkApp3 f α (← markAsDoNotSimp lhs) rhs

/--
Stores new theorem instance in the state.
Recall that new instances are internalized later, after a full round of ematching.
Expand All @@ -154,7 +164,9 @@ private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) :
let proof ← instantiateMVars proof
if grind.debug.proofs.get (← getOptions) then
check proof
let prop ← inferType proof
let mut prop ← inferType proof
if Match.isMatchEqnTheorem (← getEnv) origin.key then
prop ← annotateMatchEqnType prop
trace[grind.ematch.instance] "{← origin.pp}: {prop}"
addTheoremInstance proof prop generation

Expand Down Expand Up @@ -189,10 +201,10 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
unless (← synthesizeInstance mvar type) do
trace[grind.issues] "failed to synthesize instance when instantiating {← thm.origin.pp}{indentExpr type}"
return ()
let proof := mkAppN proof mvars
if (← mvars.allM (·.mvarId!.isAssigned)) then
addNewInstance thm.origin (mkAppN proof mvars) c.gen
addNewInstance thm.origin proof c.gen
else
let proof := mkAppN proof mvars
let mvars ← mvars.filterM fun mvar => return !(← mvar.mvarId!.isAssigned)
if let some mvarBad ← mvars.findM? fun mvar => return !(← isProof mvar) then
trace[grind.issues] "failed to instantiate {← thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}"
Expand Down
88 changes: 74 additions & 14 deletions src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,47 @@ structure EMatchTheorem where
origin : Origin
deriving Inhabited

/-- The key is a symbol from `EMatchTheorem.symbols`. -/
abbrev EMatchTheorems := PHashMap Name (List EMatchTheorem)
/-- Set of E-matching theorems. -/
structure EMatchTheorems where
/-- The key is a symbol from `EMatchTheorem.symbols`. -/
private map : PHashMap Name (List EMatchTheorem) := {}
/-- Set of theorem names that have been inserted using `insert`. -/
private thmNames : PHashSet Name := {}
deriving Inhabited

/--
Inserts a `thm` with symbols `[s_1, ..., s_n]` to `s`.
We add `s_1 -> { thm with symbols := [s_2, ..., s_n] }`.
When `grind` internalizes a term containing symbol `s`, we
process all theorems `thm` associated with key `s`.
If their `thm.symbols` is empty, we say they are activated.
Otherwise, we reinsert into `map`.
-/
def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchTheorems := Id.run do
let .const declName :: syms := thm.symbols
| unreachable!
let thm := { thm with symbols := syms }
if let some thms := s.find? declName then
return PersistentHashMap.insert s declName (thm::thms)
let { map, thmNames } := s
let thmNames := thmNames.insert thm.origin.key
if let some thms := map.find? declName then
return { map := map.insert declName (thm::thms), thmNames }
else
return PersistentHashMap.insert s declName [thm]
return { map := map.insert declName [thm], thmNames }

/--
Retrieves theorems from `s` associated with the given symbol. See `EMatchTheorem.insert`.
The theorems are removed from `s`.
-/
@[inline]
def EMatchTheorems.retrieve? (s : EMatchTheorems) (sym : Name) : Option (List EMatchTheorem × EMatchTheorems) :=
if let some thms := s.map.find? sym then
some (thms, { s with map := s.map.erase sym })
else
none

/-- Returns `true` if `declName` is the name of a theorem that was inserted using `insert`. -/
def EMatchTheorems.containsTheoremName (s : EMatchTheorems) (declName : Name) : Bool :=
s.thmNames.contains declName

def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
if thm.proof.isConst && thm.levelParams.isEmpty then
Expand All @@ -85,7 +115,7 @@ def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr
private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem EMatchTheorems ←
registerSimpleScopedEnvExtension {
addEntry := EMatchTheorems.insert
initial := .empty
initial := {}
}

-- TODO: create attribute?
Expand Down Expand Up @@ -320,8 +350,8 @@ private def checkCoverage (thmProof : Expr) (numParams : Nat) (bvarsFound : Std.
Given a theorem with proof `proof` and `numParams` parameters, returns a message
containing the parameters at positions `paramPos`.
-/
private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : MetaM MessageData := do
forallBoundedTelescope (← inferType proof) numParms fun xs _ => do
private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) : MetaM MessageData := do
forallBoundedTelescope (← inferType proof) numParams fun xs _ => do
let mut msg := m!""
let mut first := true
for h : i in [:xs.size] do
Expand All @@ -331,23 +361,53 @@ private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : M
msg := msg ++ m!"{x} : {← inferType x}"
addMessageContextFull msg

def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
/--
Creates an E-matching theorem for `declName` with `numParams` parameters, and the given set of patterns.
Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM EMatchTheorem := do
let .thmInfo info ← getConstInfo declName
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
let us := info.levelParams.map mkLevelParam
let proof := mkConst declName us
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
assert! symbols.all fun s => s matches .const _
trace[grind.ematch.pattern] "{declName}: {patterns.map ppPattern}"
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
ematchTheoremsExt.add {
proof, patterns, numParams, symbols
levelParams := #[]
origin := .decl declName
return {
proof, patterns, numParams, symbols
levelParams := #[]
origin := .decl declName
}

/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
-/
def mkEMatchEqTheorem (declName : Name) : MetaM EMatchTheorem := do
let info ← getConstInfo declName
let (numParams, patterns) ← forallTelescopeReducing info.type fun xs type => do
let_expr Eq _ lhs _ := type | throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
return (xs.size, [lhs.abstract xs])
mkEMatchTheorem declName numParams patterns

/--
Adds an E-matching theorem to the environment.
See `mkEMatchTheorem`.
-/
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
ematchTheoremsExt.add (← mkEMatchTheorem declName numParams patterns)

/--
Adds an E-matching equality theorem to the environment.
See `mkEMatchEqTheorem`.
-/
def addEMatchEqTheorem (declName : Name) : MetaM Unit := do
ematchTheoremsExt.add (← mkEMatchEqTheorem declName)

/-- Returns the E-matching theorems registered in the environment. -/
def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState (← getEnv)

Expand Down
36 changes: 27 additions & 9 deletions src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Authors: Leonardo de Moura
prelude
import Init.Grind.Util
import Lean.Meta.LitValues
import Lean.Meta.Match.MatcherInfo
import Lean.Meta.Match.MatchEqsExt
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util

Expand Down Expand Up @@ -50,21 +52,36 @@ private partial def internalizePattern (pattern : Expr) (generation : Nat) : Goa
else pattern.withApp fun f args => do
return mkAppN f (← args.mapM (internalizePattern · generation))

private partial def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do
-- Recall that we use the proof as part of the key for a set of instances found so far.
-- We don't want to use structural equality when comparing keys.
let proof ← shareCommon thm.proof
let thm := { thm with proof, patterns := (← thm.patterns.mapM (internalizePattern · generation)) }
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
modify fun s => { s with newThms := s.newThms.push thm }

/--
If `Config.matchEqs` is set to `true`, and `f` is `match`-auxiliary function,
adds its equations to `newThms`.
-/
private partial def addMatchEqns (f : Expr) (generation : Nat) : GoalM Unit := do
if !(← getConfig).matchEqs then return ()
let .const declName _ := f | return ()
if !(← isMatcher declName) then return ()
if (← get).matchEqNames.contains declName then return ()
modify fun s => { s with matchEqNames := s.matchEqNames.insert declName }
for eqn in (← Match.getEquationsFor declName).eqnNames do
activateTheorem (← mkEMatchEqTheorem eqn) generation

private partial def activateTheoremPatterns (fName : Name) (generation : Nat) : GoalM Unit := do
if let some thms := (← get).thmMap.find? fName then
modify fun s => { s with thmMap := s.thmMap.erase fName }
if let some (thms, thmMap) := (← get).thmMap.retrieve? fName then
modify fun s => { s with thmMap }
let appMap := (← get).appMap
for thm in thms do
let symbols := thm.symbols.filter fun sym => !appMap.contains sym
let thm := { thm with symbols }
match symbols with
| [] =>
-- Recall that we use the proof as part of the key for a set of instances found so far.
-- We don't want to use structural equality when comparing keys.
let proof ← shareCommon thm.proof
let thm := { thm with proof, patterns := (← thm.patterns.mapM (internalizePattern · generation)) }
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
modify fun s => { s with newThms := s.newThms.push thm }
| [] => activateTheorem thm generation
| _ =>
trace[grind.ematch] "reinsert `{thm.origin.key}`"
modify fun s => { s with thmMap := s.thmMap.insert thm }
Expand Down Expand Up @@ -95,6 +112,7 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
-- We do not want to internalize the components of a literal value.
mkENode e generation
else e.withApp fun f args => do
addMatchEqns f generation
if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then
-- We only internalize the proposition. We can skip the proof because of
-- proof irrelevance
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.DoNotSimp

namespace Lean.Meta.Grind

Expand All @@ -38,7 +39,7 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fa
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
let thms ← grindNormExt.getTheorems
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
let simprocs := #[(← addDoNotSimp (← grindNormSimprocExt.getSimprocs))]
let simp ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
Expand Down
2 changes: 2 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Lean.Meta.Tactic.Assert
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.DoNotSimp
import Lean.Meta.Tactic.Grind.MarkNestedProofs

namespace Lean.Meta.Grind
Expand All @@ -33,6 +34,7 @@ def simp (e : Expr) : GrindM Simp.Result := do
let e' ← eraseIrrelevantMData e'
let e' ← foldProjs e'
let e' ← normalizeLevels e'
let e' ← eraseDoNotSimp e'
let e' ← canon e'
let e' ← shareCommon e'
trace[grind.simp] "{e}\n===>\n{e'}"
Expand Down
2 changes: 2 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ structure Goal where
preInstances : PreInstanceSet := {}
/-- new facts to be processed. -/
newFacts : Std.Queue NewFact := ∅
/-- `match` auxiliary functions whose equations have already been created and activated. -/
matchEqNames : PHashSet Name := {}
deriving Inhabited

def Goal.admit (goal : Goal) : MetaM Unit :=
Expand Down
Loading

0 comments on commit ad593b3

Please sign in to comment.