Skip to content

Commit

Permalink
WIP: refactor: rephrase the register frame-condition precondition as …
Browse files Browse the repository at this point in the history
…non-membership in a list of modified registers

Instead of a sequence of individual register non-equality pre-conditions
  • Loading branch information
alexkeizer committed Oct 16, 2024
1 parent 82b2ce1 commit 4405b76
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 49 deletions.
44 changes: 44 additions & 0 deletions Tactics/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,50 @@ def Lean.Expr.eqReadField? (e : Expr) : Option (Expr × Expr × Expr) := do
/-- Return the expression for `Memory` -/
def mkMemory : Expr := mkConst ``Memory

/-- Return `<x> ∈ <xs>`, given expressions `α : Type u`, `x : α`
and `xs : List α` -/
def mkListMem (u : Level) (α x xs : Expr) : Expr :=
let list := mkApp (.const ``List [u]) α
let inst := mkApp (.const ``List.instMembership [u]) α
mkApp5 (.const ``Membership.mem [u,u]) α list inst xs x

/-- Return `<f> ∈ <fs>`, given expressions `f : StateField`
and `fs : List StateField` -/
def mkStateFieldListMem (x xs : Expr) : Expr :=
mkListMem 0 (mkConst ``StateField) x xs

/-- Return a proof of type `<x> ∈ <xs>`, given expressions `α : Type u`, `x : α`
and `xs : List α`, assuming that `xs` is of the form `[x₁, x₂, ⋯, xₙ]` and that
`x` is *syntactically* equal to one of the elements `xᵢ`.
Returns `none` if the proof could not be constructed -/
partial def mkListMemProof (u : Level) (α x xs : Expr) : Option Expr := do
let_expr List.cons _α hd tl := xs | none
if hd == x then
mkApp3 (.const ``List.Mem.head [u]) α x tl
else
mkApp5 (.const ``List.Mem.tail [u]) α x hd tl (← mkListMemProof u α x tl)

/-- auxiliary lemma for use in `mkNeProofOfMemAndNotMem` -/
private theorem mkNeProofOfMemAndNotMem.aux.{u}
{α : Type u} {x y : α} {xs : List α}
(h_mem : x ∈ xs) (h_not_mem : y ∉ xs) :
x ≠ y := by
rintro rfl; contradiction

/-- Return a proof of type `<x> ≠ <y>`, given proofs
`memProof : <x> ∈ <xs>` and `nonMemProof : <y> ∉ <xs>`, assuming that
`α : Type u`, `x y : α`, and `xs : List α` -/
@[inline] def mkNeProofOfMemAndNotMem (u : Level) (α x y xs memProof nonMemProof : Expr) :
Expr :=
mkApp6 (.const ``mkNeProofOfMemAndNotMem.aux [u]) α x y xs
memProof nonMemProof

/-- Return a proof of `<x> ∉ <xs>`, given `notMemProof : <x> ∉ <y> :: xs`,
assuming that `α : Type u`, `x y : α`, and `xs : List α` -/
@[inline] def mkNotMemOfNotMemCons (u : Level) (α x y xs notMemProof : Expr) : Expr :=
mkApp5 (.const ``List.not_mem_of_not_mem_cons [u]) α x y xs notMemProof

/-! ## Expr Helpers -/

