Skip to content

Commit

Permalink
fix monomorphization saturation bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PratherConid committed Dec 16, 2024
1 parent 150c660 commit f7e1466
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 44 deletions.
120 changes: 76 additions & 44 deletions Auto/Translation/Monomorphization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ open Lean
initialize
registerTraceClass `auto.mono
registerTraceClass `auto.mono.match
registerTraceClass `auto.mono.match.verbose
registerTraceClass `auto.mono.printLemmaInst
registerTraceClass `auto.mono.printConstInst
registerTraceClass `auto.mono.printResult
registerTraceClass `auto.mono.printInputLemmas

register_option auto.mono.saturationThreshold : Nat := {
defValue := 250
descr := "Threshold for number of potentially new lemma" ++
" instances generated during the saturation loop of monomorphization"
defValue := 1024
descr := "Threshold for number of matches performed" ++
" during the saturation loop of monomorphization"
}

register_option auto.mono.recordInstInst : Bool := {
Expand Down Expand Up @@ -418,7 +419,16 @@ def LemmaInst.matchConstInst (ci : ConstInst) (li : LemmaInst) : MetaM (Std.Hash
let (lmvars, mvars, mi) ← MLemmaInst.ofLemmaInst li
if lmvars.size == 0 && mvars.size == 0 then
return Std.HashSet.empty
MLemmaInst.matchConstInst ci mi mi.type
-- Match with `b` in `∀ (x₁ : α₁) ⋯ (xₙ : αₙ). b := li.type`
let mut ret ← MLemmaInst.matchConstInst ci mi mi.type
-- Match with `α₁ ⋯ αₙ` in `∀ (x₁ : α₁) ⋯ (xₙ : αₙ). b := li.type`
for mvar in mvars do
let .mvar m := mvar
| throwError "{decl_name%} :: Unexpected error"
let mtype ← m.getType
if ← Meta.isProp mtype then
ret := mergeHashSet ret (← MLemmaInst.matchConstInst ci mi (← m.getType))
return ret

/--
Check whether the leading `∀` quantifiers of expression `e`
Expand Down Expand Up @@ -484,7 +494,7 @@ def LemmaInst.monomorphic? (li : LemmaInst) : MetaM (Option LemmaInst) := do
valid instances of constants (dependent arguments fully
instantiated). They form the initial elements of `ciMap`
and `activeCi`
(3) Repeat:
(3) Repeat: **TODO**
· Dequeue an element `(name, n)` from `activeCi`
· For each element `ais : LemmaInsts` in `liArr`,
for each expression `e` in `ais`, traverse `e` to
Expand All @@ -501,12 +511,13 @@ def LemmaInst.monomorphic? (li : LemmaInst) : MetaM (Option LemmaInst) := do
-/
structure State where
-- The `Expr` is the fingerprint of the `ConstInst`
ciMap : Std.HashMap Expr ConstInsts := {}
-- The `Expr` is the fingerprint of the `ConstInst`
activeCi : Std.Queue (Expr × Nat) := Std.Queue.empty
ciMap : Std.HashMap Expr ConstInsts := {}
-- During initialization, we supply an array `lemmas` of lemmas
-- `liArr[i]` are instances of `lemmas[i]`.
lisArr : Array LemmaInsts := #[]
lisArr : Array LemmaInsts := #[]
-- The `Nat` in `LemmaInst × Nat` indicates the `LemmaInst`'s
-- position in ``lisArr``
active : Std.Queue (ConstInst ⊕ (LemmaInst × Nat)) := Std.Queue.empty

abbrev MonoM := StateRefT State MetaM

Expand Down Expand Up @@ -537,19 +548,8 @@ def processConstInst (ci : ConstInst) : MonoM Unit := do
return
trace[auto.mono.printConstInst] "New {ci}"
setCiMap ((← getCiMap).insert ci.fingerPrint (insts.push ci))
-- Do not match against ConstInsts that do not have dependent or
-- instance arguments
if ci.argsIdx.size == 0 then
return
-- Do not match against `=` and `∃`
-- If some polymorphic argument of the a theorem only occurs
-- as the first argument of `=` or `∃`, the theorem is probably
-- implied by the axioms of higher order logic, e.g.
-- `Eq.trans : ∀ {α} (x y z : α), x = y → y = z → x = z`
if ci.head.isNamedConst ``Exists || ci.head.isNamedConst ``Eq then
return
-- Insert `ci` into `activeCi` so that we can later match on it
setActiveCi ((← getActiveCi).enqueue (ci.fingerPrint, insts.size))
setActive ((← getActive).enqueue (.inl ci))

def initializeMonoM (lemmas : Array Lemma) : MonoM Unit := do
let lemmaInsts ← liftM <| lemmas.mapM (fun lem => do
Expand All @@ -563,10 +563,10 @@ def initializeMonoM (lemmas : Array Lemma) : MonoM Unit := do
for ci in cis do
processConstInst ci

def dequeueActiveCi? : MonoM (Option (Expr × Nat)) := do
match (← getActiveCi).dequeue? with
| .some (elem, ci') =>
setActiveCi ci'
def dequeueActive? : MonoM (Option (ConstInst ⊕ (LemmaInst × Nat))) := do
match (← getActive).dequeue? with
| .some (elem, ac') =>
setActive ac'
return .some elem
| .none => return .none

Expand All @@ -591,34 +591,66 @@ def saturate : MonoM Unit := do
cnt := cnt + 1
if (← saturationThresholdReached? cnt) then
return
match ← dequeueActiveCi? with
| .some (name, cisIdx) =>
let ci ← lookupActiveCi! name cisIdx
match ← dequeueActive? with
| .some (.inl ci) =>
let lisArr ← getLisArr
trace[auto.mono.match] "Matching against {ci}"
for (lis, idx) in lisArr.zipWithIndex do
cnt := cnt + 1
let mut newLis := lis
for li in lis do
cnt := cnt + 1
let matchLis := (← LemmaInst.matchConstInst ci li).toArray
for matchLi in matchLis do
-- `matchLi` is a result of matching a subterm of `li` against `ci`
cnt := cnt + 1
if (← saturationThresholdReached? cnt) then
return
let new? ← newLis.newInst? matchLi
-- A new instance of an assumption
if new? then
trace[auto.mono.printLemmaInst] "New {matchLi}"
newLis := newLis.push matchLi
let newCis ← collectConstInsts matchLi.params #[] matchLi.type
for newCi in newCis do
processConstInst newCi
let newLis_cnt ← matchCiAndLi ci li idx cnt
let newLis := newLis_cnt.fst
setLisArr ((← getLisArr).set! idx newLis)
cnt := newLis_cnt.snd
if (← saturationThresholdReached? cnt) then
return
| .some (.inr (li, idx)) =>
trace[auto.mono.match] "Matching against {li}"
let cis := ((← getCiMap).toArray.map Prod.snd).concatMap id
for ci in cis do
cnt := cnt + 1
let newLis_cnt ← matchCiAndLi ci li idx cnt
let newLis := newLis_cnt.fst
setLisArr ((← getLisArr).set! idx newLis)
cnt := newLis_cnt.snd
if (← saturationThresholdReached? cnt) then
return
| .none =>
trace[auto.mono] "Monomorphization Saturated after {cnt} small steps"
return
where
matchCiAndLi (ci : ConstInst) (li : LemmaInst) (idx : Nat) (cnt : Nat) :
MonoM (LemmaInsts × Nat) := do
let mut cnt := cnt
let mut newLis := (← getLisArr)[idx]!
-- Do not match against ConstInsts that have no dependent or instance arguments
if ci.argsIdx.size == 0 then
return (newLis, cnt)
-- Do not match against `=` and `∃`
-- If some polymorphic argument of the a theorem only occurs
-- as the first argument of `=` or `∃`, the theorem is probably
-- implied by the axioms of higher order logic, e.g.
-- `Eq.trans : ∀ {α} (x y z : α), x = y → y = z → x = z`
if ci.head.isNamedConst ``Exists || ci.head.isNamedConst ``Eq then
return (newLis, cnt)
trace[auto.mono.match.verbose] "Matching {ci} against {li}"
cnt := cnt + 1
let matchLis := (← LemmaInst.matchConstInst ci li).toArray
for matchLi in matchLis do
-- `matchLi` is a result of matching a subterm of `li` against `ci`
cnt := cnt + 1
if (← saturationThresholdReached? cnt) then
return (newLis, cnt)
let new? ← newLis.newInst? matchLi
-- A new instance of an assumption
if new? then
trace[auto.mono.printLemmaInst] "New {matchLi}"
newLis := newLis.push matchLi
setActive ((← getActive).enqueue (.inr (matchLi, idx)))
let newCis ← collectConstInsts matchLi.params #[] matchLi.type
for newCi in newCis do
processConstInst newCi
return (newLis, cnt)

/-- Remove non-monomorphic lemma instances -/
def postprocessSaturate : MonoM Unit := do
Expand Down
30 changes: 30 additions & 0 deletions Test/Test_Regression.lean
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,36 @@ example
@hap (List α) (List α) (List α) instHAppend as (@hap (List α) (List α) (List α) instHAppend bs (@hap (List α) (List α) (List α) instHAppend cs ds)) := by
auto [ap_assoc]

-- Matching with leading propositional ∀ quantifiers

example
(p : ∀ (α : Type), List α → Prop)
(h1 : ∀ α x, p α x → q)
(h2 : p A x) : q := by
auto

example
(p q : ∀ (α : Type), List α → Prop)
(h1 : ∀ α β x y, p α x → q β y → r)
(h2 : p A x)
(h3 : q B y) : r := by
auto

-- One LemmaInst match multiple ConstInst

example
(p1 p2 : ∀ (α : Type), List α → Prop)
(h1 : ∀ α β x y, p1 α x → p2 β y)
(h2 : p1 A x) : p2 B y := by
auto

example
(p1 p2 p3 : ∀ (α β : Type), List α → List β → Prop)
(h1 : ∀ α β γ δ ε π x y z t u v, p1 α β x y → p2 γ δ z t → p3 ε π u v)
(h2 : p1 A B x y)
(h3 : p2 C D z t) : p3 E F u v := by
auto

-- Metavariable

example (u : α) (h : ∀ (z : α), x = z ∧ z = y) : x = y := by
Expand Down

0 comments on commit f7e1466

Please sign in to comment.