Skip to content

Commit

Permalink
Merge branch 'axeffects-tracing' into replace-init_next_step
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer committed Oct 4, 2024
2 parents 5bf77e7 + 366d482 commit 92c26e4
Show file tree
Hide file tree
Showing 75 changed files with 2,226 additions and 1,163 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/lake-packages/*
/.lake/*
*.log
/data/*
111 changes: 63 additions & 48 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ attribute [bitvec_rules] BitVec.msb_zero
attribute [bitvec_rules] BitVec.toNat_cast
attribute [bitvec_rules] BitVec.getLsbD_cast
attribute [bitvec_rules] BitVec.getMsbD_cast
attribute [bitvec_rules] BitVec.toNat_ofInt
-- attribute [bitvec_rules] BitVec.toNat_ofInt -- TODO: not tagged bv_toNat.
attribute [bv_toNat] BitVec.toNat_ofInt
attribute [bitvec_rules] BitVec.toInt_ofInt
attribute [bitvec_rules] BitVec.ofInt_natCast
attribute [bitvec_rules] BitVec.toNat_zeroExtend'
attribute [bitvec_rules] BitVec.toNat_zeroExtend
attribute [bitvec_rules] BitVec.toNat_truncate
-- attribute [bitvec_rules] BitVec.toNat_zeroExtend
-- attribute [bitvec_rules] BitVec.toNat_truncate
attribute [bitvec_rules] BitVec.zeroExtend_zero
attribute [bitvec_rules] BitVec.ofNat_toNat
attribute [bitvec_rules] BitVec.getLsbD_zeroExtend'
Expand Down Expand Up @@ -82,15 +83,15 @@ attribute [bitvec_rules] BitVec.toNat_xor
attribute [bitvec_rules] BitVec.toFin_xor
attribute [bitvec_rules] BitVec.getLsbD_xor
attribute [bitvec_rules] BitVec.truncate_xor
attribute [bitvec_rules] BitVec.toNat_not
-- attribute [bitvec_rules] BitVec.toNat_not
attribute [bitvec_rules] BitVec.toFin_not
attribute [bitvec_rules] BitVec.getLsbD_not
attribute [bitvec_rules] BitVec.truncate_not
attribute [bitvec_rules] BitVec.not_cast
attribute [bitvec_rules] BitVec.and_cast
attribute [bitvec_rules] BitVec.or_cast
attribute [bitvec_rules] BitVec.xor_cast
attribute [bitvec_rules] BitVec.toNat_shiftLeft
-- attribute [bitvec_rules] BitVec.toNat_shiftLeft
attribute [bitvec_rules] BitVec.toFin_shiftLeft
attribute [bitvec_rules] BitVec.getLsbD_shiftLeft
attribute [bitvec_rules] BitVec.getMsbD_shiftLeft
Expand Down Expand Up @@ -124,23 +125,24 @@ attribute [bitvec_rules] BitVec.not_concat
attribute [bitvec_rules] BitVec.concat_or_concat
attribute [bitvec_rules] BitVec.concat_and_concat
attribute [bitvec_rules] BitVec.concat_xor_concat
attribute [bitvec_rules] BitVec.toNat_add
-- attribute [bitvec_rules] BitVec.toNat_add
attribute [bitvec_rules] BitVec.toFin_add
attribute [bitvec_rules] BitVec.ofFin_add
attribute [bitvec_rules] BitVec.add_ofFin
attribute [bitvec_rules] BitVec.add_zero
attribute [bitvec_rules] BitVec.zero_add
attribute [bitvec_rules] BitVec.toInt_add
attribute [bitvec_rules] BitVec.toNat_sub
-- attribute [bitvec_rules] BitVec.toNat_sub
attribute [bv_toNat] toNat_sub
attribute [bitvec_rules] BitVec.toFin_sub
attribute [bitvec_rules] BitVec.ofFin_sub
attribute [bitvec_rules] BitVec.sub_ofFin
attribute [bitvec_rules] BitVec.sub_zero
attribute [bitvec_rules] BitVec.sub_self
attribute [bitvec_rules] BitVec.toNat_neg
-- attribute [bitvec_rules] BitVec.toNat_neg
attribute [bitvec_rules] BitVec.toFin_neg
attribute [bitvec_rules] BitVec.neg_zero
attribute [bitvec_rules] BitVec.toNat_mul
-- attribute [bitvec_rules] BitVec.toNat_mul
attribute [bitvec_rules] BitVec.toFin_mul
attribute [bitvec_rules] BitVec.mul_zero
attribute [bitvec_rules] BitVec.mul_one
Expand Down Expand Up @@ -246,6 +248,10 @@ attribute [bitvec_rules] Nat.reduceLeDiff
attribute [bitvec_rules] Nat.reduceSubDiff
attribute [bitvec_rules] BitVec.toNat_ofNat

-- This might be a neccesary evil: it introduces a modulus,
-- but it's also really useful.
attribute [bitvec_rules] BitVec.toNat_ofNat

-- Some Fin lemmas useful for bitvector reasoning:
attribute [bitvec_rules] Fin.eta
attribute [bitvec_rules] Fin.isLt
Expand Down Expand Up @@ -285,11 +291,7 @@ abbrev ror (x : BitVec n) (r : Nat) : BitVec n :=
the `n`-bit bitvector `x`. -/
@[bitvec_rules]
abbrev lsb (x : BitVec n) (i : Nat) : BitVec 1 :=
-- TODO: We could use
-- BitVec.extractLsb' i 1 x
-- and avoid the cast here, but unfortunately, extractLsb' isn't supported
-- by LeanSAT.
(BitVec.extractLsb i i x).cast (by omega)
BitVec.extractLsb' i 1 x

abbrev partInstall (hi lo : Nat) (val : BitVec (hi - lo + 1)) (x : BitVec n): BitVec n :=
let mask := allOnes (hi - lo + 1)
Expand Down Expand Up @@ -552,16 +554,18 @@ theorem extractLsb_eq (x : BitVec n) (h : n = n - 1 + 1) :
ext1
simp [←h]

theorem extractLsb'_eq (x : BitVec n) :
BitVec.extractLsb' 0 n x = x := by
unfold extractLsb'
simp only [Nat.shiftRight_zero, ofNat_toNat, zeroExtend_eq]

@[bitvec_rules]
protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) :
extractLsb j 0 (zeroExtend i x) = zeroExtend (j + 1) x := by
protected theorem extractLsb'_of_zeroExtend (x : BitVec n) (h : j i) :
extractLsb' 0 j (zeroExtend i x) = zeroExtend j x := by
apply BitVec.eq_of_getLsbD_eq
simp
intro k
have q : k < i := by omega
by_cases h : decide (k ≤ j) <;> simp [q, h]
simp_all
omega

@[bitvec_rules, simp]
theorem zero_append {w} (x : BitVec 0) (y : BitVec w) :
Expand Down Expand Up @@ -636,12 +640,11 @@ theorem append_of_extract_general_nat (high low n vn : Nat) (h : vn < 2 ^ n) :
done

theorem append_of_extract (n : Nat) (v : BitVec n)
(high0 : high = n - low) (low0 : 1 <= low)
(h : high + (low - 1 - 0 + 1) = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb (low - 1) 0 v) = v := by
(high0 : high = n - low) (h : high + low = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb' 0 low v) = v := by
ext
subst high
have vlt := v.isLt; simp_all only [Nat.sub_zero]
have vlt := v.isLt
have := append_of_extract_general_nat (n - low) low n (BitVec.toNat v) vlt
have low_le : low ≤ n := by omega
simp_all [toNat_zeroExtend, Nat.sub_add_cancel, low_le]
Expand All @@ -651,17 +654,14 @@ theorem append_of_extract (n : Nat) (v : BitVec n)
exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) vlt
done

theorem append_of_extract_general (v : BitVec n)
(low0 : 1 <= low)
(h1 : high = width)
(h2 : (high + low - 1 - 0 + 1) = (width + (low - 1 - 0 + 1))) :
BitVec.cast h1 (zeroExtend high (v >>> low)) ++ extractLsb (low - 1) 0 v =
BitVec.cast h2 (extractLsb (high + low - 1) 0 v) := by
@[bitvec_rules]
theorem append_of_extract_general (v : BitVec n) :
(zeroExtend high (v >>> low)) ++ extractLsb' 0 low v =
extractLsb' 0 (high + low) v := by
ext
have := append_of_extract_general_nat high low n (BitVec.toNat v)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero, h1]
simp only [h_vlt, h1, forall_prop_of_true] at this
have low' : 1 ≤ width + low := Nat.le_trans low0 (Nat.le_add_left low width)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero]
simp only [h_vlt, forall_prop_of_true] at this
simp_all [toNat_zeroExtend, Nat.sub_add_cancel]
rw [Nat.mod_eq_of_lt (b := 2 ^ n)] at this
· rw [this]
Expand Down Expand Up @@ -790,7 +790,7 @@ def genBVPatMatchTest (var : Term) (pat : BVPat) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some bv ← c.toBVLit? then
let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv)
let test ← `(extractLsb' $(quote shift) $(quote len) $var == $bv)
result ← `($result && $test)
shift := shift + len
return result
Expand All @@ -812,7 +812,7 @@ def declBVPatVars (var : Term) (pat : BVPat) (rhs : Term) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some y ← c.toBVVar? then
let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var)
let rhs ← `(extractLsb' $(quote shift) $(quote len) $var)
result ← `(let $y := $rhs; $result)
shift := shift + len
return result
Expand Down Expand Up @@ -936,10 +936,10 @@ Definition to extract the `n`th least significant *Byte* from a bitvector.
TODO: this should be named `getLsByte`, or `getLsbByte` (Shilpi prefers this).
-/
def extractLsByte (val : BitVec w₁) (n : Nat) : BitVec 8 :=
val.extractLsb ((n + 1) * 8 - 1) (n * 8) |> .cast (by omega)
val.extractLsb' (n * 8) 8

theorem extractLsByte_def (val : BitVec w₁) (n : Nat) :
val.extractLsByte n = (val.extractLsb ((n + 1)*8 - 1) (n * 8) |>.cast (by omega)) := rfl
val.extractLsByte n = val.extractLsb' (n * 8) 8 := rfl

-- TODO: upstream
theorem extractLsb_or (x y : BitVec w₁) (n : Nat) :
Expand All @@ -951,16 +951,31 @@ theorem extractLsb_or (x y : BitVec w₁) (n : Nat) :
· simp only [h, decide_True, Bool.true_and]
· simp only [h, decide_False, Bool.false_and, Bool.or_self]

-- TODO: upstream
theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) :
(x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by
apply BitVec.eq_of_getLsbD_eq
simp only [getLsbD_extract, getLsbD_or]
intros i
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, getLsbD_or, Bool.true_and]

-- TODO: upstream
protected theorem extractLsb'_ofNat (x n : Nat) (l lo : Nat) :
extractLsb' lo l (BitVec.ofNat n x) = .ofNat l ((x % 2^n) >>> lo) := by
apply eq_of_getLsbD_eq
intro ⟨i, _lt⟩
simp [BitVec.ofNat]

theorem extractLsByte_zero {w : Nat} : (0#w).extractLsByte i = 0#8 := by
simp only [extractLsByte, BitVec.extractLsb_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]
simp only [extractLsByte, BitVec.extractLsb'_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]

theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) :
x.extractLsByte a = 0#8 := by
apply BitVec.eq_of_getLsbD_eq
intros i
simp only [getLsbD_zero, extractLsByte_def,
getLsbD_cast, getLsbD_extract, Bool.and_eq_false_imp, decide_eq_true_eq]
intros _
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, Bool.true_and]
apply BitVec.getLsbD_ge
omega

Expand All @@ -969,10 +984,13 @@ theorem getLsbD_extractLsByte (val : BitVec w₁) :
((BitVec.extractLsByte val n).getLsbD i) =
(decide (i ≤ 7) && val.getLsbD (n * 8 + i)) := by
simp only [extractLsByte, getLsbD_cast, getLsbD_extract]
rw [Nat.succ_mul]
simp only [Nat.add_one_sub_one,
Nat.add_sub_cancel_left]

simp only [getLsbD_extractLsb']
generalize val.getLsbD (n * 8 + i) = x
by_cases h : i < 8
· simp only [show (i : Nat) ≤ 7 by omega, decide_True, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq, h]
· simp only [show ¬(i : Nat) ≤ 7 by omega, decide_False, Bool.false_and,
Bool.and_eq_false_imp, decide_eq_true_eq, h]

/--
Two bitvectors of length `n*8` are equal if all their bytes are equal.
Expand All @@ -996,9 +1014,7 @@ theorem eq_of_extractLsByte_eq (x y : BitVec (n * 8))
@bollu: it's not clear if the definition for n=0 is desirable.
-/
def extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) : BitVec (n * 8) :=
match h : n with
| 0 => 0#0
| x + 1 => val.extractLsb (base * 8 + n * 8 - 1) (base * 8) |>.cast (by omega)
extractLsb' (base * 8) (n * 8) val

@[bitvec_rules]
theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
Expand All @@ -1011,10 +1027,9 @@ theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat)
simp only [show ¬i < 0 by omega, decide_False, Bool.false_and]
· simp only [extractLsBytes, getLsbD_cast, getLsbD_extract, Nat.zero_lt_succ, decide_True,
Bool.true_and]
simp only [show base * 8 + (n + 1) * 8 - 1 - base * 8 = (n + 1) * 8 - 1 by omega]
by_cases h : i < (n + 1) * 8
· simp only [show i ≤ (n + 1) * 8 - 1 by omega, decide_True, Bool.true_and, h]
· simp only [show ¬(i ≤ (n + 1) * 8 - 1) by omega, decide_False, Bool.false_and, h]
· simp only [getLsbD_extractLsb', h, decide_True, Bool.true_and]
· simp only [getLsbD_extractLsb', h, decide_False, Bool.false_and]

theorem extractLsByte_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
(BitVec.extractLsBytes val base n).extractLsByte i =
Expand Down
2 changes: 1 addition & 1 deletion Arm/Cosim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def sfp_list (s : ArmState) : List (BitVec 64) := Id.run do
let mut acc := []
for i in [0:32] do
let reg := read_sfp 128 (BitVec.ofNat 5 i) s
acc := acc ++ [(extractLsb 63 0 reg), (extractLsb 127 64 reg)]
acc := acc ++ [(extractLsb' 0 64 reg), (extractLsb' 64 64 reg)]
pure acc

/-- Get the flags in an ArmState as a 4-bit bitvector.-/
Expand Down
2 changes: 1 addition & 1 deletion Arm/Insts/BR/Cond_branch_imm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def Cond_branch_imm_inst.condition_holds
let N := read_flag PFlag.N s
let V := read_flag PFlag.V s
let result :=
match (extractLsb 3 1 inst.cond) with
match (extractLsb' 1 3 inst.cond) with
| 0b000#3 => Z = 1#1
| 0b001#3 => C = 1#1
| 0b010#3 => N = 1#1
Expand Down
Loading

0 comments on commit 92c26e4

Please sign in to comment.