Skip to content

Commit

Permalink
Refactor to use extractLsb' instead of extractLsb (#191)
Browse files Browse the repository at this point in the history
## Description:

This PR refactors to use `extractLsb'` instead of `extractLsb` in the
whole repository. Doing so, we are able to simplify the proofs and
remove unnecessary hypothesis inputs to several functions.

### Testing:

What tests have been run? Did `make all` succeed for your changes? Was
conformance testing successful on an Aarch64 machine? Yes.

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.
  • Loading branch information
pennyannn authored Oct 3, 2024
1 parent 16e3f4f commit 9f2c4f5
Show file tree
Hide file tree
Showing 38 changed files with 516 additions and 613 deletions.
87 changes: 48 additions & 39 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,7 @@ abbrev ror (x : BitVec n) (r : Nat) : BitVec n :=
the `n`-bit bitvector `x`. -/
@[bitvec_rules]
abbrev lsb (x : BitVec n) (i : Nat) : BitVec 1 :=
-- TODO: We could use
-- BitVec.extractLsb' i 1 x
-- and avoid the cast here, but unfortunately, extractLsb' isn't supported
-- by LeanSAT.
(BitVec.extractLsb i i x).cast (by omega)
BitVec.extractLsb' i 1 x

abbrev partInstall (hi lo : Nat) (val : BitVec (hi - lo + 1)) (x : BitVec n): BitVec n :=
let mask := allOnes (hi - lo + 1)
Expand Down Expand Up @@ -552,16 +548,18 @@ theorem extractLsb_eq (x : BitVec n) (h : n = n - 1 + 1) :
ext1
simp [←h]

theorem extractLsb'_eq (x : BitVec n) :
BitVec.extractLsb' 0 n x = x := by
unfold extractLsb'
simp only [Nat.shiftRight_zero, ofNat_toNat, zeroExtend_eq]

@[bitvec_rules]
protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) :
extractLsb j 0 (zeroExtend i x) = zeroExtend (j + 1) x := by
protected theorem extractLsb'_of_zeroExtend (x : BitVec n) (h : j i) :
extractLsb' 0 j (zeroExtend i x) = zeroExtend j x := by
apply BitVec.eq_of_getLsbD_eq
simp
intro k
have q : k < i := by omega
by_cases h : decide (k ≤ j) <;> simp [q, h]
simp_all
omega

@[bitvec_rules, simp]
theorem zero_append {w} (x : BitVec 0) (y : BitVec w) :
Expand Down Expand Up @@ -636,12 +634,11 @@ theorem append_of_extract_general_nat (high low n vn : Nat) (h : vn < 2 ^ n) :
done

theorem append_of_extract (n : Nat) (v : BitVec n)
(high0 : high = n - low) (low0 : 1 <= low)
(h : high + (low - 1 - 0 + 1) = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb (low - 1) 0 v) = v := by
(high0 : high = n - low) (h : high + low = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb' 0 low v) = v := by
ext
subst high
have vlt := v.isLt; simp_all only [Nat.sub_zero]
have vlt := v.isLt
have := append_of_extract_general_nat (n - low) low n (BitVec.toNat v) vlt
have low_le : low ≤ n := by omega
simp_all [toNat_zeroExtend, Nat.sub_add_cancel, low_le]
Expand All @@ -651,17 +648,14 @@ theorem append_of_extract (n : Nat) (v : BitVec n)
exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) vlt
done