/-- Throw an error if `e` is not of type `expectedType` -/
Expand Down
124 changes: 75 additions & 49 deletions Tactics/Sym/AxEffects.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ structure AxEffects where
fields : Std.HashMap StateField AxEffects.FieldEffect
/-- An expression that contains the proof of:
```lean
∀ (f : StateField), f ≠ <f₁> → ⋯ → f ≠ <fₙ>
∀ (f : StateField), f ∉ [f₁, ⋯, fₙ]
r f <currentState> = r f <initialState> `
```
where `f₁, ⋯, fₙ` are the keys of `fields`
Expand Down Expand Up @@ -125,9 +125,13 @@ def initial (state : Expr) : AxEffects where
currentState := state
fields := .empty
nonEffectProof :=
-- `fun f => rfl`
mkLambda `f .default (mkConst ``StateField) <|
mkEqReflArmState <| mkApp2 (mkConst ``r) (.bvar 0) state
-- `fun (f : StateField) (h : f ∉ []) => rfl`
let SF := mkConst ``StateField
mkLambda `f .default SF <|
let f_nin_nil := -- `f ∉ []`
mkNot <| mkStateFieldListMem (.bvar 0) (mkApp (.const ``List.nil [0]) SF)
mkLambda `h .default f_nin_nil <|
mkEqReflArmState <| mkApp2 (mkConst ``r) (.bvar 1) state
memoryEffects := .initial state
programProof :=
-- `rfl`
Expand Down Expand Up @@ -192,21 +196,18 @@ private def rewriteType (e eq : Expr) : MetaM Expr := do
by constructing an application of `eff.nonEffectProof` -/
partial def mkAppNonEffect (eff : AxEffects) (field : Expr) : MetaM Expr := do
let msg := m!"constructing application of non-effects proof"
withTraceNode msg (tag := "mkAppNonEffect") <| do
trace[Tactic.sym] "nonEffectProof: {eff.nonEffectProof}"
Sym.withTraceNode msg (tag := "mkAppNonEffect") <| do
Sym.traceLargeMsg "nonEffectProof" m!"{eff.nonEffectProof}"

let nonEffectProof := mkApp eff.nonEffectProof field
let typeOfNonEffects ← inferType nonEffectProof
forallTelescope typeOfNonEffects <| fun fvars _ => do
trace[Tactic.sym] "hypotheses of nonEffectProof: {fvars}"
let lctx ← getLCtx
let pre ← fvars.mapM fun expr => do
let ty := lctx.get! expr.fvarId! |>.type
mkDecideProof ty
let (Expr.forallE _name binderType _body _) := typeOfNonEffects
| throwError m!"internal error: expected a forall, found:\n {typeOfNonEffects}"
trace[Tactic.sym] "non-effect precondition: {binderType}"

let nonEffectProof := mkAppN nonEffectProof pre
trace[Tactic.sym] "constructed: {nonEffectProof}"
return nonEffectProof
let nonMemProof ← mkDecideProof binderType
Sym.traceLargeMsg "constructed proof of precondition" m!"{nonMemProof}"
return mkApp nonEffectProof nonMemProof

/-- Get the value for a field, if one is stored in `eff.fields`,
or assemble an instantiation of the non-effects proof otherwise -/
Expand Down Expand Up @@ -324,8 +325,8 @@ private def update_w (eff : AxEffects) (fld val : Expr) :
Sym.withTraceNode m!"processing: w {fld} {val} …" (tag := "updateWrite") <| do
let rField ← reflectStateField fld

-- Update all other fields
let fields
-- ### Update all other fields
let otherFields
eff.fields.toList.filterMapM fun ⟨fld', {value, proof}⟩ => do
if fld' ≠ rField then
let proof : Expr ← do
Expand All @@ -340,55 +341,80 @@ private def update_w (eff : AxEffects) (fld val : Expr) :
else
return none

-- Update the main field
-- ### Update the main field
let newField : FieldEffect := {
value := val
proof :=
-- `r_of_w_same <fld> <val> <currentState>`
mkApp3 (mkConst ``r_of_w_same) fld val eff.currentState
}
let fields := (rField, newField) :: fields
let fields := (rField, newField) :: otherFields

-- ### Update the non-effects proof
let nonEffectProof ← lambdaBoundedTelescope eff.nonEffectProof 2 fun args proof => do
let [f /- : StateField -/, nonMemHyp /- : f ∉ ?modifiedFields -/] := args.toList
| throwError "internal error: expected exactly two arguments, found:\
{args}\n\nIn non-effect proof:\n {eff.nonEffectProof}"

let modifiedFields : Expr /- : List StateField -/do
let ty ← inferType nonMemHyp
let_expr Not ty := ty
| let m ← mkFreshExprMVar none
throwError "interal error: expected f ∉ {m}, found:\n {ty}"
let_expr Membership.mem _α _γ _inst _f fields := ty
| let m ← mkFreshExprMVar none
throwError "interal error: expected f ∈ {m}, found:\n {ty}"
pure fields

