Skip to content

Commit

Permalink
Merge main into duper2
Browse files Browse the repository at this point in the history
  • Loading branch information
JOSHCLUNE committed Dec 10, 2024
2 parents 9158025 + d437b8d commit 3357e49
Show file tree
Hide file tree
Showing 50 changed files with 1,065 additions and 549 deletions.
3 changes: 2 additions & 1 deletion Auto.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Auto.Tactic
import Auto.EvaluateAuto.TestCode

def hello := "world"
def hello := "world"
12 changes: 6 additions & 6 deletions Auto/Embedding/Lam2Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def Lam₂Type.check_iff_interp
simp [check, interp]
cases Nat.decLt n (List.length lctx)
case isTrue h =>
simp [h]; simp [List.get?_eq_get h]
simp [h]
case isFalse h =>
simp [h]; simp at h; simp [List.get?_len_le h]
simp [h, List.getElem?_eq]
case func fn arg IHfn IHarg =>
revert IHfn IHarg
simp [check, interp]
Expand All @@ -134,9 +134,9 @@ def Lam₂Type.check_iff_interp
simp;
match cifn : interp val lctx fn, ciarg : interp val lctx arg with
| .some ⟨0, _⟩, .some ⟨0, _⟩ => simp
| .some ⟨0, _⟩, .some ⟨n + 1, _⟩ => simp; simp_arith
| .some ⟨0, _⟩, .some ⟨n + 1, _⟩ => simp
| .some ⟨0, _⟩, .none => simp
| .some ⟨n + 1, _⟩, _ => simp; simp_arith
| .some ⟨n + 1, _⟩, _ => simp
| .none , _ => simp
| .some 0, .some (n + 1) =>
simp;
Expand Down Expand Up @@ -168,7 +168,7 @@ def Lam₂Type.check_iff_interp
simp;
match cifn : interp val lctx fn, ciarg : interp val lctx arg with
| .some ⟨n + 1, _⟩, .some ⟨0, _⟩ => simp
| .some ⟨n + 1, _⟩, .some ⟨m + 1, _⟩ => simp; simp_arith
| .some ⟨n + 1, _⟩, .some ⟨m + 1, _⟩ => simp
| .some ⟨n + 1, _⟩, .none => simp
| .some ⟨0, _⟩, _ => simp
| .none , _ => simp
Expand All @@ -185,7 +185,7 @@ def Lam₂Type.check_iff_interp
| .some 0, _ =>
simp;
match cifn : interp val lctx fn with
| .some ⟨n + 1, _⟩ => simp; simp_arith
| .some ⟨n + 1, _⟩ => simp
| .some ⟨0, _⟩ => simp
| .none => simp
| .none, _ =>
Expand Down
2 changes: 1 addition & 1 deletion Auto/Embedding/LamBase.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3781,7 +3781,7 @@ def LamWF.bvarApps
rw [List.length_reverse]; apply Nat.le_refl
dsimp [lctxr]; rw [← List.reverseAux, List.reverseAux_eq]
rw [pushLCtxs_lt (by rw [List.length_append]; apply Nat.le_trans exlt (Nat.le_add_right _ _))]
rw [List.getD_eq_getElem?_getD]; rw [List.getElem?_append];
rw [List.getD_eq_getElem?_getD]; rw [List.getElem?_append_left exlt];
rw [List.getElem?_reverse (by dsimp [List.length]; apply Nat.le_refl _)]
dsimp [List.length]; simp
conv => enter [2, 3]; rw [tyeq]
Expand Down
14 changes: 6 additions & 8 deletions Auto/Embedding/LamBitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ namespace BVLems
unfold BitVec.ushiftRight BitVec.toNat
dsimp; rw [Nat.shiftRight_eq_div_pow]

theorem toNat_setWidth {a : BitVec n} (le : n ≤ m) : (a.setWidth m).toNat = a.toNat := by
unfold setWidth
simp only [le, ↓reduceDIte, toNat_setWidth']
theorem toNat_zeroExtend' {a : BitVec n} (le : n ≤ m) : (a.setWidth' le).toNat = a.toNat := rfl