theorem append_of_extract_general (v : BitVec n)
(low0 : 1 <= low)
(h1 : high = width)
(h2 : (high + low - 1 - 0 + 1) = (width + (low - 1 - 0 + 1))) :
BitVec.cast h1 (zeroExtend high (v >>> low)) ++ extractLsb (low - 1) 0 v =
BitVec.cast h2 (extractLsb (high + low - 1) 0 v) := by
@[bitvec_rules]
theorem append_of_extract_general (v : BitVec n) :
(zeroExtend high (v >>> low)) ++ extractLsb' 0 low v =
extractLsb' 0 (high + low) v := by
ext
have := append_of_extract_general_nat high low n (BitVec.toNat v)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero, h1]
simp only [h_vlt, h1, forall_prop_of_true] at this
have low' : 1 ≤ width + low := Nat.le_trans low0 (Nat.le_add_left low width)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero]
simp only [h_vlt, forall_prop_of_true] at this
simp_all [toNat_zeroExtend, Nat.sub_add_cancel]
rw [Nat.mod_eq_of_lt (b := 2 ^ n)] at this
· rw [this]
Expand Down Expand Up @@ -790,7 +784,7 @@ def genBVPatMatchTest (var : Term) (pat : BVPat) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some bv ← c.toBVLit? then
let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv)
let test ← `(extractLsb' $(quote shift) $(quote len) $var == $bv)
result ← `($result && $test)
shift := shift + len
return result
Expand All @@ -812,7 +806,7 @@ def declBVPatVars (var : Term) (pat : BVPat) (rhs : Term) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some y ← c.toBVVar? then
let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var)
let rhs ← `(extractLsb' $(quote shift) $(quote len) $var)
result ← `(let $y := $rhs; $result)
shift := shift + len
return result
Expand Down Expand Up @@ -936,10 +930,10 @@ Definition to extract the `n`th least significant *Byte* from a bitvector.
TODO: this should be named `getLsByte`, or `getLsbByte` (Shilpi prefers this).
-/
def extractLsByte (val : BitVec w₁) (n : Nat) : BitVec 8 :=
val.extractLsb ((n + 1) * 8 - 1) (n * 8) |> .cast (by omega)
val.extractLsb' (n * 8) 8

theorem extractLsByte_def (val : BitVec w₁) (n : Nat) :
val.extractLsByte n = (val.extractLsb ((n + 1)*8 - 1) (n * 8) |>.cast (by omega)) := rfl
val.extractLsByte n = val.extractLsb' (n * 8) 8 := rfl

-- TODO: upstream
theorem extractLsb_or (x y : BitVec w₁) (n : Nat) :
Expand All @@ -951,16 +945,31 @@ theorem extractLsb_or (x y : BitVec w₁) (n : Nat) :
· simp only [h, decide_True, Bool.true_and]
· simp only [h, decide_False, Bool.false_and, Bool.or_self]

-- TODO: upstream
theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) :
(x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by
apply BitVec.eq_of_getLsbD_eq
simp only [getLsbD_extract, getLsbD_or]
intros i
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, getLsbD_or, Bool.true_and]

-- TODO: upstream
protected theorem extractLsb'_ofNat (x n : Nat) (l lo : Nat) :
extractLsb' lo l (BitVec.ofNat n x) = .ofNat l ((x % 2^n) >>> lo) := by
apply eq_of_getLsbD_eq
intro ⟨i, _lt⟩
simp [BitVec.ofNat]

theorem extractLsByte_zero {w : Nat} : (0#w).extractLsByte i = 0#8 := by
simp only [extractLsByte, BitVec.extractLsb_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]
simp only [extractLsByte, BitVec.extractLsb'_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]

theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) :
x.extractLsByte a = 0#8 := by
apply BitVec.eq_of_getLsbD_eq
intros i
simp only [getLsbD_zero, extractLsByte_def,
getLsbD_cast, getLsbD_extract, Bool.and_eq_false_imp, decide_eq_true_eq]
intros _
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, Bool.true_and]
apply BitVec.getLsbD_ge
omega

Expand All @@ -969,10 +978,13 @@ theorem getLsbD_extractLsByte (val : BitVec w₁) :
((BitVec.extractLsByte val n).getLsbD i) =
(decide (i ≤ 7) && val.getLsbD (n * 8 + i)) := by
simp only [extractLsByte, getLsbD_cast, getLsbD_extract]
rw [Nat.succ_mul]
simp only [Nat.add_one_sub_one,
Nat.add_sub_cancel_left]

