Skip to content

Commit

Permalink
Merge branch 'main' into register-frame-condition-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
shigoel authored Oct 17, 2024
2 parents e90accd + 1b67f50 commit e3e31ad
Show file tree
Hide file tree
Showing 28 changed files with 695 additions and 317,445 deletions.
622 changes: 375 additions & 247 deletions Arm/BitVec.lean

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions Arm/Insts/DPSFP/Advanced_simd_three_different.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def polynomial_mult_aux (i : Nat) (result : BitVec (m+n))
polynomial_mult_aux (i+1) new_res op1 op2
termination_by (m - i)

/-
Ref.:
https://developer.arm.com/documentation/ddi0602/2024-09/Shared-Pseudocode/shared-functions-vector?lang=en#impl-shared.PolynomialMult.2
bits(M+N) PolynomialMult(bits(M) op1, bits(N) op2)
result = Zeros(M+N);
extended_op2 = ZeroExtend(op2, M+N);
for i=0 to M-1
if op1<i> == '1' then
result = result EOR LSL(extended_op2, i);
return result;
-/
def polynomial_mult (op1 : BitVec m) (op2 : BitVec n) : BitVec (m+n) :=
let result := 0#(m+n)
let extended_op2 := zeroExtend (m+n) op2
Expand Down
18 changes: 18 additions & 0 deletions Arm/Insts/DPSFP/Advanced_simd_three_same.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ def binary_vector_op_aux (e : Nat) (elems : Nat) (esize : Nat)
binary_vector_op_aux (e + 1) elems esize op x y result
termination_by (elems - e)

theorem binary_vector_op_aux_of_lt {n} {e elems} (h : e < elems) (esize op)
(x y result : BitVec n) :
binary_vector_op_aux e elems esize op x y result
= let element1 := elem_get x e esize
let element2 := elem_get y e esize
let elem_result := op element1 element2
let result := elem_set result e esize elem_result
binary_vector_op_aux (e + 1) elems esize op x y result := by
conv => { lhs; unfold binary_vector_op_aux }
have : ¬(elems ≤ e) := by omega
simp only [this, ↓reduceIte]

theorem binary_vector_op_aux_of_not_lt {n} {e elems} (h : ¬(e < elems))
(esize op) (x y result : BitVec n) :
binary_vector_op_aux e elems esize op x y result = result := by
unfold binary_vector_op_aux
simp only [ite_eq_left_iff, Nat.not_le, h, false_implies]

