Skip to content

Commit

Permalink
feat: split on match-expressions in the grind tactic (#6569)
Browse files Browse the repository at this point in the history
This PR adds support for case splitting on `match`-expressions in
`grind`.
We still need to add support for resolving the antecedents of
`match`-conditional equations.
  • Loading branch information
leodemoura authored Jan 8, 2025
1 parent 9040108 commit 00ef231
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 19 deletions.
7 changes: 7 additions & 0 deletions src/Init/Grind/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def doNotSimp {α : Sort u} (a : α) : α := a
/-- Gadget for representing offsets `t+k` in patterns. -/
def offset (a b : Nat) : Nat := a + b

/--
Gadget for annotating the equalities in `match`-equations conclusions.
`_origin` is the term used to instantiate the `match`-equation using E-matching.
When `EqMatch a b origin` is `True`, we mark `origin` as a resolved case-split.
-/
def EqMatch (a b : α) {_origin : α} : Prop := a = b

theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (nestedProof p hp) (nestedProof q hq) := by
subst h; apply HEq.refl

Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Cases.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private def mkEqAndProof (lhs rhs : Expr) : MetaM (Expr × Expr) := do
else
pure (mkApp4 (mkConst ``HEq [u]) lhsType lhs rhsType rhs, mkApp2 (mkConst ``HEq.refl [u]) lhsType lhs)

private partial def withNewEqs (targets targetsNew : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α :=
partial def withNewEqs (targets targetsNew : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α :=
let rec loop (i : Nat) (newEqs : Array Expr) (newRefls : Array Expr) := do
if i < targets.size then
let (newEqType, newRefl) ← mkEqAndProof targets[i]! targetsNew[i]!
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Lean.Meta.Tactic.Grind.Parser
import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.Main

import Lean.Meta.Tactic.Grind.CasesMatch

namespace Lean

Expand Down Expand Up @@ -52,5 +52,6 @@ builtin_initialize registerTraceClass `grind.debug.proj
builtin_initialize registerTraceClass `grind.debug.parent
builtin_initialize registerTraceClass `grind.debug.final
builtin_initialize registerTraceClass `grind.debug.forallPropagator
builtin_initialize registerTraceClass `grind.debug.split

end Lean
53 changes: 53 additions & 0 deletions src/Lean/Meta/Tactic/Grind/CasesMatch.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Cases
import Lean.Meta.Match.MatcherApp

namespace Lean.Meta.Grind

def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
let some app ← matchMatcherApp? e
| throwTacticEx `grind.casesMatch mvarId m!"`match`-expression expected{indentExpr e}"
let (motive, eqRefls) ← mkMotiveAndRefls app
let target ← mvarId.getType
let mut us := app.matcherLevels
if let some i := app.uElimPos? then
us := us.set! i (← getLevel target)
let splitterName := (← Match.getEquationsFor app.matcherName).splitterName
let splitterApp := mkConst splitterName us.toList
let splitterApp := mkAppN splitterApp app.params
let splitterApp := mkApp splitterApp motive
let splitterApp := mkAppN splitterApp app.discrs
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType splitterApp) app.alts.size (kind := .syntheticOpaque)
let splitterApp := mkAppN splitterApp mvars
let val := mkAppN splitterApp eqRefls
mvarId.assign val
updateTags mvars
return mvars.toList.map (·.mvarId!)
where
mkMotiveAndRefls (app : MatcherApp) : MetaM (Expr × Array Expr) := do
let dummy := mkSort 0
let aux := mkApp (mkAppN e.getAppFn app.params) dummy
forallBoundedTelescope (← inferType aux) app.discrs.size fun xs _ => do
withNewEqs app.discrs xs fun eqs eqRefls => do
let type ← mvarId.getType
let type ← mkForallFVars eqs type
let motive ← mkLambdaFVars xs type
return (motive, eqRefls)

updateTags (mvars : Array Expr) : MetaM Unit := do
let tag ← mvarId.getTag
if mvars.size == 1 then
mvars[0]!.mvarId!.setTag tag
else
let mut idx := 1
for mvar in mvars do
mvar.mvarId!.setTag (Name.num tag idx)
idx := idx + 1

end Lean.Meta.Grind
16 changes: 9 additions & 7 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,18 @@ private def processContinue (c : Choice) (p : Expr) : M Unit := do
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }

/-- Helper function for marking parts of `match`-equation theorem as "do-not-simplify" -/
private partial def annotateMatchEqnType (prop : Expr) : M Expr := do
/--
Helper function for marking parts of `match`-equation theorem as "do-not-simplify"
`initApp` is the match-expression used to instantiate the `match`-equation.
-/
private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do
if let .forallE n d b bi := prop then
withLocalDecl n bi (← markAsDoNotSimp d) fun x => do
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x))
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp)
else
let_expr f@Eq α lhs rhs := prop | return prop
return mkApp3 f α (← markAsDoNotSimp lhs) rhs
-- See comment at `Grind.EqMatch`
return mkApp4 (mkConst ``Grind.EqMatch f.constLevels!) α (← markAsDoNotSimp lhs) rhs initApp

