Skip to content

Commit

Permalink
done refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
PratherConid committed May 19, 2024
1 parent a8ad757 commit 7b7768b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 109 deletions.
110 changes: 12 additions & 98 deletions Auto/Translation/Lam2DAtomAsFVar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ structure State where
varVal : Array (Expr × LamSort)
lamEVarTy : Array LamSort
-- Type atoms and term atoms to be abstracted
atomsToAbstract : Array (FVarId × Expr) := #[]
atomsToAbstract : Array (FVarId × Expr) := #[]
-- Etoms to be abstracted
etomsToAbstract : Array (FVarId × Nat) := #[]
etomsToAbstract : Array (FVarId × Nat) := #[]
-- Type atoms that are used in the expressions sent to external prover
typeAtomFVars : HashMap Nat FVarId := {}
typeAtomFVars : HashMap Nat Expr := {}
-- Term atoms that are used in the expressions sent to external prover
termAtomFVars : HashMap Nat FVarId := {}
termAtomFVars : HashMap Nat Expr := {}
-- Etoms that are used in the expression sent to external prover
etomFVars : HashMap Nat FVarId := {}
etomFVars : HashMap Nat Expr := {}

abbrev ExternM := StateRefT State MetaStateM

Expand All @@ -75,122 +75,36 @@ def withTypeAtomsAsFVar (atoms : Array Nat) : ExternM Unit :=
let name := (`_exTy).appendIndexAfter (← getTypeAtomFVars).size
let newFVarId ← withLocalDecl name .default (.sort lvl) .default
setAtomsToAbstract ((← getAtomsToAbstract).push (newFVarId, e))
setTypeAtomFVars ((← getTypeAtomFVars).insert atom newFVarId)

/--
Takes a `s : LamSort` and produces the `un-lifted` version of `s.interp`
(note that `s.interp` is lifted)
This function should be called after we've called
`withTypeAtomAsFVar` on all the type atoms occurring in `s`
-/
def interpLamSortAsUnlifted : LamSort → ExternM Expr
| .atom n => do
let .some fid := (← getTypeAtomFVars).find? n
| throwError "interpLamSortAsUnlifted :: Cannot find fvarId assigned to type atom {n}"
return .fvar fid
| .base b => return Lam2D.interpLamBaseSortAsUnlifted b
| .func s₁ s₂ => do
return .forallE `_ (← interpLamSortAsUnlifted s₁) (← interpLamSortAsUnlifted s₂) .default
setTypeAtomFVars ((← getTypeAtomFVars).insert atom (.fvar newFVarId))

def withTermAtomsAsFVar (atoms : Array Nat) : ExternM Unit :=
for atom in atoms do
if (← getTermAtomFVars).contains atom then
continue
let .some (e, s) := (← getVarVal)[atom]?
| throwError "withTermAtomAsFVar :: Unknown term atom {atom}"
let sinterp ← interpLamSortAsUnlifted s
let sinterp ← Lam2D.interpLamSortAsUnlifted (← getTypeAtomFVars) s
let name := (`e!).appendIndexAfter (← getTermAtomFVars).size
let newFVarId ← withLocalDecl name .default sinterp .default
setAtomsToAbstract ((← getAtomsToAbstract).push (newFVarId, e))
setTermAtomFVars ((← getTermAtomFVars).insert atom newFVarId)
setTermAtomFVars ((← getTermAtomFVars).insert atom (.fvar newFVarId))

def withEtomsAsFVar (etoms : Array Nat) : ExternM Unit :=
for etom in etoms do
if (← getEtomFVars).contains etom then
return
let .some s := (← getLamEVarTy)[etom]?
| throwError "withEtomAsFVar :: Unknown etom {etom}"
let sinterp ← interpLamSortAsUnlifted s
let sinterp ← Lam2D.interpLamSortAsUnlifted (← getTypeAtomFVars) s
let name := (`e?).appendIndexAfter (← getEtomFVars).size
let newFVarId ← withLocalDecl name .default sinterp .default
setEtomsToAbstract ((← getEtomsToAbstract).push (newFVarId, etom))
setEtomFVars ((← getEtomFVars).insert etom newFVarId)

open Embedding in
def interpOtherConstAsUnlifted (oc : OtherConst) : ExternM Expr := do
let .some (.defnInfo constIdVal) := (← getEnv).find? ``constId
| throwError "interpOtherConstAsUnlifted :: Unexpected error"
let constIdExpr := fun params => constIdVal.value.instantiateLevelParams constIdVal.levelParams params
match oc with
| .smtAttr1T _ sattr sterm => do
let tyattr ← interpLamSortAsUnlifted sattr
let sortattr ← runMetaM <| Expr.normalizeType (← MetaState.inferType tyattr)
let Expr.sort lvlattr := sortattr
| throwError "interpOtherConstAsUnlifted :: Unexpected sort {sortattr}"
let tyterm ← interpLamSortAsUnlifted sterm
let sortterm ← runMetaM <| Expr.normalizeType (← MetaState.inferType tyterm)
let Expr.sort lvlterm := sortterm
| throwError "interpOtherConstAsUnlifted :: Unexpected sort {sortterm}"
return Lean.mkApp2 (constIdExpr [lvlattr, lvlterm]) tyattr tyterm

open Embedding in
def interpLamBaseTermAsUnlifted : LamBaseTerm → ExternM Expr
| .pcst pc => Lam2D.interpPropConstAsUnlifted pc
| .bcst bc => return Lam2D.interpBoolConstAsUnlifted bc
| .ncst nc => return Lam2D.interpNatConstAsUnlifted nc
| .icst ic => return Lam2D.interpIntConstAsUnlifted ic
| .scst sc => return Lam2D.interpStringConstAsUnlifted sc
| .bvcst bvc => return Lam2D.interpBitVecConstAsUnlifted bvc
| .ocst oc => interpOtherConstAsUnlifted oc
| .eqI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog)
| .forallEI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog)
| .existEI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog)
| .iteI _ => throwError ("interpLamTermAsUnlifted :: " ++ exportError.ImpPolyLog)
| .eq s => do
return ← runMetaM <| Meta.mkAppOptM ``Eq #[← interpLamSortAsUnlifted s]
| .forallE s => do
let ty ← interpLamSortAsUnlifted s
let sort ← runMetaM <| Expr.normalizeType (← MetaState.inferType ty)
let Expr.sort lvl := sort
| throwError "interpLamBaseTermAsUnlifted :: Unexpected sort {sort}"
let .some (.defnInfo forallVal) := (← getEnv).find? ``forallF
| throwError "interpLamBaseTermAsUnlifted :: Unexpected error"
let forallFExpr := forallVal.value.instantiateLevelParams forallVal.levelParams [lvl, .zero]
return mkAppN forallFExpr #[← interpLamSortAsUnlifted s]
| .existE s => do
return ← runMetaM <| Meta.mkAppOptM ``Exists #[← interpLamSortAsUnlifted s]
| .ite s => do
return ← runMetaM <| Meta.mkAppOptM ``Bool.ite' #[← interpLamSortAsUnlifted s]