let newProofOfNe := fun oldProof neProof /- : `<f> ≠ <fld>` -/ => do
let r_of_w :=
mkApp5 (mkConst ``r_of_w_different) f fld val eff.currentState neProof
mkEqTrans r_of_w oldProof

let h? := mkListMemProof 0 (mkConst ``StateField) fld modifiedFields
if let some h /- : `<fld> ∈ <modifiedFields>` -/ := h? then
-- `fld` was previously modified

let neProof := -- : `<f> ≠ <fld>`
mkNeProofOfMemAndNotMem 0 (mkConst ``StateField) f fld modifiedFields h nonMemHyp
-- Adjust the proof
let proof ← newProofOfNe proof neProof
-- And abstract `f` and `nonMemHyp` again, without changing their types
mkLambdaFVars #[f, nonMemHyp] proof

-- Update the non-effects proof
let nonEffectProof ← lambdaTelescope eff.nonEffectProof fun args proof => do
let f := args[0]!
else
-- `fld` was *not* previously modified, so we need to change the type of
-- the `nonMemHyp` precondition

/- First, assume we have a proof `h_neq : <f> ≠ <fld>`, and use that
to compute the new `nonEffectProof` -/
let k := fun args h_neq => do
let r_of_w := mkApp5 (mkConst ``r_of_w_different)
f fld val eff.currentState h_neq
let proof ← mkEqTrans r_of_w proof
mkLambdaFVars args proof
-- ^^ `fun f ... => Eq.trans (r_of_w_different ... <h_neq>) <proof>`

/- Then, determine `h_neq` so that we can pass it to `k`.
Notice how we have to modify the environment, to add `h_neq` as a new local
hypothesis if it wan't present yet, but only in some branches.
This is why we had to define `k` as a monadic continuation,
so we can nest `k` under a `withLocalDeclD` -/
let h_neq_type := mkApp3 (.const ``Ne [1]) (mkConst ``StateField) f fld
let h_neq? ← args.findM? fun h => do
let hTy ← inferType h
return hTy == h_neq_type
match h_neq? with
| some h_neq => k args h_neq
| none =>
let name := Name.mkSimple s!"h_neq_{rField}"
withLocalDeclD name h_neq_type fun h_neq =>
k (args.push h_neq) h_neq
let newModifiedFields := -- `<fld> :: <modifiedFields>`
mkApp3 (mkConst ``List.cons) (mkConst ``StateField) fld modifiedFields

-- Update the memory effects
withLocalDeclD `h (mkNot <| mkStateFieldListMem f newModifiedFields) fun newNonMemHyp => do
let proof := proof.replaceFVar nonMemHyp <|
mkNotMemOfNotMemCons 0 (mkConst ``StateField) f fld
modifiedFields newNonMemHyp

let h := -- : `<fld> ∈ <newModifiedFields>`
mkApp3 (.const ``List.Mem.head [0]) (mkConst ``StateField) fld modifiedFields
let neProof := -- : `<f> ≠ <fld>`
mkNeProofOfMemAndNotMem 0 (mkConst ``StateField) f fld modifiedFields h nonMemHyp

-- Adjust the proof
let proof ← newProofOfNe proof neProof
-- And abstract `f` and `newNonMemHyp`
mkLambdaFVars #[f, newNonMemHyp] proof

-- ### Update the memory effects
let memoryEffects ← eff.memoryEffects.updateWrite eff.currentState fld val

-- Update the program proof
-- ### Update the program proof
let programProof ←
-- `Eq.trans (w_program ...) <programProof>`
mkEqTrans
(mkAppN (mkConst ``w_program) #[fld, val, eff.currentState])
eff.programProof

-- Update the stack alignment proof
-- ### Update the stack alignment proof
let mut sideConditions := eff.sideConditions
let mut stackAlignmentProof? := eff.stackAlignmentProof?
if let some proof := stackAlignmentProof? then
Expand Down

0 comments on commit 4405b76

Please sign in to comment.