/--
Perform pairwise op on esize-bit slices of x and y
-/
Expand Down
4 changes: 2 additions & 2 deletions Arm/Memory/MemoryProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_helper2
simp_all only [l1, decide_True, Bool.true_and, Nat.add_mod_mod]
rw [read_mem_bytes_of_write_mem_bytes_subset_helper1] <;> assumption
case neg =>
simp only [h₀, BitVec.bitvec_to_nat_of_nat, BitVec.toNat_append, Nat.testBit_or]
simp only [h₀, BitVec.toNat_ofNat, BitVec.toNat_append, Nat.testBit_or]
simp only [Nat.testBit_shiftLeft, Nat.testBit_mod_two_pow]
by_cases h₁ : (i < 8)
case pos => -- (i < 8)
Expand Down Expand Up @@ -948,7 +948,7 @@ private theorem write_mem_bytes_irrelevant_helper (h : n * 8 + 8 = (n + 1) * 8)
((BitVec.cast h (read_mem_bytes n (addr + 1#64) s ++ read_mem addr s)) >>> 8)) =
read_mem_bytes n (addr + 1#64) s := by
ext
simp [ushiftRight, ShiftRight.shiftRight, BitVec.bitvec_to_nat_of_nat]
simp [ushiftRight, ShiftRight.shiftRight, BitVec.toNat_ofNat]
have h_x_size := (read_mem_bytes n (addr + 1#64) s).isLt
have h_y_size := (read_mem addr s).isLt
generalize h_x : (BitVec.toNat (read_mem_bytes n (addr + 1#64) s)) = x
Expand Down
12 changes: 6 additions & 6 deletions Arm/Memory/SeparateProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ theorem n_minus_1_lt_2_64_1 (n : Nat)
(h1 : Nat.succ 0 ≤ n) (h2 : n < 2 ^ 64) :
(BitVec.ofNat 64 (n - 1)) < (BitVec.ofNat 64 (2^64 - 1)) := by
refine BitVec.val_bitvec_lt.mp ?a
simp [BitVec.bitvec_to_nat_of_nat]
simp [BitVec.toNat_ofNat]
have : n - 1 < 2 ^ 64 := by omega
simp_all [Nat.mod_eq_of_lt]
exact Nat.sub_lt_left_of_lt_add h1 h2
Expand Down Expand Up @@ -175,26 +175,26 @@ theorem first_addresses_add_one_preserves_subset_same_addr
rw [h3]
apply first_addresses_add_one_preserves_subset_same_addr_helper
rw [←BitVec.val_bitvec_lt]
simp [BitVec.bitvec_to_nat_of_nat]
simp [BitVec.toNat_ofNat]
simp_all [Nat.mod_eq_of_lt]
case inr =>
rename_i h3
have ⟨h3_0, h3_1⟩ := h3
rw [BitVec.add_sub_self_left_64] at h3_0
rw [BitVec.add_sub_self_left_64] at h3_0
rw [←BitVec.nat_bitvec_le] at h3_0
simp_all [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt]
simp_all [BitVec.toNat_ofNat, Nat.mod_eq_of_lt]
apply (BitVec.nat_bitvec_le ((BitVec.ofNat 64 m) - 1#64) ((BitVec.ofNat 64 n) - 1#64)).mp
rw [nat_bitvec_sub1]; rw [nat_bitvec_sub1]
simp [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt]
simp [BitVec.toNat_ofNat, Nat.mod_eq_of_lt]
· rw [Nat.mod_eq_of_lt h1u]
rw [Nat.mod_eq_of_lt h2u]
rw [Nat.mod_eq_of_lt (by omega)]
rw [Nat.mod_eq_of_lt (by omega)]
exact Nat.sub_le_sub_right h3_0 1
· simp [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt, h2u]
· simp [BitVec.toNat_ofNat, Nat.mod_eq_of_lt, h2u]
exact h2l
· simp [BitVec.bitvec_to_nat_of_nat, Nat.mod_eq_of_lt, h1u]
· simp [BitVec.toNat_ofNat, Nat.mod_eq_of_lt, h1u]
exact h1l
case right =>
rw [BitVec.add_sub_add_left]
Expand Down
188 changes: 165 additions & 23 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,139 @@ import Tactics.CSE
import Tactics.ClearNamed
import Arm.Memory.SeparateAutomation
import Arm.Syntax
import Correctness.ArmSpec

namespace GCMGmultV8Program
open ArmStateNotation

#genStepEqTheorems gcm_gmult_v8_program

/-
theorem vrev128_64_8_in_terms_of_rev_elems (x : BitVec 128) :
DPSFP.vrev128_64_8 x =
rev_elems 128 8 ((BitVec.setWidth 64 x) ++ (BitVec.setWidth 64 (x >>> 64))) _p1 _p2 := by
simp only [DPSFP.vrev128_64_8]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
rw [rev_elems_64_8_append_eq_rev_elems_128_8]
done
-/

theorem vrev128_64_8_in_terms_of_rev_elems (x : BitVec 128) :
DPSFP.vrev128_64_8 x =
-- rev_elems 64 8 (BitVec.setWidth 64 (x >>> 64)) _p1 _p2 ++
-- rev_elems 64 8 (BitVec.setWidth 64 x) _p3 _p4 := by
rev_elems 64 8 (BitVec.extractLsb' 64 64 x) _p1 _p2 ++
rev_elems 64 8 (BitVec.extractLsb' 0 64 x) _p3 _p4 := by
simp only [DPSFP.vrev128_64_8]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
exact rfl
done

-- (TODO) Should we simply replace one function by the other here?
theorem gcm_polyval_mul_eq_polynomial_mult {x y : BitVec 128} :
GCMV8.gcm_polyval_mul x y = DPSFP.polynomial_mult x y := by
sorry

theorem eq_of_rev_elems_eq (x y : BitVec 128) (h : x = y) :
(rev_elems 128 8 x _p1 _p2 = rev_elems 128 8 y _p1 _p2) := by
congr

theorem pmull_op_e_0_eize_64_elements_1_size_128_eq (x y : BitVec 64) :
DPSFP.pmull_op 0 64 1 x y 0#128 =
DPSFP.polynomial_mult x y := by
unfold DPSFP.pmull_op
simp (config := {ground := true}) only [minimal_theory]
unfold DPSFP.pmull_op
simp (config := {ground := true}) only [minimal_theory]
simp only [state_simp_rules, bitvec_rules]
done

theorem rev_elems_128_8_eq_rev_elems_64_8_extractLsb' (x : BitVec 128) :
rev_elems 128 8 x _p1 _p2 =
rev_elems 64 8 (BitVec.extractLsb' 0 64 x) _p3 _p4 ++ rev_elems 64 8 (BitVec.extractLsb' 64 64 x) _p5 _p6 := by
repeat unfold rev_elems
simp (config := {decide := true, ground := true}) only [minimal_theory, BitVec.cast_eq]
bv_check
"lrat_files/GCMGmultV8Sym.lean-GCMGmultV8Program.rev_elems_128_8_eq_rev_elems_64_8_extractLsb'-51-2.lrat"
done

theorem rev_elems_64_8_append_eq_rev_elems_128_8 (x y : BitVec 64) :
rev_elems 64 8 x _p1 _p2 ++ rev_elems 64 8 y _p3 _p4 =
rev_elems 128 8 (y ++ x) _p5 _p6 := by
repeat unfold rev_elems
simp (config := {decide := true, ground := true}) only [minimal_theory, BitVec.cast_eq]
bv_check
"lrat_files/GCMGmultV8Sym.lean-GCMGmultV8Program.rev_elems_64_8_append_eq_rev_elems_128_8-60-2.lrat"
done

private theorem lsb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 64 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 0 64 x := by
bv_decide
rw [BitVec.extractLsb'_append]
simp_all (config := {ground := true}) only [bitvec_rules]
congr

private theorem msb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 0 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 64 64 x := by
rw [BitVec.extractLsb'_append]
simp_all (config := {ground := true}) only [bitvec_rules]
congr

private theorem zeroExtend_allOnes_lsh_64 :
~~~(BitVec.zeroExtend 128 (BitVec.allOnes 64) <<< 64)
= 0x0000000000000000ffffffffffffffff#128 := by
decide

private theorem zeroExtend_allOnes_lsh_0 :
~~~(BitVec.zeroExtend 128 (BitVec.allOnes 64) <<< 0) =
0xffffffffffffffff0000000000000000#128 := by
decide

private theorem BitVec.extractLsb'_64_128_of_appends (x y w z : BitVec 64) :
BitVec.extractLsb' 64 128 (x ++ y ++ (w ++ z)) =
y ++ w := by
bv_decide

private theorem BitVec.and_high_to_extractLsb'_concat (x : BitVec 128) :
x &&& 0xffffffffffffffff0000000000000000#128 = (BitVec.extractLsb' 64 64 x) ++ 0#64 := by
bv_decide

theorem extractLsb'_zero_extractLsb'_of_le (h : len1 ≤ len2) :
BitVec.extractLsb' 0 len1 (BitVec.extractLsb' start len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega

theorem extractLsb'_extractLsb'_zero_of_le (h : start + len1 ≤ len2):
BitVec.extractLsb' start len1 (BitVec.extractLsb' 0 len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega
theorem BitVec.extractLsb'_append_eq (x : BitVec (n + n)) :
BitVec.extractLsb' n n x ++ BitVec.extractLsb' 0 n x = x := by
have h1 := @BitVec.append_of_extract_general (n + n) n n x
simp only [Nat.reduceAdd, BitVec.extractLsb'_eq] at h1
have h3 : BitVec.setWidth n (x >>> n) = BitVec.extractLsb' n n x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_setWidth, Fin.is_lt, decide_True, BitVec.getLsbD_ushiftRight,
Bool.true_and, BitVec.getLsbD_extractLsb']
simp_all only


/-
(TODO) Need a lemma like the following, which breaks up a polynomial
multiplication into four constituent ones, for normalization.
-/
example :
let p := 0b11#2
let q := 0b10#2
let w := 0b01#2
let z := 0b01#2
(DPSFP.polynomial_mult
(p ++ q)
(w ++ z))
=
((DPSFP.polynomial_mult p w) ++ 0#4) ^^^
(0#4 ++ (DPSFP.polynomial_mult q z)) ^^^
(0#2 ++ (DPSFP.polynomial_mult p z) ++ 0#2) ^^^
(0#2 ++ (DPSFP.polynomial_mult q w) ++ 0#2) := by native_decide


set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 50 in
Expand Down Expand Up @@ -82,7 +182,12 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
(sf, s0) ∧
-- Memory frame condition.
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 16)] (sf, s0) ∧
sf[r (.GPR 0) s0, 16] = GCMV8.GCMGmultV8_alt (HTable.extractLsb' 0 128) Xi := by
sf[r (.GPR 0) s0, 16] =
rev_elems 128 8
(GCMV8.GCMGmultV8_alt
(HTable.extractLsb' 0 128)
(rev_elems 128 8 Xi (by decide) (by decide)))
(by decide) (by decide) := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
simp only [Nat.reduceMul] at Xi HTable
Expand Down Expand Up @@ -121,6 +226,7 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0 + 16#64) 16 _ h_HTable.symm]
repeat sorry
simp only [h_HTable_high, h_HTable_low, ←h_Xi]
clear h_mem_sep h_run
/-
simp/ground below to reduce
(BitVec.extractLsb' 0 64
Expand All @@ -136,12 +242,48 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
-- (FIXME @bollu) cse leaves the goal unchanged here, quietly, likely due to
-- subexpressions occurring in dep. contexts. Maybe a message here would be helpful.
generalize h_Xi_rev : DPSFP.vrev128_64_8 Xi = Xi_rev
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)] at h_Xi_rev
generalize h_Xi_upper_rev : rev_elems 64 8 (BitVec.extractLsb' 64 64 Xi) (by decide) (by decide) = Xi_upper_rev
generalize h_Xi_lower_rev : rev_elems 64 8 (BitVec.extractLsb' 0 64 Xi) (by decide) (by decide) = Xi_lower_rev
-- Simplifying the RHS
simp only [←h_HTable, GCMV8.GCMGmultV8_alt,
simp only [GCMV8.GCMGmultV8_alt,
GCMV8.lo, GCMV8.hi,
GCMV8.gcm_polyval]
repeat rw [extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [extractLsb'_extractLsb'_zero_of_le (by decide)]
GCMV8.gcm_polyval,
←h_HTable, ←h_Xi_rev, h_Xi_lower_rev, h_Xi_upper_rev]
simp only [pmull_op_e_0_eize_64_elements_1_size_128_eq, gcm_polyval_mul_eq_polynomial_mult]
simp only [zeroExtend_allOnes_lsh_64, zeroExtend_allOnes_lsh_0]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [BitVec.xor_append]
repeat rw [BitVec.extractLsb'_append_right]
repeat rw [BitVec.extractLsb'_append_left]
repeat rw [BitVec.extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [BitVec.extractLsb'_extractLsb'_zero_of_le (by decide)]
rw [BitVec.and_high_to_extractLsb'_concat]
generalize h_HTable_upper : (BitVec.extractLsb' 64 64 HTable) = HTable_upper
generalize h_HTable_lower : (BitVec.extractLsb' 0 64 HTable) = HTable_lower
generalize h_term_u0u1 : (DPSFP.polynomial_mult HTable_upper Xi_upper_rev) = u0u1 at *
generalize h_term_l0l1 : (DPSFP.polynomial_mult HTable_lower Xi_lower_rev) = l0l1 at *
generalize h_term_1 : (DPSFP.polynomial_mult (BitVec.extractLsb' 128 64 HTable) (Xi_lower_rev ^^^ Xi_upper_rev) ^^^
BitVec.extractLsb' 64 128 (l0l1 ++ u0u1) ^^^
(u0u1 ^^^ l0l1)) = term_1
generalize h_term_2 : ((term_1 &&& 0xffffffffffffffff#128 ||| BitVec.zeroExtend 128 (BitVec.setWidth 64 u0u1) <<< 64) ^^^
DPSFP.polynomial_mult (BitVec.extractLsb' 0 64 u0u1) 0xc200000000000000#64)
= term_2
generalize h_term_3 : (BitVec.extractLsb' 64 128 (term_2 ++ term_2) ^^^
(BitVec.extractLsb' 64 64 l0l1 ++ 0x0#64 |||
BitVec.zeroExtend 128 (BitVec.extractLsb' 64 64 term_1) <<< 0))
= term_3
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [@rev_elems_64_8_append_eq_rev_elems_128_8 _ _ (by decide) (by decide) (by decide) (by decide)]
apply eq_of_rev_elems_eq
rw [@rev_elems_128_8_eq_rev_elems_64_8_extractLsb' _ (by decide) (by decide) (by decide) (by decide) (by decide)]
rw [h_Xi_upper_rev, h_Xi_lower_rev]
rw [BitVec.extractLsb'_append_eq]
simp [GCMV8.gcm_polyval_red]
-- have h_reduce : (GCMV8.reduce 0x100000000000000000000000000000087#129 0x1#129) = 1#129 := by native_decide
-- simp [GCMV8.gcm_polyval_red, GCMV8.irrepoly, GCMV8.pmod, h_reduce]
-- repeat (unfold GCMV8.pmod.pmodTR; simp)

sorry
done
Expand Down
2 changes: 1 addition & 1 deletion Proofs/AES-GCM/GCMInitV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ theorem gcm_init_v8_program_correct (s0 sf : ArmState)
Nat.zero_mod, Nat.zero_add, Nat.sub_zero, Nat.mul_one, Nat.zero_mul, Nat.one_mul,
Nat.reduceSub, BitVec.reduceMul, BitVec.reduceXOr, BitVec.mul_one, Nat.add_one_sub_one,
BitVec.one_mul]
-- bv_check "GCMInitV8Sym.lean-GCMInitV8Program.gcm_init_v8_program_correct-117-4.lrat"
-- bv_check "lrat_files/GCMInitV8Sym.lean-GCMInitV8Program.gcm_init_v8_program_correct-117-4.lrat"
-- TODO: proof works in vscode but timeout in the CI -- need to investigate further
-/

Binary file not shown.
Loading

0 comments on commit e3e31ad

Please sign in to comment.