Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add fine grained control over mem_omega rewriting [7/?] #238

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 136 additions & 34 deletions Arm/Memory/MemOmega.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Lean
import Lean.Meta.Tactic.Rewrite
import Lean.Meta.Tactic.Rewrites
import Lean.Elab.Tactic.Conv
import Lean.Elab.Tactic.Simp
import Lean.Elab.Tactic.Conv.Basic
import Tactics.Simp
import Tactics.BvOmegaBench
Expand All @@ -27,6 +28,20 @@ open Lean Meta Elab Tactic Memory

namespace MemOmega

/--
A user given hypothesis for mem_omega, which we process as either a hypothesis (FVarId),
or a term that is added into the user's context.
-/
inductive UserHyp
| hyp : FVarId → UserHyp
| star : UserHyp
| excludeHyp : FVarId → UserHyp
| expr : Expr → UserHyp

namespace UserHyp
end UserHyp


structure Config where
/--
If true, then MemOmega will explode uses of pairwiseSeparate [mem₁, ... memₙ]
Expand All @@ -42,6 +57,11 @@ def Config.mkBang (c : Config) : Config :=
structure Context where
/-- User configurable options for `simp_mem`. -/
cfg : Config
/--
If we are using `mem_omega only [...]`, then we will have `some` plus the hyps.
If we are using `mem_omega`, then we will get `none`.
-/
userHyps? : Option (Array UserHyp)
/-- Cache of `bv_toNat` simp context. -/
bvToNatSimpCtx : Simp.Context
/-- Cache of `bv_toNat` simprocs. -/
Expand All @@ -50,7 +70,7 @@ structure Context where

namespace Context

