Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change memory-effects theorem to a quantifier-free statement #224

Merged
merged 17 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 89 additions & 66 deletions Arm/Memory/MemoryProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ section MemoryProofs

open BitVec

/-! ## One byte read/write lemmas-/
namespace Memory

theorem read_write_same :
read addr (write addr v mem) = v := by
simp [read, write, store_read_over_write_same]

theorem read_write_different (h : addr1 ≠ addr2) :
read addr1 (write addr2 v s) = read addr1 s := by
simp [read, write, store_read_over_write_different (h := h)]

theorem write_write_shadow :
write addr val2 (write addr val1 s) = write addr val2 s := by
unfold write write_store; simp_all

theorem write_irrelevant :
write addr (read addr s) s = s := by
simp [read, write, store_write_irrelevant]

end Memory

----------------------------------------------------------------------
-- Key theorem: read_mem_bytes_of_write_mem_bytes_same

Expand All @@ -34,32 +55,39 @@ theorem mem_separate_preserved_second_start_addr_add_one
apply BitVec.val_nat_le 1 m 64 h0 (_ : 1 < 2^64) h1
decide

theorem read_mem_of_write_mem_bytes_different (hn1 : n <= 2^64)
(h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) :
read_mem addr1 (write_mem_bytes n addr2 v s) = read_mem addr1 s := by
by_cases hn0 : n = 0
case pos => -- n = 0
subst n; simp only [write_mem_bytes]
case neg => -- n ≠ 0
have hn0' : 0 < n := by omega
induction n, hn0' using Nat.le_induction generalizing addr2 s
case base =>
have h' : addr1 ≠ addr2 := by apply mem_separate_starting_addresses_neq h
simp only [write_mem_bytes]
apply read_mem_of_write_mem_different h'
case succ =>
have h' : addr1 ≠ addr2 := by refine mem_separate_starting_addresses_neq h
rename_i m hn n_ih
simp_all only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero,
Nat.succ_ne_zero, not_false_eq_true, ne_eq,
write_mem_bytes, Nat.add_eq, Nat.add_zero]
rw [n_ih]
· rw [read_mem_of_write_mem_different h']
· omega
· rw [addr_add_one_add_m_sub_one m addr2 hn hn1]
rw [mem_separate_preserved_second_start_addr_add_one hn hn1 h]
· omega
done
theorem Memory.read_write_bytes_different (hn1 : n ≤ 2^64)
alexkeizer marked this conversation as resolved.
Show resolved Hide resolved
(h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) :
read addr1 (write_bytes n addr2 v mem) = read addr1 mem := by
induction n generalizing mem addr1 addr2
case zero => simp only [write_bytes]
case succ n ih =>
have h_neq : addr1 ≠ addr2 :=
mem_separate_starting_addresses_neq h
rw [Nat.add_one_sub_one] at h
cases n
case zero =>
simp [write_bytes, read_write_different h_neq]
case succ n =>
have h_sep : mem_separate addr1 addr1 (addr2 + 1#64)
(addr2 + 1#64 + BitVec.ofNat 64 n) := by
unfold mem_separate mem_overlap at h ⊢
simp only [BitVec.sub_self, ofNat_add, Bool.or_self_right, Bool.not_or,
Bool.and_eq_true, Bool.not_eq_eq_eq_not, Bool.not_true,
decide_eq_false_iff_not, BitVec.not_le] at h ⊢
generalize hn' : BitVec.ofNat 64 n = n' at *
have : n' ≠ -1 := by bv_omega
clear hn1 ih
bv_decide
have h_neq : addr1 ≠ addr2 :=
mem_separate_starting_addresses_neq h
rw [write_bytes, ih (by omega) h_sep, Memory.read_write_different h_neq]

theorem read_mem_of_write_mem_bytes_different (hn1 : n ≤ 2^64)
(h : mem_separate addr1 addr1 addr2 (addr2 + (BitVec.ofNat 64 (n - 1)))) :
read_mem addr1 (write_mem_bytes n addr2 v s) = read_mem addr1 s := by
simp only [ArmState.read_mem_eq_mem_read,
Memory.write_mem_bytes_eq_mem_write_bytes]
exact Memory.read_write_bytes_different hn1 h

theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8))
(hn0 : Nat.succ 0 ≤ n)
Expand All @@ -69,47 +97,42 @@ theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8
· omega
done

example (s : ArmState) :
read_mem_bytes n addr s = s.mem.read_bytes n addr := by
exact Memory.State.read_mem_bytes_eq_mem_read_bytes s

@[state_simp_rules]
theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n <= 2^64) :
read_mem_bytes n addr (write_mem_bytes n addr v s) = v := by
by_cases hn0 : n = 0
case pos =>
subst n
unfold read_mem_bytes
simp only [of_length_zero]
case neg => -- n ≠ 0
have hn0' : 0 < n := by omega
induction n, hn0' using Nat.le_induction generalizing addr s
case base =>
simp only [read_mem_bytes, write_mem_bytes,
read_mem_of_write_mem_same, BitVec.cast_eq]
have l1 := BitVec.extractLsb'_eq v
simp only [Nat.reduceSucc, Nat.one_mul, Nat.succ_sub_succ_eq_sub,
Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq,
forall_const] at l1
rw [l1]
have l2 := BitVec.empty_bitvector_append_left v
simp only [Nat.reduceSucc, Nat.one_mul, Nat.zero_add,
BitVec.cast_eq, forall_const] at l2
exact l2
case succ =>
rename_i n hn n_ih
simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero, write_mem_bytes]
rw [n_ih]
rw [read_mem_of_write_mem_bytes_different]
· simp only [Nat.add_eq, Nat.add_zero, read_mem_of_write_mem_same]
rw [append_byte_of_extract_rest_same_cast n v hn]
· omega
· have := mem_separate_contiguous_regions addr 0#64 (BitVec.ofNat 64 (n - 1))
simp only [Nat.reducePow, Nat.succ_sub_succ_eq_sub, Nat.sub_zero,
BitVec.sub_zero, ofNat_lt_ofNat, Nat.reduceMod,
BitVec.add_zero] at this
apply this
simp only [Nat.reducePow] at hn1
omega
· omega
· omega
done
theorem Memory.read_bytes_write_bytes_same (hn1 : n ≤ 2^64) :
read_bytes n addr (write_bytes n addr v mem) = v := by
induction n generalizing addr mem
case zero =>
simp [read_bytes, of_length_zero]
case succ n ih =>
simp only [read_bytes, write_bytes]
rw [ih (by omega)]
have h_sep :
let m := BitVec.ofNat 64 (n - 1)
mem_separate addr addr (addr + 1#64) (addr + 1#64 + m) := by
rw [← mem_separate_contiguous_regions addr 0#64 _]
· simp; rfl
· bv_omega
rw [read_write_bytes_different (by omega) h_sep, read_write_same]
apply BitVec.eq_of_getLsbD_eq
intro i
simp only [getLsbD_cast, getLsbD_append]
by_cases hi : i.val < 8
· simp [hi]
· have h₁ : i.val - 8 < n * 8 := by omega
have h₂ : 8 + (i.val - 8) = i.val := by omega
simp [hi, h₁, h₂]

@[state_simp_rules, memory_rules]
theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n ≤ 2^64) :
read_mem_bytes n addr (write_mem_bytes n addr v s) = v := by
open Memory in
rw [State.read_mem_bytes_eq_mem_read_bytes,
write_mem_bytes_eq_mem_write_bytes,
Memory.read_bytes_write_bytes_same hn1]

----------------------------------------------------------------------
-- Key theorem: read_mem_bytes_of_write_mem_bytes_different
Expand Down
20 changes: 13 additions & 7 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,12 @@ theorem read_mem_bytes_w_of_read_mem_eq
= read_mem_bytes n₁ addr₁ s₂ := by
simp only [read_mem_bytes_of_w, h]

@[state_simp_rules]
theorem mem_w_of_mem_eq {s₁ s₂ : ArmState} (h : s₁.mem = s₂.mem) (fld val) :
(w fld val s₁).mem = s₂.mem := by
unfold w;
cases fld <;> exact h

@[state_simp_rules]
theorem write_mem_bytes_program {n : Nat} (addr : BitVec 64) (bytes : BitVec (n * 8)):
(write_mem_bytes n addr bytes s).program = s.program := by
Expand Down Expand Up @@ -838,6 +844,9 @@ def read_bytes (n : Nat) (addr : BitVec 64) (m : Memory) : BitVec (n * 8) :=
have h : n' * 8 + 8 = (n' + 1) * 8 := by simp_arith
BitVec.cast h (rest ++ byte)