/--
Stores new theorem instance in the state.
Expand All @@ -218,9 +222,7 @@ private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) :
check proof
let mut prop ← inferType proof
if Match.isMatchEqnTheorem (← getEnv) origin.key then
-- `initApp` is a match-application that we don't need to split at anymore.
markCaseSplitAsResolved (← read).initApp
prop ← annotateMatchEqnType prop
prop ← annotateMatchEqnType prop (← read).initApp
trace_goal[grind.ematch.instance] "{← origin.pp}: {prop}"
addTheoremInstance proof prop (generation+1)

Expand Down
7 changes: 7 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Propagate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
let_expr Eq _ a b := e | return ()
pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e)

/-- Propagates `EqMatch` downwards -/
builtin_grind_propagator propagateEqMatchDown ↓Grind.EqMatch := fun e => do
if (← isEqTrue e) then
let_expr Grind.EqMatch _ a b origin := e | return ()
markCaseSplitAsResolved origin
pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e)

/-- Propagates `HEq` downwards -/
builtin_grind_propagator propagateHEqDown ↓HEq := fun e => do
if (← isEqTrue e) then
Expand Down
13 changes: 8 additions & 5 deletions src/Lean/Meta/Tactic/Grind/Split.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.CasesMatch

namespace Lean.Meta.Grind

Expand Down Expand Up @@ -50,10 +51,10 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
return .ready
| _ =>
if (← isResolvedCaseSplit e) then
trace[grind.debug.split] "split resolved: {e}"
return .resolved
if (← isMatcherApp e) then
return .notReady -- TODO: implement splitters for `match`
-- return .ready
return .ready
let .const declName .. := e.getAppFn | unreachable!
if (← isInductivePredicate declName <&&> isEqTrue e) then
return .ready
Expand Down Expand Up @@ -111,9 +112,11 @@ def splitNext : GrindTactic := fun goal => do
| return none
let gen ← getGeneration c
trace_goal[grind.split] "{c}, generation: {gen}"
-- TODO: `match`
let major ← mkCasesMajor c
let mvarIds ← cases (← get).mvarId major
let mvarIds ← if (← isMatcherApp c) then
casesMatch (← get).mvarId c
else
let major ← mkCasesMajor c
cases (← get).mvarId major
let goal ← get
let goals := mvarIds.map fun mvarId => { goal with mvarId }
let goals ← introNewHyp goals [] (gen+1)
Expand Down
11 changes: 6 additions & 5 deletions tests/lean/run/grind_match1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ info: [grind.assert] (match as, bs with
[grind.assert] a₁ :: f 0 = as
[grind.assert] f 0 = a₂ :: f 1
[grind.assert] ¬d = []
[grind.assert] Lean.Grind.EqMatch
(match a₁ :: a₂ :: f 1, [] with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys)
[]
[grind.split.resolved] match as, bs with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys
[grind.assert] (match a₁ :: a₂ :: f 1, [] with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys) =
[]
-/
#guard_msgs (info) in
example (f : Nat → List Nat) : g as bs = d → bs = [] → a₁ :: f 0 = as → f 0 = a₂ :: f 1 → d = [] := by
Expand Down
28 changes: 28 additions & 0 deletions tests/lean/run/grind_match2.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
def g (a : α) (as : List α) : List α :=
match as with
| [] => [a]
| b::bs => a::a::b::bs

set_option trace.grind true in
set_option trace.grind.assert true in
example : ¬ (g a as).isEmpty := by
unfold List.isEmpty
unfold g
grind

def h (as : List Nat) :=
match as with
| [] => 1
| [_] => 2
| _::_::_ => 3

/--
info: [grind] closed `grind.1`
[grind] closed `grind.2`
[grind] closed `grind.3`
-/
#guard_msgs (info) in
set_option trace.grind true in
example : h as ≠ 0 := by
unfold h
grind

0 comments on commit 00ef231

Please sign in to comment.