def init (cfg : Config) : MetaM Context := do
def init (cfg : Config) (userHyps? : Option (Array UserHyp)) : MetaM Context := do
let (bvToNatSimpCtx, bvToNatSimprocs) ←
LNSymSimpContext
(config := {failIfUnchanged := false})
Expand All @@ -59,7 +79,7 @@ def init (cfg : Config) : MetaM Context := do
-- (thms := #[``mem_legal'.iff_omega, ``mem_subset'.iff_omega, ``mem_separate'.iff_omega])
(simp_attrs := #[`bv_toNat])
(useDefaultSimprocs := false)
return {cfg, bvToNatSimpCtx, bvToNatSimprocs}
return {cfg, bvToNatSimpCtx, bvToNatSimprocs, userHyps? }
end Context

abbrev MemOmegaM := (ReaderT Context MetaM)
Expand All @@ -68,67 +88,149 @@ namespace MemOmegaM
def run (ctx : Context) (x : MemOmegaM α) : MetaM α := ReaderT.run x ctx
end MemOmegaM

/-- Modify the set of hypotheses `hyp` based on the user hyp `hyp`. -/
bollu marked this conversation as resolved.
Show resolved Hide resolved
def mkKeepHypsOfUserHyp (g : MVarId) (set : Std.HashSet FVarId) (hyp : UserHyp) : MetaM <| Std.HashSet FVarId :=
match hyp with
| .hyp fvar => return set.insert fvar
| .excludeHyp fvar => return set.erase fvar
| .star => do
let allHyps ← g.getNondepPropHyps
return allHyps.foldl (init := set) (fun s fvar => s.insert fvar)
| .expr _e => return set

/--
Given the user hypotheses, build a more focusedd MVarId that contains only those hypotheses.
This makes `omega` focus only on those hypotheses, since omega by default crawls the entire goal state.

This is arguably a workaround to having to plumb the hypotheses through the full layers of code, but it works,
and should be a cheap solution.
-/
def mkGoalWithOnlyUserHyps (g : MVarId) (userHyps? : Option (Array UserHyp)) : MetaM <| MVarId :=
match userHyps? with
| none => pure g
| some userHyps => do
g.withContext do
let mut keepHyps : Std.HashSet FVarId ← userHyps.foldlM
(init := ∅)
(mkKeepHypsOfUserHyp g)
let hyps ← g.getNondepPropHyps
let mut g := g
for h in hyps do
if !keepHyps.contains h then
g ← g.withContext <| g.clear h
return g

def memOmega (g : MVarId) : MemOmegaM Unit := do
g.withContext do
/- We need to explode all pairwise separate hyps -/
let rawHyps ← getLocalHyps
let mut hyps := #[]
-- extract out structed values for all hyps.
for h in rawHyps do
hyps ← hypothesisOfExpr h hyps

-- only enable pairwise constraints if it is enabled.
let isPairwiseEnabled := (← readThe Context).cfg.explodePairwiseSeparate
hyps := hyps.filter (!·.isPairwiseSeparate || isPairwiseEnabled)

-- used specialized procedure that doesn't unfold everything for the easy case.
if ← closeMemSideCondition g (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs hyps then
return ()
else
-- in the bad case, just rip through everything.
let (_, g) ← Hypothesis.addOmegaFactsOfHyps g hyps.toList #[]

TacticM.withTraceNode' m!"Reducion to omega" do
try
TacticM.traceLargeMsg m!"goal (Note: can be large)" m!"{g}"
omega g (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs
trace[simp_mem.info] "{checkEmoji} `omega` succeeded."
catch e =>
trace[simp_mem.info] "{crossEmoji} `omega` failed with error:\n{e.toMessageData}"
let g ← mkGoalWithOnlyUserHyps g (← readThe Context).userHyps?
g.withContext do
let rawHyps ← getLocalHyps
let mut hyps := #[]
-- extract out structed values for all hyps.
for h in rawHyps do
hyps ← hypothesisOfExpr h hyps

-- only enable pairwise constraints if it is enabled.
let isPairwiseEnabled := (← readThe Context).cfg.explodePairwiseSeparate
hyps := hyps.filter (!·.isPairwiseSeparate || isPairwiseEnabled)

-- used specialized procedure that doesn't unfold everything for the easy case.
if ← closeMemSideCondition g (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs hyps then
return ()
else
-- in the bad case, just rip through everything.
let (_, g) ← Hypothesis.addOmegaFactsOfHyps g hyps.toList #[]

TacticM.withTraceNode' m!"Reducion to omega" do
try
TacticM.traceLargeMsg m!"goal (Note: can be large)" m!"{g}"
omega g (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs
trace[simp_mem.info] "{checkEmoji} `omega` succeeded."
catch e =>
trace[simp_mem.info] "{crossEmoji} `omega` failed with error:\n{e.toMessageData}"
throw e

/--
Allow elaboration of `MemOmegaConfig` arguments to tactics.
-/
declare_config_elab elabMemOmegaConfig MemOmega.Config

syntax userHyp := (&"-")? term <|> Parser.Tactic.locationWildcard

syntax memOmegaWith := &"with" "[" withoutPosition(userHyp,*,?) "]"

/--
The `mem_omega` tactic is a finishing tactic which is used to dispatch memory side conditions.
Broadly, the algorithm works as follows:
- It scans the set of hypotheses for `mem_separate`, `mem_subset`, and `mem_legal` hypotheses, and turns them into `omega` based information.
- It calls `omega` as a finishing tactic to close the current goal state.
- Cruicially, it **does not unfold** `pairwiseSeparate` constraints. We expect the user to do so. If they want `pairwiseSeparate` unfolded, then please use `mem_omega!`.
-/
syntax (name := mem_omega) "mem_omega" (Lean.Parser.Tactic.config)? : tactic
syntax (name := mem_omega) "mem_omega" (Lean.Parser.Tactic.config)? (memOmegaWith)? : tactic
bollu marked this conversation as resolved.
Show resolved Hide resolved

/--
The `mem_omega!` tactic is a finishing tactic, that is a more aggressive variant of `mem_omega`.
-/
syntax (name := mem_omega_bang) "mem_omega!" (Lean.Parser.Tactic.config)? : tactic
syntax (name := mem_omega_bang) "mem_omega!" (memOmegaWith)? : tactic

/--
build a `UserHyp` from the raw syntax.
This supports using fars, using CDot notation to partially apply theorems, and to use terms.

Adapted from Lean.Elab.Tactic.Rw.WithRWRulesSeq, Lean.Elab.Tactic.Simp.resolveSimpIdTheorem, Lean.Elab.Tactic.Simp.addSimpTheorem
-/
def UserHyp.ofTerm (t : TSyntax `term) : TacticM UserHyp := do
-- See if we can interpret `id` as a hypothesis first.
if let .some fvarId ← optional <| getFVarId t then
return .hyp fvarId
else if let some e ← Term.elabCDotFunctionAlias? t then
return .expr e
else
let e ← Term.elabTerm t none
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
let e ← instantiateMVars e
let e := e.eta
if e.hasMVar then
throwErrorAt t "found metavariables when elaborating rule, giving up."
return .expr e

/- Make a UserHyp from the raw syntax -/
open memOmega in
def UserHyp.ofSyntax (stx : TSyntax `MemOmega.userHyp) : TacticM UserHyp :=
bollu marked this conversation as resolved.
Show resolved Hide resolved
let arg := stx.raw[1]
if arg.getKind == ``Parser.Tactic.locationWildcard then
return .star
else
match stx with
| `(userHyp| $t:term) => UserHyp.ofTerm t
| `(userHyp| -$t:term) => do
if let .some fvarId ← optional <| getFVarId t then
return .excludeHyp fvarId
throwError "Cannot exclude non-hypothesis '{t}'."
| stx => do
throwError "Cannot parse user hypothesis '{stx}'."

-- Adapted from WithRWRulesSeq.
def elabMemOmegaWith : TSyntax ``MemOmega.memOmegaWith → TacticM (Array UserHyp)
| `(memOmegaWith| with [ $[ $rules],* ]) => do
rules.mapM UserHyp.ofSyntax
| _ => throwUnsupportedSyntax

open Lean.Parser.Tactic in
@[tactic mem_omega]
def evalMemOmega : Tactic := fun
| `(tactic| mem_omega $[$cfg]?) => do
def evalMemOmega : Tactic := fun
| `(tactic| mem_omega $[$cfg:config]? $[$v:memOmegaWith]?) => do
let cfg ← elabMemOmegaConfig (mkOptionalNode cfg)
let memOmegaRules? := ← v.mapM elabMemOmegaWith
liftMetaFinishingTactic fun g => do
memOmega g |>.run (← Context.init cfg)
memOmega g |>.run (← Context.init cfg memOmegaRules?)
| _ => throwUnsupportedSyntax

@[tactic mem_omega_bang]
def evalMemOmegaBang : Tactic := fun
| `(tactic| mem_omega! $[$cfg]?) => do
let cfg ← elabMemOmegaConfig (mkOptionalNode cfg)
liftMetaFinishingTactic fun g => do
memOmega g |>.run (← Context.init cfg.mkBang)
memOmega g |>.run (← Context.init cfg.mkBang .none)
| _ => throwUnsupportedSyntax

end MemOmega
49 changes: 47 additions & 2 deletions Arm/Memory/SeparateAutomation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import Tactics.Simp
import Tactics.BvOmegaBench
import Arm.Memory.Common
import Arm.Memory.MemOmega
import Lean.Elab.Tactic.Location
import Init.Tactics

open Lean Meta Elab Tactic Memory

Expand Down Expand Up @@ -366,10 +368,53 @@ Allow elaboration of `SimpMemConfig` arguments to tactics.
-/
declare_config_elab elabSimpMemConfig SeparateAutomation.SimpMemConfig


/-
This allows users to supply a list of hypotheses that simp_mem should use.
Modeled after `rwRule`.
-/
syntax simpMemRule := term

/-
The kind of simplification that must be performed. If we are told
that we must simplify a separation, a subset, or a read of a write,
we perform this kind of simplification.
-/
syntax simpMemSimplificationKind := "⟂" <|> "⊂w" <|> "⊂r" (term)?


open Lean.Parser.Tactic in
/--
Implement the simp_mem tactic frontend.
The simp_mem tactic allows simplifying expressions of the form `Memory.read_bytes rbase rn (mem')`.
`simp_mem` attempts to discover the result of the expression by various heuristics,
which can be controlled by the end user:

- (a) If `mem' = Memory.write_bytes wbase wn mem` and we know that `(rbase, rn) ⟂ (wbase, wn)`, then we simplify to `mem.read (rbase, rn)`.
- (b) If `mem' = Memory.write_bytes wbase wn wval mem` and we kow that `(rbase, rn) ⊆ (wbase, wn)`, then we simplify to `wval.extract (rbase, rn) (wbase, wn)`.
- (c) If we have a hypothesis `hr' : mem'.read_bytes rbase' rn' = rval`, and we know that `(rbase, rn) ⊆ (rbase', rn')`, then we simplify to `rval.extract (rbase, rn) (rbase', rn')`.

These simplifications are performed by reducing the problem to a problem that can be solved by a decision procedure (`omega`) to establish
which hypotheses are at play. `simp_mem` can be controlled along multiple axes:

1. The hypotheses that `simp_mem` will pass along to the decision procedure to discover overlapping reads (like `hr'`),
and hypotheses to establish memory (non-)interference, such as `(rbase, rn) ⟂ (wbase, wn)`.
+ simp_mem using []: try to perform the rewrite using no hypotheses.
+ simp_mem using [h₁, h₂]: try to perform the rewrite using h₁, h₂, as hypotheses.

2. The kind of rewrite that simp_mem should apply. By default, it explores all possible choices, which might be expensive due to repeated calls to the decision
procedure. The user can describe which of (a), (b), (c) above happen:
+ `simp_mem ⟂` : Only simplify when read is disjoint from write.
+ `simp_mem ⊂w` : Only simplify when read overlaps the write.
+ `simp_mem ⊂r hr` : Simplify when read overlaps with a known read `hr : mem.read_bytes baseaddr' n' = val`.
This is useful for static information such as lookup tables that are at a fixed location and never modified.
+ `simp_mem ⊂r` : Simplify when read overlaps with known read from hypothesis list.

3. The targets where the rewrite must be applied. (This needs some thought: does this even make sense?)
+ `simp_mem at ⊢`
+ `simp_mem at h₁, h₂, ⊢`

-/
syntax (name := simp_mem) "simp_mem" (Lean.Parser.Tactic.config)? : tactic
syntax (name := simp_mem) "simp_mem" (Lean.Parser.Tactic.config)? (simpMemSimplificationKind)? ("using" "[" withoutPosition(simpMemRule,*,?) "]")? (location)? : tactic

@[tactic simp_mem]
def evalSimpMem : Tactic := fun
Expand Down
38 changes: 18 additions & 20 deletions Proofs/Experiments/Memcpy/MemCpyVCG.lean
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ section PartialCorrectness
-- set_option skip_proof.skip true in
-- set_option trace.profiler true in
-- set_option profiler true in
set_option maxHeartbeats 0 in
#time set_option maxHeartbeats 0 in
theorem Memcpy.extracted_2 (s0 si : ArmState)
(h_si_x0_nonzero : si.x0 ≠ 0)
(h_s0_x1 : s0.x1 + 0x10#64 * (s0.x0 - si.x0) + 0x10#64 = s0.x1 + 0x10#64 * (s0.x0 - (si.x0 - 0x1#64)))
Expand All @@ -466,26 +466,27 @@ theorem Memcpy.extracted_2 (s0 si : ArmState)
(Memory.write_bytes 16 (s0.x2 + 0x10#64 * (s0.x0 - si.x0))
(Memory.read_bytes 16 (s0.x1 + 0x10#64 * (s0.x0 - si.x0)) si.mem) si.mem) =
Memory.read_bytes n addr s0.mem := by
have h_le : (s0.x0 - (si.x0 - 0x1#64)).toNat ≤ s0.x0.toNat := by bv_omega_bench
have h_upper_bound := hsep.hb.omega_def
have h_upper_bound₂ := h_pre_1.hb.omega_def
have h_upper_bound₃ := hsep.ha.omega_def
have h_width_lt : (0x10#64).toNat * (s0.x0 - (si.x0 - 0x1#64)).toNat < 2 ^ 64 := by mem_omega
have h_width_lt : (0x10#64).toNat * (s0.x0 - (si.x0 - 0x1#64)).toNat < 2 ^ 64 := by
mem_omega with [h_assert_1, h_pre_1]
rw [Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate']
· rw [h_assert_6]
skip_proof mem_omega
skip_proof mem_omega with [h_assert_1, h_pre_1, hsep]
· -- @bollu: TODO: figure out why this is so slow!/
apply mem_separate'.symm
apply mem_separate'.of_subset'_of_subset' hsep
· apply mem_subset'.of_omega
skip_proof refine ⟨?_, ?_, ?_, ?_⟩ <;> skip_proof bv_omega_bench
skip_proof refine ⟨?_, ?_, ?_, ?_⟩
· mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1] -- TODO: add support for patterns like *, -<hyp1>, ... -<hypN>
· mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1]
· mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1]
· mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1] -- , hsep] -- adding `hsep` makes this way slower.
· apply mem_subset'_refl hsep.hb

-- set_option skip_proof.skip true in
set_option maxHeartbeats 0 in
-- set_option trace.profiler true in
-- set_option profiler true in
theorem Memcpy.extracted_0 (s0 si : ArmState)
#time theorem Memcpy.extracted_0 (s0 si : ArmState)
(h_si_x0_nonzero : si.x0 ≠ 0)
(h_s0_x1 : s0.x1 + 0x10#64 * (s0.x0 - si.x0) + 0x10#64 = s0.x1 + 0x10#64 * (s0.x0 - (si.x0 - 0x1#64)))
(h_s0_x2 : s0.x2 + 0x10#64 * (s0.x0 - si.x0) + 0x10#64 = s0.x2 + 0x10#64 * (s0.x0 - (si.x0 - 0x1#64)))
Expand Down Expand Up @@ -518,9 +519,10 @@ theorem Memcpy.extracted_0 (s0 si : ArmState)
apply And.intro
· intros i hi
have h_subset_2 : mem_subset' s0.x2 (0x10#64 * (s0.x0 - si.x0)).toNat s0.x2 (s0.x0.toNat * 16) := by
skip_proof mem_omega
-- skip_proof mem_omega with [*, -h_assert_4, -h_assert_3, -h_pre_6]
skip_proof mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1]
have h_subset_1 : mem_subset' (s0.x1 + 0x10#64 * (s0.x0 - si.x0)) 16 s0.x1 (s0.x0.toNat * 16) := by
skip_proof mem_omega
skip_proof mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1]
have icases : i = s0.x0 - si.x0 ∨ i < s0.x0 - si.x0 := by skip_proof bv_omega_bench
have s2_sum_inbounds := h_pre_1.hb.omega_def
have i_sub_x0_mul_16 : 16 * i.toNat < 16 * s0.x0.toNat := by skip_proof bv_omega_bench
Expand All @@ -531,24 +533,20 @@ theorem Memcpy.extracted_0 (s0 si : ArmState)
· simp only [Nat.reduceMul, BitVec.toNat_add, BitVec.toNat_mul, BitVec.toNat_ofNat,
Nat.reducePow, Nat.reduceMod, BitVec.toNat_sub, Nat.add_mod_mod, Nat.sub_self,
BitVec.extractLsBytes_eq_self, BitVec.cast_eq]
rw [h_assert_6 _ _ (by mem_omega)]
· skip_proof mem_omega
rw [h_assert_6 _ _ (by mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1])]
· skip_proof mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1]
· rw [Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate']
· apply h_assert_5 _ hi
· constructor
· skip_proof mem_omega
· skip_proof mem_omega
· skip_proof mem_omega with [*, -h_s0_x1, -h_s0_x2, -h_assert_1, -h_assert_6, -h_pre_1, -h_pre_6]
· skip_proof mem_omega with [h_si_x0_nonzero, h_assert_1, h_pre_1]
· left
-- @bollu: TODO, see if `simp_mem` can figure this out given less aggressive
-- proof states.
skip_proof {
have s2_sum_inbounds := h_pre_1.hb.omega_def
have i_sub_x0_mul_16 : 16 * i.toNat < 16 * s0.x0.toNat := by skip_proof bv_omega_bench
rw [BitVec.toNat_add_eq_toNat_add_toNat (by bv_omega_bench)]
rw [BitVec.toNat_add_eq_toNat_add_toNat (by bv_omega_bench)]
rw [BitVec.toNat_mul_of_lt (by bv_omega_bench)]
rw [BitVec.toNat_mul_of_lt (by bv_omega_bench)]
bv_omega_bench
skip_proof mem_omega with [*, -h_assert_3, -h_assert_4]
}
· intros n addr hsep
apply Memcpy.extracted_2 <;> assumption
Expand Down
Loading
Loading