-- TODO (@bollu): we should drop the `State` namespace here, given that
-- this namespace is used nowhere else. Also, `ArmState.read_mem_eq_mem_read`
-- should probably live under the `Memory` namespace.
@[memory_rules]
theorem State.read_mem_bytes_eq_mem_read_bytes (s : ArmState) :
read_mem_bytes n addr s = s.mem.read_bytes n addr := by
Expand Down Expand Up @@ -1163,13 +1172,10 @@ theorem Memory.mem_eq_iff_read_mem_bytes_eq {s₁ s₂ : ArmState} :
· intro h _ _; rw[h]
· exact Memory.eq_of_read_mem_bytes_eq

theorem read_mem_bytes_write_mem_bytes_of_read_mem_eq
(h : ∀ n addr, read_mem_bytes n addr s₁ = read_mem_bytes n addr s₂)
(n₂ addr₂ val n₁ addr₁) :
read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₁)
= read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₂) := by
revert n₁ addr₁
simp only [← Memory.mem_eq_iff_read_mem_bytes_eq] at h ⊢
theorem mem_write_mem_bytes_of_mem_eq
(h : s₁.mem = s₂.mem) (n addr val) :
(write_mem_bytes n addr val s₁).mem
= (write_mem_bytes n addr val s₂).mem := by
simp only [memory_rules, h]

