From 00ef231a6e03398c2ad3b577ab036f901ec88543 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 7 Jan 2025 19:10:11 -0800 Subject: [PATCH] feat: split on `match`-expressions in the `grind` tactic (#6569) This PR adds support for case splitting on `match`-expressions in `grind`. We still need to add support for resolving the antecedents of `match`-conditional equations. --- src/Init/Grind/Util.lean | 7 +++ src/Lean/Meta/Tactic/Cases.lean | 2 +- src/Lean/Meta/Tactic/Grind.lean | 3 +- src/Lean/Meta/Tactic/Grind/CasesMatch.lean | 53 ++++++++++++++++++++++ src/Lean/Meta/Tactic/Grind/EMatch.lean | 16 ++++--- src/Lean/Meta/Tactic/Grind/Propagate.lean | 7 +++ src/Lean/Meta/Tactic/Grind/Split.lean | 13 ++++-- tests/lean/run/grind_match1.lean | 11 +++-- tests/lean/run/grind_match2.lean | 28 ++++++++++++ 9 files changed, 121 insertions(+), 19 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/CasesMatch.lean create mode 100644 tests/lean/run/grind_match2.lean diff --git a/src/Init/Grind/Util.lean b/src/Init/Grind/Util.lean index 9ef005388f..f37dcc4248 100644 --- a/src/Init/Grind/Util.lean +++ b/src/Init/Grind/Util.lean @@ -21,6 +21,13 @@ def doNotSimp {α : Sort u} (a : α) : α := a /-- Gadget for representing offsets `t+k` in patterns. -/ def offset (a b : Nat) : Nat := a + b +/-- +Gadget for annotating the equalities in `match`-equations conclusions. +`_origin` is the term used to instantiate the `match`-equation using E-matching. +When `EqMatch a b origin` is `True`, we mark `origin` as a resolved case-split. +-/ +def EqMatch (a b : α) {_origin : α} : Prop := a = b + theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (nestedProof p hp) (nestedProof q hq) := by subst h; apply HEq.refl diff --git a/src/Lean/Meta/Tactic/Cases.lean b/src/Lean/Meta/Tactic/Cases.lean index 56e8b8514a..7e4e7cf34f 100644 --- a/src/Lean/Meta/Tactic/Cases.lean +++ b/src/Lean/Meta/Tactic/Cases.lean @@ -33,7 +33,7 @@ private def mkEqAndProof (lhs rhs : Expr) : MetaM (Expr × Expr) := do else pure (mkApp4 (mkConst ``HEq [u]) lhsType lhs rhsType rhs, mkApp2 (mkConst ``HEq.refl [u]) lhsType lhs) -private partial def withNewEqs (targets targetsNew : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α := +partial def withNewEqs (targets targetsNew : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α := let rec loop (i : Nat) (newEqs : Array Expr) (newRefls : Array Expr) := do if i < targets.size then let (newEqType, newRefl) ← mkEqAndProof targets[i]! targetsNew[i]! diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 593bb147c6..80edeeef6c 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -23,7 +23,7 @@ import Lean.Meta.Tactic.Grind.Parser import Lean.Meta.Tactic.Grind.EMatchTheorem import Lean.Meta.Tactic.Grind.EMatch import Lean.Meta.Tactic.Grind.Main - +import Lean.Meta.Tactic.Grind.CasesMatch namespace Lean @@ -52,5 +52,6 @@ builtin_initialize registerTraceClass `grind.debug.proj builtin_initialize registerTraceClass `grind.debug.parent builtin_initialize registerTraceClass `grind.debug.final builtin_initialize registerTraceClass `grind.debug.forallPropagator +builtin_initialize registerTraceClass `grind.debug.split end Lean diff --git a/src/Lean/Meta/Tactic/Grind/CasesMatch.lean b/src/Lean/Meta/Tactic/Grind/CasesMatch.lean new file mode 100644 index 0000000000..1e5f07ed6a --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/CasesMatch.lean @@ -0,0 +1,53 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Util +import Lean.Meta.Tactic.Cases +import Lean.Meta.Match.MatcherApp + +namespace Lean.Meta.Grind + +def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do + let some app ← matchMatcherApp? e + | throwTacticEx `grind.casesMatch mvarId m!"`match`-expression expected{indentExpr e}" + let (motive, eqRefls) ← mkMotiveAndRefls app + let target ← mvarId.getType + let mut us := app.matcherLevels + if let some i := app.uElimPos? then + us := us.set! i (← getLevel target) + let splitterName := (← Match.getEquationsFor app.matcherName).splitterName + let splitterApp := mkConst splitterName us.toList + let splitterApp := mkAppN splitterApp app.params + let splitterApp := mkApp splitterApp motive + let splitterApp := mkAppN splitterApp app.discrs + let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType splitterApp) app.alts.size (kind := .syntheticOpaque) + let splitterApp := mkAppN splitterApp mvars + let val := mkAppN splitterApp eqRefls + mvarId.assign val + updateTags mvars + return mvars.toList.map (·.mvarId!) +where + mkMotiveAndRefls (app : MatcherApp) : MetaM (Expr × Array Expr) := do + let dummy := mkSort 0 + let aux := mkApp (mkAppN e.getAppFn app.params) dummy + forallBoundedTelescope (← inferType aux) app.discrs.size fun xs _ => do + withNewEqs app.discrs xs fun eqs eqRefls => do + let type ← mvarId.getType + let type ← mkForallFVars eqs type + let motive ← mkLambdaFVars xs type + return (motive, eqRefls) + + updateTags (mvars : Array Expr) : MetaM Unit := do + let tag ← mvarId.getTag + if mvars.size == 1 then + mvars[0]!.mvarId!.setTag tag + else + let mut idx := 1 + for mvar in mvars do + mvar.mvarId!.setTag (Name.num tag idx) + idx := idx + 1 + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index bce265b738..fccb34b22f 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -199,14 +199,18 @@ private def processContinue (c : Choice) (p : Expr) : M Unit := do let c := { c with gen := Nat.max gen c.gen } modify fun s => { s with choiceStack := c :: s.choiceStack } -/-- Helper function for marking parts of `match`-equation theorem as "do-not-simplify" -/ -private partial def annotateMatchEqnType (prop : Expr) : M Expr := do +/-- +Helper function for marking parts of `match`-equation theorem as "do-not-simplify" +`initApp` is the match-expression used to instantiate the `match`-equation. +-/ +private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do if let .forallE n d b bi := prop then withLocalDecl n bi (← markAsDoNotSimp d) fun x => do - mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x)) + mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp) else let_expr f@Eq α lhs rhs := prop | return prop - return mkApp3 f α (← markAsDoNotSimp lhs) rhs + -- See comment at `Grind.EqMatch` + return mkApp4 (mkConst ``Grind.EqMatch f.constLevels!) α (← markAsDoNotSimp lhs) rhs initApp /-- Stores new theorem instance in the state. @@ -218,9 +222,7 @@ private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) : check proof let mut prop ← inferType proof if Match.isMatchEqnTheorem (← getEnv) origin.key then - -- `initApp` is a match-application that we don't need to split at anymore. - markCaseSplitAsResolved (← read).initApp - prop ← annotateMatchEqnType prop + prop ← annotateMatchEqnType prop (← read).initApp trace_goal[grind.ematch.instance] "{← origin.pp}: {prop}" addTheoremInstance proof prop (generation+1) diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index 42d1135a14..bfe1a3d014 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -134,6 +134,13 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do let_expr Eq _ a b := e | return () pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e) +/-- Propagates `EqMatch` downwards -/ +builtin_grind_propagator propagateEqMatchDown ↓Grind.EqMatch := fun e => do + if (← isEqTrue e) then + let_expr Grind.EqMatch _ a b origin := e | return () + markCaseSplitAsResolved origin + pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e) + /-- Propagates `HEq` downwards -/ builtin_grind_propagator propagateHEqDown ↓HEq := fun e => do if (← isEqTrue e) then diff --git a/src/Lean/Meta/Tactic/Grind/Split.lean b/src/Lean/Meta/Tactic/Grind/Split.lean index 96ae7ce416..e57a5270e4 100644 --- a/src/Lean/Meta/Tactic/Grind/Split.lean +++ b/src/Lean/Meta/Tactic/Grind/Split.lean @@ -7,6 +7,7 @@ prelude import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Intro import Lean.Meta.Tactic.Grind.Cases +import Lean.Meta.Tactic.Grind.CasesMatch namespace Lean.Meta.Grind @@ -50,10 +51,10 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do return .ready | _ => if (← isResolvedCaseSplit e) then + trace[grind.debug.split] "split resolved: {e}" return .resolved if (← isMatcherApp e) then - return .notReady -- TODO: implement splitters for `match` - -- return .ready + return .ready let .const declName .. := e.getAppFn | unreachable! if (← isInductivePredicate declName <&&> isEqTrue e) then return .ready @@ -111,9 +112,11 @@ def splitNext : GrindTactic := fun goal => do | return none let gen ← getGeneration c trace_goal[grind.split] "{c}, generation: {gen}" - -- TODO: `match` - let major ← mkCasesMajor c - let mvarIds ← cases (← get).mvarId major + let mvarIds ← if (← isMatcherApp c) then + casesMatch (← get).mvarId c + else + let major ← mkCasesMajor c + cases (← get).mvarId major let goal ← get let goals := mvarIds.map fun mvarId => { goal with mvarId } let goals ← introNewHyp goals [] (gen+1) diff --git a/tests/lean/run/grind_match1.lean b/tests/lean/run/grind_match1.lean index c50bfcf6e6..303b9f7172 100644 --- a/tests/lean/run/grind_match1.lean +++ b/tests/lean/run/grind_match1.lean @@ -24,15 +24,16 @@ info: [grind.assert] (match as, bs with [grind.assert] a₁ :: f 0 = as [grind.assert] f 0 = a₂ :: f 1 [grind.assert] ¬d = [] +[grind.assert] Lean.Grind.EqMatch + (match a₁ :: a₂ :: f 1, [] with + | [], x => bs + | head :: head_1 :: tail, [] => [] + | x :: xs, ys => x :: g xs ys) + [] [grind.split.resolved] match as, bs with | [], x => bs | head :: head_1 :: tail, [] => [] | x :: xs, ys => x :: g xs ys -[grind.assert] (match a₁ :: a₂ :: f 1, [] with - | [], x => bs - | head :: head_1 :: tail, [] => [] - | x :: xs, ys => x :: g xs ys) = - [] -/ #guard_msgs (info) in example (f : Nat → List Nat) : g as bs = d → bs = [] → a₁ :: f 0 = as → f 0 = a₂ :: f 1 → d = [] := by diff --git a/tests/lean/run/grind_match2.lean b/tests/lean/run/grind_match2.lean new file mode 100644 index 0000000000..24160dd195 --- /dev/null +++ b/tests/lean/run/grind_match2.lean @@ -0,0 +1,28 @@ +def g (a : α) (as : List α) : List α := + match as with + | [] => [a] + | b::bs => a::a::b::bs + +set_option trace.grind true in +set_option trace.grind.assert true in +example : ¬ (g a as).isEmpty := by + unfold List.isEmpty + unfold g + grind + +def h (as : List Nat) := + match as with + | [] => 1 + | [_] => 2 + | _::_::_ => 3 + +/-- +info: [grind] closed `grind.1` +[grind] closed `grind.2` +[grind] closed `grind.3` +-/ +#guard_msgs (info) in +set_option trace.grind true in +example : h as ≠ 0 := by + unfold h + grind