Skip to content

Commit

Permalink
feat: give preference to case-splits with fewer cases
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Jan 12, 2025
1 parent 2d04171 commit 1b50af3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 31 deletions.
21 changes: 14 additions & 7 deletions src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1964,15 +1964,22 @@ def sortFVarIds (fvarIds : Array FVarId) : MetaM (Array FVarId) := do

end Methods

/-- Return `true` if `declName` is an inductive predicate. That is, `inductive` type in `Prop`. -/
def isInductivePredicate (declName : Name) : MetaM Bool := do
/--
Return `some info` if `declName` is an inductive predicate where `info : InductiveVal`.
That is, `inductive` type in `Prop`.
-/
def isInductivePredicate? (declName : Name) : MetaM (Option InductiveVal) := do
match (← getEnv).find? declName with
| some (.inductInfo { type := type, ..}) =>
forallTelescopeReducing type fun _ type => do
| some (.inductInfo info) =>
forallTelescopeReducing info.type fun _ type => do
match (← whnfD type) with
| .sort u .. => return u == levelZero
| _ => return false
| _ => return false
| .sort u .. => if u == levelZero then return some info else return none
| _ => return none
| _ => return none

/-- Return `true` if `declName` is an inductive predicate. That is, `inductive` type in `Prop`. -/
def isInductivePredicate (declName : Name) : MetaM Bool := do
return (← isInductivePredicate? declName).isSome

def isListLevelDefEqAux : List Level → List Level → MetaM Bool
| [], [] => return true
Expand Down
63 changes: 39 additions & 24 deletions src/Lean/Meta/Tactic/Grind/Split.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Lean.Meta.Grind
inductive CaseSplitStatus where
| resolved
| notReady
| ready
| ready (numCases : Nat) (isRec := false)
deriving Inhabited, BEq

private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
Expand All @@ -24,7 +24,7 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
if (← isEqTrue a <||> isEqTrue b) then
return .resolved
else
return .ready
return .ready 2
else if (← isEqFalse e) then
return .resolved
else
Expand All @@ -36,60 +36,74 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
if (← isEqFalse a <||> isEqFalse b) then
return .resolved
else
return .ready
return .ready 2
else
return .notReady
| Eq _ _ _ =>
if (← isEqTrue e <||> isEqFalse e) then
return .ready
return .ready 2
else
return .notReady
| ite _ c _ _ _ =>
if (← isEqTrue c <||> isEqFalse c) then
return .resolved
else
return .ready
return .ready 2
| dite _ c _ _ _ =>
if (← isEqTrue c <||> isEqFalse c) then
return .resolved
else
return .ready
return .ready 2
| _ =>
if (← isResolvedCaseSplit e) then
trace[grind.debug.split] "split resolved: {e}"
return .resolved
if (← isMatcherApp e) then
return .ready
if let some info := isMatcherAppCore? (← getEnv) e then
return .ready info.numAlts
let .const declName .. := e.getAppFn | unreachable!
if (← isInductivePredicate declName <&&> isEqTrue e) then
return .ready
if let some info ← isInductivePredicate? declName then
if (← isEqTrue e) then
return .ready info.ctors.length info.isRec
return .notReady

private inductive SplitCandidate where
| none
| some (c : Expr) (numCases : Nat) (isRec : Bool)

/-- Returns the next case-split to be performed. It uses a very simple heuristic. -/
private def selectNextSplit? : GoalM (Option Expr) := do
if (← isInconsistent) then return none
if (← checkMaxCaseSplit) then return none
go (← get).splitCandidates none []
private def selectNextSplit? : GoalM SplitCandidate := do
if (← isInconsistent) then return .none
if (← checkMaxCaseSplit) then return .none
go (← get).splitCandidates .none []
where
go (cs : List Expr) (c? : Option Expr) (cs' : List Expr) : GoalM (Option Expr) := do
go (cs : List Expr) (c? : SplitCandidate) (cs' : List Expr) : GoalM SplitCandidate := do
match cs with
| [] =>
modify fun s => { s with splitCandidates := cs'.reverse }
if c?.isSome then
if let .some _ numCases isRec := c? then
let numSplits := (← get).numSplits
-- We only increase the number of splits if there is more than one case or it is recursive.
let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits
-- Remark: we reset `numEmatch` after each case split.
-- We should consider other strategies in the future.
modify fun s => { s with numSplits := s.numSplits + 1, numEmatch := 0 }
modify fun s => { s with numSplits, numEmatch := 0 }
return c?
| c::cs =>
match (← checkCaseSplitStatus c) with
| .notReady => go cs c? (c::cs')
| .resolved => go cs c? cs'
| .ready =>
| .ready numCases isRec =>
match c? with
| none => go cs (some c) cs'
| some c' =>
if (← getGeneration c) < (← getGeneration c') then
go cs (some c) (c'::cs')
| .none => go cs (.some c numCases isRec) cs'
| .some c' numCases' _ =>
let isBetter : GoalM Bool := do
if numCases == 1 && !isRec && numCases' > 1 then
return true
if (← getGeneration c) < (← getGeneration c') then
return true
return numCases < numCases'
if (← isBetter) then
go cs (.some c numCases isRec) (c'::cs')
else
go cs c? (c::cs')

Expand Down Expand Up @@ -118,9 +132,10 @@ and returns a new list of goals if successful.
-/
def splitNext : GrindTactic := fun goal => do
let (goals?, _) ← GoalM.run goal do
let some c ← selectNextSplit?
let .some c numCases isRec ← selectNextSplit?
| return none
let gen ← getGeneration c
let genNew := if numCases > 1 || isRec then gen+1 else gen
trace_goal[grind.split] "{c}, generation: {gen}"
let mvarIds ← if (← isMatcherApp c) then
casesMatch (← get).mvarId c
Expand All @@ -129,7 +144,7 @@ def splitNext : GrindTactic := fun goal => do
cases (← get).mvarId major
let goal ← get
let goals := mvarIds.map fun mvarId => { goal with mvarId }
let goals ← introNewHyp goals [] (gen+1)
let goals ← introNewHyp goals [] genNew
return some goals
return goals?

Expand Down

0 comments on commit 1b50af3

Please sign in to comment.