Skip to content

Commit

Permalink
feat: improve cases tactic used in grind (#6516)
Browse files Browse the repository at this point in the history
This PR enhances the `cases` tactic used in the `grind` tactic and
ensures that it can be applied to arbitrary expressions.
  • Loading branch information
leodemoura authored Jan 3, 2025
1 parent 10b2f6b commit 7b496bf
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 40 deletions.
71 changes: 48 additions & 23 deletions src/Lean/Meta/Tactic/Cases.lean
Original file line number Diff line number Diff line change
Expand Up @@ -66,30 +66,31 @@ structure GeneralizeIndicesSubgoal where
numEqs : Nat

/--
Similar to `generalizeTargets` but customized for the `casesOn` motive.
Given a metavariable `mvarId` representing the
```
Ctx, h : I A j, D |- T
```
where `fvarId` is `h`s id, and the type `I A j` is an inductive datatype where `A` are parameters,
and `j` the indices. Generate the goal
```
Ctx, h : I A j, D, j' : J, h' : I A j' |- j == j' -> h == h' -> T
```
Remark: `(j == j' -> h == h')` is a "telescopic" equality.
Remark: `j` is sequence of terms, and `j'` a sequence of free variables.
The result contains the fields
- `mvarId`: the new goal
- `indicesFVarIds`: `j'` ids
- `fvarId`: `h'` id
- `numEqs`: number of equations in the target -/
def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndicesSubgoal :=
Given a metavariable `mvarId` representing the goal
```
Ctx |- T
```
and an expression `e : I A j`, where `I A j` is an inductive datatype where `A` are parameters,
and `j` the indices. Generate the goal
```
Ctx, j' : J, h' : I A j' |- j == j' -> e == h' -> T
```
Remark: `(j == j' -> e == h')` is a "telescopic" equality.
Remark: `j` is sequence of terms, and `j'` a sequence of free variables.
The result contains the fields
- `mvarId`: the new goal
- `indicesFVarIds`: `j'` ids
- `fvarId`: `h'` id
- `numEqs`: number of equations in the target
If `varName?` is not none, it is used to name `h'`.
-/
def generalizeIndices' (mvarId : MVarId) (e : Expr) (varName? : Option Name := none) : MetaM GeneralizeIndicesSubgoal :=
mvarId.withContext do
let lctx ← getLCtx
let localInsts ← getLocalInstances
mvarId.checkNotAssigned `generalizeIndices
let fvarDecl ← fvarId.getDecl
let type ← whnf fvarDecl.type
let type ← whnfD (← inferType e)
type.withApp fun f args => matchConstInduct f (fun _ => throwTacticEx `generalizeIndices mvarId "inductive type expected") fun val _ => do
unless val.numIndices > 0 do throwTacticEx `generalizeIndices mvarId "indexed inductive type expected"
unless args.size == val.numIndices + val.numParams do throwTacticEx `generalizeIndices mvarId "ill-formed inductive datatype"
Expand All @@ -98,9 +99,10 @@ def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndi
let IAType ← inferType IA
forallTelescopeReducing IAType fun newIndices _ => do
let newType := mkAppN IA newIndices
withLocalDeclD fvarDecl.userName newType fun h' =>
let varName ← if let some varName := varName? then pure varName else mkFreshUserName `x
withLocalDeclD varName newType fun h' =>
withNewEqs indices newIndices fun newEqs newRefls => do
let (newEqType, newRefl) ← mkEqAndProof fvarDecl.toExpr h'
let (newEqType, newRefl) ← mkEqAndProof e h'
let newRefls := newRefls.push newRefl
withLocalDeclD `h newEqType fun newEq => do
let newEqs := newEqs.push newEq
Expand All @@ -112,7 +114,7 @@ def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndi
let auxType ← mkForallFVars newIndices auxType
let newMVar ← mkFreshExprMVarAt lctx localInsts auxType MetavarKind.syntheticOpaque tag
/- assign mvarId := newMVar indices h refls -/
mvarId.assign (mkAppN (mkApp (mkAppN newMVar indices) fvarDecl.toExpr) newRefls)
mvarId.assign (mkAppN (mkApp (mkAppN newMVar indices) e) newRefls)
let (indicesFVarIds, newMVarId) ← newMVar.mvarId!.introNP newIndices.size
let (fvarId, newMVarId) ← newMVarId.intro1P
return {
Expand All @@ -122,6 +124,29 @@ def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndi
numEqs := newEqs.size
}

/--
Similar to `generalizeTargets` but customized for the `casesOn` motive.
Given a metavariable `mvarId` representing the
```
Ctx, h : I A j, D |- T
```
where `fvarId` is `h`s id, and the type `I A j` is an inductive datatype where `A` are parameters,
and `j` the indices. Generate the goal
```
Ctx, h : I A j, D, j' : J, h' : I A j' |- j == j' -> h == h' -> T
```
Remark: `(j == j' -> h == h')` is a "telescopic" equality.
Remark: `j` is sequence of terms, and `j'` a sequence of free variables.
The result contains the fields
- `mvarId`: the new goal
- `indicesFVarIds`: `j'` ids
- `fvarId`: `h'` id
- `numEqs`: number of equations in the target -/
def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndicesSubgoal :=
mvarId.withContext do
let fvarDecl ← fvarId.getDecl
generalizeIndices' mvarId fvarDecl.toExpr fvarDecl.userName

structure CasesSubgoal extends InductionSubgoal where
ctorName : Name