/--
Takes a `t : LamTerm` and produces the `un-lifted` version of `t.interp`.
This function should be called after we've called
`withTermAtomAsFVar` on all the term atoms occurring in `t`
`lctx` is for pretty printing
-/
def interpLamTermAsUnlifted (lctx : Nat) : LamTerm → ExternM Expr
| .atom n => do
let .some fid := (← getTermAtomFVars).find? n
| throwError "interpLamTermAsUnlifted :: Cannot find fvarId assigned to term atom {n}"
return .fvar fid
| .etom n => do
let .some fid := (← getEtomFVars).find? n
| throwError "interpLamSortAsUnlifted :: Cannot find fvarId assigned to etom {n}"
return .fvar fid
| .base b => interpLamBaseTermAsUnlifted b
| .bvar n => return .bvar n
| .lam s t => do
let sinterp ← interpLamSortAsUnlifted s
let tinterp ← interpLamTermAsUnlifted lctx.succ t
let name := (`eb!).appendIndexAfter lctx
return .lam name sinterp tinterp .default
| .app _ fn arg => do
return .app (← interpLamTermAsUnlifted lctx fn) (← interpLamTermAsUnlifted lctx arg)
setEtomFVars ((← getEtomFVars).insert etom (.fvar newFVarId))

def withTranslatedLamSorts (ss : Array LamSort) : ExternM (Array Expr) := do
let typeHs := collectLamSortsAtoms ss
withTypeAtomsAsFVar typeHs.toArray
ss.mapM interpLamSortAsUnlifted
ss.mapM (m:=CoreM) (Lam2D.interpLamSortAsUnlifted (← getTypeAtomFVars))

/--
The external prover should only see the local context
Expand All @@ -204,7 +118,7 @@ def withTranslatedLamTerms (ts : Array LamTerm) : ExternM (Array Expr) := do
withTypeAtomsAsFVar typeHs.toArray
withTermAtomsAsFVar termHs.toArray
withEtomsAsFVar etomHs.toArray
ts.mapM (interpLamTermAsUnlifted 0)
MetaState.runMetaM <| ts.mapM (Lam2D.interpLamTermAsUnlifted (← getTypeAtomFVars) (← getTermAtomFVars) (← getEtomFVars) 0)

/--
Given a list of non-dependent types `ty₁, ty₂, ⋯, tyₙ`, add
Expand Down
26 changes: 15 additions & 11 deletions Auto/Translation/LamUtils.lean
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,16 @@ namespace Lam2D
Takes a `s : LamSort` and produces the `un-lifted` version of `s.interp`
(note that `s.interp` is lifted)
-/
def interpLamSortAsUnlifted (tyVal : Array (Expr × Level)) : LamSort → CoreM Expr
def interpLamSortAsUnlifted (tyVal : HashMap Nat Expr) : LamSort → CoreM Expr
| .atom n => do
let .some (e, _) := tyVal[n]?
let .some e := tyVal.find? n
| throwError "interpLamSortAsUnlifted :: Cannot find fvarId assigned to type atom {n}"
return e
| .base b => return Lam2D.interpLamBaseSortAsUnlifted b
| .func s₁ s₂ => do
return .forallE `_ (← interpLamSortAsUnlifted tyVal s₁) (← interpLamSortAsUnlifted tyVal s₂) .default

def interpOtherConstAsUnlifted (tyVal : Array (Expr × Level)) (oc : OtherConst) : MetaM Expr := do
def interpOtherConstAsUnlifted (tyVal : HashMap Nat Expr) (oc : OtherConst) : MetaM Expr := do
let .some (.defnInfo constIdVal) := (← getEnv).find? ``constId
| throwError "interpOtherConstAsUnlifted :: Unexpected error"
let constIdExpr := fun params => constIdVal.value.instantiateLevelParams constIdVal.levelParams params
Expand All @@ -328,7 +328,7 @@ namespace Lam2D
| throwError "interpOtherConstAsUnlifted :: Unexpected sort {sortterm}"
return Lean.mkApp2 (constIdExpr [lvlattr, lvlterm]) tyattr tyterm

def interpLamBaseTermAsUnlifted (tyVal : Array (Expr × Level)) : LamBaseTerm → MetaM Expr
def interpLamBaseTermAsUnlifted (tyVal : HashMap Nat Expr) : LamBaseTerm → MetaM Expr
| .pcst pc => Lam2D.interpPropConstAsUnlifted pc
| .bcst bc => return Lam2D.interpBoolConstAsUnlifted bc
| .ncst nc => return Lam2D.interpNatConstAsUnlifted nc
Expand Down Expand Up @@ -359,28 +359,32 @@ namespace Lam2D
/--
Takes a `t : LamTerm` and produces the `un-lifted` version of `t.interp`.
`lctx` is for pretty printing
Note that `etom`s generated by the verified checker do not directly correspond
to Lean expressions. Therefore, we need to introduce new free variables to
represent `etom`s.
-/
def interpLamTermAsUnlifted
(tyVal : Array (Expr × Level)) (varVal : Array (Expr × LamSort)) (etomFVars : HashMap Nat FVarId)
(tyVal : HashMap Nat Expr) (varVal : HashMap Nat Expr) (etomVal : HashMap Nat Expr)
(lctx : Nat) : LamTerm → MetaM Expr
| .atom n => do
let .some (e, _) := varVal[n]?
let .some e := varVal.find? n
| throwError "interpLamTermAsUnlifted :: Cannot find fvarId assigned to term atom {n}"
return e
| .etom n => do
let .some fid := etomFVars.find? n
let .some efvar := etomVal.find? n
| throwError "interpLamSortAsUnlifted :: Cannot find fvarId assigned to etom {n}"
return .fvar fid
return efvar
| .base b => interpLamBaseTermAsUnlifted tyVal b
| .bvar n => return .bvar n
| .lam s t => do
let sinterp ← interpLamSortAsUnlifted tyVal s
let tinterp ← interpLamTermAsUnlifted tyVal varVal etomFVars lctx.succ t
let tinterp ← interpLamTermAsUnlifted tyVal varVal etomVal lctx.succ t
let name := (`eb!).appendIndexAfter lctx
return .lam name sinterp tinterp .default
| .app _ fn arg => do
let fninterp ← interpLamTermAsUnlifted tyVal varVal etomFVars lctx fn
let arginterp ← interpLamTermAsUnlifted tyVal varVal etomFVars lctx arg
let fninterp ← interpLamTermAsUnlifted tyVal varVal etomVal lctx fn
let arginterp ← interpLamTermAsUnlifted tyVal varVal etomVal lctx arg
return .app fninterp arginterp

end Lam2D

0 comments on commit 7b7768b

Please sign in to comment.