diff --git a/Arm/Insts/Common.lean b/Arm/Insts/Common.lean index 71e877da..04948a16 100644 --- a/Arm/Insts/Common.lean +++ b/Arm/Insts/Common.lean @@ -611,13 +611,10 @@ example : rev_vector 32 16 8 0xaabbccdd#32 (by decide) /-- Divide bv `vector` into elements, each of size `size`. This function gets the `e`'th element from the `vector`. -/ @[state_simp_rules] -def elem_get (vector : BitVec n) (e : Nat) (size : Nat) - (h: size > 0): BitVec size := +def elem_get (vector : BitVec n) (e : Nat) (size : Nat): BitVec size := -- assert (e+1)*size <= n let lo := e * size - let hi := lo + size - 1 - have h : hi - lo + 1 = size := by simp only [hi, lo]; omega - BitVec.cast h $ extractLsb hi lo vector + extractLsb' lo size vector /-- Divide bv `vector` into elements, each of size `size`. This function sets the `e`'th element in the `vector`. -/ @@ -646,7 +643,7 @@ deriving DecidableEq, Repr export ShiftInfo (esize elements shift unsigned round accumulate) @[state_simp_rules] -def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0) +def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) : BitVec n := -- assert shift > 0 let fn := if unsigned then ushiftRight else sshiftRight @@ -656,8 +653,7 @@ def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0 BitVec.ofInt (n + 1) rounded else BitVec.ofInt (n + 1) value - have h₀ : n - 1 - 0 + 1 = n := by omega - BitVec.cast h₀ $ extractLsb (n-1) 0 (fn rounded_bv shift) + extractLsb' 0 n (fn rounded_bv shift) @[state_simp_rules] def Int_with_unsigned (unsigned : Bool) (value : BitVec n) : Int := @@ -669,9 +665,9 @@ def shift_right_common_aux if h : info.elements ≤ e then result else - let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize info.h - let shift_elem := RShr info.unsigned elem info.shift info.round info.h - let acc_elem := elem_get operand2 e info.esize info.h + shift_elem + let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize + let shift_elem := RShr info.unsigned elem info.shift info.round + let acc_elem := elem_get operand2 e info.esize + shift_elem let result := elem_set result e info.esize acc_elem info.h have _ : info.elements - (e + 1) < info.elements - e := by omega shift_right_common_aux (e + 1) info operand operand2 result @@ -683,6 +679,13 @@ theorem shift_le (x : Nat) (shift :Nat) : simp only [Nat.shiftRight_eq_div_pow] exact Nat.div_le_self x (2 ^ shift) +-- FIXME: should this be upstreamed? +theorem extractLsb'_ofNat (x n : Nat) (lo size : Nat) : + extractLsb' lo size (BitVec.ofNat n x) = .ofNat size ((x % 2^n) >>> lo) := by + apply eq_of_getLsbD_eq + intro ⟨i, _lt⟩ + simp [BitVec.ofNat] + @[state_simp_rules] theorem shift_right_common_aux_64_2_tff (operand : BitVec 128) (shift : Nat) (result : BitVec 128): @@ -691,8 +694,8 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128) unsigned := true, round := false, accumulate := false, h := (by omega)} operand 0#128 result = - (ushiftRight (extractLsb 127 64 operand) shift) - ++ (ushiftRight (extractLsb 63 0 operand) shift) := by + (ushiftRight (extractLsb' 64 64 operand) shift) + ++ (ushiftRight (extractLsb' 0 64 operand) shift) := by unfold shift_right_common_aux simp only [minimal_theory, bitvec_rules] unfold shift_right_common_aux @@ -728,16 +731,17 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128) -- Eliminating casting functions Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat ] - generalize (extractLsb 127 64 operand) = x; simp at x - generalize (extractLsb 63 0 operand) = y; simp at y - have h0 : ∀ (z : BitVec 64), extractLsb 63 0 ((zeroExtend 65 z).ushiftRight shift) + simp only [reduceExtracLsb', BitVec.zero_add] + generalize (extractLsb' 64 64 operand) = x + generalize (extractLsb' 0 64 operand) = y + have h0 : ∀ (z : BitVec 64), extractLsb' 0 64 ((zeroExtend 65 z).ushiftRight shift) = z.ushiftRight shift := by intro z simp only [ushiftRight, toNat_truncate] have h1: z.toNat % 2 ^ 65 = z.toNat := by omega simp only [h1] simp only [Std.Tactic.BVDecide.Normalize.BitVec.ofNatLt_reduce] - simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.extractLsb_ofNat, Nat.shiftRight_zero] + simp only [Nat.sub_zero, Nat.reduceAdd, extractLsb'_ofNat, Nat.shiftRight_zero] have h2 : z.toNat >>> shift % 2 ^ 65 = z.toNat >>> shift := by refine Nat.mod_eq_of_lt ?h3 have h4 : z.toNat >>> shift ≤ z.toNat := by exact shift_le z.toNat shift @@ -780,10 +784,10 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128) unsigned := false, round := false, accumulate := false, h := (by omega) } operand 0#128 result = - (sshiftRight (extractLsb 127 96 operand) shift) - ++ (sshiftRight (extractLsb 95 64 operand) shift) - ++ (sshiftRight (extractLsb 63 32 operand) shift) - ++ (sshiftRight (extractLsb 31 0 operand) shift) := by + (sshiftRight (extractLsb' 96 32 operand) shift) + ++ (sshiftRight (extractLsb' 64 32 operand) shift) + ++ (sshiftRight (extractLsb' 32 32 operand) shift) + ++ (sshiftRight (extractLsb' 0 32 operand) shift) := by unfold shift_right_common_aux simp only [minimal_theory, bitvec_rules] unfold shift_right_common_aux @@ -823,20 +827,19 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128) -- Eliminating casting functions ofInt_eq_signExtend ] - generalize extractLsb 31 0 operand = a; simp at a - generalize extractLsb 63 32 operand = b; simp at b - generalize extractLsb 95 64 operand = c; simp at c - generalize extractLsb 127 96 operand = d; simp at d + generalize extractLsb' 0 32 operand = a + generalize extractLsb' 32 32 operand = b + generalize extractLsb' 64 32 operand = c + generalize extractLsb' 96 32 operand = d have h : ∀ (x : BitVec 32), - extractLsb 31 0 ((signExtend 33 x).sshiftRight shift) + extractLsb' 0 32 ((signExtend 33 x).sshiftRight shift) = x.sshiftRight shift := by intros x apply eq_of_getLsbD_eq; intros i; simp at i simp only [getLsbD_sshiftRight] - simp only [Nat.sub_zero, Nat.reduceAdd, getLsbD_extract, Nat.zero_add, - getLsbD_sshiftRight, getLsbD_signExtend] - simp only [show (i : Nat) ≤ 31 by omega, - decide_True, Bool.true_and] + simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, + Nat.zero_add, getLsbD_sshiftRight, + getLsbD_signExtend, Bool.true_and] simp only [show ¬33 ≤ (i : Nat) by omega, decide_False, Bool.not_false, Bool.true_and] simp only [show ¬32 ≤ (i : Nat) by omega, @@ -876,7 +879,7 @@ def shift_left_common_aux if h : info.elements ≤ e then result else - let elem := elem_get operand e info.esize info.h + let elem := elem_get operand e info.esize let shift_elem := elem <<< info.shift let result := elem_set result e info.esize shift_elem info.h have _ : info.elements - (e + 1) < info.elements - e := by omega @@ -891,8 +894,8 @@ theorem shift_left_common_aux_64_2 (operand : BitVec 128) unsigned := unsigned, round := round, accumulate := accumulate, h := (by omega)} operand result = - (extractLsb 127 64 operand <<< shift) - ++ (extractLsb 63 0 operand <<< shift) := by + (extractLsb' 64 64 operand <<< shift) + ++ (extractLsb' 0 64 operand <<< shift) := by unfold shift_left_common_aux simp only [minimal_theory, bitvec_rules] unfold shift_left_common_aux diff --git a/Arm/Insts/DPSFP/Advanced_simd_copy.lean b/Arm/Insts/DPSFP/Advanced_simd_copy.lean index aaa094f1..a2146a9e 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_copy.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_copy.lean @@ -44,7 +44,7 @@ def exec_dup_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState : let elements := datasize / esize let operand := read_sfp idxdsize inst.Rn s have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let element := elem_get operand index esize h₀ + let element := elem_get operand index esize let result := dup_aux 0 elements esize element (BitVec.zero datasize) h₀ -- State Updates let s := write_pc ((read_pc s) + 4#64) s @@ -81,7 +81,7 @@ def exec_ins_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState : let operand := read_sfp idxdsize inst.Rn s let result := read_sfp 128 inst.Rd s have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let elem := elem_get operand src_index esize h₀ + let elem := elem_get operand src_index esize let result := elem_set result dst_index esize elem h₀ -- State Updates let s := write_pc ((read_pc s) + 4#64) s @@ -123,7 +123,7 @@ def exec_smov_umov (inst : Advanced_simd_copy_cls) (s : ArmState) (signed : Bool -- if index == 0 then CheckFPEnabled64 else CheckFPAdvSIMDEnabled64 let operand := read_sfp idxdsize inst.Rn s have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let element := elem_get operand index esize h₀ + let element := elem_get operand index esize let result := if signed then signExtend datasize element else zeroExtend datasize element -- State Updates let s := write_pc ((read_pc s) + 4#64) s diff --git a/Arm/Insts/DPSFP/Advanced_simd_permute.lean b/Arm/Insts/DPSFP/Advanced_simd_permute.lean index 5ccc6bff..95a2920c 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_permute.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_permute.lean @@ -22,8 +22,8 @@ def trn_aux (p : Nat) (pairs : Nat) (esize : Nat) (part : Nat) result else let idx_from := 2 * p + part - let op1_part := elem_get operand1 idx_from esize h - let op2_part := elem_get operand2 idx_from esize h + let op1_part := elem_get operand1 idx_from esize + let op2_part := elem_get operand2 idx_from esize let result := elem_set result (2 * p) esize op1_part h let result := elem_set result (2 * p + 1) esize op2_part h have h₁ : pairs - (p + 1) < pairs - p := by omega diff --git a/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean b/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean index 3f2b25c2..effdd228 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean @@ -27,8 +27,7 @@ def exec_advanced_simd_scalar_copy let idxdsize := 64 <<< (lsb inst.imm5 4).toNat let esize := 8 <<< size let operand := read_sfp idxdsize inst.Rn s - have h : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let result := elem_get operand index.toNat esize h + let result := elem_get operand index.toNat esize -- State Updates let s := write_pc ((read_pc s) + 4#64) s let s := write_sfp esize inst.Rd result s diff --git a/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean b/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean index b032904c..16024f23 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean @@ -35,10 +35,10 @@ def tblx_aux (i : Nat) (elements : Nat) (indices : BitVec datasize) result else have h₁ : 8 > 0 := by decide - let index := (elem_get indices i 8 h₁).toNat + let index := (elem_get indices i 8).toNat let result := if index < 16 * regs then - let val := elem_get table index 8 h₁ + let val := elem_get table index 8 elem_set result i 8 val h₁ else result diff --git a/Arm/Insts/DPSFP/Advanced_simd_three_different.lean b/Arm/Insts/DPSFP/Advanced_simd_three_different.lean index 53800ea8..5a64fe3f 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_three_different.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_three_different.lean @@ -37,8 +37,8 @@ def pmull_op (e : Nat) (esize : Nat) (elements : Nat) (x : BitVec n) if h₀ : elements <= e then result else - let element1 := elem_get x e esize H - let element2 := elem_get y e esize H + let element1 := elem_get x e esize + let element2 := elem_get y e esize let elem_result := polynomial_mult element1 element2 have h₁ : esize + esize = 2 * esize := by omega have h₂ : 2 * esize > 0 := by omega diff --git a/Arm/Insts/DPSFP/Advanced_simd_three_same.lean b/Arm/Insts/DPSFP/Advanced_simd_three_same.lean index a532ac04..7da3cc42 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_three_same.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_three_same.lean @@ -25,8 +25,8 @@ def binary_vector_op_aux (e : Nat) (elems : Nat) (esize : Nat) result else have h₁ : e < elems := by omega - let element1 := elem_get x e esize H - let element2 := elem_get y e esize H + let element1 := elem_get x e esize + let element2 := elem_get y e esize let elem_result := op element1 element2 let result := elem_set result e esize elem_result H have ht1 : elems - (e + 1) < elems - e := by omega diff --git a/Proofs/AES-GCM/GCMInitV8Sym.lean b/Proofs/AES-GCM/GCMInitV8Sym.lean index b4d287b3..dcfe8edc 100644 --- a/Proofs/AES-GCM/GCMInitV8Sym.lean +++ b/Proofs/AES-GCM/GCMInitV8Sym.lean @@ -94,5 +94,6 @@ theorem gcm_init_v8_program_correct (s0 sf : ArmState) , shift_right_common_aux_32_4_fff , DPSFP.AdvSIMDExpandImm , DPSFP.dup_aux_0_4_32] + simp only [BitVec.extractLsb'_eq_extractLsb.symm] generalize read_mem_bytes 16 (r (StateField.GPR 1#5) s0) s0 = Hinit sorry