simp only [getLsbD_extractLsb']
generalize val.getLsbD (n * 8 + i) = x
by_cases h : i < 8
· simp only [show (i : Nat) ≤ 7 by omega, decide_True, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq, h]
· simp only [show ¬(i : Nat) ≤ 7 by omega, decide_False, Bool.false_and,
Bool.and_eq_false_imp, decide_eq_true_eq, h]

/--
Two bitvectors of length `n*8` are equal if all their bytes are equal.
Expand All @@ -996,9 +1008,7 @@ theorem eq_of_extractLsByte_eq (x y : BitVec (n * 8))
@bollu: it's not clear if the definition for n=0 is desirable.
-/
def extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) : BitVec (n * 8) :=
match h : n with
| 0 => 0#0
| x + 1 => val.extractLsb (base * 8 + n * 8 - 1) (base * 8) |>.cast (by omega)
extractLsb' (base * 8) (n * 8) val

@[bitvec_rules]
theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
Expand All @@ -1011,10 +1021,9 @@ theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat)
simp only [show ¬i < 0 by omega, decide_False, Bool.false_and]
· simp only [extractLsBytes, getLsbD_cast, getLsbD_extract, Nat.zero_lt_succ, decide_True,
Bool.true_and]
simp only [show base * 8 + (n + 1) * 8 - 1 - base * 8 = (n + 1) * 8 - 1 by omega]
by_cases h : i < (n + 1) * 8
· simp only [show i ≤ (n + 1) * 8 - 1 by omega, decide_True, Bool.true_and, h]
· simp only [show ¬(i ≤ (n + 1) * 8 - 1) by omega, decide_False, Bool.false_and, h]
· simp only [getLsbD_extractLsb', h, decide_True, Bool.true_and]
· simp only [getLsbD_extractLsb', h, decide_False, Bool.false_and]

theorem extractLsByte_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
(BitVec.extractLsBytes val base n).extractLsByte i =
Expand Down
2 changes: 1 addition & 1 deletion Arm/Cosim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def sfp_list (s : ArmState) : List (BitVec 64) := Id.run do
let mut acc := []
for i in [0:32] do
let reg := read_sfp 128 (BitVec.ofNat 5 i) s
acc := acc ++ [(extractLsb 63 0 reg), (extractLsb 127 64 reg)]
acc := acc ++ [(extractLsb' 0 64 reg), (extractLsb' 64 64 reg)]
pure acc

/-- Get the flags in an ArmState as a 4-bit bitvector.-/
Expand Down
2 changes: 1 addition & 1 deletion Arm/Insts/BR/Cond_branch_imm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def Cond_branch_imm_inst.condition_holds
let N := read_flag PFlag.N s
let V := read_flag PFlag.V s
let result :=
match (extractLsb 3 1 inst.cond) with
match (extractLsb' 1 3 inst.cond) with
| 0b000#3 => Z = 1#1
| 0b001#3 => C = 1#1
| 0b010#3 => N = 1#1
Expand Down
52 changes: 22 additions & 30 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def ConditionHolds (cond : BitVec 4) (s : ArmState) : Bool :=
let C := read_flag C s
let V := read_flag V s
let result :=
match (extractLsb 3 1 cond) with
match (extractLsb' 1 3 cond) with
| 0b000#3 => Z = 1#1 -- EQ or NE
| 0b001#3 => C = 1#1 -- CS or CC
| 0b010#3 => N = 1#1 -- MI or PL
Expand All @@ -129,7 +129,7 @@ def ConditionHolds (cond : BitVec 4) (s : ArmState) : Bool :=
theorem sgt_iff_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) :
(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1) ↔ BitVec.slt y x := by
simp [AddWithCarry, make_pstate]
simp [AddWithCarry, make_pstate, lsb]
split
· bv_decide
· bv_decide
Expand All @@ -138,7 +138,7 @@ theorem sgt_iff_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) :
theorem sgt_iff_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) :
(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1) ↔ BitVec.slt y x := by
simp [AddWithCarry, make_pstate]
simp [AddWithCarry, make_pstate, lsb]
split
· bv_decide
· bv_decide
Expand All @@ -147,7 +147,7 @@ theorem sgt_iff_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) :
theorem sle_iff_not_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) :
(¬(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1)) ↔ BitVec.sle x y := by
simp [AddWithCarry, make_pstate]
simp [AddWithCarry, make_pstate, lsb]
split
· bv_decide
· bv_decide
Expand All @@ -156,7 +156,7 @@ theorem sle_iff_not_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) :
theorem sle_iff_not_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) :
(¬(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧
(AddWithCarry x (~~~y) 1#1).snd.z = 0#1)) ↔ BitVec.sle x y := by
simp [AddWithCarry, make_pstate]
simp [AddWithCarry, make_pstate, lsb]
split
· bv_decide
· bv_decide
Expand All @@ -174,12 +174,10 @@ theorem zero_iff_z_eq_one (x : BitVec 64) :
· bv_decide
done


/-- `Aligned x a` witnesses that the bitvector `x` is `a`-bit aligned. -/
def Aligned (x : BitVec n) (a : Nat) : Prop :=
-- (TODO @alex) Switch to using extractLsb' to unify the two cases.
match a with
| 0 => True
| a' + 1 => extractLsb a' 0 x = BitVec.zero _
extractLsb' 0 a x = BitVec.zero _

/-- We need to prove why the Aligned predicate is Decidable. -/
instance : Decidable (Aligned x a) := by
Expand All @@ -201,7 +199,7 @@ theorem Aligned_BitVecAdd_64_4 {x : BitVec 64} {y : BitVec 64}

theorem Aligned_AddWithCarry_64_4 (x : BitVec 64) (y : BitVec 64) (carry_in : BitVec 1)
(x_aligned : Aligned x 4)
(y_carry_in_aligned : Aligned (BitVec.add (extractLsb 3 0 y) (zeroExtend 4 carry_in)) 4)
(y_carry_in_aligned : Aligned (BitVec.add (extractLsb' 0 4 y) (zeroExtend 4 carry_in)) 4)
: Aligned (AddWithCarry x y carry_in).fst 4 := by
unfold AddWithCarry Aligned at *
simp_all only [Nat.sub_zero, zero_eq, add_eq]
Expand Down Expand Up @@ -238,7 +236,7 @@ theorem CheckSPAligment_of_write_mem_bytes :
@[state_simp_rules]
theorem CheckSPAlignment_AddWithCarry_64_4 (st : ArmState) (y : BitVec 64) (carry_in : BitVec 1)
(x_aligned : CheckSPAlignment st)
(y_carry_in_aligned : Aligned (BitVec.add (extractLsb 3 0 y) (zeroExtend 4 carry_in)) 4)
(y_carry_in_aligned : Aligned (BitVec.add (extractLsb' 0 4 y) (zeroExtend 4 carry_in)) 4)
: Aligned (AddWithCarry (r (StateField.GPR 31#5) st) y carry_in).fst 4 := by
simp_all only [CheckSPAlignment, read_gpr, zeroExtend_eq, Nat.sub_zero, add_eq,
Aligned_AddWithCarry_64_4]
Expand Down Expand Up @@ -438,7 +436,7 @@ def decode_bit_masks (immN : BitVec 1) (imms : BitVec 6) (immr : BitVec 6)
let r := immr &&& levels
let diff := s - r
let esize := 1 <<< len
let d := extractLsb (len - 1) 0 diff
let d := extractLsb' 0 len diff
let welem := zeroExtend esize (allOnes (s.toNat + 1))
let telem := zeroExtend esize (allOnes (d.toNat + 1))
let wmask := replicate (M/esize) $ rotateRight welem r.toNat
Expand Down Expand Up @@ -498,18 +496,16 @@ instance : ToString SIMDThreeSameLogicalType where toString a := toString (repr
----------------------------------------------------------------------

@[state_simp_rules]
def Vpart_read (n : BitVec 5) (part width : Nat) (s : ArmState) (H : width > 0)
def Vpart_read (n : BitVec 5) (part width : Nat) (s : ArmState)
: BitVec width :=
-- assert n >= 0 && n <= 31;
-- assert part IN {0, 1};
have h1: width - 1 + 1 = width := by omega
have h2: (width * 2 - 1 - width + 1) = width := by omega
if part = 0 then
-- assert width < 128;
BitVec.cast h1 $ extractLsb (width-1) 0 $ read_sfp 128 n s
extractLsb' 0 width $ read_sfp 128 n s
else
-- assert width IN {32,64};
BitVec.cast h2 $ extractLsb (width*2-1) width $ read_sfp 128 n s
extractLsb' width width $ read_sfp 128 n s


@[state_simp_rules]
Expand All @@ -522,7 +518,7 @@ def Vpart_write (n : BitVec 5) (part width : Nat) (val : BitVec width) (s : ArmS
write_sfp width n val s
else
-- assert width == 64
let res := (extractLsb 63 0 val) ++ (read_sfp 64 n s)
let res := (extractLsb' 0 64 val) ++ (read_sfp 64 n s)
write_sfp 128 n res s

----------------------------------------------------------------------
Expand Down Expand Up @@ -628,13 +624,10 @@ example : rev_vector 32 16 8 0xaabbccdd#32 (by decide)
/-- Divide bv `vector` into elements, each of size `size`. This function gets
the `e`'th element from the `vector`. -/
@[state_simp_rules]
def elem_get (vector : BitVec n) (e : Nat) (size : Nat)
(h: size > 0): BitVec size :=
def elem_get (vector : BitVec n) (e : Nat) (size : Nat) : BitVec size :=
-- assert (e+1)*size <= n
let lo := e * size
let hi := lo + size - 1
have h : hi - lo + 1 = size := by simp only [hi, lo]; omega
BitVec.cast h $ extractLsb hi lo vector
extractLsb' lo size vector

/-- Divide bv `vector` into elements, each of size `size`. This function sets
the `e`'th element in the `vector`. -/
Expand Down Expand Up @@ -663,7 +656,7 @@ deriving DecidableEq, Repr
export ShiftInfo (esize elements shift unsigned round accumulate)

@[state_simp_rules]
def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0)
def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool)
: BitVec n :=
-- assert shift > 0
let fn := if unsigned then ushiftRight else sshiftRight
Expand All @@ -673,8 +666,7 @@ def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0
BitVec.ofInt (n + 1) rounded
else
BitVec.ofInt (n + 1) value
have h₀ : n - 1 - 0 + 1 = n := by omega
BitVec.cast h₀ $ extractLsb (n-1) 0 (fn rounded_bv shift)
extractLsb' 0 n (fn rounded_bv shift)

@[state_simp_rules]
def Int_with_unsigned (unsigned : Bool) (value : BitVec n) : Int :=
Expand All @@ -686,9 +678,9 @@ def shift_right_common_aux
if h : info.elements ≤ e then
result
else
let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize info.h
let shift_elem := RShr info.unsigned elem info.shift info.round info.h
let acc_elem := elem_get operand2 e info.esize info.h + shift_elem
let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize
let shift_elem := RShr info.unsigned elem info.shift info.round
let acc_elem := elem_get operand2 e info.esize + shift_elem
let result := elem_set result e info.esize acc_elem info.h
have _ : info.elements - (e + 1) < info.elements - e := by omega
shift_right_common_aux (e + 1) info operand operand2 result
Expand All @@ -709,7 +701,7 @@ def shift_left_common_aux
if h : info.elements ≤ e then
result
else
let elem := elem_get operand e info.esize info.h
let elem := elem_get operand e info.esize
let shift_elem := elem <<< info.shift
let result := elem_set result e info.esize shift_elem info.h
have _ : info.elements - (e + 1) < info.elements - e := by omega
Expand Down
2 changes: 1 addition & 1 deletion Arm/Insts/DPI/Logical_imm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def MoveWidePreferred (sf immN : BitVec 1) (imms immr : BitVec 6) : Bool :=
false
-- NOTE: the second conjunct below is semantically equivalent to the ASL code
-- !((immN:imms) IN {'00xxxxx'})
else if sf = 0#1 ∧ ¬(immN = 0#1 ∧ imms.extractLsb 5 5 = 0#1) then
else if sf = 0#1 ∧ ¬(immN = 0#1 ∧ imms.extractLsb' 5 1 = 0#1) then
false

-- for MOVZ must contain no more than 16 ones
Expand Down
Loading

0 comments on commit 9f2c4f5

Please sign in to comment.