Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Rewrite simp_mem to build a new expression, thereby localizing the effects of rewrites [6/?] #237

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Arm/Memory/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ initialize Lean.registerTraceClass `simp_mem
/-- Provides extremely verbose tracing for the `simp_mem` tactic. -/
initialize Lean.registerTraceClass `simp_mem.info

/-- Provides even more verbose tracing for the `simp_mem` tactic. -/
initialize Lean.registerTraceClass `simp_mem.expr_walk_trace

/-- Provides extremely verbose tracing for the `simp_mem` tactic. -/
initialize Lean.registerTraceClass `Tactic.address_normalization

Expand Down
39 changes: 24 additions & 15 deletions Arm/Memory/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -621,17 +621,23 @@ end Hypotheses

section Simplify

def rewriteWithEquality (rw : Expr) (msg : MessageData) : TacticM Unit := do
structure SimplifyResult where
eNew : Expr
eqProof : Expr
deriving Inhabited

/-- Rewrite expression `e` with rewrite `rw` -/
def rewriteWithEquality (rw : Expr) (e : Expr) (msg : MessageData) : TacticM SimplifyResult := do
TacticM.withTraceNode' msg do
withMainContext do
TacticM.traceLargeMsg m!"rewrite" m!"{← inferType rw}"
-- TacticM.traceLargeMsg m!"rewrite" m!"{← inferType rw}"
let goal ← getMainGoal
let result ← goal.rewrite (← getMainTarget) rw
let mvarId' ← (← getMainGoal).replaceTargetEq result.eNew result.eqProof
trace[simp_mem.info] "{checkEmoji} rewritten goal {mvarId'}"
unless result.mvarIds == [] do
throwError m!"{crossEmoji} internal error: expected rewrite to produce no side conditions. Produced {result.mvarIds}"
replaceMainGoal [mvarId']
let result ← goal.rewrite e rw
-- let mvarId' ← (← getMainGoal).replaceTargetEq result.eNew result.eqProof
trace[simp_mem.info] "{checkEmoji} rewritten goal {e}"
check result.eNew
check result.eqProof
return { eNew := result.eNew, eqProof := result.eqProof }

/--
info: Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate' {x : BitVec 64} {xn : Nat} {y : BitVec 64} {yn : Nat}
Expand All @@ -644,15 +650,16 @@ info: Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate' {x : BitVec 6
using `Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate'`. -/
def MemSeparateProof.rewriteReadOfSeparatedWrite
(er : ReadBytesExpr) (ew : WriteBytesExpr)
(separate : MemSeparateProof { sa := er.span, sb := ew.span }) : TacticM Unit := do
(separate : MemSeparateProof { sa := er.span, sb := ew.span })
(e : Expr) : TacticM SimplifyResult := do
let call :=
mkAppN (Expr.const ``Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate' [])
#[er.span.base, er.span.n,
ew.span.base, ew.span.n,
ew.mem,
separate.h,
ew.val]
rewriteWithEquality call m!"rewriting read({er})⟂write({ew})"
rewriteWithEquality call e m!"rewriting read({er})⟂write({ew})"

/--
info: Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset' {bn : Nat} {b : BitVec 64} {val : BitVec (bn * 8)}
Expand All @@ -665,15 +672,16 @@ def MemSubsetProof.rewriteReadOfSubsetRead
(er : ReadBytesExpr)
(hread : ReadBytesEqProof)
(hsubset : MemSubsetProof { sa := er.span, sb := hread.read.span })
: TacticM Unit := do
(e : Expr)
: TacticM SimplifyResult := do
let call := mkAppN (Expr.const ``Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset' [])
#[hread.read.span.n, hread.read.span.base,
hread.val,
er.span.base, er.span.n,
er.mem,
hread.h,
hsubset.h]
rewriteWithEquality call m!"rewriting read({er})⊆read({hread.read})"
rewriteWithEquality call e m!"rewriting read({er})⊆read({hread.read})"

/--
info: Memory.read_bytes_write_bytes_eq_of_mem_subset' {x : BitVec 64} {xn : Nat} {y : BitVec 64} {yn : Nat} {mem : Memory}
Expand All @@ -684,15 +692,16 @@ info: Memory.read_bytes_write_bytes_eq_of_mem_subset' {x : BitVec 64} {xn : Nat}

def MemSubsetProof.rewriteReadOfSubsetWrite
(er : ReadBytesExpr) (ew : WriteBytesExpr)
(hsubset : MemSubsetProof { sa := er.span, sb := ew.span }) :
TacticM Unit := do
(hsubset : MemSubsetProof { sa := er.span, sb := ew.span })
(e : Expr) :
TacticM SimplifyResult := do
let call := mkAppN (Expr.const ``Memory.read_bytes_write_bytes_eq_of_mem_subset' [])
#[er.span.base, er.span.n,
ew.span.base, ew.span.n,
ew.mem,
hsubset.h,
ew.val]
rewriteWithEquality call m!"rewriting read({er})⊆write({ew})"
rewriteWithEquality call e m!"rewriting read({er})⊆write({ew})"

end Simplify

Expand Down
149 changes: 89 additions & 60 deletions Arm/Memory/SeparateAutomation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -191,81 +191,104 @@ def getConfig : SimpMemM SimpMemConfig := do
/-- info: state_value (fld : StateField) : Type -/
#guard_msgs in #check state_value


def SimpMemM.findOverlappingReadHypAux (hyps : Array Memory.Hypothesis) (er : ReadBytesExpr) (hReadEq : ReadBytesEqProof) :
SimpMemM <| Option (MemSubsetProof { sa := er.span, sb := hReadEq.read.span }) := do
withTraceNode m!"{processingEmoji} ... ⊆ {hReadEq.read.span} ? " do
-- the read we are analyzing should be a subset of the hypothesis
let subset := (MemSubsetProp.mk er.span hReadEq.read.span)
let some hSubsetProof ← proveWithOmega? subset (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps
| return none
return some (hSubsetProof)

def SimpMemM.findOverlappingReadHyp (hyps : Array Memory.Hypothesis) (er : ReadBytesExpr) :
SimpMemM <| Option (Σ (hread : ReadBytesEqProof), MemSubsetProof { sa := er.span, sb := hread.read.span }) := do
for hyp in hyps do
let Hypothesis.read_eq hReadEq := hyp
| continue
let some subsetProof ← SimpMemM.findOverlappingReadHypAux hyps er hReadEq
| continue
return some ⟨hReadEq, subsetProof⟩
return none


mutual

/--
Pattern match for memory patterns, and simplify them.
Close memory side conditions with `simplifyGoal`.
Returns if progress was made.
-/
partial def SimpMemM.simplifyExpr (e : Expr) (hyps : Array Memory.Hypothesis) : SimpMemM Unit := do
partial def SimpMemM.simplifyExpr (e : Expr) (hyps : Array Memory.Hypothesis) : SimpMemM (Option SimplifyResult) := do
consumeRewriteFuel
if ← outofRewriteFuel? then
trace[simp_mem.info] "out of fuel for rewriting, stopping."

let e := e.consumeMData

if e.isSort then
trace[simp_mem.info] "skipping sort '{e}'."

if let .some er := ReadBytesExpr.ofExpr? e then
if let .some ew := WriteBytesExpr.ofExpr? er.mem then
trace[simp_mem.info] "{checkEmoji} Found read of write."
trace[simp_mem.info] "read: {er}"
trace[simp_mem.info] "write: {ew}"
trace[simp_mem.info] "{processingEmoji} read({er.span})⟂/⊆write({ew.span})"

let separate := MemSeparateProp.mk er.span ew.span
let subset := MemSubsetProp.mk er.span ew.span
if let .some separateProof ← proveWithOmega? separate (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then do
trace[simp_mem.info] "{checkEmoji} {separate}"
MemSeparateProof.rewriteReadOfSeparatedWrite er ew separateProof
setChanged
else if let .some subsetProof ← proveWithOmega? subset (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then do
trace[simp_mem.info] "{checkEmoji} {subset}"
MemSubsetProof.rewriteReadOfSubsetWrite er ew subsetProof
setChanged
else
trace[simp_mem.info] "{crossEmoji} Could not prove {er.span} ⟂/⊆ {ew.span}"
let .some er := ReadBytesExpr.ofExpr? e
| SimpMemM.walkExpr e hyps

if let .some ew := WriteBytesExpr.ofExpr? er.mem then
trace[simp_mem.info] "{checkEmoji} Found read of write."
trace[simp_mem.info] "read: {er}"
trace[simp_mem.info] "write: {ew}"
trace[simp_mem.info] "{processingEmoji} read({er.span})⟂/⊆write({ew.span})"

let separate := MemSeparateProp.mk er.span ew.span
let subset := MemSubsetProp.mk er.span ew.span
if let .some separateProof ← proveWithOmega? separate (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then do
trace[simp_mem.info] "{checkEmoji} {separate}"
let result ← MemSeparateProof.rewriteReadOfSeparatedWrite er ew separateProof e
setChanged
return result
else if let .some subsetProof ← proveWithOmega? subset (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then do
trace[simp_mem.info] "{checkEmoji} {subset}"
let result ← MemSubsetProof.rewriteReadOfSubsetWrite er ew subsetProof e
setChanged
return result
else
-- read
trace[simp_mem.info] "{checkEmoji} Found read {er}."
-- TODO: we don't need a separate `subset` branch for the writes: instead, for the write,
-- we can add the theorem that `(write region).read = write val`.
-- Then this generic theory will take care of it.
withTraceNode m!"Searching for overlapping read {er.span}." do
for hyp in hyps do
if let Hypothesis.read_eq hReadEq := hyp then do
withTraceNode m!"{processingEmoji} ... ⊆ {hReadEq.read.span} ? " do
-- the read we are analyzing should be a subset of the hypothesis
let subset := (MemSubsetProp.mk er.span hReadEq.read.span)
if let some hSubsetProof ← proveWithOmega? subset (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then
trace[simp_mem.info] "{checkEmoji} ... ⊆ {hReadEq.read.span}"
MemSubsetProof.rewriteReadOfSubsetRead er hReadEq hSubsetProof
setChanged
else
trace[simp_mem.info] "{crossEmoji} ... ⊊ {hReadEq.read.span}"
trace[simp_mem.info] "{crossEmoji} Could not prove {er.span} ⟂/⊆ {ew.span}"
SimpMemM.walkExpr e hyps
else
if e.isForall then
Lean.Meta.forallTelescope e fun xs b => do
for x in xs do
SimpMemM.simplifyExpr x hyps
-- we may have a hypothesis like
-- ∀ (x : read_mem (read_mem_bytes ...) ... = out).
-- we want to simplify the *type* of x.
SimpMemM.simplifyExpr (← inferType x) hyps
SimpMemM.simplifyExpr b hyps
else if e.isLambda then
Lean.Meta.lambdaTelescope e fun xs b => do
for x in xs do
SimpMemM.simplifyExpr x hyps
SimpMemM.simplifyExpr (← inferType x) hyps
SimpMemM.simplifyExpr b hyps
else
-- check if we have expressions.
match e with
| .app f x =>
SimpMemM.simplifyExpr f hyps
SimpMemM.simplifyExpr x hyps
| _ => return ()
-- read
trace[simp_mem.info] "{checkEmoji} Found read {er}."
-- TODO: we don't need a separate `subset` branch for the writes: instead, for the write,
-- we can add the theorem that `(write region).read = write val`.
-- Then this generic theory will take care of it.
withTraceNode m!"Searching for overlapping read {er.span}." do
let some ⟨hReadEq, hSubsetProof⟩ ← findOverlappingReadHyp hyps er
| SimpMemM.walkExpr e hyps
let out ← MemSubsetProof.rewriteReadOfSubsetRead er hReadEq hSubsetProof e
setChanged
return out

partial def SimpMemM.walkExpr (e : Expr) (hyps : Array Memory.Hypothesis) : SimpMemM (Option SimplifyResult) := do
withTraceNode (traceClass := `simp_mem.expr_walk_trace) m!"🎯 {e} | kind:{Expr.ctorName e}" (collapsed := false) do
let e ← instantiateMVars e
match e.consumeMData with
| .app f x =>
let fResult ← SimpMemM.simplifyExpr f hyps
let xResult ← SimpMemM.simplifyExpr x hyps
-- return (← SimplifyResult.default e)
match (fResult, xResult) with
| (none, some xResult) =>
let outResult ← mkCongrArg f xResult.eqProof
return some ⟨e.updateApp! f xResult.eNew, outResult⟩
| (some fResult, none) =>
let outResult ← mkCongrFun fResult.eqProof x
return some ⟨e.updateApp! fResult.eNew x, outResult⟩
| (some fResult, some xResult) =>
let outResult ← mkCongr fResult.eqProof xResult.eqProof
return some ⟨e.updateApp! fResult.eNew xResult.eNew, outResult⟩
| _ => return none
-- let outResult ← mkCongr fResult.eqProof xResult.eqProof
-- -- | I think I see where the problem is. here, I should have updated with the other result.
-- return ⟨e.updateApp! f x, outResult⟩
bollu marked this conversation as resolved.
Show resolved Hide resolved
| _ => return none


/--
Expand All @@ -277,7 +300,13 @@ partial def SimpMemM.simplifyGoal (g : MVarId) (hyps : Array Memory.Hypothesis)
SimpMemM.withContext g do
let gt ← g.getType
withTraceNode m!"Simplifying goal." do
SimpMemM.simplifyExpr (← whnf gt) hyps
let some out ← SimpMemM.simplifyExpr gt hyps
| return ()
-- Note: this could impact performance, so delete this if it turns out to be a resource hog.
check out.eNew
check out.eqProof
bollu marked this conversation as resolved.
Show resolved Hide resolved
let newGoal ← (← getMainGoal).replaceTargetEq out.eNew out.eqProof
replaceMainGoal [newGoal]
end

/--
Expand Down
8 changes: 7 additions & 1 deletion Proofs/Experiments/MemoryAliasing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ theorem mem_automation_test_1
simp_mem
rfl


-- rfl

/-- info: 'mem_automation_test_1' depends on axioms: [propext, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms mem_automation_test_1

Expand Down Expand Up @@ -306,7 +309,7 @@ end ReadOverlappingWrite
/- We check that we correctly visit the expression tree, both for binders,
and for general walking. -/
namespace ExprVisitor

/-
/-- Check that we correctly go under binders -/
theorem test_quantified_1 {val : BitVec (16 * 8)}
(hlegal : mem_legal' 0 16) : ∀ (_irrelevant : Nat),
Expand All @@ -320,6 +323,7 @@ theorem test_quantified_1 {val : BitVec (16 * 8)}
info: 'ExprVisitor.test_quantified_1' depends on axioms: [propext, Classical.choice, Quot.sound]
-/
#guard_msgs in #print axioms test_quantified_1
-/

/-- Check that we correctly walk under applications. -/
theorem test_app_1 {val : BitVec (16 * 8)}
Expand All @@ -333,6 +337,7 @@ theorem test_app_1 {val : BitVec (16 * 8)}
/-- info: 'ExprVisitor.test_app_1' depends on axioms: [propext, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms test_app_1

/-
/--
Check that we correctly walk under applications (`f <walk inside>`)
and binders (`∀ f, <walk inside>`) simultaneously.
Expand Down Expand Up @@ -364,6 +369,7 @@ theorem test_quantified_app_2 {val : BitVec (16 * 8)}
rfl

end ExprVisitor
-/

namespace MathProperties

Expand Down
Loading