Expand Down
33 changes: 17 additions & 16 deletions src/Lean/Meta/Tactic/Grind/Cases.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,29 @@ namespace Lean.Meta.Grind
The `grind` tactic includes an auxiliary `cases` tactic that is not intended for direct use by users.
This method implements it.
This tactic is automatically applied when introducing local declarations with a type tagged with `[grind_cases]`.
It is also used for "case-splitting" on terms during the search.
It differs from the user-facing Lean `cases` tactic in the following ways:
- It avoids unnecessary `revert` and `intro` operations.
- It does not introduce new local declarations for each minor premise. Instead, the `grind` tactic preprocessor is responsible for introducing them.
- It assumes that the major premise (i.e., the parameter `fvarId`) is the latest local declaration in the current goal.
- If the major premise type is an indexed family, auxiliary declarations and (heterogeneous) equalities are introduced.
However, these equalities are not resolved using `unifyEqs`. Instead, the `grind` tactic employs union-find and
congruence closure to process these auxiliary equalities. This approach avoids applying substitution to propositions
that have already been internalized by `grind`.
-/
def cases (mvarId : MVarId) (fvarId : FVarId) : MetaM (List MVarId) := mvarId.withContext do
def cases (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
let tag ← mvarId.getTag
let type ← whnf (← fvarId.getType)
let type ← whnf (← inferType e)
let .const declName _ := type.getAppFn | throwInductiveExpected type
let .inductInfo _ ← getConstInfo declName | throwInductiveExpected type
let recursorInfo ← mkRecursorInfo (mkCasesOnName declName)
let k (mvarId : MVarId) (fvarId : FVarId) (indices : Array Expr) (clearMajor : Bool) : MetaM (List MVarId) := do
let recursor ← mkRecursorAppPrefix mvarId `grind.cases fvarId recursorInfo indices
let mut recursor := mkApp (mkAppN recursor indices) (mkFVar fvarId)
let k (mvarId : MVarId) (fvarId : FVarId) (indices : Array FVarId) : MetaM (List MVarId) := do
let indicesExpr := indices.map mkFVar
let recursor ← mkRecursorAppPrefix mvarId `grind.cases fvarId recursorInfo indicesExpr
let mut recursor := mkApp (mkAppN recursor indicesExpr) (mkFVar fvarId)
let mut recursorType ← inferType recursor
let mut mvarIdsNew := #[]
for _ in [:recursorInfo.numMinors] do
Expand All @@ -41,22 +42,22 @@ def cases (mvarId : MVarId) (fvarId : FVarId) : MetaM (List MVarId) := mvarId.wi
recursorType := recursorTypeNew
let mvar ← mkFreshExprSyntheticOpaqueMVar targetNew tag
recursor := mkApp recursor mvar
let mvarIdNew ← if clearMajor then
mvar.mvarId!.clear fvarId
else
pure mvar.mvarId!
let mvarIdNew ← mvar.mvarId!.tryClearMany (indices.push fvarId)
mvarIdsNew := mvarIdsNew.push mvarIdNew
mvarId.assign recursor
return mvarIdsNew.toList
if recursorInfo.numIndices > 0 then
let s ← generalizeIndices mvarId fvarId
let s ← generalizeIndices' mvarId e
s.mvarId.withContext do
k s.mvarId s.fvarId (s.indicesFVarIds.map mkFVar) (clearMajor := false)
k s.mvarId s.fvarId s.indicesFVarIds
else if let .fvar fvarId := e then
k mvarId fvarId #[]
else
let indices ← getMajorTypeIndices mvarId `grind.cases recursorInfo type
k mvarId fvarId indices (clearMajor := true)
let mvarId ← mvarId.assert (← mkFreshUserName `x) type e
let (fvarId, mvarId) ← mvarId.intro1
mvarId.withContext do k mvarId fvarId #[]
where
throwInductiveExpected {α} (type : Expr) : MetaM α := do
throwTacticEx `grind.cases mvarId m!"(non-recursive) inductive type expected at {mkFVar fvarId}{indentExpr type}"
throwTacticEx `grind.cases mvarId m!"(non-recursive) inductive type expected at {e}{indentExpr type}"

end Lean.Meta.Grind
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Grind/Intro.lean
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private def isCasesCandidate (type : Expr) : MetaM Bool := do

private def applyCases? (goal : Goal) (fvarId : FVarId) : MetaM (Option (List Goal)) := goal.mvarId.withContext do
if (← isCasesCandidate (← fvarId.getType)) then
let mvarIds ← cases goal.mvarId fvarId
let mvarIds ← cases goal.mvarId (mkFVar fvarId)
return mvarIds.map fun mvarId => { goal with mvarId }
else
return none
Expand Down
43 changes: 43 additions & 0 deletions tests/lean/run/grind_cases_tac.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import Lean

open Lean Meta Grind Elab Tactic in
elab "cases' " e:term : tactic => withMainContext do
let e ← elabTerm e none
setGoals (← Grind.cases (← getMainGoal) e)

inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → Vec α n → Vec α (n+1)

def f (v : Vec α n) : Bool :=
match v with
| .nil => true
| .cons .. => false

/--
info: n : Nat
v : Vec Nat n
h : f v ≠ false
⊢ n + 1 = 0 → HEq (Vec.cons 10 v) Vec.nil → False
---
info: n : Nat
v : Vec Nat n
h : f v ≠ false
⊢ ∀ {n_1 : Nat} (a : Nat) (a_1 : Vec Nat n_1), n + 1 = n_1 + 1 → HEq (Vec.cons 10 v) (Vec.cons a a_1) → False
-/
#guard_msgs (info) in
example (v : Vec Nat n) (h : f v ≠ false) : False := by
cases' (Vec.cons 10 v)
next => trace_state; sorry
next => trace_state; sorry

/--
info: ⊢ False → False
---
info: ⊢ True → False
-/
#guard_msgs (info) in
example : False := by
cases' (Or.inr (a := False) True.intro)
next => trace_state; sorry
next => trace_state; sorry

0 comments on commit 7b496bf

Please sign in to comment.