Skip to content

Commit

Permalink
Progress towards GCMGmultV8Sym; some cleanup of bitvec lemmas
Browse files Browse the repository at this point in the history
  • Loading branch information
shigoel committed Oct 17, 2024
1 parent 4817d24 commit 12c2cdf
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 116 deletions.
34 changes: 26 additions & 8 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,24 @@ theorem extractLsb'_eq (x : BitVec n) :
unfold extractLsb'
simp only [Nat.shiftRight_zero, ofNat_toNat, setWidth_eq]

theorem extractLsb'_zero_extractLsb'_of_le (h : len1 ≤ len2) :
BitVec.extractLsb' 0 len1 (BitVec.extractLsb' start len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega

theorem extractLsb'_extractLsb'_zero_of_le (h : start + len1 ≤ len2):
BitVec.extractLsb' start len1 (BitVec.extractLsb' 0 len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega

-- TODO: upstream
theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) :
(x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by
Expand All @@ -604,7 +622,7 @@ protected theorem extractLsb'_of_setWidth (x : BitVec n) (h : j ≤ i) :
have q : k < i := by omega
by_cases h : decide (k ≤ j) <;> simp [q, h]

theorem BitVec.extractLsb'_append (x : BitVec n) (y : BitVec m) :
theorem extractLsb'_append (x : BitVec n) (y : BitVec m) :
(x ++ y).extractLsb' start len
= let len' := min len (m - start)
(x.extractLsb' (start - m) (len - len')
Expand All @@ -628,15 +646,15 @@ theorem BitVec.extractLsb'_append (x : BitVec n) (y : BitVec m) :
have h₅ : start - m + (↑i - (m - start)) = start + ↑i - m := by omega
simp [h₂, h₃, h₄, h₅]

theorem BitVec.cast_eq_of_heq (x : BitVec n) (y : BitVec m) (h : n = m) :
theorem cast_eq_of_heq (x : BitVec n) (y : BitVec m) (h : n = m) :
HEq x y → x.cast h = y := by
cases h; simp

theorem BitVec.cast_heq_iff (x : BitVec n) (y : BitVec m) (h : n = n') :
theorem cast_heq_iff (x : BitVec n) (y : BitVec m) (h : n = n') :
HEq (x.cast h) y ↔ HEq x y := by
cases h; simp

theorem BitVec.extractLsb'_append_right_of_le (h : start + len ≤ m)
theorem extractLsb'_append_right_of_le (h : start + len ≤ m)
(x : BitVec n) (y : BitVec m) :
(x ++ y).extractLsb' start len = y.extractLsb' start len := by
have len'_eq : min len (m - start) = len := by omega
Expand All @@ -646,12 +664,12 @@ theorem BitVec.extractLsb'_append_right_of_le (h : start + len ≤ m)
simp only [zero_width_append, heq_eq_eq, cast_heq_iff]

@[bitvec_rules]
theorem BitVec.extractLsb'_append_right (x : BitVec n) (y : BitVec m) :
theorem extractLsb'_append_right (x : BitVec n) (y : BitVec m) :
(x ++ y).extractLsb' 0 m = y := by
rw [extractLsb'_append_right_of_le (by omega), extractLsb'_eq]

@[simp]
theorem BitVec.extractLsb'_append_left_of_le (h : m ≤ start)
theorem extractLsb'_append_left_of_le (h : m ≤ start)
(x : BitVec n) (y : BitVec m) :
(x ++ y).extractLsb' start len = x.extractLsb' (start - m) len := by
have len'_eq : min len (m - start) = m - start := by omega
Expand All @@ -661,11 +679,11 @@ theorem BitVec.extractLsb'_append_left_of_le (h : m ≤ start)
simp only [append_zero_width, heq_eq_eq, cast_heq_iff, Nat.sub_zero]

@[bitvec_rules]
theorem BitVec.extractLsb'_append_left (x : BitVec n) (y : BitVec m) :
theorem extractLsb'_append_left (x : BitVec n) (y : BitVec m) :
(x ++ y).extractLsb' m n = x := by
rw [extractLsb'_append_left_of_le (by omega), Nat.sub_self, extractLsb'_eq]

theorem BitVec.extractLsb'_extractLsb'_of_le {w : Nat} (start₁ len₁ start₂ len₂)
theorem extractLsb'_extractLsb'_of_le {w : Nat} (start₁ len₁ start₂ len₂)
(h : start₂ + len₂ ≤ len₁)
(x : BitVec w) :
(x.extractLsb' start₁ len₁).extractLsb' start₂ len₂
Expand Down
11 changes: 11 additions & 0 deletions Arm/Insts/DPSFP/Advanced_simd_three_different.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def polynomial_mult_aux (i : Nat) (result : BitVec (m+n))
polynomial_mult_aux (i+1) new_res op1 op2
termination_by (m - i)

/-
Ref.:
https://developer.arm.com/documentation/ddi0602/2024-09/Shared-Pseudocode/shared-functions-vector?lang=en#impl-shared.PolynomialMult.2
bits(M+N) PolynomialMult(bits(M) op1, bits(N) op2)
result = Zeros(M+N);
extended_op2 = ZeroExtend(op2, M+N);
for i=0 to M-1
if op1<i> == '1' then
result = result EOR LSL(extended_op2, i);
return result;
-/
def polynomial_mult (op1 : BitVec m) (op2 : BitVec n) : BitVec (m+n) :=
let result := 0#(m+n)
let extended_op2 := zeroExtend (m+n) op2
Expand Down
188 changes: 165 additions & 23 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,139 @@ import Tactics.CSE
import Tactics.ClearNamed
import Arm.Memory.SeparateAutomation
import Arm.Syntax
import Correctness.ArmSpec

namespace GCMGmultV8Program
open ArmStateNotation

#genStepEqTheorems gcm_gmult_v8_program

/-
theorem vrev128_64_8_in_terms_of_rev_elems (x : BitVec 128) :
DPSFP.vrev128_64_8 x =
rev_elems 128 8 ((BitVec.setWidth 64 x) ++ (BitVec.setWidth 64 (x >>> 64))) _p1 _p2 := by
simp only [DPSFP.vrev128_64_8]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
rw [rev_elems_64_8_append_eq_rev_elems_128_8]
done
-/

theorem vrev128_64_8_in_terms_of_rev_elems (x : BitVec 128) :
DPSFP.vrev128_64_8 x =
-- rev_elems 64 8 (BitVec.setWidth 64 (x >>> 64)) _p1 _p2 ++
-- rev_elems 64 8 (BitVec.setWidth 64 x) _p3 _p4 := by
rev_elems 64 8 (BitVec.extractLsb' 64 64 x) _p1 _p2 ++
rev_elems 64 8 (BitVec.extractLsb' 0 64 x) _p3 _p4 := by
simp only [DPSFP.vrev128_64_8]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
unfold rev_vector
simp (config := {decide := true}) only [bitvec_rules, minimal_theory]
exact rfl
done

-- (TODO) Should we simply replace one function by the other here?
theorem gcm_polyval_mul_eq_polynomial_mult {x y : BitVec 128} :
GCMV8.gcm_polyval_mul x y = DPSFP.polynomial_mult x y := by
sorry

theorem eq_of_rev_elems_eq (x y : BitVec 128) (h : x = y) :
(rev_elems 128 8 x _p1 _p2 = rev_elems 128 8 y _p1 _p2) := by
congr

theorem pmull_op_e_0_eize_64_elements_1_size_128_eq (x y : BitVec 64) :
DPSFP.pmull_op 0 64 1 x y 0#128 =
DPSFP.polynomial_mult x y := by
unfold DPSFP.pmull_op
simp (config := {ground := true}) only [minimal_theory]
unfold DPSFP.pmull_op
simp (config := {ground := true}) only [minimal_theory]
simp only [state_simp_rules, bitvec_rules]
done

theorem rev_elems_128_8_eq_rev_elems_64_8_extractLsb' (x : BitVec 128) :
rev_elems 128 8 x _p1 _p2 =
rev_elems 64 8 (BitVec.extractLsb' 0 64 x) _p3 _p4 ++ rev_elems 64 8 (BitVec.extractLsb' 64 64 x) _p5 _p6 := by
repeat unfold rev_elems
simp (config := {decide := true, ground := true}) only [minimal_theory, BitVec.cast_eq]
bv_check
"lrat_files/GCMGmultV8Sym.lean-GCMGmultV8Program.rev_elems_128_8_eq_rev_elems_64_8_extractLsb'-51-2.lrat"
done

theorem rev_elems_64_8_append_eq_rev_elems_128_8 (x y : BitVec 64) :
rev_elems 64 8 x _p1 _p2 ++ rev_elems 64 8 y _p3 _p4 =
rev_elems 128 8 (y ++ x) _p5 _p6 := by
repeat unfold rev_elems
simp (config := {decide := true, ground := true}) only [minimal_theory, BitVec.cast_eq]
bv_check
"lrat_files/GCMGmultV8Sym.lean-GCMGmultV8Program.rev_elems_64_8_append_eq_rev_elems_128_8-60-2.lrat"
done

private theorem lsb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 64 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 0 64 x := by
bv_decide
rw [BitVec.extractLsb'_append]
simp_all (config := {ground := true}) only [bitvec_rules]
congr

private theorem msb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 0 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 64 64 x := by
rw [BitVec.extractLsb'_append]
simp_all (config := {ground := true}) only [bitvec_rules]
congr

private theorem zeroExtend_allOnes_lsh_64 :
~~~(BitVec.zeroExtend 128 (BitVec.allOnes 64) <<< 64)
= 0x0000000000000000ffffffffffffffff#128 := by
decide

private theorem zeroExtend_allOnes_lsh_0 :
~~~(BitVec.zeroExtend 128 (BitVec.allOnes 64) <<< 0) =
0xffffffffffffffff0000000000000000#128 := by
decide

private theorem BitVec.extractLsb'_64_128_of_appends (x y w z : BitVec 64) :
BitVec.extractLsb' 64 128 (x ++ y ++ (w ++ z)) =
y ++ w := by
bv_decide

private theorem BitVec.and_high_to_extractLsb'_concat (x : BitVec 128) :
x &&& 0xffffffffffffffff0000000000000000#128 = (BitVec.extractLsb' 64 64 x) ++ 0#64 := by
bv_decide

theorem extractLsb'_zero_extractLsb'_of_le (h : len1 ≤ len2) :
BitVec.extractLsb' 0 len1 (BitVec.extractLsb' start len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega

theorem extractLsb'_extractLsb'_zero_of_le (h : start + len1 ≤ len2):
BitVec.extractLsb' start len1 (BitVec.extractLsb' 0 len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega
theorem BitVec.extractLsb'_append_eq (x : BitVec (n + n)) :
BitVec.extractLsb' n n x ++ BitVec.extractLsb' 0 n x = x := by
have h1 := @BitVec.append_of_extract_general (n + n) n n x
simp only [Nat.reduceAdd, BitVec.extractLsb'_eq] at h1
have h3 : BitVec.setWidth n (x >>> n) = BitVec.extractLsb' n n x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_setWidth, Fin.is_lt, decide_True, BitVec.getLsbD_ushiftRight,
Bool.true_and, BitVec.getLsbD_extractLsb']
simp_all only


/-
(TODO) Need a lemma like the following, which breaks up a polynomial
multiplication into four constituent ones, for normalization.
-/
example :
let p := 0b11#2
let q := 0b10#2
let w := 0b01#2
let z := 0b01#2
(DPSFP.polynomial_mult
(p ++ q)
(w ++ z))
=
((DPSFP.polynomial_mult p w) ++ 0#4) ^^^
(0#4 ++ (DPSFP.polynomial_mult q z)) ^^^
(0#2 ++ (DPSFP.polynomial_mult p z) ++ 0#2) ^^^
(0#2 ++ (DPSFP.polynomial_mult q w) ++ 0#2) := by native_decide


set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 50 in
Expand Down Expand Up @@ -82,7 +182,12 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
(sf, s0) ∧
-- Memory frame condition.
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 16)] (sf, s0) ∧
sf[r (.GPR 0) s0, 16] = GCMV8.GCMGmultV8_alt (HTable.extractLsb' 0 128) Xi := by
sf[r (.GPR 0) s0, 16] =
rev_elems 128 8
(GCMV8.GCMGmultV8_alt
(HTable.extractLsb' 0 128)
(rev_elems 128 8 Xi (by decide) (by decide)))
(by decide) (by decide) := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
simp only [Nat.reduceMul] at Xi HTable
Expand Down Expand Up @@ -143,6 +248,7 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0 + 16#64) 16 _ h_HTable.symm]
repeat sorry
simp only [h_HTable_high, h_HTable_low, ←h_Xi]
clear h_mem_sep h_run
/-
simp/ground below to reduce
(BitVec.extractLsb' 0 64
Expand All @@ -158,12 +264,48 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
-- (FIXME @bollu) cse leaves the goal unchanged here, quietly, likely due to
-- subexpressions occurring in dep. contexts. Maybe a message here would be helpful.
generalize h_Xi_rev : DPSFP.vrev128_64_8 Xi = Xi_rev
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)] at h_Xi_rev
generalize h_Xi_upper_rev : rev_elems 64 8 (BitVec.extractLsb' 64 64 Xi) (by decide) (by decide) = Xi_upper_rev
generalize h_Xi_lower_rev : rev_elems 64 8 (BitVec.extractLsb' 0 64 Xi) (by decide) (by decide) = Xi_lower_rev
-- Simplifying the RHS
simp only [←h_HTable, GCMV8.GCMGmultV8_alt,
simp only [GCMV8.GCMGmultV8_alt,
GCMV8.lo, GCMV8.hi,
GCMV8.gcm_polyval]
repeat rw [extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [extractLsb'_extractLsb'_zero_of_le (by decide)]
GCMV8.gcm_polyval,
←h_HTable, ←h_Xi_rev, h_Xi_lower_rev, h_Xi_upper_rev]
simp only [pmull_op_e_0_eize_64_elements_1_size_128_eq, gcm_polyval_mul_eq_polynomial_mult]
simp only [zeroExtend_allOnes_lsh_64, zeroExtend_allOnes_lsh_0]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [BitVec.xor_append]
repeat rw [BitVec.extractLsb'_append_right]
repeat rw [BitVec.extractLsb'_append_left]
repeat rw [BitVec.extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [BitVec.extractLsb'_extractLsb'_zero_of_le (by decide)]
rw [BitVec.and_high_to_extractLsb'_concat]
generalize h_HTable_upper : (BitVec.extractLsb' 64 64 HTable) = HTable_upper
generalize h_HTable_lower : (BitVec.extractLsb' 0 64 HTable) = HTable_lower
generalize h_term_u0u1 : (DPSFP.polynomial_mult HTable_upper Xi_upper_rev) = u0u1 at *
generalize h_term_l0l1 : (DPSFP.polynomial_mult HTable_lower Xi_lower_rev) = l0l1 at *
generalize h_term_1 : (DPSFP.polynomial_mult (BitVec.extractLsb' 128 64 HTable) (Xi_lower_rev ^^^ Xi_upper_rev) ^^^
BitVec.extractLsb' 64 128 (l0l1 ++ u0u1) ^^^
(u0u1 ^^^ l0l1)) = term_1
generalize h_term_2 : ((term_1 &&& 0xffffffffffffffff#128 ||| BitVec.zeroExtend 128 (BitVec.setWidth 64 u0u1) <<< 64) ^^^
DPSFP.polynomial_mult (BitVec.extractLsb' 0 64 u0u1) 0xc200000000000000#64)
= term_2
generalize h_term_3 : (BitVec.extractLsb' 64 128 (term_2 ++ term_2) ^^^
(BitVec.extractLsb' 64 64 l0l1 ++ 0x0#64 |||
BitVec.zeroExtend 128 (BitVec.extractLsb' 64 64 term_1) <<< 0))
= term_3
rw [@vrev128_64_8_in_terms_of_rev_elems (by decide) (by decide) (by decide) (by decide)]
rw [BitVec.extractLsb'_64_128_of_appends]
rw [@rev_elems_64_8_append_eq_rev_elems_128_8 _ _ (by decide) (by decide) (by decide) (by decide)]
apply eq_of_rev_elems_eq
rw [@rev_elems_128_8_eq_rev_elems_64_8_extractLsb' _ (by decide) (by decide) (by decide) (by decide) (by decide)]
rw [h_Xi_upper_rev, h_Xi_lower_rev]
rw [BitVec.extractLsb'_append_eq]
simp [GCMV8.gcm_polyval_red]
-- have h_reduce : (GCMV8.reduce 0x100000000000000000000000000000087#129 0x1#129) = 1#129 := by native_decide
-- simp [GCMV8.gcm_polyval_red, GCMV8.irrepoly, GCMV8.pmod, h_reduce]
-- repeat (unfold GCMV8.pmod.pmodTR; simp)

sorry
done
Expand Down
2 changes: 1 addition & 1 deletion Proofs/AES-GCM/GCMInitV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ theorem gcm_init_v8_program_correct (s0 sf : ArmState)
Nat.zero_mod, Nat.zero_add, Nat.sub_zero, Nat.mul_one, Nat.zero_mul, Nat.one_mul,
Nat.reduceSub, BitVec.reduceMul, BitVec.reduceXOr, BitVec.mul_one, Nat.add_one_sub_one,
BitVec.one_mul]
-- bv_check "GCMInitV8Sym.lean-GCMInitV8Program.gcm_init_v8_program_correct-117-4.lrat"
-- bv_check "lrat_files/GCMInitV8Sym.lean-GCMInitV8Program.gcm_init_v8_program_correct-117-4.lrat"
-- TODO: proof works in vscode but timeout in the CI -- need to investigate further
-/

Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 12c2cdf

Please sign in to comment.