Skip to content

Commit

Permalink
Switch to using extractLsb' to avoid casting in the goal
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Sep 20, 2024
1 parent 1bcf287 commit 0d3f545
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 46 deletions.
69 changes: 36 additions & 33 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -611,13 +611,10 @@ example : rev_vector 32 16 8 0xaabbccdd#32 (by decide)
/-- Divide bv `vector` into elements, each of size `size`. This function gets
the `e`'th element from the `vector`. -/
@[state_simp_rules]
def elem_get (vector : BitVec n) (e : Nat) (size : Nat)
(h: size > 0): BitVec size :=
def elem_get (vector : BitVec n) (e : Nat) (size : Nat): BitVec size :=
-- assert (e+1)*size <= n
let lo := e * size
let hi := lo + size - 1
have h : hi - lo + 1 = size := by simp only [hi, lo]; omega
BitVec.cast h $ extractLsb hi lo vector
extractLsb' lo size vector

/-- Divide bv `vector` into elements, each of size `size`. This function sets
the `e`'th element in the `vector`. -/
Expand Down Expand Up @@ -646,7 +643,7 @@ deriving DecidableEq, Repr
export ShiftInfo (esize elements shift unsigned round accumulate)

@[state_simp_rules]
def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0)
def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool)
: BitVec n :=
-- assert shift > 0
let fn := if unsigned then ushiftRight else sshiftRight
Expand All @@ -656,8 +653,7 @@ def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0
BitVec.ofInt (n + 1) rounded
else
BitVec.ofInt (n + 1) value
have h₀ : n - 1 - 0 + 1 = n := by omega
BitVec.cast h₀ $ extractLsb (n-1) 0 (fn rounded_bv shift)
extractLsb' 0 n (fn rounded_bv shift)

@[state_simp_rules]
def Int_with_unsigned (unsigned : Bool) (value : BitVec n) : Int :=
Expand All @@ -669,9 +665,9 @@ def shift_right_common_aux
if h : info.elements ≤ e then
result
else
let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize info.h
let shift_elem := RShr info.unsigned elem info.shift info.round info.h
let acc_elem := elem_get operand2 e info.esize info.h + shift_elem
let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize
let shift_elem := RShr info.unsigned elem info.shift info.round
let acc_elem := elem_get operand2 e info.esize + shift_elem
let result := elem_set result e info.esize acc_elem info.h
have _ : info.elements - (e + 1) < info.elements - e := by omega
shift_right_common_aux (e + 1) info operand operand2 result
Expand All @@ -683,6 +679,13 @@ theorem shift_le (x : Nat) (shift :Nat) :
simp only [Nat.shiftRight_eq_div_pow]
exact Nat.div_le_self x (2 ^ shift)

-- FIXME: should this be upstreamed?
theorem extractLsb'_ofNat (x n : Nat) (lo size : Nat) :
extractLsb' lo size (BitVec.ofNat n x) = .ofNat size ((x % 2^n) >>> lo) := by
apply eq_of_getLsbD_eq
intro ⟨i, _lt⟩
simp [BitVec.ofNat]