/- Helper lemma for `state_eq_iff_components_eq` -/
Expand Down
4 changes: 2 additions & 2 deletions Arm/Syntax.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import Arm.Memory.Separate

namespace ArmStateNotation

/-! We build a notation for `read_mem_bytes $n $base $s` as `$s[$base, $n]` -/
/-! We build a notation for `$s.mem.read_bytes $n $base $s` as `$s[$base, $n]` -/
shigoel marked this conversation as resolved.
Show resolved Hide resolved
@[inherit_doc read_mem_bytes]
syntax:max term noWs "[" withoutPosition(term) "," withoutPosition(term) noWs "]" : term
macro_rules | `($s[$base,$n]) => `(read_mem_bytes $n $base $s)
macro_rules | `($s[$base,$n]) => `(Memory.read_bytes $n $base (ArmState.mem $s))


/-! Notation to specify the frame condition for non-memory state components. E.g.,
Expand Down
26 changes: 2 additions & 24 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -89,37 +89,15 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
simp (config := {ground := true}) only at h_s0_pc
-- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow
-- unable to be reflected

sym_n 27
-- Epilogue
simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at *
simp only [memory_rules] at *
sym_aggregate
-- Split conjunction
repeat' apply And.intro
· -- Aggregate the memory (non)effects.
-- (FIXME) This will be tackled by `sym_aggregate` when `sym_n` and `simp_mem`
-- are merged.
simp only [*]
/-
(FIXME @bollu) `simp_mem; rfl` creates a malformed proof here. The tactic produces
no goals, but we get the following error message:

application type mismatch
Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
(Eq.mp (congrArg (Eq HTable) (Memory.State.read_mem_bytes_eq_mem_read_bytes s0))
(Eq.mp (congrArg (fun x => HTable = read_mem_bytes 256 x s0) zeroExtend_eq_of_r_gpr) h_HTable))
argument has type
HTable = Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem
but function has type
Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem = HTable →
mem_subset' (r (StateField.GPR 1#5) s0) 256 (r (StateField.GPR 1#5) s0) 256 →
Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem =
HTable.extractLsBytes (BitVec.toNat (r (StateField.GPR 1#5) s0) - BitVec.toNat (r (StateField.GPR 1#5) s0)) 256

simp_mem; rfl
-/
rw [Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate']
simp_mem
· simp_mem; rfl
alexkeizer marked this conversation as resolved.
Show resolved Hide resolved
shigoel marked this conversation as resolved.
Show resolved Hide resolved
· simp only [List.mem_cons, List.mem_singleton, not_or, and_imp]
sym_aggregate
· intro n addr h_separate
Expand Down
Loading
Loading