theorem toNat_zeroExtend {a : BitVec n} (i : Nat) : (a.zeroExtend i).toNat = a.toNat % (2 ^ i) := by
unfold BitVec.zeroExtend setWidth; cases hdec : decide (n ≤ i)
unfold BitVec.zeroExtend BitVec.setWidth; cases hdec : decide (n ≤ i)
case false =>
have hnle := of_decide_eq_false hdec
rw [Bool.dite_eq_false (proof:=hnle)]; rfl
Expand Down Expand Up @@ -103,9 +101,9 @@ namespace BVLems
apply Nat.le_trans (toNat_le _) (Nat.pow_le_pow_of_le_right (.step .refl) h)

theorem msb_equiv_lt (a : BitVec n) : !a.msb ↔ a.toNat < 2 ^ (n - 1) := by
dsimp [BitVec.msb, getMsbD, getLsbD]
dsimp [BitVec.msb, BitVec.getMsbD, BitVec.getLsbD]
cases n
case zero => simp [BitVec.toNat]
case zero => cases a <;> simp
case succ n =>
have dtrue : decide (0 < n + 1) = true := by simp
rw [dtrue, Bool.not_eq_true', Bool.true_and, Nat.succ_sub_one, Nat.testBit_false_iff]
Expand All @@ -114,7 +112,7 @@ namespace BVLems
theorem msb_equiv_lt' (a : BitVec n) : !a.msb ↔ 2 * a.toNat < 2 ^ n := by
rw [msb_equiv_lt]
cases n
case zero => simp [BitVec.toNat]
case zero => cases a <;> simp
case succ n =>
rw [Nat.succ_sub_one, Nat.pow_succ, Nat.mul_comm (m:=2)]
apply Iff.symm; apply Nat.mul_lt_mul_left
Expand All @@ -135,7 +133,7 @@ namespace BVLems
rw [Nat.sub_one, Nat.pred_lt_iff_le (Nat.two_pow_pos _)]
apply Nat.le_trans (Nat.sub_le _ _) (Nat.pow_le_pow_of_le_right (.step .refl) h)
apply eq_of_val_eq; rw [toNat_ofNatLt, hzero]
rw [toNat_neg, Int.mod_def', Int.emod];
rw [toNat_neg, Int.mod_def', Int.emod]
rw [Nat.zero_mod, Int.natAbs_ofNat, Nat.succ_eq_add_one, Nat.zero_add]
rw [Int.subNatNat_of_sub_eq_zero ((Nat.sub_eq_zero_iff_le).mpr (Nat.two_pow_pos _))]
rw [Int.toNat_ofNat, BitVec.toNat_ofNat]
Expand Down
87 changes: 48 additions & 39 deletions Auto/Embedding/LamConv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1351,11 +1351,13 @@ def LamTerm.betaBounded (n : Nat) (t : LamTerm) :=
| .bvar _ => t
| .lam s t => .lam s (t.betaBounded n')
| .app .. =>
let tb := t.headBetaBounded n'
let fn := tb.getAppFn
let args := tb.getAppArgs
let argsb := args.map (fun ((s, arg) : LamSort × _) => (s, betaBounded n' arg))
LamTerm.mkAppN fn argsb
match t.isHeadBetaTarget with
| true => LamTerm.betaBounded n' (t.headBetaBounded n')
| false =>
let fn := t.getAppFn
let args := t.getAppArgs
let argsb := args.map (fun ((s, arg) : LamSort × _) => (s, betaBounded n' arg))
LamTerm.mkAppN fn argsb

theorem LamTerm.maxEVarSucc_betaBounded :
(LamTerm.betaBounded n t).maxEVarSucc ≤ t.maxEVarSucc := by
Expand All @@ -1366,16 +1368,20 @@ theorem LamTerm.maxEVarSucc_betaBounded :
case lam s t => apply IH
case app s fn arg =>
dsimp [betaBounded, maxEVarSucc]
apply LamTerm.maxEVarSucc_mkAppN
case hs =>
apply HList.toMapTy; dsimp [Function.comp]
apply HList.map _ LamTerm.maxEVarSucc_getAppArgs
intro a; cases a; dsimp; intro h
apply Nat.le_trans _ (Nat.le_trans h _)
apply IH; apply maxEVarSucc_headBetaBounded
case ht =>
apply Nat.le_trans maxEVarSucc_getAppFn
apply maxEVarSucc_headBetaBounded
cases (app s fn arg).isHeadBetaTarget
case true =>
apply Nat.le_trans IH
apply Nat.le_trans LamTerm.maxEVarSucc_headBetaBounded (Nat.le_refl _)
case false =>
apply LamTerm.maxEVarSucc_mkAppN
case hs =>
apply HList.toMapTy; dsimp [Function.comp]
apply HList.map _ LamTerm.maxEVarSucc_getAppArgs
intro a; cases a; dsimp; intro h
apply Nat.le_trans _ (Nat.le_trans h _)
apply IH; exact Nat.le_refl _
case ht =>
apply Nat.le_trans maxEVarSucc_getAppFn (Nat.le_refl _)

def LamTerm.betaReduced (t : LamTerm) :=
match t with
Expand All @@ -1395,30 +1401,33 @@ theorem LamEquiv.ofBetaBounded
match wf with
| .ofLam _ wf => apply LamEquiv.ofLam; apply IH wf
case app s fn arg =>
dsimp;
have ⟨_, ⟨wfhbb, _⟩⟩ := LamEquiv.ofHeadBetaBounded (n:=n) wf
apply LamEquiv.trans (LamEquiv.ofHeadBetaBounded (n:=n) wf)
apply LamEquiv.trans (LamEquiv.eq wfhbb (LamTerm.appFn_appArg_eq _))
let masterArr := (LamTerm.getAppArgs (LamTerm.headBetaBounded n (.app s fn arg))).map (fun (s, arg) => (s, arg, arg.betaBounded n))
have eq₁ : (LamTerm.getAppArgs (LamTerm.headBetaBounded n (.app s fn arg))) = masterArr.map (fun (s, arg₁, _) => (s, arg₁)) := by
dsimp; rw [List.map_map]; rw [List.map_equiv _ id, List.map_id]
intro x; cases x; rfl
have eq₂ : List.map
(fun x => (x.fst, LamTerm.betaBounded n x.snd))
(LamTerm.getAppArgs (LamTerm.headBetaBounded n (.app s fn arg))) = masterArr.map (fun (s, _, arg₂) => (s, arg₂)) := by
dsimp; rw [List.map_map]; apply List.map_equiv;
intro x; cases x; rfl
rw [eq₂, eq₁]; have ⟨fnTy, wfFn⟩ := wfhbb.getAppFn
apply LamEquiv.congrs (fnTy:=fnTy)
case wfApp => rw [← eq₁, ← LamTerm.appFn_appArg_eq]; exact wfhbb
case hFn => apply LamEquiv.refl wfFn
case hArgs =>
dsimp;
apply HList.toMapTy; dsimp [Function.comp]
apply HList.map
(β:=fun (s, t) => LamWF lval.toLamTyVal ⟨lctx, t, s⟩)
(fun (s, t) => @IH lctx t s)
apply LamWF.getAppArgs wfhbb
cases (LamTerm.app s fn arg).isHeadBetaTarget <;> dsimp
case true =>
apply LamEquiv.trans (LamEquiv.ofHeadBetaBounded (n:=n) wf)
have ⟨_, ⟨wfhbb, _⟩⟩ := LamEquiv.ofHeadBetaBounded (n:=n) wf
apply IH wfhbb
case false =>
apply LamEquiv.trans (LamEquiv.eq wf (LamTerm.appFn_appArg_eq _))
let masterArr := (LamTerm.getAppArgs (.app s fn arg)).map (fun (s, arg) => (s, arg, arg.betaBounded n))
have eq₁ : (LamTerm.getAppArgs (.app s fn arg)) = masterArr.map (fun (s, arg₁, _) => (s, arg₁)) := by
dsimp; rw [List.map_map]; rw [List.map_equiv _ id, List.map_id]
intro x; cases x; rfl
have eq₂ : List.map
(fun x => (x.fst, LamTerm.betaBounded n x.snd))
(LamTerm.getAppArgs (.app s fn arg)) = masterArr.map (fun (s, _, arg₂) => (s, arg₂)) := by
dsimp; rw [List.map_map]; apply List.map_equiv;
intro x; cases x; rfl
rw [eq₂, eq₁]; have ⟨fnTy, wfFn⟩ := wf.getAppFn
apply LamEquiv.congrs (fnTy:=fnTy)
case wfApp => rw [← eq₁, ← LamTerm.appFn_appArg_eq]; exact wf
case hFn => apply LamEquiv.refl wfFn
case hArgs =>
dsimp;
apply HList.toMapTy; dsimp [Function.comp]
apply HList.map
(β:=fun (s, t) => LamWF lval.toLamTyVal ⟨lctx, t, s⟩)
(fun (s, t) => @IH lctx t s)
apply LamWF.getAppArgs wf

theorem LamThmEquiv.ofBetaBounded (wf : LamThmWF lval lctx rty t) :
LamThmEquiv lval lctx rty t (t.betaBounded n) := fun lctx => LamEquiv.ofBetaBounded (wf lctx)
Expand Down
8 changes: 4 additions & 4 deletions Auto/Embedding/LamInhReasoning.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ namespace Auto.Embedding.Lam
def Inhabitation.subsumeQuick (s₁ s₂ : LamSort) : Bool :=
let s₁args := s₁.getArgTys
let s₁res := s₁.getResTy
let s₂args := HashSet.empty.insertMany s₂.getArgTys
let s₂args := Std.HashSet.empty.insertMany s₂.getArgTys
let s₂res := s₂.getResTy
s₂args.contains s₂res || (s₁res == s₂res && (
s₁args.all (fun arg => s₂args.contains arg)
))

/-- Run a quick test on whether the inhabitation of `s` is trivial -/
def Inhabitation.trivialQuick := go HashSet.empty
where go (argTys : HashSet LamSort) (s : LamSort) : Bool :=
def Inhabitation.trivialQuick := go Std.HashSet.empty
where go (argTys : Std.HashSet LamSort) (s : LamSort) : Bool :=
match s with
| .func argTy resTy =>
argTys.contains s || go (argTys.insert argTy) resTy
| _ => argTys.contains s

end Auto.Embedding.Lam
end Auto.Embedding.Lam
2 changes: 1 addition & 1 deletion Auto/Embedding/LamSystem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ theorem LamTerm.rwGenAllWith_lam : rwGenAllWith conv rty (.lam s body) =
| .none =>
match rty with
| .func _ resTy => (rwGenAllWith conv resTy body).bind (LamTerm.lam s ·)
| _ => .none := by delta rwGenAllWith; simp only
| _ => .none := by cases rty <;> simp [rwGenAllWith]

theorem LamTerm.rwGenAllWith_app : rwGenAllWith conv rty (.app s fn arg) =
match conv rty (.app s fn arg) with
Expand Down
24 changes: 12 additions & 12 deletions Auto/Embedding/LamTermInterp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ theorem LamTerm.lamCheck?Eq'_bvar
(h : lctx.get? n = .some ⟨s, val⟩) :
LamTerm.lamCheck?Eq' lval lctx (.bvar n) s := by
dsimp [lamCheck?Eq', lamCheck?]; have ⟨hlt, _⟩ := List.get?_eq_some.mp h
rw [pushLCtxs_lt hlt, List.getD_eq_get?, h]; rfl
rw [pushLCtxs_lt hlt, List.getD_eq_getElem?_getD, ← List.get?_eq_getElem?, h]; rfl

theorem LamTerm.lamCheck?Eq'_lam
(h : LamTerm.lamCheck?Eq' lval (⟨argTy, val⟩ :: lctx) body s) :
Expand Down Expand Up @@ -201,7 +201,7 @@ theorem LamTerm.interpEq_bvar
(s : LamSort) (val : s.interp lval.tyVal) (h : lctx.get? n = .some ⟨s, val⟩) :
LamTerm.interpEq lval lctx (.bvar n) val := by
dsimp [interpEq, interp]; have ⟨hlt, _⟩ := List.get?_eq_some.mp h
rw [pushLCtxs_lt hlt, List.getD_eq_get?, h]; rfl
rw [pushLCtxs_lt hlt, List.getD_eq_getElem?_getD, ← List.get?_eq_getElem?, h]; rfl

theorem LamTerm.interpEq_lam
(lval : LamValuation) (lctx : List ((s : LamSort) × s.interp lval.tyVal))
Expand Down Expand Up @@ -243,7 +243,7 @@ namespace Interp

structure State where
sortFVars : Array FVarId := #[]
sortMap : HashMap LamSort FVarId := {}
sortMap : Std.HashMap LamSort FVarId := {}
-- Let `n := lctxTyRev.size`
-- Reversed
lctxTyRev : Array LamSort := #[]
Expand All @@ -253,11 +253,11 @@ structure State where
-- Required : `lctxTyDrop[i] ≝ lctxTyRev[:(i+1)].data ≝ lctxTerm[(n-i-1):]`
lctxTyDrop : Array FVarId := #[]
-- Required : `tyEqFact[i][j] : lctxTy[i:].drop j = lctxTy[i+j:]`
typeEqFact : HashMap (Nat × Nat) FVarId := {}
typeEqFact : Std.HashMap (Nat × Nat) FVarId := {}
-- Required : `lctxTermDrop[i] ≝ lctxTermRev[:(i+1)].data ≝ lctxTerm[(n-i-1):]`
lctxTermDrop : Array FVarId := #[]
-- Required : `termEqFact[i][j] : lctxTerm[i:].drop j = lctxTerm[i+j:]`
termEqFact : HashMap (Nat × Nat) FVarId := {}
termEqFact : Std.HashMap (Nat × Nat) FVarId := {}
-- Required : `lctxCon[i] : lctxTerm[i].map Sigma.snd = lctxTy[i]`
lctxCon : Array FVarId := #[]

Expand All @@ -268,10 +268,10 @@ abbrev InterpM := StateRefT State MetaState.MetaStateM
def getLCtxTy! (idx : Nat) : InterpM LamSort := do
let lctxTyRev ← getLctxTyRev
if idx ≥ lctxTyRev.size then
throwError "getLCtxTy! :: Index out of bound"
throwError "{decl_name%} :: Index out of bound"
match lctxTyRev[idx]? with
| .some s => return s
| .none => throwError "getLCtxTy! :: Unexpected error"
| .none => throwError "{decl_name%} :: Unexpected error"

/--
Turning a sort into `fvar` in a hash-consing manner
Expand All @@ -282,7 +282,7 @@ def getLCtxTy! (idx : Nat) : InterpM LamSort := do
def sort2FVarId (s : LamSort) : InterpM FVarId := do
let sortMap ← getSortMap
let userName := (`interpsf).appendIndexAfter (← getSortMap).size
match sortMap.find? s with
match sortMap.get? s with
| .some id => return id
| .none =>
match s with
Expand All @@ -303,7 +303,7 @@ def collectSortFor (ltv : LamTyVal) : LamTerm → InterpM LamSort
| .atom n => do
let _ ← sort2FVarId (ltv.lamVarTy n)
return ltv.lamVarTy n
| .etom _ => throwError "collectSortFor :: etoms should not occur here"
| .etom _ => throwError "{decl_name%} :: etoms should not occur here"
| .base b => do
let s := b.lamCheck ltv
let _ ← sort2FVarId s
Expand All @@ -325,8 +325,8 @@ def collectSortFor (ltv : LamTyVal) : LamTerm → InterpM LamSort
if argTy' == argTy && argTy' == s then
return resTy
else
throwError "collectSortFor :: Application type mismatch"
| _ => throwError "collectSortFor :: Malformed application"
throwError "{decl_name%} :: Application type mismatch"
| _ => throwError "{decl_name%} :: Malformed application"
where withLCtxTy {α : Type} (s : LamSort) (k : InterpM α) : InterpM α := do
let lctxTyRev ← getLctxTyRev
setLctxTyRev (lctxTyRev.push s)
Expand All @@ -336,4 +336,4 @@ where withLCtxTy {α : Type} (s : LamSort) (k : InterpM α) : InterpM α := do

end Interp

end Auto.Embedding.Lam
end Auto.Embedding.Lam
4 changes: 2 additions & 2 deletions Auto/Embedding/Lift.lean
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def imulLift.{u} (m n : GLift.{1, u} Int) :=
GLift.up (Int.mul m.down n.down)

def idivLift.{u} (m n : GLift.{1, u} Int) :=
GLift.up (Int.div m.down n.down)
GLift.up (Int.tdiv m.down n.down)

def imodLift.{u} (m n : GLift.{1, u} Int) :=
GLift.up (Int.mod m.down n.down)
GLift.up (Int.tmod m.down n.down)

def iedivLift.{u} (m n : GLift.{1, u} Int) :=
GLift.up (Int.ediv m.down n.down)
Expand Down
Loading

0 comments on commit 3357e49

Please sign in to comment.