@[state_simp_rules]
theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
(shift : Nat) (result : BitVec 128):
Expand All @@ -691,8 +694,8 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
unsigned := true, round := false, accumulate := false,
h := (by omega)}
operand 0#128 result =
(ushiftRight (extractLsb 127 64 operand) shift)
++ (ushiftRight (extractLsb 63 0 operand) shift) := by
(ushiftRight (extractLsb' 64 64 operand) shift)
++ (ushiftRight (extractLsb' 0 64 operand) shift) := by
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
Expand Down Expand Up @@ -728,16 +731,17 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
-- Eliminating casting functions
Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat
]
generalize (extractLsb 127 64 operand) = x; simp at x
generalize (extractLsb 63 0 operand) = y; simp at y
have h0 : ∀ (z : BitVec 64), extractLsb 63 0 ((zeroExtend 65 z).ushiftRight shift)
simp only [reduceExtracLsb', BitVec.zero_add]
generalize (extractLsb' 64 64 operand) = x
generalize (extractLsb' 0 64 operand) = y
have h0 : ∀ (z : BitVec 64), extractLsb' 0 64 ((zeroExtend 65 z).ushiftRight shift)
= z.ushiftRight shift := by
intro z
simp only [ushiftRight, toNat_truncate]
have h1: z.toNat % 2 ^ 65 = z.toNat := by omega
simp only [h1]
simp only [Std.Tactic.BVDecide.Normalize.BitVec.ofNatLt_reduce]
simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.extractLsb_ofNat, Nat.shiftRight_zero]
simp only [Nat.sub_zero, Nat.reduceAdd, extractLsb'_ofNat, Nat.shiftRight_zero]
have h2 : z.toNat >>> shift % 2 ^ 65 = z.toNat >>> shift := by
refine Nat.mod_eq_of_lt ?h3
have h4 : z.toNat >>> shift ≤ z.toNat := by exact shift_le z.toNat shift
Expand Down Expand Up @@ -780,10 +784,10 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
unsigned := false, round := false, accumulate := false,
h := (by omega) }
operand 0#128 result =
(sshiftRight (extractLsb 127 96 operand) shift)
++ (sshiftRight (extractLsb 95 64 operand) shift)
++ (sshiftRight (extractLsb 63 32 operand) shift)
++ (sshiftRight (extractLsb 31 0 operand) shift) := by
(sshiftRight (extractLsb' 96 32 operand) shift)
++ (sshiftRight (extractLsb' 64 32 operand) shift)
++ (sshiftRight (extractLsb' 32 32 operand) shift)
++ (sshiftRight (extractLsb' 0 32 operand) shift) := by
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
Expand Down Expand Up @@ -823,20 +827,19 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
-- Eliminating casting functions
ofInt_eq_signExtend
]
generalize extractLsb 31 0 operand = a; simp at a
generalize extractLsb 63 32 operand = b; simp at b
generalize extractLsb 95 64 operand = c; simp at c
generalize extractLsb 127 96 operand = d; simp at d
generalize extractLsb' 0 32 operand = a
generalize extractLsb' 32 32 operand = b
generalize extractLsb' 64 32 operand = c
generalize extractLsb' 96 32 operand = d
have h : ∀ (x : BitVec 32),
extractLsb 31 0 ((signExtend 33 x).sshiftRight shift)
extractLsb' 0 32 ((signExtend 33 x).sshiftRight shift)
= x.sshiftRight shift := by
intros x
apply eq_of_getLsbD_eq; intros i; simp at i
simp only [getLsbD_sshiftRight]
simp only [Nat.sub_zero, Nat.reduceAdd, getLsbD_extract, Nat.zero_add,
getLsbD_sshiftRight, getLsbD_signExtend]
simp only [show (i : Nat) ≤ 31 by omega,
decide_True, Bool.true_and]
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True,
Nat.zero_add, getLsbD_sshiftRight,
getLsbD_signExtend, Bool.true_and]
simp only [show ¬33 ≤ (i : Nat) by omega,
decide_False, Bool.not_false, Bool.true_and]
simp only [show ¬32 ≤ (i : Nat) by omega,
Expand Down Expand Up @@ -876,7 +879,7 @@ def shift_left_common_aux
if h : info.elements ≤ e then
result
else
let elem := elem_get operand e info.esize info.h
let elem := elem_get operand e info.esize
let shift_elem := elem <<< info.shift
let result := elem_set result e info.esize shift_elem info.h
have _ : info.elements - (e + 1) < info.elements - e := by omega
Expand All @@ -891,8 +894,8 @@ theorem shift_left_common_aux_64_2 (operand : BitVec 128)
unsigned := unsigned, round := round, accumulate := accumulate,
h := (by omega)}
operand result =
(extractLsb 127 64 operand <<< shift)
++ (extractLsb 63 0 operand <<< shift) := by
(extractLsb' 64 64 operand <<< shift)
++ (extractLsb' 0 64 operand <<< shift) := by
unfold shift_left_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_left_common_aux
Expand Down
6 changes: 3 additions & 3 deletions Arm/Insts/DPSFP/Advanced_simd_copy.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def exec_dup_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let elements := datasize / esize
let operand := read_sfp idxdsize inst.Rn s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let element := elem_get operand index esize h₀
let element := elem_get operand index esize
let result := dup_aux 0 elements esize element (BitVec.zero datasize) h₀
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
Expand Down Expand Up @@ -81,7 +81,7 @@ def exec_ins_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let operand := read_sfp idxdsize inst.Rn s
let result := read_sfp 128 inst.Rd s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let elem := elem_get operand src_index esize h₀
let elem := elem_get operand src_index esize
let result := elem_set result dst_index esize elem h₀
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
Expand Down Expand Up @@ -123,7 +123,7 @@ def exec_smov_umov (inst : Advanced_simd_copy_cls) (s : ArmState) (signed : Bool
-- if index == 0 then CheckFPEnabled64 else CheckFPAdvSIMDEnabled64
let operand := read_sfp idxdsize inst.Rn s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let element := elem_get operand index esize h₀
let element := elem_get operand index esize
let result := if signed then signExtend datasize element else zeroExtend datasize element
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_permute.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def trn_aux (p : Nat) (pairs : Nat) (esize : Nat) (part : Nat)
result
else
let idx_from := 2 * p + part
let op1_part := elem_get operand1 idx_from esize h
let op2_part := elem_get operand2 idx_from esize h
let op1_part := elem_get operand1 idx_from esize
let op2_part := elem_get operand2 idx_from esize
let result := elem_set result (2 * p) esize op1_part h
let result := elem_set result (2 * p + 1) esize op2_part h
have h₁ : pairs - (p + 1) < pairs - p := by omega
Expand Down
3 changes: 1 addition & 2 deletions Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def exec_advanced_simd_scalar_copy
let idxdsize := 64 <<< (lsb inst.imm5 4).toNat
let esize := 8 <<< size
let operand := read_sfp idxdsize inst.Rn s
have h : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let result := elem_get operand index.toNat esize h
let result := elem_get operand index.toNat esize
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
let s := write_sfp esize inst.Rd result s
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def tblx_aux (i : Nat) (elements : Nat) (indices : BitVec datasize)
result
else
have h₁ : 8 > 0 := by decide
let index := (elem_get indices i 8 h₁).toNat
let index := (elem_get indices i 8).toNat
let result :=
if index < 16 * regs then
let val := elem_get table index 8 h₁
let val := elem_get table index 8
elem_set result i 8 val h₁
else
result
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_three_different.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def pmull_op (e : Nat) (esize : Nat) (elements : Nat) (x : BitVec n)
if h₀ : elements <= e then
result
else
let element1 := elem_get x e esize H
let element2 := elem_get y e esize H
let element1 := elem_get x e esize
let element2 := elem_get y e esize
let elem_result := polynomial_mult element1 element2
have h₁ : esize + esize = 2 * esize := by omega
have h₂ : 2 * esize > 0 := by omega
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_three_same.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def binary_vector_op_aux (e : Nat) (elems : Nat) (esize : Nat)
result
else
have h₁ : e < elems := by omega
let element1 := elem_get x e esize H
let element2 := elem_get y e esize H
let element1 := elem_get x e esize
let element2 := elem_get y e esize
let elem_result := op element1 element2
let result := elem_set result e esize elem_result H
have ht1 : elems - (e + 1) < elems - e := by omega
Expand Down
1 change: 1 addition & 0 deletions Proofs/AES-GCM/GCMInitV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ theorem gcm_init_v8_program_correct (s0 sf : ArmState)
, shift_right_common_aux_32_4_fff
, DPSFP.AdvSIMDExpandImm
, DPSFP.dup_aux_0_4_32]
simp only [BitVec.extractLsb'_eq_extractLsb.symm]
generalize read_mem_bytes 16 (r (StateField.GPR 1#5) s0) s0 = Hinit
sorry

0 comments on commit 0d3f545

Please sign in to comment.