Skip to content

Commit

Permalink
refactor monomorphization
Browse files Browse the repository at this point in the history
  • Loading branch information
PratherConid committed Dec 19, 2024
1 parent bf88b6e commit e6d6330
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 39 deletions.
72 changes: 34 additions & 38 deletions Auto/Translation/Monomorphization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ structure State where
-- During initialization, we supply an array `lemmas` of lemmas
-- `liArr[i]` are instances of `lemmas[i]`.
lisArr : Array LemmaInsts := #[]
-- Definitional equalities from instance relations between `ConstInst`s
ciInstDefEqs : Array LemmaInst := #[]
-- The `Nat` in `LemmaInst × Nat` indicates the `LemmaInst`'s
-- position in ``lisArr``
active : Std.Queue (ConstInst ⊕ (LemmaInst × Nat)) := Std.Queue.empty
Expand Down Expand Up @@ -567,6 +569,8 @@ def initializeMonoM (lemmas : Array Lemma) : MonoM Unit := do
let li ← LemmaInst.ofLemmaHOL lem
trace[auto.mono.printLemmaInst] "New {li}"
return li)
for (li, idx) in lemmaInsts.zipWithIndex do
setActive ((← getActive).enqueue (.inr (li, idx)))
let lemmaInsts := lemmaInsts.map (fun x => #[x])
setLisArr lemmaInsts
for lem in lemmas do
Expand Down Expand Up @@ -605,6 +609,7 @@ def saturate : MonoM Unit := do
return
match ← dequeueActive? with
| .some (.inl ci) =>
generateCiInstDefEq ci
let lisArr ← getLisArr
trace[auto.mono.match] "Matching against {ci}"
for (lis, idx) in lisArr.zipWithIndex do
Expand Down Expand Up @@ -657,48 +662,32 @@ where
-- A new instance of an assumption
if new? then
trace[auto.mono.printLemmaInst] "New {matchLi}"
-- Attempt to instantiate instance arguments and get monomoprhic lemma instance
let matchLi := (← LemmaInst.monomorphic? matchLi).getD matchLi
newLis := newLis.push matchLi
setActive ((← getActive).enqueue (.inr (matchLi, idx)))
let newCis ← collectConstInsts matchLi.params #[] matchLi.type
for newCi in newCis do
processConstInst newCi
collectAndProcessConstInsts matchLi
return (newLis, cnt)

/-- Remove non-monomorphic lemma instances -/
def postprocessSaturate : MonoM LemmaInsts := do
let lisArr ← getLisArr
let lisArr ← liftM <| lisArr.mapM (fun lis => lis.filterMapM LemmaInst.monomorphic?)
let lis := lisArr.flatMap id
-- Since typeclasses might have been instantiated during `LemmaInst.monomorphic?`,
-- we need to run ``collectConstInst`` again. Also, this must precede
-- collecting definitional equalities related to `ConstInst`s
refreshConstInsts lis
-- Collect definitional equalities related to `ConstInst`s
-- **TODO:** Collect definitional equalities during monomorphization
-- and make uses of the `active` field. This is because new `ConstInst`s
-- might be generated during collection of definitional equalities,
-- and they may produce more definitional equalities
let mut cieqs : Array LemmaInst := #[]
let cis := ((← getCiMap).toArray.map Prod.snd).flatMap id
for (ci₁, idx₁) in cis.zipWithIndex do
for (ci₂, idx₂) in cis.zipWithIndex do
if idx₁ < idx₂ && !(isTrigger ci₁.head) && !(isTrigger ci₂.head) then
if let .some (proof, eq) ← bidirectionalOfInstanceEq ci₁ ci₂ then
/-
This `ci` comes from `active`, so it is already canonicalized
-/
generateCiInstDefEq (ci : ConstInst) : MonoM Unit := do
if isTrigger ci.head then
return
let cis := ((← getCiMap).toArray.map Prod.snd).flatMap id
for (ci', _) in cis.zipWithIndex do
if (← ci.toExpr) != (← ci'.toExpr) && !(isTrigger ci'.head) then
if let .some (proof, eq) ← bidirectionalOfInstanceEq ci ci' then
let eq := Expr.eraseMData (← Core.betaReduce eq)
let eq ← Meta.transform eq (pre := fun e => do return .continue (← unfoldProj e))
let newLi ← LemmaInst.ofLemma ⟨⟨proof, eq, .leaf "ciInstDefEq"⟩, #[]⟩
cieqs := cieqs.push newLi
trace[auto.mono.ciInstDefEq] "{eq}"
-- Since new `ConstInst`s might be produced during definitional equality
-- generation, we need to ``collectConstInst`` again
refreshConstInsts cieqs
return lis ++ cieqs
where
refreshConstInsts (lis : LemmaInsts) : MonoM Unit :=
for li in lis do
let newCis ← collectConstInsts li.params #[] li.type
for newCi in newCis do
processConstInst newCi
let newLi ← LemmaInst.ofLemma ⟨⟨proof, eq, .leaf "ciInstDefEq"⟩, #[]⟩
setCiInstDefEqs ((← getCiInstDefEqs).push newLi)
collectAndProcessConstInsts newLi
collectAndProcessConstInsts (li : LemmaInst) : MonoM Unit := do
let newCis ← collectConstInsts li.params #[] li.type
for newCi in newCis do
processConstInst newCi
bidirectionalOfInstanceEq (ci₁ ci₂ : ConstInst) : MetaM (Option (Expr × Expr)) := do
let mode := auto.mono.ciInstDefEq.mode.get (← getOptions)
Meta.withNewMCtxDepth <| Meta.withTransparency mode <| do
Expand All @@ -709,6 +698,13 @@ where
| .const name _ => name == ``SMT.Attribute.trigger
| _ => false

/-- Remove non-monomorphic lemma instances -/
def getAllMonoLemmaInsts : MonoM LemmaInsts := do
let lisArr ← getLisArr
let lisArr ← liftM <| lisArr.mapM (fun lis => lis.filterMapM LemmaInst.monomorphic?)
let lis := lisArr.flatMap id
return lis ++ (← getCiInstDefEqs)

/-- Collect inductive types -/
def collectMonoMutInds : MonoM (Array (Array SimpleIndVal)) := do
let cis := (Array.mk ((← getCiMap).toList.map Prod.snd)).flatMap id
Expand Down Expand Up @@ -968,7 +964,7 @@ def intromono (lemmas : Array Lemma) (mvarId : MVarId) : MetaM MVarId := do
let monoMAction : MonoM LemmaInsts := (do
initializeMonoM lemmas
saturate
let monoLemmas ← postprocessSaturate
let monoLemmas ← getAllMonoLemmaInsts
trace[auto.mono] "Monomorphization took {(← IO.monoMsNow) - startTime}ms"
return monoLemmas)
let (monoLemmas, _) ← monoMAction.run {}
Expand Down Expand Up @@ -998,7 +994,7 @@ where
let startTime ← IO.monoMsNow
initializeMonoM lemmas
saturate
let monoLemmas ← postprocessSaturate
let monoLemmas ← getAllMonoLemmaInsts
let monoIndVals ← collectMonoMutInds
trace[auto.mono] "Monomorphization of lemmas took {(← IO.monoMsNow) - startTime}ms"
return (monoLemmas, monoIndVals)
Expand Down
1 change: 0 additions & 1 deletion Test/Test_Regression.lean
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ section UnfoldConst
example : c₂ = 2 := by auto u[c₂] d[c₁]
example (h : c₃ = c₁) : c₃ = 2 := by auto [h] u[c₁]
example : let c := 2; c = 2 := by
try auto u[c];
auto

example : True := by auto d[Nat.rec]
Expand Down

0 comments on commit e6d6330

Please sign in to comment.