From 1b50af3e95ae7fd698b68d8f550d19dadf43dbe2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 11 Jan 2025 19:26:21 -0800 Subject: [PATCH] feat: give preference to case-splits with fewer cases --- src/Lean/Meta/Basic.lean | 21 ++++++--- src/Lean/Meta/Tactic/Grind/Split.lean | 63 +++++++++++++++++---------- 2 files changed, 53 insertions(+), 31 deletions(-) diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index d395b62478a6..0da4e5fde124 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -1964,15 +1964,22 @@ def sortFVarIds (fvarIds : Array FVarId) : MetaM (Array FVarId) := do end Methods -/-- Return `true` if `declName` is an inductive predicate. That is, `inductive` type in `Prop`. -/ -def isInductivePredicate (declName : Name) : MetaM Bool := do +/-- +Return `some info` if `declName` is an inductive predicate where `info : InductiveVal`. +That is, `inductive` type in `Prop`. +-/ +def isInductivePredicate? (declName : Name) : MetaM (Option InductiveVal) := do match (← getEnv).find? declName with - | some (.inductInfo { type := type, ..}) => - forallTelescopeReducing type fun _ type => do + | some (.inductInfo info) => + forallTelescopeReducing info.type fun _ type => do match (← whnfD type) with - | .sort u .. => return u == levelZero - | _ => return false - | _ => return false + | .sort u .. => if u == levelZero then return some info else return none + | _ => return none + | _ => return none + +/-- Return `true` if `declName` is an inductive predicate. That is, `inductive` type in `Prop`. -/ +def isInductivePredicate (declName : Name) : MetaM Bool := do + return (← isInductivePredicate? declName).isSome def isListLevelDefEqAux : List Level → List Level → MetaM Bool | [], [] => return true diff --git a/src/Lean/Meta/Tactic/Grind/Split.lean b/src/Lean/Meta/Tactic/Grind/Split.lean index 4d242cf97486..b7c567f90224 100644 --- a/src/Lean/Meta/Tactic/Grind/Split.lean +++ b/src/Lean/Meta/Tactic/Grind/Split.lean @@ -14,7 +14,7 @@ namespace Lean.Meta.Grind inductive CaseSplitStatus where | resolved | notReady - | ready + | ready (numCases : Nat) (isRec := false) deriving Inhabited, BEq private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do @@ -24,7 +24,7 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do if (← isEqTrue a <||> isEqTrue b) then return .resolved else - return .ready + return .ready 2 else if (← isEqFalse e) then return .resolved else @@ -36,60 +36,74 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do if (← isEqFalse a <||> isEqFalse b) then return .resolved else - return .ready + return .ready 2 else return .notReady | Eq _ _ _ => if (← isEqTrue e <||> isEqFalse e) then - return .ready + return .ready 2 else return .notReady | ite _ c _ _ _ => if (← isEqTrue c <||> isEqFalse c) then return .resolved else - return .ready + return .ready 2 | dite _ c _ _ _ => if (← isEqTrue c <||> isEqFalse c) then return .resolved else - return .ready + return .ready 2 | _ => if (← isResolvedCaseSplit e) then trace[grind.debug.split] "split resolved: {e}" return .resolved - if (← isMatcherApp e) then - return .ready + if let some info := isMatcherAppCore? (← getEnv) e then + return .ready info.numAlts let .const declName .. := e.getAppFn | unreachable! - if (← isInductivePredicate declName <&&> isEqTrue e) then - return .ready + if let some info ← isInductivePredicate? declName then + if (← isEqTrue e) then + return .ready info.ctors.length info.isRec return .notReady +private inductive SplitCandidate where + | none + | some (c : Expr) (numCases : Nat) (isRec : Bool) + /-- Returns the next case-split to be performed. It uses a very simple heuristic. -/ -private def selectNextSplit? : GoalM (Option Expr) := do - if (← isInconsistent) then return none - if (← checkMaxCaseSplit) then return none - go (← get).splitCandidates none [] +private def selectNextSplit? : GoalM SplitCandidate := do + if (← isInconsistent) then return .none + if (← checkMaxCaseSplit) then return .none + go (← get).splitCandidates .none [] where - go (cs : List Expr) (c? : Option Expr) (cs' : List Expr) : GoalM (Option Expr) := do + go (cs : List Expr) (c? : SplitCandidate) (cs' : List Expr) : GoalM SplitCandidate := do match cs with | [] => modify fun s => { s with splitCandidates := cs'.reverse } - if c?.isSome then + if let .some _ numCases isRec := c? then + let numSplits := (← get).numSplits + -- We only increase the number of splits if there is more than one case or it is recursive. + let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits -- Remark: we reset `numEmatch` after each case split. -- We should consider other strategies in the future. - modify fun s => { s with numSplits := s.numSplits + 1, numEmatch := 0 } + modify fun s => { s with numSplits, numEmatch := 0 } return c? | c::cs => match (← checkCaseSplitStatus c) with | .notReady => go cs c? (c::cs') | .resolved => go cs c? cs' - | .ready => + | .ready numCases isRec => match c? with - | none => go cs (some c) cs' - | some c' => - if (← getGeneration c) < (← getGeneration c') then - go cs (some c) (c'::cs') + | .none => go cs (.some c numCases isRec) cs' + | .some c' numCases' _ => + let isBetter : GoalM Bool := do + if numCases == 1 && !isRec && numCases' > 1 then + return true + if (← getGeneration c) < (← getGeneration c') then + return true + return numCases < numCases' + if (← isBetter) then + go cs (.some c numCases isRec) (c'::cs') else go cs c? (c::cs') @@ -118,9 +132,10 @@ and returns a new list of goals if successful. -/ def splitNext : GrindTactic := fun goal => do let (goals?, _) ← GoalM.run goal do - let some c ← selectNextSplit? + let .some c numCases isRec ← selectNextSplit? | return none let gen ← getGeneration c + let genNew := if numCases > 1 || isRec then gen+1 else gen trace_goal[grind.split] "{c}, generation: {gen}" let mvarIds ← if (← isMatcherApp c) then casesMatch (← get).mvarId c @@ -129,7 +144,7 @@ def splitNext : GrindTactic := fun goal => do cases (← get).mvarId major let goal ← get let goals := mvarIds.map fun mvarId => { goal with mvarId } - let goals ← introNewHyp goals [] (gen+1) + let goals ← introNewHyp goals [] genNew return some goals return goals?