diff --git a/.gitignore b/.gitignore index 6c2506f9..1e3e3f4d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ /lake-packages/* /.lake/* *.log +/data/* diff --git a/Arm/BitVec.lean b/Arm/BitVec.lean index a93da2de..24dad21f 100644 --- a/Arm/BitVec.lean +++ b/Arm/BitVec.lean @@ -43,12 +43,13 @@ attribute [bitvec_rules] BitVec.msb_zero attribute [bitvec_rules] BitVec.toNat_cast attribute [bitvec_rules] BitVec.getLsbD_cast attribute [bitvec_rules] BitVec.getMsbD_cast -attribute [bitvec_rules] BitVec.toNat_ofInt +-- attribute [bitvec_rules] BitVec.toNat_ofInt -- TODO: not tagged bv_toNat. +attribute [bv_toNat] BitVec.toNat_ofInt attribute [bitvec_rules] BitVec.toInt_ofInt attribute [bitvec_rules] BitVec.ofInt_natCast attribute [bitvec_rules] BitVec.toNat_zeroExtend' -attribute [bitvec_rules] BitVec.toNat_zeroExtend -attribute [bitvec_rules] BitVec.toNat_truncate +-- attribute [bitvec_rules] BitVec.toNat_zeroExtend +-- attribute [bitvec_rules] BitVec.toNat_truncate attribute [bitvec_rules] BitVec.zeroExtend_zero attribute [bitvec_rules] BitVec.ofNat_toNat attribute [bitvec_rules] BitVec.getLsbD_zeroExtend' @@ -82,7 +83,7 @@ attribute [bitvec_rules] BitVec.toNat_xor attribute [bitvec_rules] BitVec.toFin_xor attribute [bitvec_rules] BitVec.getLsbD_xor attribute [bitvec_rules] BitVec.truncate_xor -attribute [bitvec_rules] BitVec.toNat_not +-- attribute [bitvec_rules] BitVec.toNat_not attribute [bitvec_rules] BitVec.toFin_not attribute [bitvec_rules] BitVec.getLsbD_not attribute [bitvec_rules] BitVec.truncate_not @@ -90,7 +91,7 @@ attribute [bitvec_rules] BitVec.not_cast attribute [bitvec_rules] BitVec.and_cast attribute [bitvec_rules] BitVec.or_cast attribute [bitvec_rules] BitVec.xor_cast -attribute [bitvec_rules] BitVec.toNat_shiftLeft +-- attribute [bitvec_rules] BitVec.toNat_shiftLeft attribute [bitvec_rules] BitVec.toFin_shiftLeft attribute [bitvec_rules] BitVec.getLsbD_shiftLeft attribute [bitvec_rules] BitVec.getMsbD_shiftLeft @@ -124,23 +125,24 @@ attribute [bitvec_rules] BitVec.not_concat attribute [bitvec_rules] BitVec.concat_or_concat attribute [bitvec_rules] BitVec.concat_and_concat attribute [bitvec_rules] BitVec.concat_xor_concat -attribute [bitvec_rules] BitVec.toNat_add +-- attribute [bitvec_rules] BitVec.toNat_add attribute [bitvec_rules] BitVec.toFin_add attribute [bitvec_rules] BitVec.ofFin_add attribute [bitvec_rules] BitVec.add_ofFin attribute [bitvec_rules] BitVec.add_zero attribute [bitvec_rules] BitVec.zero_add attribute [bitvec_rules] BitVec.toInt_add -attribute [bitvec_rules] BitVec.toNat_sub +-- attribute [bitvec_rules] BitVec.toNat_sub +attribute [bv_toNat] toNat_sub attribute [bitvec_rules] BitVec.toFin_sub attribute [bitvec_rules] BitVec.ofFin_sub attribute [bitvec_rules] BitVec.sub_ofFin attribute [bitvec_rules] BitVec.sub_zero attribute [bitvec_rules] BitVec.sub_self -attribute [bitvec_rules] BitVec.toNat_neg +-- attribute [bitvec_rules] BitVec.toNat_neg attribute [bitvec_rules] BitVec.toFin_neg attribute [bitvec_rules] BitVec.neg_zero -attribute [bitvec_rules] BitVec.toNat_mul +-- attribute [bitvec_rules] BitVec.toNat_mul attribute [bitvec_rules] BitVec.toFin_mul attribute [bitvec_rules] BitVec.mul_zero attribute [bitvec_rules] BitVec.mul_one @@ -246,6 +248,10 @@ attribute [bitvec_rules] Nat.reduceLeDiff attribute [bitvec_rules] Nat.reduceSubDiff attribute [bitvec_rules] BitVec.toNat_ofNat +-- This might be a neccesary evil: it introduces a modulus, +-- but it's also really useful. +attribute [bitvec_rules] BitVec.toNat_ofNat + -- Some Fin lemmas useful for bitvector reasoning: attribute [bitvec_rules] Fin.eta attribute [bitvec_rules] Fin.isLt @@ -285,11 +291,7 @@ abbrev ror (x : BitVec n) (r : Nat) : BitVec n := the `n`-bit bitvector `x`. -/ @[bitvec_rules] abbrev lsb (x : BitVec n) (i : Nat) : BitVec 1 := - -- TODO: We could use - -- BitVec.extractLsb' i 1 x - -- and avoid the cast here, but unfortunately, extractLsb' isn't supported - -- by LeanSAT. - (BitVec.extractLsb i i x).cast (by omega) + BitVec.extractLsb' i 1 x abbrev partInstall (hi lo : Nat) (val : BitVec (hi - lo + 1)) (x : BitVec n): BitVec n := let mask := allOnes (hi - lo + 1) @@ -552,16 +554,18 @@ theorem extractLsb_eq (x : BitVec n) (h : n = n - 1 + 1) : ext1 simp [←h] +theorem extractLsb'_eq (x : BitVec n) : + BitVec.extractLsb' 0 n x = x := by + unfold extractLsb' + simp only [Nat.shiftRight_zero, ofNat_toNat, zeroExtend_eq] + @[bitvec_rules] -protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) : - extractLsb j 0 (zeroExtend i x) = zeroExtend (j + 1) x := by +protected theorem extractLsb'_of_zeroExtend (x : BitVec n) (h : j ≤ i) : + extractLsb' 0 j (zeroExtend i x) = zeroExtend j x := by apply BitVec.eq_of_getLsbD_eq - simp intro k have q : k < i := by omega by_cases h : decide (k ≤ j) <;> simp [q, h] - simp_all - omega @[bitvec_rules, simp] theorem zero_append {w} (x : BitVec 0) (y : BitVec w) : @@ -636,12 +640,11 @@ theorem append_of_extract_general_nat (high low n vn : Nat) (h : vn < 2 ^ n) : done theorem append_of_extract (n : Nat) (v : BitVec n) - (high0 : high = n - low) (low0 : 1 <= low) - (h : high + (low - 1 - 0 + 1) = n) : - BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb (low - 1) 0 v) = v := by + (high0 : high = n - low) (h : high + low = n) : + BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb' 0 low v) = v := by ext subst high - have vlt := v.isLt; simp_all only [Nat.sub_zero] + have vlt := v.isLt have := append_of_extract_general_nat (n - low) low n (BitVec.toNat v) vlt have low_le : low ≤ n := by omega simp_all [toNat_zeroExtend, Nat.sub_add_cancel, low_le] @@ -651,17 +654,14 @@ theorem append_of_extract (n : Nat) (v : BitVec n) exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) vlt done -theorem append_of_extract_general (v : BitVec n) - (low0 : 1 <= low) - (h1 : high = width) - (h2 : (high + low - 1 - 0 + 1) = (width + (low - 1 - 0 + 1))) : - BitVec.cast h1 (zeroExtend high (v >>> low)) ++ extractLsb (low - 1) 0 v = - BitVec.cast h2 (extractLsb (high + low - 1) 0 v) := by +@[bitvec_rules] +theorem append_of_extract_general (v : BitVec n) : + (zeroExtend high (v >>> low)) ++ extractLsb' 0 low v = + extractLsb' 0 (high + low) v := by ext have := append_of_extract_general_nat high low n (BitVec.toNat v) - have h_vlt := v.isLt; simp_all only [Nat.sub_zero, h1] - simp only [h_vlt, h1, forall_prop_of_true] at this - have low' : 1 ≤ width + low := Nat.le_trans low0 (Nat.le_add_left low width) + have h_vlt := v.isLt; simp_all only [Nat.sub_zero] + simp only [h_vlt, forall_prop_of_true] at this simp_all [toNat_zeroExtend, Nat.sub_add_cancel] rw [Nat.mod_eq_of_lt (b := 2 ^ n)] at this · rw [this] @@ -790,7 +790,7 @@ def genBVPatMatchTest (var : Term) (pat : BVPat) : MacroM Term := do for c in pat.getComponents do let len := c.length if let some bv ← c.toBVLit? then - let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv) + let test ← `(extractLsb' $(quote shift) $(quote len) $var == $bv) result ← `($result && $test) shift := shift + len return result @@ -812,7 +812,7 @@ def declBVPatVars (var : Term) (pat : BVPat) (rhs : Term) : MacroM Term := do for c in pat.getComponents do let len := c.length if let some y ← c.toBVVar? then - let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var) + let rhs ← `(extractLsb' $(quote shift) $(quote len) $var) result ← `(let $y := $rhs; $result) shift := shift + len return result @@ -936,10 +936,10 @@ Definition to extract the `n`th least significant *Byte* from a bitvector. TODO: this should be named `getLsByte`, or `getLsbByte` (Shilpi prefers this). -/ def extractLsByte (val : BitVec w₁) (n : Nat) : BitVec 8 := - val.extractLsb ((n + 1) * 8 - 1) (n * 8) |> .cast (by omega) + val.extractLsb' (n * 8) 8 theorem extractLsByte_def (val : BitVec w₁) (n : Nat) : - val.extractLsByte n = (val.extractLsb ((n + 1)*8 - 1) (n * 8) |>.cast (by omega)) := rfl + val.extractLsByte n = val.extractLsb' (n * 8) 8 := rfl -- TODO: upstream theorem extractLsb_or (x y : BitVec w₁) (n : Nat) : @@ -951,8 +951,23 @@ theorem extractLsb_or (x y : BitVec w₁) (n : Nat) : · simp only [h, decide_True, Bool.true_and] · simp only [h, decide_False, Bool.false_and, Bool.or_self] +-- TODO: upstream +theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) : + (x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by + apply BitVec.eq_of_getLsbD_eq + simp only [getLsbD_extract, getLsbD_or] + intros i + simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, getLsbD_or, Bool.true_and] + +-- TODO: upstream +protected theorem extractLsb'_ofNat (x n : Nat) (l lo : Nat) : + extractLsb' lo l (BitVec.ofNat n x) = .ofNat l ((x % 2^n) >>> lo) := by + apply eq_of_getLsbD_eq + intro ⟨i, _lt⟩ + simp [BitVec.ofNat] + theorem extractLsByte_zero {w : Nat} : (0#w).extractLsByte i = 0#8 := by - simp only [extractLsByte, BitVec.extractLsb_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat] + simp only [extractLsByte, BitVec.extractLsb'_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat] theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) : x.extractLsByte a = 0#8 := by @@ -960,7 +975,7 @@ theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) : intros i simp only [getLsbD_zero, extractLsByte_def, getLsbD_cast, getLsbD_extract, Bool.and_eq_false_imp, decide_eq_true_eq] - intros _ + simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, Bool.true_and] apply BitVec.getLsbD_ge omega @@ -969,10 +984,13 @@ theorem getLsbD_extractLsByte (val : BitVec w₁) : ((BitVec.extractLsByte val n).getLsbD i) = (decide (i ≤ 7) && val.getLsbD (n * 8 + i)) := by simp only [extractLsByte, getLsbD_cast, getLsbD_extract] - rw [Nat.succ_mul] - simp only [Nat.add_one_sub_one, - Nat.add_sub_cancel_left] - + simp only [getLsbD_extractLsb'] + generalize val.getLsbD (n * 8 + i) = x + by_cases h : i < 8 + · simp only [show (i : Nat) ≤ 7 by omega, decide_True, Bool.true_and, + Bool.and_iff_right_iff_imp, decide_eq_true_eq, h] + · simp only [show ¬(i : Nat) ≤ 7 by omega, decide_False, Bool.false_and, + Bool.and_eq_false_imp, decide_eq_true_eq, h] /-- Two bitvectors of length `n*8` are equal if all their bytes are equal. @@ -996,9 +1014,7 @@ theorem eq_of_extractLsByte_eq (x y : BitVec (n * 8)) @bollu: it's not clear if the definition for n=0 is desirable. -/ def extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) : BitVec (n * 8) := - match h : n with - | 0 => 0#0 - | x + 1 => val.extractLsb (base * 8 + n * 8 - 1) (base * 8) |>.cast (by omega) + extractLsb' (base * 8) (n * 8) val @[bitvec_rules] theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) : @@ -1011,10 +1027,9 @@ theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) simp only [show ¬i < 0 by omega, decide_False, Bool.false_and] · simp only [extractLsBytes, getLsbD_cast, getLsbD_extract, Nat.zero_lt_succ, decide_True, Bool.true_and] - simp only [show base * 8 + (n + 1) * 8 - 1 - base * 8 = (n + 1) * 8 - 1 by omega] by_cases h : i < (n + 1) * 8 - · simp only [show i ≤ (n + 1) * 8 - 1 by omega, decide_True, Bool.true_and, h] - · simp only [show ¬(i ≤ (n + 1) * 8 - 1) by omega, decide_False, Bool.false_and, h] + · simp only [getLsbD_extractLsb', h, decide_True, Bool.true_and] + · simp only [getLsbD_extractLsb', h, decide_False, Bool.false_and] theorem extractLsByte_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) : (BitVec.extractLsBytes val base n).extractLsByte i = diff --git a/Arm/Cosim.lean b/Arm/Cosim.lean index 42432346..4c98224c 100644 --- a/Arm/Cosim.lean +++ b/Arm/Cosim.lean @@ -195,7 +195,7 @@ def sfp_list (s : ArmState) : List (BitVec 64) := Id.run do let mut acc := [] for i in [0:32] do let reg := read_sfp 128 (BitVec.ofNat 5 i) s - acc := acc ++ [(extractLsb 63 0 reg), (extractLsb 127 64 reg)] + acc := acc ++ [(extractLsb' 0 64 reg), (extractLsb' 64 64 reg)] pure acc /-- Get the flags in an ArmState as a 4-bit bitvector.-/ diff --git a/Arm/Insts/BR/Cond_branch_imm.lean b/Arm/Insts/BR/Cond_branch_imm.lean index d15aa16b..33a0fdfb 100644 --- a/Arm/Insts/BR/Cond_branch_imm.lean +++ b/Arm/Insts/BR/Cond_branch_imm.lean @@ -29,7 +29,7 @@ def Cond_branch_imm_inst.condition_holds let N := read_flag PFlag.N s let V := read_flag PFlag.V s let result := - match (extractLsb 3 1 inst.cond) with + match (extractLsb' 1 3 inst.cond) with | 0b000#3 => Z = 1#1 | 0b001#3 => C = 1#1 | 0b010#3 => N = 1#1 diff --git a/Arm/Insts/Common.lean b/Arm/Insts/Common.lean index d15fff48..484c83fe 100644 --- a/Arm/Insts/Common.lean +++ b/Arm/Insts/Common.lean @@ -109,7 +109,7 @@ def ConditionHolds (cond : BitVec 4) (s : ArmState) : Bool := let C := read_flag C s let V := read_flag V s let result := - match (extractLsb 3 1 cond) with + match (extractLsb' 1 3 cond) with | 0b000#3 => Z = 1#1 -- EQ or NE | 0b001#3 => C = 1#1 -- CS or CC | 0b010#3 => N = 1#1 -- MI or PL @@ -129,7 +129,7 @@ def ConditionHolds (cond : BitVec 4) (s : ArmState) : Bool := theorem sgt_iff_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) : (((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧ (AddWithCarry x (~~~y) 1#1).snd.z = 0#1) ↔ BitVec.slt y x := by - simp [AddWithCarry, make_pstate] + simp [AddWithCarry, make_pstate, lsb] split · bv_decide · bv_decide @@ -138,7 +138,7 @@ theorem sgt_iff_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) : theorem sgt_iff_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) : (((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧ (AddWithCarry x (~~~y) 1#1).snd.z = 0#1) ↔ BitVec.slt y x := by - simp [AddWithCarry, make_pstate] + simp [AddWithCarry, make_pstate, lsb] split · bv_decide · bv_decide @@ -147,7 +147,7 @@ theorem sgt_iff_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) : theorem sle_iff_not_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) : (¬(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧ (AddWithCarry x (~~~y) 1#1).snd.z = 0#1)) ↔ BitVec.sle x y := by - simp [AddWithCarry, make_pstate] + simp [AddWithCarry, make_pstate, lsb] split · bv_decide · bv_decide @@ -156,7 +156,7 @@ theorem sle_iff_not_n_eq_v_and_z_eq_0_64 (x y : BitVec 64) : theorem sle_iff_not_n_eq_v_and_z_eq_0_32 (x y : BitVec 32) : (¬(((AddWithCarry x (~~~y) 1#1).snd.n = (AddWithCarry x (~~~y) 1#1).snd.v) ∧ (AddWithCarry x (~~~y) 1#1).snd.z = 0#1)) ↔ BitVec.sle x y := by - simp [AddWithCarry, make_pstate] + simp [AddWithCarry, make_pstate, lsb] split · bv_decide · bv_decide @@ -174,12 +174,10 @@ theorem zero_iff_z_eq_one (x : BitVec 64) : · bv_decide done + /-- `Aligned x a` witnesses that the bitvector `x` is `a`-bit aligned. -/ def Aligned (x : BitVec n) (a : Nat) : Prop := - -- (TODO @alex) Switch to using extractLsb' to unify the two cases. - match a with - | 0 => True - | a' + 1 => extractLsb a' 0 x = BitVec.zero _ + extractLsb' 0 a x = BitVec.zero _ /-- We need to prove why the Aligned predicate is Decidable. -/ instance : Decidable (Aligned x a) := by @@ -201,7 +199,7 @@ theorem Aligned_BitVecAdd_64_4 {x : BitVec 64} {y : BitVec 64} theorem Aligned_AddWithCarry_64_4 (x : BitVec 64) (y : BitVec 64) (carry_in : BitVec 1) (x_aligned : Aligned x 4) - (y_carry_in_aligned : Aligned (BitVec.add (extractLsb 3 0 y) (zeroExtend 4 carry_in)) 4) + (y_carry_in_aligned : Aligned (BitVec.add (extractLsb' 0 4 y) (zeroExtend 4 carry_in)) 4) : Aligned (AddWithCarry x y carry_in).fst 4 := by unfold AddWithCarry Aligned at * simp_all only [Nat.sub_zero, zero_eq, add_eq] @@ -221,24 +219,36 @@ def CheckSPAlignment (s : ArmState) : Prop := instance : Decidable (CheckSPAlignment s) := by unfold CheckSPAlignment; infer_instance @[state_simp_rules] -theorem CheckSPAligment_of_w_different (h : StateField.GPR 31#5 ≠ fld) : +theorem CheckSPAligment_w_different_eq (h : StateField.GPR 31#5 ≠ fld) : CheckSPAlignment (w fld v s) = CheckSPAlignment s := by simp_all only [CheckSPAlignment, state_simp_rules, minimal_theory, bitvec_rules] +theorem CheckSPAligment_w_of_ne_sp_of (h : StateField.GPR 31#5 ≠ fld) : + CheckSPAlignment s → CheckSPAlignment (w fld v s) := by + simp only [CheckSPAligment_w_different_eq h, imp_self] + @[state_simp_rules] theorem CheckSPAligment_of_w_sp : CheckSPAlignment (w (StateField.GPR 31#5) v s) = (Aligned v 4) := by simp_all only [CheckSPAlignment, state_simp_rules, minimal_theory, bitvec_rules] +theorem CheckSPAligment_w_sp_of (h : Aligned v 4) : + CheckSPAlignment (w (StateField.GPR 31#5) v s) := by + simp only [CheckSPAlignment, read_gpr, r_of_w_same, zeroExtend_eq, h] + @[state_simp_rules] -theorem CheckSPAligment_of_write_mem_bytes : +theorem CheckSPAligment_write_mem_bytes_eq : CheckSPAlignment (write_mem_bytes n addr v s) = CheckSPAlignment s := by simp_all only [CheckSPAlignment, state_simp_rules, minimal_theory, bitvec_rules] +theorem CheckSPAligment_write_mem_bytes_of : + CheckSPAlignment s → CheckSPAlignment (write_mem_bytes n addr v s) := by + simp only [CheckSPAligment_write_mem_bytes_eq, imp_self] + @[state_simp_rules] theorem CheckSPAlignment_AddWithCarry_64_4 (st : ArmState) (y : BitVec 64) (carry_in : BitVec 1) (x_aligned : CheckSPAlignment st) - (y_carry_in_aligned : Aligned (BitVec.add (extractLsb 3 0 y) (zeroExtend 4 carry_in)) 4) + (y_carry_in_aligned : Aligned (BitVec.add (extractLsb' 0 4 y) (zeroExtend 4 carry_in)) 4) : Aligned (AddWithCarry (r (StateField.GPR 31#5) st) y carry_in).fst 4 := by simp_all only [CheckSPAlignment, read_gpr, zeroExtend_eq, Nat.sub_zero, add_eq, Aligned_AddWithCarry_64_4] @@ -438,7 +448,7 @@ def decode_bit_masks (immN : BitVec 1) (imms : BitVec 6) (immr : BitVec 6) let r := immr &&& levels let diff := s - r let esize := 1 <<< len - let d := extractLsb (len - 1) 0 diff + let d := extractLsb' 0 len diff let welem := zeroExtend esize (allOnes (s.toNat + 1)) let telem := zeroExtend esize (allOnes (d.toNat + 1)) let wmask := replicate (M/esize) $ rotateRight welem r.toNat @@ -498,18 +508,16 @@ instance : ToString SIMDThreeSameLogicalType where toString a := toString (repr ---------------------------------------------------------------------- @[state_simp_rules] -def Vpart_read (n : BitVec 5) (part width : Nat) (s : ArmState) (H : width > 0) +def Vpart_read (n : BitVec 5) (part width : Nat) (s : ArmState) : BitVec width := -- assert n >= 0 && n <= 31; -- assert part IN {0, 1}; - have h1: width - 1 + 1 = width := by omega - have h2: (width * 2 - 1 - width + 1) = width := by omega if part = 0 then -- assert width < 128; - BitVec.cast h1 $ extractLsb (width-1) 0 $ read_sfp 128 n s + extractLsb' 0 width $ read_sfp 128 n s else -- assert width IN {32,64}; - BitVec.cast h2 $ extractLsb (width*2-1) width $ read_sfp 128 n s + extractLsb' width width $ read_sfp 128 n s @[state_simp_rules] @@ -522,7 +530,7 @@ def Vpart_write (n : BitVec 5) (part width : Nat) (val : BitVec width) (s : ArmS write_sfp width n val s else -- assert width == 64 - let res := (extractLsb 63 0 val) ++ (read_sfp 64 n s) + let res := (extractLsb' 0 64 val) ++ (read_sfp 64 n s) write_sfp 128 n res s ---------------------------------------------------------------------- @@ -628,13 +636,10 @@ example : rev_vector 32 16 8 0xaabbccdd#32 (by decide) /-- Divide bv `vector` into elements, each of size `size`. This function gets the `e`'th element from the `vector`. -/ @[state_simp_rules] -def elem_get (vector : BitVec n) (e : Nat) (size : Nat) - (h: size > 0): BitVec size := +def elem_get (vector : BitVec n) (e : Nat) (size : Nat) : BitVec size := -- assert (e+1)*size <= n let lo := e * size - let hi := lo + size - 1 - have h : hi - lo + 1 = size := by simp only [hi, lo]; omega - BitVec.cast h $ extractLsb hi lo vector + extractLsb' lo size vector /-- Divide bv `vector` into elements, each of size `size`. This function sets the `e`'th element in the `vector`. -/ @@ -663,7 +668,7 @@ deriving DecidableEq, Repr export ShiftInfo (esize elements shift unsigned round accumulate) @[state_simp_rules] -def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0) +def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) : BitVec n := -- assert shift > 0 let fn := if unsigned then ushiftRight else sshiftRight @@ -673,8 +678,7 @@ def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0 BitVec.ofInt (n + 1) rounded else BitVec.ofInt (n + 1) value - have h₀ : n - 1 - 0 + 1 = n := by omega - BitVec.cast h₀ $ extractLsb (n-1) 0 (fn rounded_bv shift) + extractLsb' 0 n (fn rounded_bv shift) @[state_simp_rules] def Int_with_unsigned (unsigned : Bool) (value : BitVec n) : Int := @@ -686,9 +690,9 @@ def shift_right_common_aux if h : info.elements ≤ e then result else - let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize info.h - let shift_elem := RShr info.unsigned elem info.shift info.round info.h - let acc_elem := elem_get operand2 e info.esize info.h + shift_elem + let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize + let shift_elem := RShr info.unsigned elem info.shift info.round + let acc_elem := elem_get operand2 e info.esize + shift_elem let result := elem_set result e info.esize acc_elem info.h have _ : info.elements - (e + 1) < info.elements - e := by omega shift_right_common_aux (e + 1) info operand operand2 result @@ -709,7 +713,7 @@ def shift_left_common_aux if h : info.elements ≤ e then result else - let elem := elem_get operand e info.esize info.h + let elem := elem_get operand e info.esize let shift_elem := elem <<< info.shift let result := elem_set result e info.esize shift_elem info.h have _ : info.elements - (e + 1) < info.elements - e := by omega diff --git a/Arm/Insts/DPI/Logical_imm.lean b/Arm/Insts/DPI/Logical_imm.lean index a2d4d8c8..0a1d7afb 100644 --- a/Arm/Insts/DPI/Logical_imm.lean +++ b/Arm/Insts/DPI/Logical_imm.lean @@ -63,7 +63,7 @@ def MoveWidePreferred (sf immN : BitVec 1) (imms immr : BitVec 6) : Bool := false -- NOTE: the second conjunct below is semantically equivalent to the ASL code -- !((immN:imms) IN {'00xxxxx'}) - else if sf = 0#1 ∧ ¬(immN = 0#1 ∧ imms.extractLsb 5 5 = 0#1) then + else if sf = 0#1 ∧ ¬(immN = 0#1 ∧ imms.extractLsb' 5 1 = 0#1) then false -- for MOVZ must contain no more than 16 ones diff --git a/Arm/Insts/DPR/Data_processing_one_source.lean b/Arm/Insts/DPR/Data_processing_one_source.lean index 36c6165c..145f004a 100644 --- a/Arm/Insts/DPR/Data_processing_one_source.lean +++ b/Arm/Insts/DPR/Data_processing_one_source.lean @@ -55,7 +55,7 @@ private theorem opc_and_sf_constraint (x : BitVec 2) (y : BitVec 1) @[state_simp_rules] def exec_data_processing_rev (inst : Data_processing_one_source_cls) (s : ArmState) : ArmState := - let opc : BitVec 2 := extractLsb 1 0 inst.opcode + let opc : BitVec 2 := extractLsb' 0 2 inst.opcode if H₁ : opc = 0b11#2 ∧ inst.sf = 0b0#1 then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s else @@ -65,16 +65,16 @@ def exec_data_processing_rev let esize := 8 have opc_h₁ : opc.toNat ≥ 0 := by simp only [ge_iff_le, Nat.zero_le] have opc_h₂ : opc.toNat < 4 := by - refine BitVec.isLt (extractLsb 1 0 inst.opcode) + refine BitVec.isLt (extractLsb' 0 2 inst.opcode) have opc_sf_h : ¬(opc.toNat = 3 ∧ inst.sf.toNat = 0) := by - apply opc_and_sf_constraint (extractLsb 1 0 inst.opcode) inst.sf H₁ + apply opc_and_sf_constraint (extractLsb' 0 2 inst.opcode) inst.sf H₁ have h₀ : 0 < esize := by decide have h₁ : esize ≤ container_size := by apply shiftLeft_ge have h₂ : container_size ≤ datasize := by apply container_size_le_datasize opc.toNat inst.sf.toNat opc_h₁ opc_h₂ opc_sf_h have h₃ : esize ∣ container_size := by simp only [esize, container_size] - generalize BitVec.toNat (extractLsb 1 0 inst.opcode) = x + generalize BitVec.toNat (extractLsb' 0 2 inst.opcode) = x simp only [Nat.shiftLeft_eq] generalize 2 ^ x = n simp only [Nat.dvd_mul_right] diff --git a/Arm/Insts/DPR/Data_processing_two_source.lean b/Arm/Insts/DPR/Data_processing_two_source.lean index 22798d4e..a48ac981 100644 --- a/Arm/Insts/DPR/Data_processing_two_source.lean +++ b/Arm/Insts/DPR/Data_processing_two_source.lean @@ -19,7 +19,7 @@ open BitVec def exec_data_processing_shift (inst : Data_processing_two_source_cls) (s : ArmState) : ArmState := let datasize := 32 <<< inst.sf.toNat - let shift_type := decode_shift $ extractLsb 1 0 inst.opcode + let shift_type := decode_shift $ extractLsb' 0 2 inst.opcode let operand2 := read_gpr_zr datasize inst.Rm s let amount := BitVec.ofInt 6 (operand2.toInt % datasize) let operand := read_gpr_zr datasize inst.Rn s diff --git a/Arm/Insts/DPSFP/Advanced_simd_copy.lean b/Arm/Insts/DPSFP/Advanced_simd_copy.lean index bd4cf712..ffd077a1 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_copy.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_copy.lean @@ -32,14 +32,14 @@ def exec_dup_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState : if size > 3 ∨ (size = 3 ∧ inst.Q = 0) then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s else - let index := (extractLsb 4 (size + 1) inst.imm5).toNat + let index := (extractLsb' (size + 1) (4 - size) inst.imm5).toNat let idxdsize := 64 <<< (lsb inst.imm5 4).toNat let esize := 8 <<< size let datasize := 64 <<< inst.Q.toNat let elements := datasize / esize let operand := read_sfp idxdsize inst.Rn s have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let element := elem_get operand index esize h₀ + let element := elem_get operand index esize let result := dup_aux 0 elements esize element (BitVec.zero datasize) h₀ -- State Updates let s := write_pc ((read_pc s) + 4#64) s @@ -69,14 +69,14 @@ def exec_ins_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState : if size > 3 then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s else - let dst_index := (extractLsb 4 (size + 1) inst.imm5).toNat - let src_index := (extractLsb 3 size inst.imm4).toNat + let dst_index := (extractLsb' (size + 1) (4 - size) inst.imm5).toNat + let src_index := (extractLsb' size (4 - size) inst.imm4).toNat let idxdsize := 64 <<< (lsb inst.imm4 3).toNat let esize := 8 <<< size let operand := read_sfp idxdsize inst.Rn s let result := read_sfp 128 inst.Rd s have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let elem := elem_get operand src_index esize h₀ + let elem := elem_get operand src_index esize let result := elem_set result dst_index esize elem h₀ -- State Updates let s := write_pc ((read_pc s) + 4#64) s @@ -89,7 +89,7 @@ def exec_ins_general (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState : if size > 3 then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s else - let index := (extractLsb 4 (size + 1) inst.imm5).toNat + let index := (extractLsb' (size + 1) (4 - size) inst.imm5).toNat let esize := 8 <<< size let element := read_gpr esize inst.Rn s let result := read_sfp 128 inst.Rd s @@ -113,12 +113,11 @@ def exec_smov_umov (inst : Advanced_simd_copy_cls) (s : ArmState) (signed : Bool (datasize = 32 ∧ esize >= 64)) then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s else - let index := (extractLsb 4 (size + 1) inst.imm5).toNat + let index := (extractLsb' (size + 1) (4 - size) inst.imm5).toNat let idxdsize := 64 <<< (lsb inst.imm5 4).toNat -- if index == 0 then CheckFPEnabled64 else CheckFPAdvSIMDEnabled64 let operand := read_sfp idxdsize inst.Rn s - have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let element := elem_get operand index esize h₀ + let element := elem_get operand index esize let result := if signed then signExtend datasize element else zeroExtend datasize element -- State Updates let s := write_pc ((read_pc s) + 4#64) s diff --git a/Arm/Insts/DPSFP/Advanced_simd_extract.lean b/Arm/Insts/DPSFP/Advanced_simd_extract.lean index cc52b4bb..9ec571cd 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_extract.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_extract.lean @@ -31,12 +31,8 @@ def exec_advanced_simd_extract let hi := read_sfp datasize inst.Rm s let lo := read_sfp datasize inst.Rn s let concat := hi ++ lo - let result := extractLsb (position + datasize - 1) position concat - have h_datasize : 1 <= datasize := by simp_all! [datasize]; split <;> decide - have h : (position + datasize - 1 - position + 1) = datasize := by - rw [Nat.add_sub_assoc, Nat.add_sub_self_left] - exact Nat.sub_add_cancel h_datasize; trivial - let s := write_sfp datasize inst.Rd (BitVec.cast h result) s + let result := extractLsb' position datasize concat + let s := write_sfp datasize inst.Rd result s let s := write_pc ((read_pc s) + 4#64) s s diff --git a/Arm/Insts/DPSFP/Advanced_simd_modified_immediate.lean b/Arm/Insts/DPSFP/Advanced_simd_modified_immediate.lean index cfec1f8f..a46404a8 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_modified_immediate.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_modified_immediate.lean @@ -52,7 +52,7 @@ def decode_immediate_op (inst : Advanced_simd_modified_immediate_cls) else (some ImmediateOp.MOVI, s) def AdvSIMDExpandImm (op : BitVec 1) (cmode : BitVec 4) (imm8 : BitVec 8) : BitVec 64 := - let cmode_high3 := extractLsb 3 1 cmode + let cmode_high3 := extractLsb' 1 3 cmode let cmode_low1 := lsb cmode 0 match cmode_high3 with | 0b000#3 => replicate 2 $ BitVec.zero 24 ++ imm8 @@ -82,13 +82,13 @@ def AdvSIMDExpandImm (op : BitVec 1) (cmode : BitVec 4) (imm8 : BitVec 8) : BitV else if cmode_low1 = 1 ∧ op = 0 then let imm32 := lsb imm8 7 ++ ~~~(lsb imm8 6) ++ (replicate 5 $ lsb imm8 6) ++ - extractLsb 5 0 imm8 ++ BitVec.zero 19 + extractLsb' 0 6 imm8 ++ BitVec.zero 19 replicate 2 imm32 else -- Assume not UsingAArch32() -- if UsingAArch32() then ReservedEncoding(); lsb imm8 7 ++ ~~~(lsb imm8 6) ++ - (replicate 8 $ lsb imm8 6) ++ extractLsb 5 0 imm8 ++ BitVec.zero 48 + (replicate 8 $ lsb imm8 6) ++ extractLsb' 0 6 imm8 ++ BitVec.zero 48 private theorem mul_div_norm_form_lemma (n m : Nat) (_h1 : 0 < m) (h2 : n ∣ m) : @@ -107,9 +107,9 @@ def exec_advanced_simd_modified_immediate let datasize := 64 <<< inst.Q.toNat let imm8 := inst.a ++ inst.b ++ inst.c ++ inst.d ++ inst.e ++ inst.f ++ inst.g ++ inst.h let imm16 : BitVec 16 := - extractLsb 7 7 imm8 ++ ~~~ (extractLsb 6 6 imm8) ++ - (replicate 2 $ extractLsb 6 6 imm8) ++ extractLsb 5 0 imm8 ++ - BitVec.zero 6 + extractLsb' 7 1 imm8 ++ ~~~ (extractLsb' 6 1 imm8) ++ + (replicate 2 $ extractLsb' 6 1 imm8) ++ + extractLsb' 0 6 imm8 ++ BitVec.zero 6 let imm64 := AdvSIMDExpandImm inst.op inst.cmode imm8 have h₁ : 16 * (datasize / 16) = datasize := by omega have h₂ : 64 * (datasize / 64) = datasize := by omega diff --git a/Arm/Insts/DPSFP/Advanced_simd_permute.lean b/Arm/Insts/DPSFP/Advanced_simd_permute.lean index 5ccc6bff..95a2920c 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_permute.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_permute.lean @@ -22,8 +22,8 @@ def trn_aux (p : Nat) (pairs : Nat) (esize : Nat) (part : Nat) result else let idx_from := 2 * p + part - let op1_part := elem_get operand1 idx_from esize h - let op2_part := elem_get operand2 idx_from esize h + let op1_part := elem_get operand1 idx_from esize + let op2_part := elem_get operand2 idx_from esize let result := elem_set result (2 * p) esize op1_part h let result := elem_set result (2 * p + 1) esize op2_part h have h₁ : pairs - (p + 1) < pairs - p := by omega diff --git a/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean b/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean index 3f2b25c2..47a1c98d 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean @@ -23,12 +23,11 @@ def exec_advanced_simd_scalar_copy if size > 3 ∨ inst.imm4 ≠ 0b0000#4 ∨ inst.op ≠ 0 then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s else - let index := extractLsb 4 (size + 1) inst.imm5 + let index := extractLsb' (size + 1) (4 - size) inst.imm5 let idxdsize := 64 <<< (lsb inst.imm5 4).toNat let esize := 8 <<< size let operand := read_sfp idxdsize inst.Rn s - have h : esize > 0 := by apply zero_lt_shift_left_pos (by decide) - let result := elem_get operand index.toNat esize h + let result := elem_get operand index.toNat esize -- State Updates let s := write_pc ((read_pc s) + 4#64) s let s := write_sfp esize inst.Rd result s diff --git a/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean b/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean index b032904c..16024f23 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean @@ -35,10 +35,10 @@ def tblx_aux (i : Nat) (elements : Nat) (indices : BitVec datasize) result else have h₁ : 8 > 0 := by decide - let index := (elem_get indices i 8 h₁).toNat + let index := (elem_get indices i 8).toNat let result := if index < 16 * regs then - let val := elem_get table index 8 h₁ + let val := elem_get table index 8 elem_set result i 8 val h₁ else result diff --git a/Arm/Insts/DPSFP/Advanced_simd_three_different.lean b/Arm/Insts/DPSFP/Advanced_simd_three_different.lean index 53800ea8..87dede23 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_three_different.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_three_different.lean @@ -37,8 +37,8 @@ def pmull_op (e : Nat) (esize : Nat) (elements : Nat) (x : BitVec n) if h₀ : elements <= e then result else - let element1 := elem_get x e esize H - let element2 := elem_get y e esize H + let element1 := elem_get x e esize + let element2 := elem_get y e esize let elem_result := polynomial_mult element1 element2 have h₁ : esize + esize = 2 * esize := by omega have h₂ : 2 * esize > 0 := by omega @@ -58,9 +58,8 @@ def exec_pmull (inst : Advanced_simd_three_different_cls) (s : ArmState) : ArmSt let datasize := 64 let part := inst.Q.toNat let elements := datasize / esize - have h₁ : datasize > 0 := by decide - let operand1 := Vpart_read inst.Rn part datasize s h₁ - let operand2 := Vpart_read inst.Rm part datasize s h₁ + let operand1 := Vpart_read inst.Rn part datasize s + let operand2 := Vpart_read inst.Rm part datasize s let result := pmull_op 0 esize elements operand1 operand2 (BitVec.zero (2*datasize)) h₀ let s := write_sfp (datasize*2) inst.Rd result s diff --git a/Arm/Insts/DPSFP/Advanced_simd_three_same.lean b/Arm/Insts/DPSFP/Advanced_simd_three_same.lean index a532ac04..7da3cc42 100644 --- a/Arm/Insts/DPSFP/Advanced_simd_three_same.lean +++ b/Arm/Insts/DPSFP/Advanced_simd_three_same.lean @@ -25,8 +25,8 @@ def binary_vector_op_aux (e : Nat) (elems : Nat) (esize : Nat) result else have h₁ : e < elems := by omega - let element1 := elem_get x e esize H - let element2 := elem_get y e esize H + let element1 := elem_get x e esize + let element2 := elem_get y e esize let elem_result := op element1 element2 let result := elem_set result e esize elem_result H have ht1 : elems - (e + 1) < elems - e := by omega diff --git a/Arm/Insts/DPSFP/Conversion_between_FP_and_Int.lean b/Arm/Insts/DPSFP/Conversion_between_FP_and_Int.lean index b4a3642d..b1f60918 100644 --- a/Arm/Insts/DPSFP/Conversion_between_FP_and_Int.lean +++ b/Arm/Insts/DPSFP/Conversion_between_FP_and_Int.lean @@ -19,12 +19,11 @@ open BitVec @[state_simp_rules] def fmov_general_aux (intsize : Nat) (fltsize : Nat) (op : FPConvOp) (part : Nat) (inst : Conversion_between_FP_and_Int_cls) (s : ArmState) - (H : 0 < fltsize) : ArmState := -- Assume CheckFPEnabled64() match op with | FPConvOp.FPConvOp_MOV_FtoI => - let fltval := Vpart_read inst.Rn part fltsize s H + let fltval := Vpart_read inst.Rn part fltsize s let intval := zeroExtend intsize fltval -- State Update let s := write_gpr intsize inst.Rd intval s @@ -32,10 +31,9 @@ def fmov_general_aux (intsize : Nat) (fltsize : Nat) (op : FPConvOp) s | FPConvOp.FPConvOp_MOV_ItoF => let intval := read_gpr intsize inst.Rn s - let fltval := extractLsb (fltsize - 1) 0 intval + let fltval := extractLsb' 0 fltsize intval -- State Update - have h₀ : fltsize - 1 - 0 + 1 = fltsize := by omega - let s := Vpart_write inst.Rd part fltsize (BitVec.cast h₀ fltval) s + let s := Vpart_write inst.Rd part fltsize fltval s let s := write_pc ((read_pc s) + 4#64) s s | _ => write_err (StateError.Other s!"fmov_general_aux called with non-FMOV op!") s @@ -45,13 +43,7 @@ def exec_fmov_general (inst : Conversion_between_FP_and_Int_cls) (s : ArmState): ArmState := let intsize := 32 <<< inst.sf.toNat let decode_fltsize := if inst.ftype = 0b10#2 then 64 else (8 <<< (inst.ftype ^^^ 0b10#2).toNat) - have H: 0 < decode_fltsize := by - simp only [decode_fltsize, beq_iff_eq] - split - · decide - · generalize BitVec.toNat (inst.ftype ^^^ 2#2) = x - apply zero_lt_shift_left_pos (by decide) - match (extractLsb 2 1 inst.opcode) ++ inst.rmode with + match (extractLsb' 1 2 inst.opcode) ++ inst.rmode with | 1100 => -- FMOV if decode_fltsize ≠ 16 ∧ decode_fltsize ≠ intsize then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s @@ -60,7 +52,7 @@ def exec_fmov_general then FPConvOp.FPConvOp_MOV_ItoF else FPConvOp.FPConvOp_MOV_FtoI let part := 0 - fmov_general_aux intsize decode_fltsize op part inst s H + fmov_general_aux intsize decode_fltsize op part inst s | 1101 => -- FMOV D[1] if intsize ≠ 64 ∨ inst.ftype ≠ 0b10#2 then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s @@ -69,13 +61,13 @@ def exec_fmov_general then FPConvOp.FPConvOp_MOV_ItoF else FPConvOp.FPConvOp_MOV_FtoI let part := 1 - fmov_general_aux intsize decode_fltsize op part inst s H + fmov_general_aux intsize decode_fltsize op part inst s | _ => write_err (StateError.Other s!"exec_fmov_general called with non-FMOV instructions!") s @[state_simp_rules] def exec_conversion_between_FP_and_Int (inst : Conversion_between_FP_and_Int_cls) (s : ArmState) : ArmState := - if inst.ftype = 0b10#2 ∧ (extractLsb 2 1 inst.opcode) ++ inst.rmode ≠ 0b1101#4 then + if inst.ftype = 0b10#2 ∧ (extractLsb' 1 2 inst.opcode) ++ inst.rmode ≠ 0b1101#4 then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s -- Assume IsFeatureImplemented(FEAT_FP16) is true else @@ -92,7 +84,7 @@ partial def Conversion_between_FP_and_Int_cls.fmov_general.rand : Cosim.CosimM ( let sf := ← BitVec.rand 1 let intsize := 32 <<< sf.toNat let decode_fltsize := if ftype == 0b10#2 then 64 else (8 <<< (ftype ^^^ 0b10#2).toNat) - if ftype == 0b10#2 && ((extractLsb 2 1 opcode) ++ rmode) != 0b1101#4 || + if ftype == 0b10#2 && ((extractLsb' 1 2 opcode) ++ rmode) != 0b1101#4 || decode_fltsize != 16 && decode_fltsize != intsize || intsize != 64 || ftype != 0b10#2 then Conversion_between_FP_and_Int_cls.fmov_general.rand diff --git a/Arm/Insts/DPSFP/Crypto_aes.lean b/Arm/Insts/DPSFP/Crypto_aes.lean index 99d99027..e58e3f73 100644 --- a/Arm/Insts/DPSFP/Crypto_aes.lean +++ b/Arm/Insts/DPSFP/Crypto_aes.lean @@ -52,8 +52,7 @@ def FFmul02 (b : BitVec 8) : BitVec 8 := 0x1E1C1A18161412100E0C0A0806040200#128 -- 0 ] let lo := b.toNat * 8 - let hi := lo + 7 - BitVec.cast (by omega) $ extractLsb hi lo $ BitVec.flatten FFmul_02 + extractLsb' lo 8 $ BitVec.flatten FFmul_02 def FFmul03 (b : BitVec 8) : BitVec 8 := let FFmul_03 := @@ -76,8 +75,7 @@ def FFmul03 (b : BitVec 8) : BitVec 8 := 0x111217141D1E1B18090A0F0C05060300#128 -- 0 ] let lo := b.toNat * 8 - let hi := lo + 7 - BitVec.cast (by omega) $ extractLsb hi lo $ BitVec.flatten FFmul_03 + extractLsb' lo 8 $ BitVec.flatten FFmul_03 def AESMixColumns (op : BitVec 128) : BitVec 128 := AESCommon.MixColumns op FFmul02 FFmul03 diff --git a/Arm/Insts/DPSFP/Crypto_three_reg_sha512.lean b/Arm/Insts/DPSFP/Crypto_three_reg_sha512.lean index e0535d58..e62b4d8e 100644 --- a/Arm/Insts/DPSFP/Crypto_three_reg_sha512.lean +++ b/Arm/Insts/DPSFP/Crypto_three_reg_sha512.lean @@ -21,14 +21,14 @@ open BitVec def sha512h (x : BitVec 128) (y : BitVec 128) (w : BitVec 128) : BitVec 128 := open sha512_helpers in - let y_127_64 := extractLsb 127 64 y - let y_63_0 := extractLsb 63 0 y + let y_127_64 := extractLsb' 64 64 y + let y_63_0 := extractLsb' 0 64 y let msigma1 := sigma_big_1 y_127_64 - let x_63_0 := extractLsb 63 0 x - let x_127_64 := extractLsb 127 64 x + let x_63_0 := extractLsb' 0 64 x + let x_127_64 := extractLsb' 64 64 x let vtmp_127_64 := ch y_127_64 x_63_0 x_127_64 - let w_127_64 := extractLsb 127 64 w - let w_63_0 := extractLsb 63 0 w + let w_127_64 := extractLsb' 64 64 w + let w_63_0 := extractLsb' 0 64 w let vtmp_127_64 := vtmp_127_64 + msigma1 + w_127_64 let tmp := vtmp_127_64 + y_63_0 let msigma1 := sigma_big_1 tmp @@ -40,16 +40,16 @@ def sha512h (x : BitVec 128) (y : BitVec 128) (w : BitVec 128) def sha512h2 (x : BitVec 128) (y : BitVec 128) (w : BitVec 128) : BitVec 128 := open sha512_helpers in - let y_63_0 := extractLsb 63 0 y - let y_127_64 := extractLsb 127 64 y + let y_63_0 := extractLsb' 0 64 y + let y_127_64 := extractLsb' 64 64 y let nsigma0 := sigma_big_0 y_63_0 - let x_63_0 := extractLsb 63 0 x + let x_63_0 := extractLsb' 0 64 x let vtmp_127_64 := maj x_63_0 y_127_64 y_63_0 - let w_127_64 := extractLsb 127 64 w + let w_127_64 := extractLsb' 64 64 w let vtmp_127_64 := vtmp_127_64 + nsigma0 + w_127_64 let nsigma0 := sigma_big_0 vtmp_127_64 let vtmp_63_0 := maj vtmp_127_64 y_63_0 y_127_64 - let w_63_0 := extractLsb 63 0 w + let w_63_0 := extractLsb' 0 64 w let vtmp_63_0 := vtmp_63_0 + nsigma0 + w_63_0 let result := vtmp_127_64 ++ vtmp_63_0 result @@ -57,15 +57,15 @@ def sha512h2 (x : BitVec 128) (y : BitVec 128) (w : BitVec 128) : def sha512su1 (x : BitVec 128) (y : BitVec 128) (w : BitVec 128) : BitVec 128 := open sha512_helpers in - let x_127_64 := extractLsb 127 64 x + let x_127_64 := extractLsb' 64 64 x let sig1 := sigma_1 x_127_64 - let w_127_64 := extractLsb 127 64 w - let y_127_64 := extractLsb 127 64 y + let w_127_64 := extractLsb' 64 64 w + let y_127_64 := extractLsb' 64 64 y let vtmp_127_64 := w_127_64 + sig1 + y_127_64 - let x_63_0 := extractLsb 63 0 x + let x_63_0 := extractLsb' 0 64 x let sig1 := sigma_1 x_63_0 - let w_63_0 := extractLsb 63 0 w - let y_63_0 := extractLsb 63 0 y + let w_63_0 := extractLsb' 0 64 w + let y_63_0 := extractLsb' 0 64 y let vtmp_63_0 := w_63_0 + sig1 + y_63_0 let result := vtmp_127_64 ++ vtmp_63_0 result diff --git a/Arm/Insts/DPSFP/Crypto_two_reg_sha512.lean b/Arm/Insts/DPSFP/Crypto_two_reg_sha512.lean index 0fbe96e8..f3bbd1fa 100644 --- a/Arm/Insts/DPSFP/Crypto_two_reg_sha512.lean +++ b/Arm/Insts/DPSFP/Crypto_two_reg_sha512.lean @@ -19,10 +19,10 @@ open BitVec def sha512su0 (x : BitVec 128) (w : BitVec 128) : BitVec 128 := open sha512_helpers in - let w_127_64 := extractLsb 127 64 w - let w_63_0 := extractLsb 63 0 w + let w_127_64 := extractLsb' 64 64 w + let w_63_0 := extractLsb' 0 64 w let sig0 := sigma_0 w_127_64 - let x_63_0 := extractLsb 63 0 x + let x_63_0 := extractLsb' 0 64 x let vtmp_63_0 := w_63_0 + sig0 let sig0 := sigma_0 x_63_0 let vtmp_127_64 := w_127_64 + sig0 diff --git a/Arm/Insts/LDST/Reg_pair.lean b/Arm/Insts/LDST/Reg_pair.lean index 164a8b5d..5c0c6263 100644 --- a/Arm/Insts/LDST/Reg_pair.lean +++ b/Arm/Insts/LDST/Reg_pair.lean @@ -40,7 +40,7 @@ def reg_pair_constrain_unpredictable (wback : Bool) (inst : Reg_pair_cls) : Bool @[state_simp_rules] def reg_pair_operation (inst : Reg_pair_cls) (inst_str : String) (signed : Bool) (datasize : Nat) (offset : BitVec 64) (s : ArmState) - (H1 : 8 ∣ datasize) (H2 : 0 < datasize) : ArmState := + (H1 : 8 ∣ datasize): ArmState := -- Note: we do not need to model the ASL function -- "CreateAccDescGPR" here, given the simplicity of our memory -- model @@ -64,18 +64,15 @@ def reg_pair_operation (inst : Reg_pair_cls) (inst_str : String) (signed : Bool) let full_data := data2 ++ data1 write_mem_bytes (2 * (datasize / 8)) address (BitVec.cast h3 full_data) s | _ => -- LOAD - have h4 : datasize - 1 - 0 + 1 = datasize := by - simp; apply Nat.sub_add_cancel H2 - have h5 : 2 * datasize - 1 - datasize + 1 = datasize := by omega let full_data := read_mem_bytes (2 * (datasize / 8)) address s - let data1 := extractLsb (datasize - 1) 0 full_data - let data2 := extractLsb ((2 * datasize) - 1) datasize full_data + let data1 := extractLsb' 0 datasize full_data + let data2 := extractLsb' datasize datasize full_data if not inst.SIMD? ∧ signed then let s := write_gpr 64 inst.Rt (signExtend 64 data1) s write_gpr 64 inst.Rt2 (signExtend 64 data2) s else - let s:= ldst_write inst.SIMD? datasize inst.Rt (BitVec.cast h4 data1) s - ldst_write inst.SIMD? datasize inst.Rt2 (BitVec.cast h5 data2) s + let s:= ldst_write inst.SIMD? datasize inst.Rt data1 s + ldst_write inst.SIMD? datasize inst.Rt2 data2 s if inst.wback then let address := if inst.postindex then address + offset else address write_gpr 64 inst.Rn address s @@ -102,11 +99,8 @@ def exec_reg_pair_common (inst : Reg_pair_cls) (inst_str : String) (s : ArmState let offset := (signExtend 64 inst.imm7) <<< scale have H1 : 8 ∣ datasize := by simp_all! only [gt_iff_lt, Nat.shiftLeft_eq, Nat.dvd_mul_right, datasize] - have H2 : 0 < datasize := by - simp_all! only [datasize] - apply zero_lt_shift_left_pos (by decide) -- State Updates - let s' := reg_pair_operation inst inst_str signed datasize offset s H1 H2 + let s' := reg_pair_operation inst inst_str signed datasize offset s H1 let s' := write_pc ((read_pc s) + 4#64) s' s' diff --git a/Arm/Insts/LDST/Reg_unscaled_imm.lean b/Arm/Insts/LDST/Reg_unscaled_imm.lean index eca300cf..1e0db697 100644 --- a/Arm/Insts/LDST/Reg_unscaled_imm.lean +++ b/Arm/Insts/LDST/Reg_unscaled_imm.lean @@ -17,7 +17,7 @@ open BitVec @[state_simp_rules] def exec_ldstur (inst : Reg_unscaled_imm_cls) (s : ArmState) : ArmState := - let scale := (extractLsb 1 1 inst.opc ++ inst.size).toNat + let scale := (extractLsb' 1 1 inst.opc ++ inst.size).toNat if scale > 4 then write_err (StateError.Illegal s!"Illegal {inst} encountered!") s else @@ -44,9 +44,9 @@ def exec_ldstur @[state_simp_rules] def exec_reg_unscaled_imm (inst : Reg_unscaled_imm_cls) (s : ArmState) : ArmState := - if inst.VR = 0b1#1 then + if inst.VR = 0b1#1 then exec_ldstur inst s - else + else write_err (StateError.Unimplemented s!"Unsupported instruction {inst} encountered!") s end LDST diff --git a/Arm/Memory/AddressNormalization.lean b/Arm/Memory/AddressNormalization.lean new file mode 100644 index 00000000..c6153afe --- /dev/null +++ b/Arm/Memory/AddressNormalization.lean @@ -0,0 +1,207 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Siddharth Bhat, Tobias Grosser +-/ + +/- +This file implements bitvector expression simplification simprocs. +We perform the following additional changes: + +1. Canonicalizing bitvector expression to always have constants on the left. + Recall that the default associativity of addition is to the left: x + y + z = (x + y) + z. + If we thus normalize our expressions to have constants on the left, + and if we constant-fold constants, we will naturally perform canonicalization. + That is, the two rewrites: + + (a) (x + c) -> (c + x). + (b) x + (y + z) -> (x + y) + z. + + combine to achieve constant folding. Observe an example: + + ((x + 10) + 20) + -b-> (20 + (x + 10)) + -b-> (20 + (10 + x)) + -a-> (20 + 10) + x + -reduceAdd-> 30 + x + +2. Canonicalizing (a + b) % n → a % n + b % n by exploiting `bv_omega`, + and eventually, `simp_mem`. +-/ +import Lean +import Arm.Memory.Attr +import Arm.Attr +import Tactics.Common + +open Lean Meta Elab Simp + + +theorem Nat.mod_eq_sub {x y : Nat} (h : x ≥ y) (h' : x - y < y) : + x % y = x - y := by + rw [Nat.mod_eq_sub_mod h, Nat.mod_eq_of_lt h'] + +private def mkLTNat (x y : Expr) : Expr := + mkAppN (.const ``LT.lt [levelZero]) #[mkConst ``Nat, mkConst ``instLTNat, x, y] + +private def mkLENat (x y : Expr) : Expr := + mkAppN (.const ``LE.le [levelZero]) #[mkConst ``Nat, mkConst ``instLENat, x, y] + +private def mkGENat (x y : Expr) : Expr := mkLENat y x + +private def mkSubNat (x y : Expr) : Expr := + let lz := levelZero + let nat := mkConst ``Nat + let instSub := mkConst ``instSubNat + let instHSub := mkAppN (mkConst ``instHSub [lz]) #[nat, instSub] + mkAppN (mkConst ``HSub.hSub [lz, lz, lz]) #[nat, nat, nat, instHSub, x, y] + +/-- +Given an expression of the form `n#w`, return the value of `n` if it is a ground constant. + +Notice that this is different from `getBitVecValue?` in that here we allow `w` to be symbolic. +Hence, we might not know the width, explaining why we return a `Nat` rather than a `BitVec`. +-/ +def getBitVecOfNatValue? (e : Expr) : (Option (Expr × Expr)) := + match_expr e with + | BitVec.ofNat nExpr vExpr => some (nExpr, vExpr) + | _ => none + +/-- +Try to build a proof for `ty` by reduction to `omega`. +Return a proof of `ty` on success, or `none` if omega failed to prove the goal. + +This is to be used to automatically prove inbounds constraints to eliminate modulos +in a simproc, hence the use of `SimpM`. +We may eventually want to exploit our memory automation framework to bring in +more `omega` facts. +-/ +@[inline] def dischargeByOmega (ty : Expr) : SimpM (Option Expr) := do + let proof : Expr ← mkFreshExprMVar ty + let g := proof.mvarId! + let some g ← g.falseOrByContra + | return none + try + g.withContext (do Lean.Elab.Tactic.Omega.omega (← getLocalHyps).toList g {}) + catch _ => + return none + return some proof + +-- x % n = x if x < n +@[inline] def reduceModOfLt (x : Expr) (n : Expr) : SimpM Step := do + trace[Tactic.address_normalization] "{processingEmoji} reduceModOfLt '{x} % {n}'" + let ltTy := mkLTNat x n + let some p ← dischargeByOmega ltTy + | return .continue + let eqProof ← mkAppM ``Nat.mod_eq_of_lt #[p] + trace[Tactic.address_normalization] "{checkEmoji} reduceModOfLt '{x} % {n}'" + return .done { expr := x, proof? := eqProof : Result } + +-- x % n = x - n if x >= n and x - n < n +@[inline] def reduceModSub (x : Expr) (n : Expr) : SimpM Step := do + trace[Tactic.address_normalization] "{processingEmoji} reduceModSub '{x} % {n}'" + let geTy := mkGENat x n + let some geProof ← dischargeByOmega geTy + | return .continue + let subTy := mkSubNat x n + let ltTy := mkLTNat subTy n + let some ltProof ← dischargeByOmega ltTy + | return .continue + let eqProof ← mkAppM ``Nat.mod_eq_sub #[geProof, ltProof] + trace[Tactic.address_normalization] "{checkEmoji} reduceModSub '{x} % {n}'" + return .done { expr := subTy, proof? := eqProof : Result } + +@[inline, bv_toNat] def reduceMod (e : Expr) : SimpM Step := do + match_expr e with + | HMod.hMod xTy nTy outTy _inst x n => + let natTy := mkConst ``Nat + if (xTy != natTy) || (nTy != natTy) || (outTy != natTy) then + return .continue + if let .done res ← reduceModOfLt x n then + return .done res + if let .done res ← reduceModSub x n then + return .done res + return .continue + | _ => do + return .continue + +simproc↑ [address_normalization] reduce_mod_omega (_ % _) := fun e => reduceMod e + +/-- Canonicalize a commutative binary operation. + +1. If both arguments are constants, we perform constant folding. +2. If only one of the arguments is a constant, we move the constant to the left. +-/ +@[inline, bv_toNat] def canonicalizeBinConst (declName : Name) -- operator to constant fold, such as `HAdd.hAdd`. + (arity : Nat) + (commProofDecl : Name) -- commProof: `∀ (x y : Bitvec w), x op y = y op x`. + (reduceProofDecl : Name) -- reduce proof: `∀ (w : Nat), (n m : Nat) (BitVec.ofNat w n) op (BitVec.ofNat w m) = BitVec.ofNat w (n op' m)`. + (fxy : Expr) : SimpM Step := do + unless fxy.isAppOfArity declName arity do return .continue + let fx := fxy.appFn! + let x := fx.appArg! + let f := fx.appFn! + let y := fxy.appArg! + trace[Tactic.address_normalization] "{processingEmoji} canonicalizeBinConst '({f} {x} {y})'" + match getBitVecOfNatValue? x with + | some (xwExpr, xvalExpr) => + -- We have a constant on the left, check if we have a constant on the right + -- so we can fully constant fold the expression. + let .some (_, yvalExpr) := getBitVecOfNatValue? y + | return .continue + + let e' ← mkAppM reduceProofDecl #[xwExpr, xvalExpr, yvalExpr] + trace[Tactic.address_normalization] "{checkEmoji} canonicalizeBinConst '({f} {x} {y})'" + return .done { expr := e', proof? := ← mkAppM reduceProofDecl #[x, y] : Result } + + | none => + -- We don't have a constant on the left, check if we have a constant on the right + -- and try to move it to the left. + let .some _ := getBitVecOfNatValue? y + | return .continue -- no constants on either side, nothing to do. + + -- Nothing more to to do, except to move the right constant to the left. + let e' := mkAppN f #[y, x] + trace[Tactic.address_normalization] "{checkEmoji} canonicalizeBinConst '({f} {x} {y})'" + return .done { expr := e', proof? := ← mkAppM commProofDecl #[x, y] : Result } + +/- Change `100` to `100#64` so we can pattern match to `BitVec.ofNat` -/ +attribute [address_normalization] BitVec.ofNat_eq_ofNat + + +theorem BitVec.add_ofNat_eq_ofNat_add {w n} (x : BitVec w) : + x + BitVec.ofNat w n = BitVec.ofNat w n + x := by + apply BitVec.add_comm + + +theorem BitVec.mul_ofNat_eq_ofNat_mul {w n} (x : BitVec w) : + x * BitVec.ofNat w n = BitVec.ofNat w n * x := by + apply BitVec.mul_comm + +simproc [address_normalization] constFoldAdd ((_ + _ : BitVec _)) := + canonicalizeBinConst ``HAdd.hAdd 6 ``BitVec.add_comm ``BitVec.add_ofNat_eq_ofNat_add + +simproc [address_normalization] constFoldMul ((_ * _ : BitVec _)) := + canonicalizeBinConst ``HMul.hMul 6 ``BitVec.mul_comm ``BitVec.mul_ofNat_eq_ofNat_mul + +@[address_normalization] +theorem BitVec.ofNat_add_ofNat_eq_add_ofNat (w : Nat) (n m : Nat) : + BitVec.ofNat w n + BitVec.ofNat w m = BitVec.ofNat w (n + m) := by + apply BitVec.eq_of_toNat_eq + simp + +@[address_normalization] +theorem BitVec.ofNat_mul_ofNat_eq_mul_ofNat (w : Nat) (n m : Nat) : + BitVec.ofNat w n * BitVec.ofNat w m = BitVec.ofNat w (n * m) := by + apply BitVec.eq_of_toNat_eq + -- Note that omega cannot close the goal since it's symbolic multiplication. + simp only [toNat_mul, toNat_ofNat, ← Nat.mul_mod] + +/-- Reassociate addition to left. -/ +@[address_normalization] +theorem BitVec.add_assoc_symm {w} (x y z : BitVec w) : x + (y + z) = x + y + z := by + rw [BitVec.add_assoc] + +/-- Reassociate multiplication to left. -/ +@[address_normalization] +theorem BitVec.mul_assoc_symm {w} (x y z : BitVec w) : x * (y * z) = x * y * z := by + rw [BitVec.mul_assoc] diff --git a/Arm/Memory/Attr.lean b/Arm/Memory/Attr.lean index bb1b42dd..b95cd9c4 100644 --- a/Arm/Memory/Attr.lean +++ b/Arm/Memory/Attr.lean @@ -14,6 +14,12 @@ initialize Lean.registerTraceClass `simp_mem /-- Provides extremely verbose tracing for the `simp_mem` tactic. -/ initialize Lean.registerTraceClass `simp_mem.info +/-- Provides extremely verbose tracing for the `simp_mem` tactic. -/ +initialize Lean.registerTraceClass `Tactic.address_normalization + -- Rules for simprocs that mine the state to extract information for `omega` -- to run. register_simp_attr memory_omega + +-- Simprocs for address normalization +register_simp_attr address_normalization diff --git a/Arm/Memory/MemoryProofs.lean b/Arm/Memory/MemoryProofs.lean index c3b048ff..3d53b9d1 100644 --- a/Arm/Memory/MemoryProofs.lean +++ b/Arm/Memory/MemoryProofs.lean @@ -63,12 +63,10 @@ theorem read_mem_of_write_mem_bytes_different (hn1 : n <= 2^64) theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8)) (hn0 : Nat.succ 0 ≤ n) - (h : (n * 8 + (7 - 0 + 1)) = (n + 1) * 8) : - BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb 7 0 v) = v := by + (h : (n * 8 + 8) = (n + 1) * 8) : + BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb' 0 8 v) = v := by apply BitVec.append_of_extract · omega - · omega - · omega done @[state_simp_rules] @@ -85,7 +83,7 @@ theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n <= 2^64) : case base => simp only [read_mem_bytes, write_mem_bytes, read_mem_of_write_mem_same, BitVec.cast_eq] - have l1 := BitVec.extractLsb_eq v + have l1 := BitVec.extractLsb'_eq v simp only [Nat.reduceSucc, Nat.one_mul, Nat.succ_sub_succ_eq_sub, Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at l1 @@ -317,7 +315,7 @@ private theorem mem_subset_neq_first_addr_small_second_region cases h2 · rename_i h simp only [BitVec.add_sub_self_left_64] at h - have l1 : n' = 18446744073709551615 := by + have l1 : n' = 18446744073709551615 := by rw [BitVec.toNat_eq] at h simp only [toNat_ofNat, Nat.reducePow, Nat.reduceMod] at h omega @@ -431,7 +429,7 @@ private theorem write_mem_bytes_of_write_mem_bytes_shadow_general_n2_eq rename_i n n_ih conv in write_mem_bytes (Nat.succ n) .. => simp only [write_mem_bytes] have n_ih' := @n_ih (addr1 + 1#64) val2 (zeroExtend (n * 8) (val1 >>> 8)) - (write_mem addr1 (extractLsb 7 0 val1) s) + (write_mem addr1 (extractLsb' 0 8 val1) s) (by omega) simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h3 by_cases h₁ : n = 0 @@ -483,7 +481,7 @@ theorem write_mem_bytes_of_write_mem_bytes_shadow_general theorem read_mem_of_write_mem_bytes_same_first_address (h0 : 0 < n) (h1 : n <= 2^64) (h : 7 - 0 + 1 = 8) : read_mem addr (write_mem_bytes n addr val s) = - BitVec.cast h (extractLsb 7 0 val) := by + BitVec.cast h (extractLsb' 0 8 val) := by unfold write_mem_bytes; simp only [Nat.sub_zero, BitVec.cast_eq] split · contradiction @@ -495,18 +493,16 @@ theorem read_mem_of_write_mem_bytes_same_first_address -- (FIXME) Argh, it's annoying to need this lemma, but using -- BitVec.cast_eq directly was cumbersome. theorem cast_of_extract_eq (v : BitVec p) - (h1 : hi1 = hi2) (h2 : lo1 = lo2) - (h : hi1 - lo1 + 1 = hi2 - lo2 + 1) : - BitVec.cast h (extractLsb hi1 lo1 v) = (extractLsb hi2 lo2 v) := by + (h1 : n1 = n2) (h2 : lo1 = lo2): + BitVec.cast h (extractLsb' lo1 n1 v) = (extractLsb' lo2 n2 v) := by subst_vars simp only [Nat.sub_zero, BitVec.cast_eq] theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address (h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 <= 2^64) - (h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1)))) - (h : n2 * 8 - 1 - 0 + 1 = n2 * 8) : + (h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1)))): read_mem_bytes n2 addr (write_mem_bytes n1 addr val s) = - BitVec.cast h (extractLsb ((n2 * 8) - 1) 0 val) := by + extractLsb' 0 (n2 * 8) val := by have rm_lemma := @read_mem_of_write_mem_bytes_same_first_address n1 addr val s h0 h1 simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at rm_lemma induction n2, h2 using Nat.le_induction generalizing n1 addr val s @@ -543,21 +539,20 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address erw [Nat.mod_eq_of_lt h3] at hn erw [Nat.mod_eq_of_lt h1] at hn exact hn - rw [n_ih (by omega) (by omega) (by omega) _ (by omega)] - · rw [BitVec.extract_lsb_of_zeroExtend (v >>> 8)] - · have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) 8 (n*8-1+1) (n*8) v + rw [n_ih (by omega) (by omega) (by omega) _] + · rw [BitVec.extractLsb'_of_zeroExtend (v >>> 8)] + · have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) (n*8) 8 v simp (config := { decide := true }) only [Nat.zero_lt_succ, Nat.mul_pos_iff_of_pos_left, Nat.succ_sub_succ_eq_sub, Nat.sub_zero, Nat.reduceAdd, Nat.succ.injEq, forall_const] at l1 - rw [l1 (by omega) (by omega)] - · simp only [Nat.add_eq, Nat.sub_zero, BitVec.cast_cast] - apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 - 1 + 1 + 7) ((n + 1) * 8 - 1) 0 0 <;> + rw [l1] + · apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 + 8) ((n + 1) * 8) 0 0 <;> omega · omega · have rw_lemma2 := @read_mem_of_write_mem_bytes_same_first_address n1_1 (addr + 1#64) (zeroExtend (n1_1 * 8) (v >>> 8)) - (write_mem addr (extractLsb 7 0 v) s) + (write_mem addr (extractLsb' 0 8 v) s) simp only [Nat.reducePow, Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at rw_lemma2 rw [rw_lemma2 (by omega) (by simp only [Nat.reducePow] at h1; omega)] @@ -636,22 +631,15 @@ theorem BitVec.to_nat_zero_lt_sub_64 (x y : BitVec 64) (h : ¬x = y) : theorem read_mem_of_write_mem_bytes_subset (h0 : 0 < n) (h1 : n <= 2^64) - (h2 : mem_subset addr2 addr2 addr1 (addr1 + (BitVec.ofNat 64 (n - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + 1) * 8 - 1 - - BitVec.toNat (addr2 - addr1) * 8 + 1) = 8) : + (h2 : mem_subset addr2 addr2 addr1 (addr1 + (BitVec.ofNat 64 (n - 1)))): read_mem addr2 (write_mem_bytes n addr1 val s) = - BitVec.cast h - (extractLsb - ((BitVec.toNat (addr2 - addr1) + 1) * 8 - 1) - (BitVec.toNat (addr2 - addr1) * 8) - val) := by + extractLsb' (BitVec.toNat (addr2 - addr1) * 8) 8 val := by induction n generalizing addr1 addr2 s case zero => contradiction case succ => rename_i n' n_ih simp_all only [write_mem_bytes, Nat.succ.injEq, Nat.zero_lt_succ, Nat.succ_sub_succ_eq_sub, Nat.sub_zero] - have cast_lemma := @cast_of_extract_eq by_cases h₀ : n' = 0 case pos => simp_all only [Nat.lt_irrefl, Nat.zero_le, Nat.zero_sub, @@ -659,20 +647,22 @@ theorem read_mem_of_write_mem_bytes_subset false_implies, implies_true] subst_vars simp only [write_mem_bytes, read_mem_of_write_mem_same] - rw [←cast_lemma] <;> bv_omega + simp only [Nat.reduceAdd, Nat.reduceMul, BitVec.sub_self, + toNat_ofNat, Nat.reducePow, Nat.zero_mod, Nat.zero_mul] case neg => -- (n' ≠ 0) by_cases h₁ : addr2 = addr1 case pos => -- (n' ≠ 0) and (addr2 = addr1) subst_vars rw [read_mem_of_write_mem_bytes_different (by omega)] · simp only [read_mem_of_write_mem_same] - rw [←cast_lemma] <;> bv_omega + simp only [BitVec.sub_self, toNat_ofNat, Nat.reducePow, + Nat.zero_mod, Nat.zero_mul] · rw [mem_separate_contiguous_regions_one_address _ (by omega)] case neg => -- (addr2 ≠ addr1) rw [n_ih] · ext -- simp only [bv_toNat] - simp only [toNat_cast, extractLsb, extractLsb', toNat_zeroExtend] + simp only [toNat_cast, extractLsb', toNat_zeroExtend] simp only [toNat_ushiftRight] simp_all only [toNat_ofNat, toNat_ofNatLt] simp only [BitVec.sub_of_add_is_sub_sub, Nat.succ_sub_succ_eq_sub, @@ -700,7 +690,6 @@ theorem read_mem_of_write_mem_bytes_subset · omega · rw [addr_add_one_add_m_sub_one _ _ (by omega) (by omega)] rw [mem_subset_one_addr_neq h₁ h2] - · omega done theorem read_mem_bytes_of_write_mem_bytes_subset_helper1 (a i : Nat) @@ -717,7 +706,7 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_helper2 (BitVec.toNat val >>> ((BitVec.toNat (addr2 - addr1) + 1) % 2 ^ 64 * 8) % 2 ^ (n * 8)) <<< 8 ||| BitVec.toNat val >>> (BitVec.toNat (addr2 - addr1) * 8) % 2 ^ 8 = BitVec.toNat val >>> (BitVec.toNat (addr2 - addr1) * 8) % - 2 ^ ((BitVec.toNat (addr2 - addr1) + (n + 1)) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1) := by + 2 ^ ((n + 1) * 8) := by have h_a_size := (addr2 - addr1).isLt have h_v_size := val.isLt -- (FIXME) whnf timeout? @@ -725,7 +714,7 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_helper2 -- generalize ha : BitVec.toNat (addr2 - addr1) = a apply Nat.eq_of_testBit_eq; intro i simp only [Nat.testBit_mod_two_pow, Nat.testBit_shiftRight] - by_cases h₀ : (i < (BitVec.toNat (addr2 - addr1) + (n + 1)) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1) + by_cases h₀ : (i < ((n + 1) * 8)) case pos => simp only [h₀, decide_True, Bool.true_and, BitVec.toNat_ofNat, BitVec.toNat_append, Nat.testBit_or] @@ -785,35 +774,26 @@ private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_lt (h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 < 2^64) (h4 : mem_subset addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1))) addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) (h5 : mem_legal addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1)))) - (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + n2) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1) - = n2 * 8) : + (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) : read_mem_bytes n2 addr2 (write_mem_bytes n1 addr1 val s) = - BitVec.cast h - (extractLsb ((((addr2 - addr1).toNat + n2) * 8) - 1) ((addr2 - addr1).toNat * 8) val) := by + extractLsb' ((addr2 - addr1).toNat * 8) (n2 * 8) val := by induction n2, h2 using Nat.le_induction generalizing addr1 addr2 val s case base => simp only [Nat.reduceSucc, Nat.succ_sub_succ_eq_sub, Nat.sub_self, BitVec.add_zero] at h4 - simp_all only [read_mem_bytes, BitVec.cast_eq] - have h' : (BitVec.toNat (addr2 - addr1) + 1) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1 = 8 := by - omega - rw [read_mem_of_write_mem_bytes_subset h0 h1 h4 h'] + simp_all only [read_mem_bytes] + rw [read_mem_of_write_mem_bytes_subset h0 h1 h4] apply BitVec.empty_bitvector_append_left - decide case succ => rename_i n h2' n_ih by_cases h_addr : addr1 = addr2 case pos => -- (addr1 = addr2) subst addr2 - have h' : (n + 1) * 8 - 1 - 0 + 1 = (n + 1) * 8 := by omega have := @read_mem_bytes_of_write_mem_bytes_subset_same_first_address n1 (n + 1) addr1 val s - h0 h1 (by omega) (by omega) h4 h' + h0 h1 (by omega) (by omega) h4 rw [this] - ext - simp only [Nat.sub_zero, BitVec.cast_eq, extractLsb_toNat, - Nat.shiftRight_zero, toNat_cast, BitVec.sub_self, - toNat_ofNat, Nat.zero_mod, Nat.zero_mul, Nat.zero_add] + simp only [BitVec.sub_self, toNat_ofNat, Nat.reducePow, + Nat.zero_mod, Nat.zero_mul] case neg => -- (addr1 ≠ addr2) simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero] simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h4 @@ -832,19 +812,15 @@ private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_lt have l2 := @first_address_is_subset_of_region addr2 (BitVec.ofNat 64 n) have l3 := mem_subset_trans l2 h4 simp only [l3, forall_const] at l1 - rw [l1 (by omega)] + rw [l1] simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h5 have n_ih' := @n_ih (addr2 + 1#64) addr1 val s (by omega) simp only [h_sub, forall_const] at n_ih' rw [mem_legal_lemma h2'] at n_ih' - · simp only [forall_const] at n_ih' - have h' : (BitVec.toNat (addr2 + 1#64 - addr1) + n) * 8 - 1 - - BitVec.toNat (addr2 + 1#64 - addr1) * 8 + 1 = - n * 8 := by - omega - rw [n_ih' h6 h'] + · simp only [h6, true_implies] at n_ih' + rw [n_ih'] ext - simp only [extractLsb, extractLsb', toNat_ofNat, toNat_cast, + simp only [extractLsb', toNat_ofNat, toNat_cast, BitVec.add_of_sub_sub_of_add] simp only [toNat_add (addr2 - addr1) 1#64, Nat.add_eq, Nat.add_zero, toNat_ofNat, Nat.add_mod_mod, cast_ofNat, toNat_append] @@ -886,45 +862,21 @@ theorem entire_memory_subset_legal_regions_eq_addr simp_all [mem_subset, mem_legal] bv_omega -private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt_helper (val : BitVec (x * 8)) - (h0 : 0 < x) - (h : (BitVec.toNat (addr2 - addr2) + x) * 8 - 1 - - BitVec.toNat (addr2 - addr2) * 8 + 1 - = - x * 8) : - val = - BitVec.cast h - (extractLsb ((BitVec.toNat (addr2 - addr2) + x) * 8 - 1) - (BitVec.toNat (addr2 - addr2) * 8) val) := by - ext - simp only [extractLsb, extractLsb', BitVec.sub_self, toNat_ofNat, - Nat.zero_mod, Nat.zero_mul, Nat.shiftRight_zero, - ofNat_toNat, toNat_cast, toNat_truncate, Nat.zero_add, - Nat.sub_zero] - rw [Nat.mod_eq_of_lt] - rw [Nat.sub_add_cancel] - · exact val.isLt - · omega - done - private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt (h0 : 0 < n1) (h1 : n1 <= my_pow 2 64) (h2 : 0 < n2) (h3 : n2 = my_pow 2 64) (h4 : mem_subset addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1))) addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) (h5 : mem_legal addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1)))) - (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + n2) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1) - = n2 * 8) : + (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))): read_mem_bytes n2 addr2 (write_mem_bytes n1 addr1 val s) = - BitVec.cast h - (extractLsb ((((addr2 - addr1).toNat + n2) * 8) - 1) ((addr2 - addr1).toNat * 8) val) := by + extractLsb' ((addr2 - addr1).toNat * 8) (n2 * 8) val := by subst n2 have l0 := @entire_memory_subset_of_only_itself n1 addr2 addr1 h1 h4 subst n1 have l1 := @entire_memory_subset_legal_regions_eq_addr addr2 addr1 h4 h6 h5 subst addr1 rw [read_mem_bytes_of_write_mem_bytes_same] - · apply read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt_helper - simp [my_pow_2_gt_zero] + · simp only [BitVec.sub_self, toNat_ofNat, Nat.reducePow, Nat.zero_mod, Nat.zero_mul] + exact Eq.symm (extractLsb'_eq val) · unfold my_pow; decide @[state_simp_rules] @@ -932,19 +884,16 @@ theorem read_mem_bytes_of_write_mem_bytes_subset (h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 <= 2^64) (h4 : mem_subset addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1))) addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) (h5 : mem_legal addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1)))) - (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + n2) * 8 - 1 - - BitVec.toNat (addr2 - addr1) * 8 + 1) - = n2 * 8) : + (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))): read_mem_bytes n2 addr2 (write_mem_bytes n1 addr1 val s) = - BitVec.cast h - (extractLsb - ((((addr2 - addr1).toNat + n2) * 8) - 1) + (extractLsb' ((addr2 - addr1).toNat * 8) + (n2 * 8) val) := by by_cases h₀ : n2 = 2^64 case pos => - apply read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt h0 + apply read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt + · exact h0 · unfold my_pow; exact h1 · exact h2 · unfold my_pow; exact h₀ @@ -988,8 +937,8 @@ private theorem write_mem_bytes_irrelevant_helper (h : n * 8 + 8 = (n + 1) * 8) done private theorem extract_byte_of_read_mem_bytes_succ (n : Nat) : - extractLsb 7 0 (read_mem_bytes (n + 1) addr s) = read_mem addr s := by - simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero, toNat_eq, extractLsb_toNat, + extractLsb' 0 8 (read_mem_bytes (n + 1) addr s) = read_mem addr s := by + simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero, toNat_eq, extractLsb'_toNat, toNat_cast, toNat_append, Nat.shiftRight_zero, Nat.reduceAdd] generalize read_mem addr s = y generalize (read_mem_bytes n (addr + 1#64) s) = x diff --git a/Arm/Memory/SeparateAutomation.lean b/Arm/Memory/SeparateAutomation.lean index ffa4e274..29ac674c 100644 --- a/Arm/Memory/SeparateAutomation.lean +++ b/Arm/Memory/SeparateAutomation.lean @@ -13,6 +13,7 @@ import Arm import Arm.Memory.MemoryProofs import Arm.BitVec import Arm.Memory.Attr +import Arm.Memory.AddressNormalization import Lean import Lean.Meta.Tactic.Rewrite import Lean.Meta.Tactic.Rewrites @@ -404,7 +405,7 @@ def simpAndIntroDef (name : String) (hdefVal : Expr) : SimpMemM FVarId := do /- Simp to gain some more juice out of the defn.. -/ let mut simpTheorems : Array SimpTheorems := #[] - for a in #[`minimal_theory, `bitvec_rules] do + for a in #[`minimal_theory, `bitvec_rules, `bv_toNat] do let some ext ← (getSimpExtension? a) | throwError m!"[simp_mem] Internal error: simp attribute {a} not found!" simpTheorems := simpTheorems.push (← ext.getTheorems) diff --git a/Arm/State.lean b/Arm/State.lean index 73bcb17d..6b559099 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -342,6 +342,8 @@ These mnemonics make it much easier to read and write theorems about assembly pr -/ @[state_simp_rules] abbrev ArmState.x0 (s : ArmState) : BitVec 64 := r (StateField.GPR 0) s +@[state_simp_rules] abbrev ArmState.w0 (s : ArmState) : BitVec 32 := + (r (StateField.GPR 0) s).zeroExtend 32 @[state_simp_rules] abbrev ArmState.x1 (s : ArmState) : BitVec 64 := r (StateField.GPR 1) s @@ -675,7 +677,7 @@ def write_mem_bytes (n : Nat) (addr : BitVec 64) (val : BitVec (n * 8)) (s : Arm match n with | 0 => s | n' + 1 => - let byte := BitVec.extractLsb 7 0 val + let byte := BitVec.extractLsb' 0 8 val let s := write_mem addr byte s let val_rest := BitVec.zeroExtend (n' * 8) (val >>> 8) write_mem_bytes n' (addr + 1#64) val_rest s @@ -962,7 +964,7 @@ def write_bytes (n : Nat) (addr : BitVec 64) match n with | 0 => m | n' + 1 => - let byte := BitVec.extractLsb 7 0 val + let byte := BitVec.extractLsb' 0 8 val let m := m.write addr byte let val_rest := BitVec.zeroExtend (n' * 8) (val >>> 8) m.write_bytes n' (addr + 1#64) val_rest @@ -988,7 +990,7 @@ and then recursing to write the rest. -/ theorem write_bytes_succ {mem : Memory} : mem.write_bytes (n + 1) addr val = - let byte := BitVec.extractLsb 7 0 val + let byte := BitVec.extractLsb' 0 8 val let mem := mem.write addr byte let val_rest := BitVec.zeroExtend (n * 8) (val >>> 8) mem.write_bytes n (addr + 1#64) val_rest := rfl diff --git a/Arm/Syntax.lean b/Arm/Syntax.lean index 88127b36..0c1ca562 100644 --- a/Arm/Syntax.lean +++ b/Arm/Syntax.lean @@ -6,6 +6,7 @@ Author(s): Siddharth Bhat Provide convenient syntax for writing down state manipulation in Arm programs. -/ import Arm.State +import Arm.Memory.Separate namespace ArmStateNotation @@ -13,4 +14,28 @@ namespace ArmStateNotation @[inherit_doc read_mem_bytes] syntax:max term noWs "[" withoutPosition(term) "," withoutPosition(term) noWs "]" : term macro_rules | `($s[$base,$n]) => `(read_mem_bytes $n $base $s) + + +/-! Notation to specify the frame condition for non-memory state components. E.g., +`REGS_UNCHANGED_EXCEPT [.GPR 0, .PC] (sf, s0)` is sugar for +`∀ f, f ∉ [.GPR 0, .PC] → r f sf = r f s0` +-/ +syntax:max "REGS_UNCHANGED_EXCEPT" "[" term,* "]" + "(" withoutPosition(term) "," withoutPosition(term) ")" : term +macro_rules +| `(REGS_UNCHANGED_EXCEPT [$regs:term,*] ($sf, $s0)) => + `(∀ f, f ∉ [$regs,*] → r f $sf = r f $s0) + +/-! Notation to specify the frame condition for memory regions. E.g., +`MEM_UNCHANGED_EXCEPT [(x, m), (y, k)] (sf, s0)` is sugar for +`∀ n addr, Memory.Region.pairwiseSeparate [(addr, n), (x, m), (y, k)] → sf[addr, n] = s0[addr, n]` +-/ +syntax:max "MEM_UNCHANGED_EXCEPT" "[" term,* "]" + "(" withoutPosition(term) "," withoutPosition(term) ")" : term +macro_rules | + `(MEM_UNCHANGED_EXCEPT [$mem:term,*] ($sf, $s0)) => + `(∀ (n : Nat) (addr : BitVec 64), + Memory.Region.pairwiseSeparate (List.cons (addr, n) [$mem,*]) → + read_mem_bytes n addr $sf = read_mem_bytes n addr $s0) + end ArmStateNotation diff --git a/Benchmarks.lean b/Benchmarks.lean index 7b7fcc0a..cc1f1c6b 100644 --- a/Benchmarks.lean +++ b/Benchmarks.lean @@ -3,6 +3,11 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Author(s): Alex Keizer -/ +import Benchmarks.SHA512_75 +import Benchmarks.SHA512_75_noKernel_noLint import Benchmarks.SHA512_150 +import Benchmarks.SHA512_150_noKernel_noLint import Benchmarks.SHA512_225 +import Benchmarks.SHA512_225_noKernel_noLint import Benchmarks.SHA512_400 +import Benchmarks.SHA512_400_noKernel_noLint diff --git a/Benchmarks/Command.lean b/Benchmarks/Command.lean index 60873189..3762f788 100644 --- a/Benchmarks/Command.lean +++ b/Benchmarks/Command.lean @@ -7,48 +7,247 @@ import Lean open Lean Parser.Command Elab.Command +initialize + registerOption `benchmark { + defValue := false + descr := "enables/disables benchmarking in `withBenchmark` combinator" + } + registerOption `benchmark.runs { + defValue := (5 : Nat) + descr := "controls how many runs the `benchmark` command does. \ + NOTE: this value is ignored when the `profiler` option is set to true" + } + /- Shouldn't be set directly, instead, use the `benchmark` command -/ + registerTraceClass `benchmark + +variable {m} [Monad m] [MonadLiftT BaseIO m] in +/-- Measure the heartbeats and time (in milliseconds) taken by `x` -/ +def withHeartbeatsAndMilliseconds (x : m α) : m (α × Nat × Nat) := do + let start ← IO.monoMsNow + let (a, heartbeats) ← withHeartbeats x + let endTime ← IO.monoMsNow + return ⟨a, heartbeats, endTime - start⟩ + +/-- Adds a trace node with the `benchmark` class, but only if the profiler +option is *not* set. + +We deliberately suppress benchmarking nodes when profiling, since it generally +only adds noise +-/ +def withBenchTraceNode (msg : MessageData) (x : CommandElabM α ) + : CommandElabM α := do + if (← getBoolOption `profiler) then + x + else + withTraceNode `benchmark (fun _ => pure msg) x (collapsed := false) + +/-- +Run a benchmark for a set number of times, and report the average runtime. + +If the `profiler` option is set true, we run the benchmark only once, with: +- `trace.profiler` to true, and +- `trace.profiler.output` set based on the `benchmark.profilerDir` and the + id of the benchmark +-/ elab "benchmark" id:ident declSig:optDeclSig val:declVal : command => do - let stx ← `(command| example $declSig:optDeclSig $val:declVal ) - - let n := 5 - let mut runTimes := #[] - let mut totalRunTime := 0 - -- geomean = exp(log((a₁ a₂ ... aₙ)^1/n)) = - -- exp(1/n * (log a₁ + log a₂ + log aₙ)). - let mut totalRunTimeLog := 0 - for _ in [0:n] do - let start ← IO.monoMsNow - elabCommand stx - let endTime ← IO.monoMsNow - let runTime := endTime - start - runTimes := runTimes.push runTime - totalRunTime := totalRunTime + runTime - totalRunTimeLog := totalRunTimeLog + Float.log runTime.toFloat - - let avg := totalRunTime.toFloat / n.toFloat / 1000 - let geomean := (Float.exp (totalRunTimeLog / n.toFloat)) / 1000.0 - logInfo m!"\ -{id}: - average runtime over {n} runs: - {avg}s - geomean over {n} runs: - {geomean}s - - indidividual runtimes: - {runTimes} -" + let originalOpts ← getOptions + let mut n := originalOpts.getNat `benchmark.runs 5 + let mut opts := originalOpts + opts := opts.setBool `benchmark true + let stx ← `(command| + example $declSig:optDeclSig $val:declVal + ) + + if (← getBoolOption `profiler) then + opts := opts.setBool `trace.profiler true + opts := opts.setNat `trace.profiler.threshold 1 + n := 1 -- only run once, if `profiler` is set to true + else + opts := opts.setBool `trace.benchmark true + + if n = 0 then + return () + + -- Set options + modifyScope fun scope => { scope with opts } + + withBenchTraceNode m!"Running {id} benchmark" <| do + let mut totalRunTime := 0 + -- geomean = exp(log((a₁ a₂ ... aₙ)^1/n)) = + -- exp(1/n * (log a₁ + log a₂ + log aₙ)). + let mut totalRunTimeLog : Float := 0 + for i in [0:n] do + let runTime ← withBenchTraceNode m!"Run {i+1} (out of {n}):" <| do + let ((), _, runTime) ← withHeartbeatsAndMilliseconds <| + elabCommand stx + + trace[benchmark] m!"Proof took {runTime / 1000}s in total" + pure runTime + totalRunTime := totalRunTime + runTime + totalRunTimeLog := totalRunTimeLog + Float.log runTime.toFloat + + let avg := totalRunTime.toFloat / n.toFloat / 1000 + let geomean := (Float.exp (totalRunTimeLog / n.toFloat)) / 1000.0 + trace[benchmark] m!"\ + {id}: + average runtime over {n} runs: + {avg}s + geomean over {n} runs: + {geomean}s + " + -- Restore options + modifyScope fun scope => { scope with opts := originalOpts } + +/-- Set various options to disable linters -/ +macro "disable_linters" "in" cmd:command : command => `(command| +set_option linter.constructorNameAsVariable false in +set_option linter.deprecated false in +set_option linter.missingDocs false in +set_option linter.omit false in +set_option linter.suspiciousUnexpanderPatterns false in +set_option linter.unnecessarySimpa false in +set_option linter.unusedRCasesPattern false in +set_option linter.unusedSectionVars false in +set_option linter.unusedVariables false in +$cmd +) /-- The default `maxHeartbeats` setting. NOTE: even if the actual default value changes at some point in the future, this value should *NOT* be updated, to ensure the percentages we've reported in previous versions remain comparable. -/ -def defaultMaxHeartbeats : Nat := 200000 +private def defaultMaxHeartbeats : Nat := 200000 + +private def percentOfDefaultMaxHeartbeats (heartbeats : Nat) : Nat := + heartbeats / (defaultMaxHeartbeats * 10) open Elab.Tactic in elab "logHeartbeats" tac:tactic : tactic => do let ((), heartbeats) ← withHeartbeats <| evalTactic tac - let percent := heartbeats / (defaultMaxHeartbeats * 10) + let percent := percentOfDefaultMaxHeartbeats heartbeats logInfo m!"used {heartbeats / 1000} heartbeats ({percent}% of the default maximum)" + +section withBenchmark +variable {m} [Monad m] [MonadLiftT BaseIO m] [MonadOptions m] [MonadLog m] + [AddMessageContext m] + +/-- if the `benchmark` option is true, execute `x` and call `f` with the amount +of heartbeats and milliseconds (in that order!) taken by `x`. + +Otherwise, just execute `x` without measurements. -/ +private def withBenchmarkAux (x : m α) (f : Nat → Nat → m Unit) : m α := do + if (← getBoolOption `benchmark) = false then + x + else + let (a, heartbeats, t) ← withHeartbeatsAndMilliseconds x + f heartbeats t + return a + + +/-- `withBenchmark header x` is a combinator that will, if the `benchmark` +option is set to `true`, log the time and heartbeats used by `x`, +in a message like: + `{header} took {x}s and {y} heartbeats ({z}% of the default maximum)` + +Otherwise, if `benchmark` is set to false, `x` is returned as-is. + +NOTE: the maximum reffered to in the message is `defaultMaxHeartbeats`, +deliberately *not* the currently confiugred `maxHeartbeats` option, to keep the +numbers comparable across different values of that option. It's thus entirely +possible to see more than 100% being reported here. -/ +def withBenchmark (header : String) (x : m α) : m α := do + withBenchmarkAux x fun heartbeats t => do + let percent := percentOfDefaultMaxHeartbeats heartbeats + logInfo m!"{header} took: {t}ms and {heartbeats} heartbeats \ + ({percent}% of the default maximum)" + +/-- Benchmark the time and heartbeats taken by a tactic, if the `benchmark` +option is set to `true` -/ +elab "with_benchmark" t:tactic : tactic => do + withBenchmark "{t}" <| Elab.Tactic.evalTactic t + +end withBenchmark + +/-! +## Aggregated benchmark statistics +We define `withAggregatedBenchmark`, which functions like `withBenchmark`, +except it will store a running average of the statistics in a `BenchmarkState` +which will be reported in one go when `reportAggregatedBenchmarks` is called. +-/ +section + +structure BenchmarkState.Stats where + totalHeartbeats : Nat := 0 + totalTimeInMs : Nat := 0 + samples : Nat := 0 + +structure BenchmarkState where + insertionOrder : List String := [] + stats : Std.HashMap String BenchmarkState.Stats := .empty + +variable {m} [Monad m] [MonadStateOf BenchmarkState m] [MonadLiftT BaseIO m] + [MonadOptions m] + +/-- `withAggregatedBenchmark header x` is a combinator that will, +if the `benchmark` option is set to `true`, +measure the time and heartbeats to the benchmark state in a way that aggregates +different measurements with the same `header`. + +See `reportAggregatedBenchmarks` to log the collected data. + +Otherwise, if `benchmark` is set to false, `x` is returned as-is. +-/ +def withAggregatedBenchmark (header : String) (x : m α) : m α := do + withBenchmarkAux x fun heartbeats t => do + modify fun state => + let s := state.stats.getD header {} + { insertionOrder := + if s.samples = 0 then + header :: state.insertionOrder + else + state.insertionOrder + stats := state.stats.insert header { + totalHeartbeats := s.totalHeartbeats + heartbeats + totalTimeInMs := s.totalTimeInMs + t + samples := s.samples + 1 + }} + +variable [MonadLog m] [AddMessageContext m] in +/-- +if the `benchmark` option is set to `true`, report the data collected by +`withAggregatedBenchmark`, and reset the state (so that the next call to +`reportAggregatedBenchmarks` will report only new data). +-/ +def reportAggregatedBenchmarks : m Unit := do + if (← getBoolOption `benchmark) = false then + return + + let { insertionOrder, stats } ← get + for header in insertionOrder do + let stats := stats.getD header {} + let heartbeats := stats.totalHeartbeats + let percent := percentOfDefaultMaxHeartbeats heartbeats + let t := stats.totalTimeInMs + let n := stats.samples + logInfo m!"{header} took: {t}ms and {heartbeats} heartbeats \ + ({percent}% of the default maximum) in total over {n} samples" + + set ({} : BenchmarkState) + +abbrev BenchT := StateT BenchmarkState + +variable [MonadLog m] [AddMessageContext m] in +/-- +Execute `x` with the default `BenchmarkState`, and report the benchmarks after +(see `reportAggregatedBenchmarks`). +-/ +def withBenchmarksReport (x : BenchT m α) : m α := + (Prod.fst <$> ·) <| StateT.run (s := {}) do + let a ← x + reportAggregatedBenchmarks + return a + +end diff --git a/Benchmarks/SHA512.lean b/Benchmarks/SHA512.lean index 1ef411bb..685a18b7 100644 --- a/Benchmarks/SHA512.lean +++ b/Benchmarks/SHA512.lean @@ -15,9 +15,15 @@ namespace Benchmarks def SHA512Bench (nSteps : Nat) : Prop := ∀ (s0 sf : ArmState) - (_h_s0_pc : read_pc s0 = 0x1264c4#64) + (_h_s0_num_blocks : r (.GPR 2#5) s0 = 10#64) + (_h_s0_pc : read_pc s0 = 0x1264c0#64) (_h_s0_err : read_err s0 = StateError.None) (_h_s0_sp_aligned : CheckSPAlignment s0) (_h_s0_program : s0.program = SHA512.program) (_h_run : sf = run nSteps s0), r StateField.ERR sf = StateError.None + ∧ r (.GPR 2#5) sf = BitVec.ofNat 64 (10 - ((nSteps + 467) / 485)) + -- / -------------------------------^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + -- | This computes the expected value of x2, taking into account that + -- | the loop body is 485 instructions long, and that x2 is first + -- | decremented after 18 instructions (485 - 18 = 467). diff --git a/Benchmarks/SHA512_150.lean b/Benchmarks/SHA512_150.lean index b54d8f83..0cf74cff 100644 --- a/Benchmarks/SHA512_150.lean +++ b/Benchmarks/SHA512_150.lean @@ -9,7 +9,9 @@ import Benchmarks.SHA512 open Benchmarks -benchmark sha512_150_instructions : SHA512Bench 150 := fun s0 => by +benchmark sha512_150_instructions : SHA512Bench 150 := fun s0 _ h => by intros sym_n 150 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) done diff --git a/Benchmarks/SHA512_150_noKernel_noLint.lean b/Benchmarks/SHA512_150_noKernel_noLint.lean new file mode 100644 index 00000000..faec9e9f --- /dev/null +++ b/Benchmarks/SHA512_150_noKernel_noLint.lean @@ -0,0 +1,19 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Tactics.Sym +import Benchmarks.Command +import Benchmarks.SHA512 + +open Benchmarks + +disable_linters in +set_option debug.skipKernelTC true in +benchmark sha512_150_noKernel_noLint : SHA512Bench 150 := fun s0 _ h => by + intros + sym_n 150 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) + done diff --git a/Benchmarks/SHA512_225.lean b/Benchmarks/SHA512_225.lean index 26310030..a43f5dc5 100644 --- a/Benchmarks/SHA512_225.lean +++ b/Benchmarks/SHA512_225.lean @@ -9,7 +9,8 @@ import Benchmarks.SHA512 open Benchmarks -benchmark sha512_225_instructions : SHA512Bench 225 := fun s0 => by +benchmark sha512_225_instructions : SHA512Bench 225 := fun s0 _ h => by intros sym_n 225 + · exact (sorry : Aligned ..) done diff --git a/Benchmarks/SHA512_225_noKernel_noLint.lean b/Benchmarks/SHA512_225_noKernel_noLint.lean new file mode 100644 index 00000000..df2b7a88 --- /dev/null +++ b/Benchmarks/SHA512_225_noKernel_noLint.lean @@ -0,0 +1,19 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Tactics.Sym +import Benchmarks.Command +import Benchmarks.SHA512 + +open Benchmarks + +disable_linters in +set_option debug.skipKernelTC true in +benchmark sha512_225_noKernel_noLint : SHA512Bench 225 := fun s0 _ h => by + intros + sym_n 225 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) + done diff --git a/Benchmarks/SHA512_400.lean b/Benchmarks/SHA512_400.lean index bd725f6f..ae26c4e9 100644 --- a/Benchmarks/SHA512_400.lean +++ b/Benchmarks/SHA512_400.lean @@ -9,7 +9,9 @@ import Benchmarks.Command open Benchmarks -benchmark sha512_400_instructions : SHA512Bench 400 := fun s0 => by +benchmark sha512_400_instructions : SHA512Bench 400 := fun s0 _ h => by intros sym_n 400 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) done diff --git a/Benchmarks/SHA512_400_noKernel_noLint.lean b/Benchmarks/SHA512_400_noKernel_noLint.lean new file mode 100644 index 00000000..cefce28c --- /dev/null +++ b/Benchmarks/SHA512_400_noKernel_noLint.lean @@ -0,0 +1,19 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Tactics.Sym +import Benchmarks.Command +import Benchmarks.SHA512 + +open Benchmarks + +disable_linters in +set_option debug.skipKernelTC true in +benchmark sha512_400_noKernel_noLint : SHA512Bench 400 := fun s0 _ h => by + intros + sym_n 400 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) + done diff --git a/Benchmarks/SHA512_50.lean b/Benchmarks/SHA512_50.lean new file mode 100644 index 00000000..7e388869 --- /dev/null +++ b/Benchmarks/SHA512_50.lean @@ -0,0 +1,17 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Tactics.Sym +import Benchmarks.Command +import Benchmarks.SHA512 + +open Benchmarks + +benchmark sha512_50 : SHA512Bench 50 := fun s0 _ h => by + intros + sym_n 50 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) + done diff --git a/Benchmarks/SHA512_50_noKernel_noLint.lean b/Benchmarks/SHA512_50_noKernel_noLint.lean new file mode 100644 index 00000000..f08a6868 --- /dev/null +++ b/Benchmarks/SHA512_50_noKernel_noLint.lean @@ -0,0 +1,19 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Tactics.Sym +import Benchmarks.Command +import Benchmarks.SHA512 + +open Benchmarks + +disable_linters in +set_option debug.skipKernelTC true in +benchmark sha512_50_noKernel_noLint : SHA512Bench 50 := fun s0 _ h => by + intros + sym_n 50 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) + done diff --git a/Benchmarks/SHA512_75.lean b/Benchmarks/SHA512_75.lean new file mode 100644 index 00000000..068b06b3 --- /dev/null +++ b/Benchmarks/SHA512_75.lean @@ -0,0 +1,17 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Tactics.Sym +import Benchmarks.Command +import Benchmarks.SHA512 + +open Benchmarks + +benchmark sha512_75 : SHA512Bench 75 := fun s0 _ h => by + intros + sym_n 75 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) + done diff --git a/Benchmarks/SHA512_75_noKernel_noLint.lean b/Benchmarks/SHA512_75_noKernel_noLint.lean new file mode 100644 index 00000000..516880c8 --- /dev/null +++ b/Benchmarks/SHA512_75_noKernel_noLint.lean @@ -0,0 +1,19 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Tactics.Sym +import Benchmarks.Command +import Benchmarks.SHA512 + +open Benchmarks + +disable_linters in +set_option debug.skipKernelTC true in +benchmark sha512_75_noKernel_noLint : SHA512Bench 75 := fun s0 _ h => by + intros + sym_n 75 + simp (config := {failIfUnchanged := false}) only [h, bitvec_rules] + all_goals exact (sorry : Aligned ..) + done diff --git a/Makefile b/Makefile index 00cf37c6..a06a2d1c 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,8 @@ SHELL := /bin/bash LAKE = lake +LEAN = $(LAKE) env lean +GIT = git NUM_TESTS?=3 VERBOSE?=--verbose @@ -37,9 +39,25 @@ awslc_elf: cosim: time -p lake exe lnsym $(VERBOSE) --num-tests $(NUM_TESTS) +BENCHMARKS = \ + Benchmarks/SHA512_50.lean \ + Benchmarks/SHA512_50_noKernel_noLint.lean \ + Benchmarks/SHA512_75.lean \ + Benchmarks/SHA512_75_noKernel_noLint.lean \ + Benchmarks/SHA512_150.lean \ + Benchmarks/SHA512_150_noKernel_noLint.lean \ + Benchmarks/SHA512_225.lean \ + Benchmarks/SHA512_225_noKernel_noLint.lean \ + Benchmarks/SHA512_400.lean \ + Benchmarks/SHA512_400_noKernel_noLint.lean + .PHONY: benchmarks benchmarks: - $(LAKE) build Benchmarks + ./scripts/benchmark.sh $(BENCHMARKS) + +.PHONY: profile +profile: + ./scripts/profile.sh $(BENCHMARKS) .PHONY: clean clean_all clean: diff --git a/Proofs/AES-GCM/GCMGmultV8Sym.lean b/Proofs/AES-GCM/GCMGmultV8Sym.lean index 7114063f..6a5ecc7a 100644 --- a/Proofs/AES-GCM/GCMGmultV8Sym.lean +++ b/Proofs/AES-GCM/GCMGmultV8Sym.lean @@ -5,21 +5,93 @@ Author(s): Alex Keizer -/ import Tests.«AES-GCM».GCMGmultV8Program import Tactics.Sym +import Tactics.Aggregate import Tactics.StepThms +import Tactics.CSE +import Arm.Memory.SeparateAutomation +import Arm.Syntax namespace GCMGmultV8Program +open ArmStateNotation #genStepEqTheorems gcm_gmult_v8_program +/- +xxx: GCMGmultV8 Xi HTable +-/ + +set_option pp.deepTerms false in +set_option pp.deepTerms.threshold 50 in +-- set_option trace.simp_mem.info true in theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState) (h_s0_program : s0.program = gcm_gmult_v8_program) (h_s0_err : read_err s0 = .None) (h_s0_pc : read_pc s0 = gcm_gmult_v8_program.min) (h_s0_sp_aligned : CheckSPAlignment s0) + (h_Xi : Xi = s0[read_gpr 64 0#5 s0, 16]) + (h_HTable : HTable = s0[read_gpr 64 1#5 s0, 256]) + (h_mem_sep : Memory.Region.pairwiseSeparate + [(read_gpr 64 0#5 s0, 16), + (read_gpr 64 1#5 s0, 256)]) (h_run : sf = run gcm_gmult_v8_program.length s0) : - read_err sf = .None := by + -- The final state is error-free. + read_err sf = .None ∧ + -- The program is unmodified in `sf`. + sf.program = gcm_gmult_v8_program ∧ + -- The stack pointer is still aligned in `sf`. + CheckSPAlignment sf ∧ + -- The final state returns to the address in register `x30` in `s0`. + read_pc sf = r (StateField.GPR 30#5) s0 ∧ + -- HTable is unmodified. + sf[read_gpr 64 1#5 s0, 256] = HTable ∧ + -- Frame conditions. + -- Note that the following also covers that the Xi address in .GPR 0 + -- is unmodified. + REGS_UNCHANGED_EXCEPT [.SFP 0, .SFP 1, .SFP 2, .SFP 3, + .SFP 17, .SFP 18, .SFP 19, .SFP 20, + .SFP 21, .PC] + (sf, s0) ∧ + -- Memory frame condition. + MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 128)] (sf, s0) := by + simp_all only [state_simp_rules, -h_run] -- prelude simp (config := {ground := true}) only at h_s0_pc -- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow -- unable to be reflected sym_n 27 + -- Epilogue + simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at * + simp only [memory_rules] at * + sym_aggregate + -- Split conjunction + repeat' apply And.intro + · -- Aggregate the memory (non)effects. + -- (FIXME) This will be tackled by `sym_aggregate` when `sym_n` and `simp_mem` + -- are merged. + simp only [*] + /- + (FIXME @bollu) `simp_mem; rfl` creates a malformed proof here. The tactic produces + no goals, but we get the following error message: + + application type mismatch + Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset' + (Eq.mp (congrArg (Eq HTable) (Memory.State.read_mem_bytes_eq_mem_read_bytes s0)) + (Eq.mp (congrArg (fun x => HTable = read_mem_bytes 256 x s0) zeroExtend_eq_of_r_gpr) h_HTable)) + argument has type + HTable = Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem + but function has type + Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem = HTable → + mem_subset' (r (StateField.GPR 1#5) s0) 256 (r (StateField.GPR 1#5) s0) 256 → + Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem = + HTable.extractLsBytes (BitVec.toNat (r (StateField.GPR 1#5) s0) - BitVec.toNat (r (StateField.GPR 1#5) s0)) 256 + + simp_mem; rfl + -/ + rw [Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate'] + simp_mem + · simp only [List.mem_cons, List.mem_singleton, not_or, and_imp] + sym_aggregate + · intro n addr h_separate + simp_mem (config := { useOmegaToClose := false }) + -- Aggregate the memory (non)effects. + simp only [*] done diff --git a/Proofs/Bit_twiddling.lean b/Proofs/Bit_twiddling.lean index 02c1edab..9cd0d829 100644 --- a/Proofs/Bit_twiddling.lean +++ b/Proofs/Bit_twiddling.lean @@ -216,7 +216,7 @@ def popcount32_spec_rec (i : Nat) (x : BitVec 32) : (BitVec 32) := match i with | 0 => BitVec.zero 32 | i' + 1 => - let bit_idx := BitVec.extractLsb i' i' x + let bit_idx := BitVec.extractLsb' i' 1 x let bv_idx := (BitVec.zeroExtend 32 bit_idx) (bv_idx + (popcount32_spec_rec i' x)) @@ -253,7 +253,7 @@ def parity32_spec_rec (i : Nat) (x : BitVec 32) : Bool := match i with | 0 => false | i' + 1 => - let bit_idx := BitVec.getLsb x i' + let bit_idx := BitVec.getLsbD x i' -- let bv_idx := (BitVec.zeroExtend 32 (BitVec.ofBool bit_idx)) Bool.xor bit_idx (parity32_spec_rec i' x) @@ -268,7 +268,7 @@ def parity32_impl (x : BitVec 32) : BitVec 32 := (0x00006996#32 >>> x4) &&& 1#32 theorem parity32_correct (x : BitVec 32) : - (parity32_spec x) = ((parity32_impl x).getLsb 0) := by + (parity32_spec x) = ((parity32_impl x).getLsbD 0) := by unfold parity32_spec parity32_impl repeat (unfold parity32_spec_rec) bv_decide diff --git a/Proofs/Experiments/Abs/Abs.lean b/Proofs/Experiments/Abs/Abs.lean index eb7eed6f..b0e75bfc 100644 --- a/Proofs/Experiments/Abs/Abs.lean +++ b/Proofs/Experiments/Abs/Abs.lean @@ -24,7 +24,7 @@ def spec (x : BitVec 32) : BitVec 32 := -- BitVec.ofNat 32 x.toInt.natAbs -- because the above has functions like `toInt` that do not play well with -- bitblasting/LeanSAT. - let msb := BitVec.extractLsb 31 31 x + let msb := BitVec.extractLsb' 31 1 x if msb == 0#1 then x else diff --git a/Proofs/Experiments/Abs/AbsVCG.lean b/Proofs/Experiments/Abs/AbsVCG.lean index 0c1f109d..beaeffc5 100644 --- a/Proofs/Experiments/Abs/AbsVCG.lean +++ b/Proofs/Experiments/Abs/AbsVCG.lean @@ -39,7 +39,7 @@ def spec (x : BitVec 32) : BitVec 32 := -- BitVec.ofNat 32 x.toInt.natAbs -- because the above has functions like `toInt` that do not play well with -- bitblasting/LeanSAT. - let msb := BitVec.extractLsb 31 31 x + let msb := BitVec.extractLsb' 31 1 x if msb == 0#1 then x else @@ -205,12 +205,12 @@ theorem effects_of_nextc_from_0x4005d0 (h_pre : abs_pre s0) r (StateField.GPR 0#5) sn = BitVec.zeroExtend 64 (AddWithCarry (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s0)) - (BitVec.replicate 32 (BitVec.extractLsb 31 31 (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s0))) &&& + (BitVec.replicate 32 (BitVec.extractLsb' 31 1 (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s0))) &&& 0xfffffffe#32 ||| (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s0)).rotateRight 31 &&& 0xffffffff#32 &&& 0x1#32) 0x0#1).fst ^^^ (BitVec.zeroExtend 64 - (BitVec.replicate 32 (BitVec.extractLsb 31 31 (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s0)))) &&& + (BitVec.replicate 32 (BitVec.extractLsb' 31 1 (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s0)))) &&& 0xfffffffe#64 ||| BitVec.zeroExtend 64 ((BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s0)).rotateRight 31) &&& 0xffffffff#64 &&& 0x1#64) ∧ diff --git a/Proofs/Experiments/Abs/AbsVCGTandem.lean b/Proofs/Experiments/Abs/AbsVCGTandem.lean index e5101019..72271b76 100644 --- a/Proofs/Experiments/Abs/AbsVCGTandem.lean +++ b/Proofs/Experiments/Abs/AbsVCGTandem.lean @@ -37,7 +37,7 @@ def spec (x : BitVec 32) : BitVec 32 := -- BitVec.ofNat 32 x.toInt.natAbs -- because the above has functions like `toInt` that do not play well with -- bitblasting/LeanSAT. - let msb := BitVec.extractLsb 31 31 x + let msb := BitVec.extractLsb' 31 1 x if msb == 0#1 then x else @@ -124,7 +124,7 @@ theorem program.stepi_0x4005d4_cut (s sn : ArmState) abs_cut sn = false ∧ r (StateField.GPR 0#5) sn = (BitVec.zeroExtend 64 - (BitVec.replicate 32 (BitVec.extractLsb 31 31 (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s)))) &&& + (BitVec.replicate 32 (BitVec.extractLsb' 31 1 (BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s)))) &&& 0xfffffffe#64 ||| BitVec.zeroExtend 64 ((BitVec.zeroExtend 32 (r (StateField.GPR 0x0#5) s)).rotateRight 31) &&& 0xffffffff#64 &&& 0x1#64) ∧ diff --git a/Proofs/Experiments/Memcpy/MemCpyVCG.lean b/Proofs/Experiments/Memcpy/MemCpyVCG.lean index 87af5e43..e12f2b2b 100644 --- a/Proofs/Experiments/Memcpy/MemCpyVCG.lean +++ b/Proofs/Experiments/Memcpy/MemCpyVCG.lean @@ -268,11 +268,10 @@ theorem program.step_8e4_8e8_of_wellformed_of_stepped (scur snext : ArmState) obtain h_sp_aligned := hscur.h_sp_aligned have := program.stepi_eq_0x8e4 h_program h_pc h_err - simp [BitVec.extractLsb] at this obtain ⟨h_step⟩ := hstep subst h_step constructor <;> simp only [*, cut, state_simp_rules, minimal_theory, bitvec_rules] - · constructor <;> simp [*, state_simp_rules, minimal_theory, BitVec.extractLsb] + · constructor <;> simp [*, state_simp_rules, minimal_theory] -- 3/7 (0x8e8#64, 0x3c810444#32), /- str q4, [x2], #16 -/ structure Step_8e8_8ec (scur : ArmState) (snext : ArmState) extends WellFormedAtPc snext 0x8ec : Prop where @@ -296,7 +295,6 @@ theorem program.step_8e8_8ec_of_wellformed (scur snext : ArmState) obtain h_sp_aligned := hscur.h_sp_aligned have := program.stepi_eq_0x8e8 h_program h_pc h_err - simp [BitVec.extractLsb] at this obtain ⟨h_step⟩ := hstep subst h_step constructor @@ -335,11 +333,10 @@ theorem program.step_8ec_8f0_of_wellformed (scur snext : ArmState) obtain h_sp_aligned := hs.h_sp_aligned have := program.stepi_eq_0x8ec h_program h_pc h_err - simp [BitVec.extractLsb] at this obtain ⟨h_step⟩ := hstep subst h_step constructor <;> simp (config := { ground := true, decide := true}) [*, - state_simp_rules, minimal_theory, BitVec.extractLsb, fst_AddWithCarry_eq_sub_neg, memory_rules] + state_simp_rules, minimal_theory, fst_AddWithCarry_eq_sub_neg, memory_rules] · constructor <;> simp [*, state_simp_rules, minimal_theory, bitvec_rules, memory_rules] -- 5/7 (0x8f0#64, 0xf100001f#32), /- cmp x0, #0x0 -/ @@ -364,7 +361,7 @@ theorem program.step_8f0_8f4_of_wellformed (scur snext : ArmState) obtain h_sp_aligned := hs.h_sp_aligned have := program.stepi_eq_0x8f0 h_program h_pc h_err - simp (config := { ground := true, decide := true}) [BitVec.extractLsb, + simp (config := { ground := true, decide := true}) [ fst_AddWithCarry_eq_sub_neg, fst_AddWithCarry_eq_add] at this obtain ⟨h_step⟩ := hstep @@ -396,14 +393,14 @@ theorem program.step_8f4_8e4_of_wellformed_of_z_eq_0 (scur snext : ArmState) obtain h_sp_aligned := hs.h_sp_aligned have := program.stepi_eq_0x8f4 h_program h_pc h_err - simp (config := { ground := true, decide := true}) [BitVec.extractLsb, + simp (config := { ground := true, decide := true}) [ fst_AddWithCarry_eq_sub_neg, fst_AddWithCarry_eq_add] at this obtain ⟨h_step⟩ := hstep subst h_step constructor <;> solve | simp (config := { ground := true, decide := true}) [*, - state_simp_rules, minimal_theory, BitVec.extractLsb, fst_AddWithCarry_eq_sub_neg] + state_simp_rules, minimal_theory, fst_AddWithCarry_eq_sub_neg] · constructor <;> simp [*, state_simp_rules, minimal_theory, bitvec_rules] -- 6/7 (0x8f4#64, 0x54ffff81#32), /- b.ne 8e4 -/ @@ -426,14 +423,14 @@ theorem program.step_8f4_8f8_of_wellformed_of_z_eq_1 (scur snext : ArmState) obtain h_sp_aligned := hs.h_sp_aligned have := program.stepi_eq_0x8f4 h_program h_pc h_err - simp (config := { ground := true, decide := true}) [BitVec.extractLsb, + simp (config := { ground := true, decide := true}) [ fst_AddWithCarry_eq_sub_neg, fst_AddWithCarry_eq_add] at this obtain ⟨h_step⟩ := hstep subst h_step constructor <;> simp (config := { ground := true, decide := true}) [*, state_simp_rules, h_z, - minimal_theory, BitVec.extractLsb, fst_AddWithCarry_eq_sub_neg, cut] + minimal_theory, fst_AddWithCarry_eq_sub_neg, cut] · constructor <;> simp [*, h_z, state_simp_rules, minimal_theory, bitvec_rules, cut] end CutTheorems @@ -441,6 +438,8 @@ end CutTheorems section PartialCorrectness -- set_option skip_proof.skip true in +-- set_option trace.profiler true in +-- set_option profiler true in set_option maxHeartbeats 0 in theorem Memcpy.extracted_2 (s0 si : ArmState) (h_si_x0_nonzero : si.x0 ≠ 0) @@ -482,6 +481,8 @@ theorem Memcpy.extracted_2 (s0 si : ArmState) -- set_option skip_proof.skip true in set_option maxHeartbeats 0 in +-- set_option trace.profiler true in +-- set_option profiler true in theorem Memcpy.extracted_0 (s0 si : ArmState) (h_si_x0_nonzero : si.x0 ≠ 0) (h_s0_x1 : s0.x1 + 0x10#64 * (s0.x0 - si.x0) + 0x10#64 = s0.x1 + 0x10#64 * (s0.x0 - (si.x0 - 0x1#64))) @@ -550,6 +551,9 @@ theorem Memcpy.extracted_0 (s0 si : ArmState) } · intros n addr hsep apply Memcpy.extracted_2 <;> assumption + +-- set_option trace.profiler true in +-- set_option profiler true in theorem partial_correctness : PartialCorrectness ArmState := by apply Correctness.partial_correctness_from_assertions diff --git a/Proofs/Popcount32.lean b/Proofs/Popcount32.lean index dbde86a2..699e0147 100644 --- a/Proofs/Popcount32.lean +++ b/Proofs/Popcount32.lean @@ -14,6 +14,7 @@ import Tactics.StepThms section popcount32 open BitVec +open ArmState /-! Source: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel @@ -24,20 +25,19 @@ int popcount_32 (unsigned int v) { v = ((v + (v >> 4) & 0xF0F0F0F) * 0x1010101) >> 24; return(v); } - -/ -def popcount32_spec_rec (i : Nat) (x : BitVec 32) : (BitVec 32) := +def popcount32_spec_rec (i : Nat) (x : BitVec 32) : BitVec 32 := match i with | 0 => 0#32 | i' + 1 => let bit_idx := BitVec.getLsbD x i' - ((BitVec.zeroExtend 32 (BitVec.ofBool bit_idx)) + (popcount32_spec_rec i' x)) + (BitVec.ofBool bit_idx).zeroExtend 32 + + popcount32_spec_rec i' x def popcount32_spec (x : BitVec 32) : BitVec 32 := popcount32_spec_rec 32 x - def popcount32_program : Program := def_program [(0x4005b4#64 , 0xd10043ff#32), -- sub sp, sp, #0x10 @@ -68,37 +68,45 @@ def popcount32_program : Program := (0x400618#64 , 0x910043ff#32), -- add sp, sp, #0x10 (0x40061c#64 , 0xd65f03c0#32)] -- ret - #genStepEqTheorems popcount32_program -theorem popcount32_sym_meets_spec (s0 s_final : ArmState) - (h_s0_pc : read_pc s0 = 0x4005b4#64) - (h_s0_program : s0.program = popcount32_program) +set_option trace.simp_mem.info true in +theorem popcount32_sym_meets_spec (s0 sf : ArmState) + (h_s0_pc : read_pc s0 = 0x4005b4#64) + (h_s0_program : s0.program = popcount32_program) (h_s0_sp_aligned : CheckSPAlignment s0) - (h_s0_err : read_err s0 = StateError.None) - (h_run : s_final = run 27 s0) : - read_gpr 32 0#5 s_final = popcount32_spec (read_gpr 32 0#5 s0) ∧ - read_err s_final = StateError.None ∧ - (∀ f, f ≠ (.GPR 0#5) ∧ f ≠ (.GPR 1#5) ∧ f ≠ (.GPR 31#5) ∧ f ≠ .PC → - r f s_final = r f s0) ∧ - (∀ (n : Nat) (addr : BitVec 64), - mem_separate' addr n (r (.GPR 31) s0 - 16#64) 16 → - s_final[addr, n] = s0[addr, n]) := by - simp_all only [state_simp_rules, -h_run] -- prelude - sym_n 27 -- Symbolic simulation - repeat' apply And.intro -- split conjunction. - · simp only [popcount32_spec, popcount32_spec_rec] + (h_s0_err : read_err s0 = StateError.None) + (h_run : sf = run 27 s0) : + -- The final state `sf` is error-free. + read_err sf = StateError.None ∧ + -- Register `w0` in `sf` contains the correct value. + w0 sf = popcount32_spec (w0 s0) ∧ + -- The frame condition describes which state components are not affected by + -- this program's execution. + REGS_UNCHANGED_EXCEPT [(.GPR 0), (.GPR 1), .SP, .PC] (sf, s0) ∧ + MEM_UNCHANGED_EXCEPT [((r .SP s0 - 16#64), 16)] (sf, s0) := by + -- Prelude + simp_all only [state_simp_rules, -h_run] + -- Symbolic simulation + sym_n 27 + -- TODO(@bollu): automation for SP alignment + case h_s1_sp_aligned => + apply Aligned_BitVecSub_64_4 (by assumption) (by decide) + case h_s26_sp_aligned => + apply Aligned_BitVecAdd_64_4 (by assumption) (by decide) + -- Split the conclusion + repeat' apply And.intro + · -- Functional Correctness + simp only [popcount32_spec, popcount32_spec_rec] bv_decide - · sym_aggregate - · intro n addr h_separate + · -- Register Frame Condition + simp only [List.mem_cons, List.mem_singleton, not_or, and_imp]; sym_aggregate + · -- Memory Frame Condition + intro n addr h_separate simp only [memory_rules] at * repeat (simp_mem (config := { useOmegaToClose := false }); sym_aggregate) - · apply Aligned_BitVecSub_64_4 -- TODO(@bollu): automation - · assumption - · decide - · apply Aligned_BitVecAdd_64_4 - · assumption - · decide + + done /-- info: 'popcount32_sym_meets_spec' depends on axioms: @@ -108,20 +116,4 @@ info: 'popcount32_sym_meets_spec' depends on axioms: ------------------------------------------------------------------------------- -/-! ## Tests for step theorem generation -/ -section Tests - -/-- -info: popcount32_program.stepi_eq_0x4005c0 {s : ArmState} (h_program : s.program = popcount32_program) - (h_pc : r StateField.PC s = 4195776#64) (h_err : r StateField.ERR s = StateError.None) : - stepi s = - w StateField.PC (4195780#64) - (w (StateField.GPR 0#5) - (zeroExtend 64 ((zeroExtend 32 (r (StateField.GPR 0#5) s)).rotateRight 1) &&& 4294967295#64 &&& 2147483647#64) - s) --/ -#guard_msgs in #check popcount32_program.stepi_eq_0x4005c0 - -end Tests - end popcount32 diff --git a/Proofs/SHA512/SHA512Loop.lean b/Proofs/SHA512/SHA512Loop.lean index 75027137..2a9f715d 100644 --- a/Proofs/SHA512/SHA512Loop.lean +++ b/Proofs/SHA512/SHA512Loop.lean @@ -45,11 +45,12 @@ def loop_post (PC N SP CtxBase InputBase : BitVec 64) ctx_addr si = CtxBase ∧ stack_ptr si = SP - 16#64 ∧ si[KtblAddr, (SHA2.k_512.length * 8)] = BitVec.flatten SHA2.k_512 ∧ - Memory.Region.pairwiseSeparate - [(SP - 16#64, 16), - (CtxBase, 64), - (InputBase, (N.toNat * 128)), - (KtblAddr, (SHA2.k_512.length * 8))] ∧ + -- (TODO @alex @bollu Uncomment, please, for stress-testing) +-- Memory.Region.pairwiseSeparate +-- [(SP - 16#64, 16), +-- (CtxBase, 64), +-- (InputBase, (N.toNat * 128)), +-- (KtblAddr, (SHA2.k_512.length * 8))] ∧ r (.GPR 3#5) si = KtblAddr ∧ input_addr si = InputBase + (N * 128#64) ∧ -- Registers contain the last processed input block. @@ -75,13 +76,13 @@ set_option debug.skipKernelTC true in -- set_option profiler true in -- set_option profiler.threshold 1 in set_option maxHeartbeats 0 in -set_option maxRecDepth 8000 in +-- set_option maxRecDepth 8000 in theorem sha512_block_armv8_loop_1block (si sf : ArmState) (h_N : N = 1#64) (h_si_prelude : SHA512.prelude 0x126500#64 N SP CtxBase InputBase si) -- TODO: Ideally, nsteps ought to be 485 to be able to simulate the loop to -- completion. - (h_steps : nsteps = 200) + (h_steps : nsteps = 400) (h_run : sf = run nsteps si) : -- (FIXME) PC should be 0x126c94#64 i.e., we are poised to execute the first -- instruction following the loop. For now, we stop early on to remain in sync. @@ -94,12 +95,19 @@ theorem sha512_block_armv8_loop_1block (si sf : ArmState) h_si_input_base, h_si_ctx, h_si_ktbl, h_si_separate⟩ := h_si_prelude simp only [num_blocks, ctx_addr, stack_ptr, input_addr] at * simp only [loop_post] + simp at h_si_separate -- Symbolic Simulation /- TODO @alex: The following aggregation fails with "simp failed, maximum number of steps exceeded" -/ - -- sym_n 200 + sym_n 100 + sym_n 100 + sym_n 100 + sym_n 100 + -- sym_aggregate + + -- Epilogue -- cse (config := { processHyps := .allHyps }) -- simp (config := {ground := true}) only diff --git a/Proofs/SHA512/SHA512Prelude.lean b/Proofs/SHA512/SHA512Prelude.lean index 78a10f98..9013de0c 100644 --- a/Proofs/SHA512/SHA512Prelude.lean +++ b/Proofs/SHA512/SHA512Prelude.lean @@ -124,7 +124,7 @@ theorem sha512_block_armv8_prelude (s0 sf : ArmState) /- (FIXME) The `rw` below fails with: tactic 'rewrite' failed, did not find instance of the pattern in the target expression - extractLsb 3 0 (?m.1887 + ?m.1888) + extractLsb' 0 4 (?m.1887 + ?m.1888) Why is `Aligned` opened up here? -/ diff --git a/Proofs/SHA512/SHA512_block_armv8_rules.lean b/Proofs/SHA512/SHA512_block_armv8_rules.lean index 888bf583..6d2375b1 100644 --- a/Proofs/SHA512/SHA512_block_armv8_rules.lean +++ b/Proofs/SHA512/SHA512_block_armv8_rules.lean @@ -18,13 +18,13 @@ open SHA2 theorem sha512_message_schedule_rule (a b c d : BitVec 128) : sha512su1 a b (sha512su0 c d) = - let a1 := extractLsb 127 64 a - let a0 := extractLsb 63 0 a - let b1 := extractLsb 127 64 b - let b0 := extractLsb 63 0 b - let c0 := extractLsb 63 0 c - let d1 := extractLsb 127 64 d - let d0 := extractLsb 63 0 d + let a1 := extractLsb' 64 64 a + let a0 := extractLsb' 0 64 a + let b1 := extractLsb' 64 64 b + let b0 := extractLsb' 0 64 b + let c0 := extractLsb' 0 64 c + let d1 := extractLsb' 64 64 d + let d0 := extractLsb' 0 64 d message_schedule_word_aux a1 b1 c0 d1 ++ message_schedule_word_aux a0 b0 d1 d0 := by simp [sha512su1, sha512su0, message_schedule_word_aux] @@ -32,11 +32,11 @@ theorem sha512_message_schedule_rule (a b c d : BitVec 128) : theorem sha512h2_rule (a b c : BitVec 128) : sha512h2 a b c = - let a0 := extractLsb 63 0 a - let b1 := extractLsb 127 64 b - let b0 := extractLsb 63 0 b - let c0 := extractLsb 63 0 c - let c1 := extractLsb 127 64 c + let a0 := extractLsb' 0 64 a + let b1 := extractLsb' 64 64 b + let b0 := extractLsb' 0 64 b + let c0 := extractLsb' 0 64 c + let c1 := extractLsb' 64 64 c ((compression_update_t2 b0 a0 b1) + c1) ++ ((compression_update_t2 ((compression_update_t2 b0 a0 b1) + c1) b0 b1) + c0) := by simp [maj, compression_update_t2, sha512h2, sigma_big_0, ror] @@ -64,12 +64,12 @@ private theorem and_nop_lemma (x : BitVec 64) : (zeroExtend 128 x) &&& 0xffffffffffffffff#128 = (zeroExtend 128 x) := by bv_decide -private theorem extractLsb_low_64_from_zeroExtend_128_or (x y : BitVec 64) : - extractLsb 63 0 ((zeroExtend 128 x) ||| (zeroExtend 128 y <<< 64)) = x := by +private theorem extractLsb'_low_64_from_zeroExtend_128_or (x y : BitVec 64) : + extractLsb' 0 64 ((zeroExtend 128 x) ||| (zeroExtend 128 y <<< 64)) = x := by bv_decide -private theorem extractLsb_high_64_from_zeroExtend_128_or (x y : BitVec 64) : - extractLsb 127 64 ((zeroExtend 128 x) ||| (zeroExtend 128 y <<< 64)) = y := by +private theorem extractLsb'_high_64_from_zeroExtend_128_or (x y : BitVec 64) : + extractLsb' 64 64 ((zeroExtend 128 x) ||| (zeroExtend 128 y <<< 64)) = y := by bv_decide -- This lemma takes ~5min with bv_decide and the generated LRAT @@ -85,16 +85,16 @@ theorem sha512h_rule_1 (a b c d e : BitVec 128) : let esize := 64 let inner_sum := (binary_vector_op_aux 0 elements esize BitVec.add c d (BitVec.zero 128) H) let outer_sum := (binary_vector_op_aux 0 elements esize BitVec.add inner_sum e (BitVec.zero 128) H) - let a0 := extractLsb 63 0 a - let a1 := extractLsb 127 64 a - let b0 := extractLsb 63 0 b - let b1 := extractLsb 127 64 b - let c0 := extractLsb 63 0 c - let c1 := extractLsb 127 64 c - let d0 := extractLsb 63 0 d - let d1 := extractLsb 127 64 d - let e0 := extractLsb 63 0 e - let e1 := extractLsb 127 64 e + let a0 := extractLsb' 0 64 a + let a1 := extractLsb' 64 64 a + let b0 := extractLsb' 0 64 b + let b1 := extractLsb' 64 64 b + let c0 := extractLsb' 0 64 c + let c1 := extractLsb' 64 64 c + let d0 := extractLsb' 0 64 d + let d1 := extractLsb' 64 64 d + let e0 := extractLsb' 0 64 e + let e1 := extractLsb' 64 64 e let hi64_spec := compression_update_t1 b1 a0 a1 c1 d1 e1 let lo64_spec := compression_update_t1 (b0 + hi64_spec) b1 a0 c0 d0 e0 sha512h a b outer_sum = hi64_spec ++ lo64_spec := by @@ -104,19 +104,19 @@ theorem sha512h_rule_1 (a b c d e : BitVec 128) : unfold sha512h compression_update_t1 sigma_big_1 ch allOnes ror simp only [Nat.reduceAdd, Nat.reduceSub, Nat.sub_zero, Nat.reducePow, reduceZeroExtend, reduceHShiftLeft, reduceNot, reduceAnd, BitVec.zero_or, shiftLeft_zero_eq] - generalize extractLsb 63 0 a = a_lo - generalize extractLsb 127 64 a = a_hi - generalize extractLsb 63 0 b = b_lo - generalize extractLsb 127 64 b = b_hi - generalize extractLsb 63 0 c = c_lo - generalize extractLsb 127 64 c = c_hi - generalize extractLsb 63 0 d = d_lo - generalize extractLsb 127 64 d = d_hi - generalize extractLsb 63 0 e = e_lo - generalize extractLsb 127 64 e = e_hi - simp at a_lo a_hi b_lo b_hi c_lo c_hi d_lo d_hi e_lo e_hi + generalize extractLsb' 0 64 a = a_lo + generalize extractLsb' 64 64 a = a_hi + generalize extractLsb' 0 64 b = b_lo + generalize extractLsb' 64 64 b = b_hi + generalize extractLsb' 0 64 c = c_lo + generalize extractLsb' 64 64 c = c_hi + generalize extractLsb' 0 64 d = d_lo + generalize extractLsb' 64 64 d = d_hi + generalize extractLsb' 0 64 e = e_lo + generalize extractLsb' 64 64 e = e_hi + -- simp at a_lo a_hi b_lo b_hi c_lo c_hi d_lo d_hi e_lo e_hi clear a b c d e - simp only [and_nop_lemma, extractLsb_low_64_from_zeroExtend_128_or, extractLsb_high_64_from_zeroExtend_128_or] + simp only [and_nop_lemma, extractLsb'_low_64_from_zeroExtend_128_or, extractLsb'_high_64_from_zeroExtend_128_or] generalize (b_hi.rotateRight 14 ^^^ b_hi.rotateRight 18 ^^^ b_hi.rotateRight 41) = aux0 generalize (b_hi &&& a_lo ^^^ ~~~b_hi &&& a_hi) = aux1 ac_rfl @@ -165,8 +165,8 @@ theorem rev_vector_of_rev_vector_128_64_8 (x : BitVec 128) : done private theorem sha512h_rule_2_helper_1 (x y : BitVec 64) : - extractLsb 63 0 - (extractLsb 191 64 + extractLsb' 0 64 + (extractLsb' 64 128 ((zeroExtend 128 x ||| zeroExtend 128 y <<< 64) ++ (zeroExtend 128 x ||| zeroExtend 128 y <<< 64))) = @@ -174,8 +174,8 @@ private theorem sha512h_rule_2_helper_1 (x y : BitVec 64) : bv_decide private theorem sha512h_rule_2_helper_2 (x y : BitVec 64) : - extractLsb 127 64 - (extractLsb 191 64 + extractLsb' 64 64 + (extractLsb' 64 128 ((zeroExtend 128 x ||| zeroExtend 128 y <<< 64) ++ (zeroExtend 128 x ||| zeroExtend 128 y <<< 64))) = @@ -186,19 +186,19 @@ private theorem sha512h_rule_2_helper_2 (x y : BitVec 64) : -- file is ~120MB. As with sha512h_rule_1 above, we prefer to just simplify and -- normalize here instead of doing bit-blasting. theorem sha512h_rule_2 (a b c d e : BitVec 128) : - let a0 := extractLsb 63 0 a - let a1 := extractLsb 127 64 a - let b0 := extractLsb 63 0 b - let b1 := extractLsb 127 64 b - let c0 := extractLsb 63 0 c - let c1 := extractLsb 127 64 c - let d0 := extractLsb 63 0 d - let d1 := extractLsb 127 64 d - let e0 := extractLsb 63 0 e - let e1 := extractLsb 127 64 e + let a0 := extractLsb' 0 64 a + let a1 := extractLsb' 64 64 a + let b0 := extractLsb' 0 64 b + let b1 := extractLsb' 64 64 b + let c0 := extractLsb' 0 64 c + let c1 := extractLsb' 64 64 c + let d0 := extractLsb' 0 64 d + let d1 := extractLsb' 64 64 d + let e0 := extractLsb' 0 64 e + let e1 := extractLsb' 64 64 e let inner_sum := binary_vector_op_aux 0 2 64 BitVec.add d e (BitVec.zero 128) h1 let concat := inner_sum ++ inner_sum - let operand := extractLsb 191 64 concat + let operand := extractLsb' 64 128 concat let hi64_spec := compression_update_t1 b1 a0 a1 c1 d0 e0 let lo64_spec := compression_update_t1 (b0 + hi64_spec) b1 a0 c0 d1 e1 sha512h a b (binary_vector_op_aux 0 2 64 BitVec.add c operand (BitVec.zero 128) h2) = @@ -210,19 +210,18 @@ theorem sha512h_rule_2 (a b c d e : BitVec 128) : reduceZeroExtend, Nat.zero_mul, reduceHShiftLeft, reduceNot, reduceAnd, Nat.one_mul, BitVec.cast_eq] simp only [shiftLeft_zero_eq, BitVec.zero_or, and_nop_lemma] - generalize extractLsb 63 0 a = a_lo - generalize extractLsb 127 64 a = a_hi - generalize extractLsb 63 0 b = b_lo - generalize extractLsb 127 64 b = b_hi - generalize extractLsb 63 0 c = c_lo - generalize extractLsb 127 64 c = c_hi - generalize extractLsb 63 0 d = d_lo - generalize extractLsb 127 64 d = d_hi - generalize extractLsb 63 0 e = e_lo - generalize extractLsb 127 64 e = e_hi - simp at a_lo a_hi b_lo b_hi c_lo c_hi d_lo d_hi e_lo e_hi + generalize extractLsb' 0 64 a = a_lo + generalize extractLsb' 64 64 a = a_hi + generalize extractLsb' 0 64 b = b_lo + generalize extractLsb' 64 64 b = b_hi + generalize extractLsb' 0 64 c = c_lo + generalize extractLsb' 64 64 c = c_hi + generalize extractLsb' 0 64 d = d_lo + generalize extractLsb' 64 64 d = d_hi + generalize extractLsb' 0 64 e = e_lo + generalize extractLsb' 64 64 e = e_hi clear a b c d e - simp only [extractLsb_high_64_from_zeroExtend_128_or, extractLsb_low_64_from_zeroExtend_128_or] + simp only [extractLsb'_high_64_from_zeroExtend_128_or, extractLsb'_low_64_from_zeroExtend_128_or] simp only [sha512h_rule_2_helper_1, sha512h_rule_2_helper_2] generalize (b_hi.rotateRight 14 ^^^ b_hi.rotateRight 18 ^^^ b_hi.rotateRight 41) = aux1 ac_rfl diff --git a/README.md b/README.md index 47939c18..2054fdf4 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,8 @@ native-code programs of interest. `benchmarks`: run benchmarks for the symbolic simulator. +`profiler`: run a single round of each benchmark, with the profiler enabled + ### Makefile variables that can be passed in at the command line `VERBOSE`: Verbose mode; prints disassembly of the instructions being diff --git a/Specs/AESArm.lean b/Specs/AESArm.lean index c549f8a6..e486fff6 100644 --- a/Specs/AESArm.lean +++ b/Specs/AESArm.lean @@ -110,21 +110,18 @@ def SubWord (w : BitVec WordSize) : BitVec WordSize := protected def InitKey {Param : KBR} (i : Nat) (key : BitVec Param.key_len) (acc : KeySchedule) : KeySchedule := - if h₀ : Param.Nk ≤ i then acc - else - have h₁ : i * 32 + 32 - 1 - i * 32 + 1 = WordSize := by - simp only [WordSize]; omega - let wd := BitVec.cast h₁ $ extractLsb (i * 32 + 32 - 1) (i * 32) key - let (x:KeySchedule) := [wd] - have _ : Param.Nk - (i + 1) < Param.Nk - i := by omega - AESArm.InitKey (Param := Param) (i + 1) key (acc ++ x) - termination_by (Param.Nk - i) + match i with + | 0 => acc + | i' + 1 => + let wd := extractLsb' ((i - 1) * 32) 32 key + AESArm.InitKey (Param := Param) i' key (wd :: acc) protected def KeyExpansion_helper {Param : KBR} (i : Nat) (ks : KeySchedule) : KeySchedule := - if h : 4 * Param.Nr + 4 ≤ i then - ks - else + match i with + | 0 => ks + | i' + 1 => + let i := 4 * Param.Nr + 4 - i let tmp := List.get! ks (i - 1) let tmp := if i % Param.Nk == 0 then @@ -135,14 +132,12 @@ protected def KeyExpansion_helper {Param : KBR} (i : Nat) (ks : KeySchedule) tmp let res := (List.get! ks (i - Param.Nk)) ^^^ tmp let ks := List.append ks [ res ] - have _ : 4 * Param.Nr + 4 - (i + 1) < 4 * Param.Nr + 4 - i := by omega - AESArm.KeyExpansion_helper (Param := Param) (i + 1) ks - termination_by (4 * Param.Nr + 4 - i) + AESArm.KeyExpansion_helper (Param := Param) i' ks def KeyExpansion {Param : KBR} (key : BitVec Param.key_len) : KeySchedule := - let seeded := AESArm.InitKey (Param := Param) 0 key [] - AESArm.KeyExpansion_helper (Param := Param) Param.Nk seeded + let seeded := AESArm.InitKey (Param := Param) Param.Nk key [] + AESArm.KeyExpansion_helper (Param := Param) (4 * Param.Nr + 4 - Param.Nk) seeded def SubBytes {Param : KBR} (state : BitVec Param.block_size) : BitVec Param.block_size := @@ -215,22 +210,22 @@ protected def getKey {Param : KBR} (n : Nat) (w : KeySchedule) : BitVec Param.bl protected def AES_encrypt_with_ks_loop {Param : KBR} (round : Nat) (state : BitVec Param.block_size) (w : KeySchedule) : BitVec Param.block_size := - if Param.Nr ≤ round then - state - else + match round with + | 0 => state + | round' + 1 => + let round := Param.Nr - round let state := SubBytes state let state := ShiftRows state let state := MixColumns state let state := AddRoundKey state $ AESArm.getKey round w - AESArm.AES_encrypt_with_ks_loop (Param := Param) (round + 1) state w - termination_by (Param.Nr - round) + AESArm.AES_encrypt_with_ks_loop (Param := Param) round' state w def AES_encrypt_with_ks {Param : KBR} (input : BitVec Param.block_size) (w : KeySchedule) : BitVec Param.block_size := -- have h₀ : WordSize + WordSize + WordSize + WordSize = Param.block_size := by -- simp only [WordSize, BlockSize, Param.h] let state := AddRoundKey input $ (AESArm.getKey 0 w) - let state := AESArm.AES_encrypt_with_ks_loop (Param := Param) 1 state w + let state := AESArm.AES_encrypt_with_ks_loop (Param := Param) (Param.Nr - 1) state w let state := SubBytes (Param := Param) state let state := ShiftRows (Param := Param) state AddRoundKey state $ AESArm.getKey Param.Nr w diff --git a/Specs/AESCommon.lean b/Specs/AESCommon.lean index f36ae837..3c8a8ff0 100644 --- a/Specs/AESCommon.lean +++ b/Specs/AESCommon.lean @@ -33,84 +33,81 @@ def SBOX := ] def ShiftRows (op : BitVec 128) : BitVec 128 := - extractLsb 95 88 op ++ extractLsb 55 48 op ++ - extractLsb 15 8 op ++ extractLsb 103 96 op ++ - extractLsb 63 56 op ++ extractLsb 23 16 op ++ - extractLsb 111 104 op ++ extractLsb 71 64 op ++ - extractLsb 31 24 op ++ extractLsb 119 112 op ++ - extractLsb 79 72 op ++ extractLsb 39 32 op ++ - extractLsb 127 120 op ++ extractLsb 87 80 op ++ - extractLsb 47 40 op ++ extractLsb 7 0 op + extractLsb' 88 8 op ++ extractLsb' 48 8 op ++ + extractLsb' 8 8 op ++ extractLsb' 96 8 op ++ + extractLsb' 56 8 op ++ extractLsb' 16 8 op ++ + extractLsb' 104 8 op ++ extractLsb' 64 8 op ++ + extractLsb' 24 8 op ++ extractLsb' 112 8 op ++ + extractLsb' 72 8 op ++ extractLsb' 32 8 op ++ + extractLsb' 120 8 op ++ extractLsb' 80 8 op ++ + extractLsb' 40 8 op ++ extractLsb' 0 8 op def SubBytes_aux (i : Nat) (op : BitVec 128) (out : BitVec 128) : BitVec 128 := - if h₀ : 16 <= i then - out - else - let idx := (extractLsb (i * 8 + 7) (i * 8) op).toNat - let val := extractLsb (idx * 8 + 7) (idx * 8) $ BitVec.flatten SBOX - have h₁ : idx * 8 + 7 - idx * 8 + 1 = i * 8 + 7 - i * 8 + 1 := by omega + match i with + | 0 => out + | i' + 1 => + let i := 16 - i + let idx := (extractLsb' (i * 8) 8 op).toNat + let val := extractLsb' (idx * 8) 8 $ BitVec.flatten SBOX + have h₁ : 8 = i * 8 + 7 - i * 8 + 1 := by omega let out := BitVec.partInstall (i * 8 + 7) (i * 8) (BitVec.cast h₁ val) out - have _ : 15 - i < 16 - i := by omega - SubBytes_aux (i + 1) op out - termination_by (16 - i) + SubBytes_aux i' op out def SubBytes (op : BitVec 128) : BitVec 128 := - SubBytes_aux 0 op (BitVec.zero 128) + SubBytes_aux 16 op (BitVec.zero 128) def MixColumns_aux (c : Nat) (in0 : BitVec 32) (in1 : BitVec 32) (in2 : BitVec 32) (in3 : BitVec 32) (out0 : BitVec 32) (out1 : BitVec 32) (out2 : BitVec 32) (out3 : BitVec 32) (FFmul02 : BitVec 8 -> BitVec 8) (FFmul03 : BitVec 8 -> BitVec 8) : BitVec 32 × BitVec 32 × BitVec 32 × BitVec 32 := - if h₀ : 4 <= c then - (out0, out1, out2, out3) - else - let lo := c * 8 + match c with + | 0 => (out0, out1, out2, out3) + | c' + 1 => + let lo := (4 - c) * 8 let hi := lo + 7 - have h₁ : hi - lo + 1 = 8 := by omega - let in0_byte := BitVec.cast h₁ $ extractLsb hi lo in0 - let in1_byte := BitVec.cast h₁ $ extractLsb hi lo in1 - let in2_byte := BitVec.cast h₁ $ extractLsb hi lo in2 - let in3_byte := BitVec.cast h₁ $ extractLsb hi lo in3 - let val0 := BitVec.cast h₁.symm $ (FFmul02 in0_byte ^^^ FFmul03 in1_byte ^^^ in2_byte ^^^ in3_byte) + let in0_byte := extractLsb' lo 8 in0 + let in1_byte := extractLsb' lo 8 in1 + let in2_byte := extractLsb' lo 8 in2 + let in3_byte := extractLsb' lo 8 in3 + have h : 8 = hi - lo + 1 := by omega + let val0 := BitVec.cast h $ FFmul02 in0_byte ^^^ FFmul03 in1_byte ^^^ in2_byte ^^^ in3_byte let out0 := BitVec.partInstall hi lo val0 out0 - let val1 := BitVec.cast h₁.symm $ (FFmul02 in1_byte ^^^ FFmul03 in2_byte ^^^ in3_byte ^^^ in0_byte) + let val1 := BitVec.cast h $ FFmul02 in1_byte ^^^ FFmul03 in2_byte ^^^ in3_byte ^^^ in0_byte let out1 := BitVec.partInstall hi lo val1 out1 - let val2 := BitVec.cast h₁.symm $ (FFmul02 in2_byte ^^^ FFmul03 in3_byte ^^^ in0_byte ^^^ in1_byte) + let val2 := BitVec.cast h $ FFmul02 in2_byte ^^^ FFmul03 in3_byte ^^^ in0_byte ^^^ in1_byte let out2 := BitVec.partInstall hi lo val2 out2 - let val3 := BitVec.cast h₁.symm $ (FFmul02 in3_byte ^^^ FFmul03 in0_byte ^^^ in1_byte ^^^ in2_byte) + let val3 := BitVec.cast h $ FFmul02 in3_byte ^^^ FFmul03 in0_byte ^^^ in1_byte ^^^ in2_byte let out3 := BitVec.partInstall hi lo val3 out3 - have _ : 3 - c < 4 - c := by omega - MixColumns_aux (c + 1) in0 in1 in2 in3 out0 out1 out2 out3 FFmul02 FFmul03 - termination_by (4 - c) + MixColumns_aux c' in0 in1 in2 in3 out0 out1 out2 out3 FFmul02 FFmul03 def MixColumns (op : BitVec 128) (FFmul02 : BitVec 8 -> BitVec 8) (FFmul03 : BitVec 8 -> BitVec 8) : BitVec 128 := let in0 := - extractLsb 103 96 op ++ extractLsb 71 64 op ++ - extractLsb 39 32 op ++ extractLsb 7 0 op + extractLsb' 96 8 op ++ extractLsb' 64 8 op ++ + extractLsb' 32 8 op ++ extractLsb' 0 8 op let in1 := - extractLsb 111 104 op ++ extractLsb 79 72 op ++ - extractLsb 47 40 op ++ extractLsb 15 8 op + extractLsb' 104 8 op ++ extractLsb' 72 8 op ++ + extractLsb' 40 8 op ++ extractLsb' 8 8 op let in2 := - extractLsb 119 112 op ++ extractLsb 87 80 op ++ - extractLsb 55 48 op ++ extractLsb 23 16 op + extractLsb' 112 8 op ++ extractLsb' 80 8 op ++ + extractLsb' 48 8 op ++ extractLsb' 16 8 op let in3 := - extractLsb 127 120 op ++ extractLsb 95 88 op ++ - extractLsb 63 56 op ++ extractLsb 31 24 op + extractLsb' 120 8 op ++ extractLsb' 88 8 op ++ + extractLsb' 56 8 op ++ extractLsb' 24 8 op let (out0, out1, out2, out3) := (BitVec.zero 32, BitVec.zero 32, BitVec.zero 32, BitVec.zero 32) let (out0, out1, out2, out3) := - MixColumns_aux 0 in0 in1 in2 in3 out0 out1 out2 out3 FFmul02 FFmul03 - extractLsb 31 24 out3 ++ extractLsb 31 24 out2 ++ - extractLsb 31 24 out1 ++ extractLsb 31 24 out0 ++ - extractLsb 23 16 out3 ++ extractLsb 23 16 out2 ++ - extractLsb 23 16 out1 ++ extractLsb 23 16 out0 ++ - extractLsb 15 8 out3 ++ extractLsb 15 8 out2 ++ - extractLsb 15 8 out1 ++ extractLsb 15 8 out0 ++ - extractLsb 7 0 out3 ++ extractLsb 7 0 out2 ++ - extractLsb 7 0 out1 ++ extractLsb 7 0 out0 + MixColumns_aux 4 in0 in1 in2 in3 out0 out1 out2 out3 FFmul02 FFmul03 + extractLsb' 24 8 out3 ++ extractLsb' 24 8 out2 ++ + extractLsb' 24 8 out1 ++ extractLsb' 24 8 out0 ++ + extractLsb' 16 8 out3 ++ extractLsb' 16 8 out2 ++ + extractLsb' 16 8 out1 ++ extractLsb' 16 8 out0 ++ + extractLsb' 8 8 out3 ++ extractLsb' 8 8 out2 ++ + extractLsb' 8 8 out1 ++ extractLsb' 8 8 out0 ++ + extractLsb' 0 8 out3 ++ extractLsb' 0 8 out2 ++ + extractLsb' 0 8 out1 ++ extractLsb' 0 8 out0 end AESCommon diff --git a/Specs/AESV8.lean b/Specs/AESV8.lean index 75ecbf5d..756acae8 100644 --- a/Specs/AESV8.lean +++ b/Specs/AESV8.lean @@ -113,8 +113,7 @@ def AESHWCtr32EncryptBlocks_helper {Param : AESArm.KBR} (in_blocks : BitVec m) let lo := m - (i + 1) * 128 let hi := lo + 127 have h5 : hi - lo + 1 = 128 := by omega - let curr_block : BitVec 128 := - BitVec.cast h5 $ BitVec.extractLsb hi lo in_blocks + let curr_block : BitVec 128 := BitVec.extractLsb' lo 128 in_blocks have h4 : 128 = Param.block_size := by cases h3 · rename_i h; simp only [h, AESArm.AES128KBR, AESArm.BlockSize] diff --git a/Specs/GCM.lean b/Specs/GCM.lean index 04f50184..b7496b62 100644 --- a/Specs/GCM.lean +++ b/Specs/GCM.lean @@ -22,16 +22,11 @@ def R : (BitVec 128) := 0b11100001#8 ++ 0b0#120 abbrev Cipher {n : Nat} {m : Nat} := BitVec n → BitVec m → BitVec n /-- The s-bit incrementing function -/ -def inc_s (s : Nat) (X : BitVec l) (H₀ : 0 < s) (H₁ : s < l) : BitVec l := - let msb_hi := l - 1 - let msb_lo := s - let lsb_hi := s - 1 - let lsb_lo := 0 - have h₁ : lsb_hi - lsb_lo + 1 = s := by omega - let upper := extractLsb msb_hi msb_lo X - let lower := BitVec.cast h₁ (extractLsb lsb_hi lsb_lo X) + 0b1#s - have h₂ : msb_hi - msb_lo + 1 + s = l := by omega - BitVec.cast h₂ (upper ++ lower) +def inc_s (s : Nat) (X : BitVec l) (H₀ : s < l) : BitVec l := + let upper := extractLsb' s (l - s) X + let lower := (extractLsb' 0 s X) + 0b1#s + have h : l - s + s = l := by omega + (upper ++ lower).cast h def mul_aux (i : Nat) (X : BitVec 128) (Z : BitVec 128) (V : BitVec 128) : BitVec 128 := @@ -55,10 +50,8 @@ def GHASH_aux (i : Nat) (H : BitVec 128) (X : BitVec n) (Y : BitVec 128) Y else let lo := (n/128 - 1 - i) * 128 - let hi := lo + 127 - have h₀ : hi - lo + 1 = 128 := by omega - let Xi := extractLsb hi lo X - let res := Y ^^^ (BitVec.cast h₀ Xi) + let Xi := extractLsb' lo 128 X + let res := Y ^^^ Xi let Y := mul res H GHASH_aux (i + 1) H X Y h termination_by (n / 128 - i) @@ -75,11 +68,11 @@ def GCTR_aux (CIPH : Cipher (n := 128) (m := m)) else let lo := (n - i - 1) * 128 let hi := lo + 127 - have h : hi - lo + 1 = 128 := by omega - let Xi := extractLsb hi lo X - let Yi := BitVec.cast h Xi ^^^ CIPH ICB K - let Y := BitVec.partInstall hi lo (BitVec.cast h.symm Yi) Y - let ICB := inc_s 32 ICB (by omega) (by omega) + have h : 128 = hi - lo + 1 := by omega + let Xi := extractLsb' lo 128 X + let Yi := Xi ^^^ CIPH ICB K + let Y := BitVec.partInstall hi lo (BitVec.cast h Yi) Y + let ICB := inc_s 32 ICB (by omega) GCTR_aux CIPH (i + 1) n K ICB X Y termination_by (n - i) @@ -137,7 +130,7 @@ def GCM_AE (CIPH : Cipher (n := 128) (m := m)) : (BitVec p) × (BitVec t) := let H := CIPH (BitVec.zero 128) K let J0 : BitVec 128 := GCM.initialize_J0 H IV - let ICB := inc_s 32 J0 (by decide) (by decide) + let ICB := inc_s 32 J0 (by decide) let C := GCTR (m := m) CIPH K ICB P let u := GCM.ceiling_in_bits p - p let v := GCM.ceiling_in_bits a - a @@ -172,7 +165,7 @@ def GCM_AD (CIPH : Cipher (n := 128) (m := m)) else let H := CIPH (BitVec.zero 128) K let J0 := GCM.initialize_J0 H IV - let ICB := inc_s 32 J0 (by decide) (by decide) + let ICB := inc_s 32 J0 (by decide) let P := GCTR (m := m) CIPH K ICB C let u := GCM.ceiling_in_bits c - c let v := GCM.ceiling_in_bits a - a diff --git a/Specs/GCMV8.lean b/Specs/GCMV8.lean index 7483e332..8b51ae15 100644 --- a/Specs/GCMV8.lean +++ b/Specs/GCMV8.lean @@ -21,10 +21,10 @@ open BitVec ------------------------------------------------------------------------------ def hi (x : BitVec 128) : BitVec 64 := - extractLsb 127 64 x + extractLsb' 64 64 x def lo (x : BitVec 128) : BitVec 64 := - extractLsb 63 0 x + extractLsb' 0 64 x ------------------------------------------------------------------------------ -- Functions related to galois field operations diff --git a/Tactics/Aggregate.lean b/Tactics/Aggregate.lean index cdb2f6ca..19afad66 100644 --- a/Tactics/Aggregate.lean +++ b/Tactics/Aggregate.lean @@ -6,6 +6,7 @@ Author(s): Alex Keizer, Siddharth Bhat import Lean import Tactics.Common import Tactics.Simp +import Tactics.Sym.LCtxSearch open Lean Meta Elab.Tactic @@ -65,44 +66,46 @@ elab "sym_aggregate" simpConfig?:(config)? loc?:(location)? : tactic => withMain let simpConfig? ← simpConfig?.mapM fun cfg => elabSimpConfig (mkNullNode #[cfg]) (kind := .simp) - let lctx ← getLCtx - -- We keep `expectedRead`/`expectedAlign` as monadic values, - -- so that we get new metavariables for each localdecl we check - let expectedRead : MetaM Expr := do - let fld ← mkFreshExprMVar (mkConst ``StateField) - let state ← mkFreshExprMVar mkArmState - let rhs ← mkFreshExprMVar none - mkEq (mkApp2 (mkConst ``r) fld state) rhs - let expectedReadMem : MetaM Expr := do - let n ← mkFreshExprMVar (mkConst ``Nat) - let addr ← mkFreshExprMVar (mkApp (mkConst ``BitVec) (toExpr 64)) - let mem ← mkFreshExprMVar (mkConst ``Memory) - let rhs ← mkFreshExprMVar none - mkEq (mkApp3 (mkConst ``Memory.read_bytes) n addr mem) rhs - let expectedAlign : MetaM Expr := do - let state ← mkFreshExprMVar mkArmState - return mkApp (mkConst ``CheckSPAlignment) state + /- + We construct `axHyps` by running a `State` monad, which is + initialized with an empty array + -/ + let ((), axHyps) ← StateT.run (s := #[]) <| + searchLCtx <| do + let whenFound := fun decl _ => do + -- Whenever a match is found, we add the corresponding declaration + -- to the `axHyps` array in the monadic state + modify (·.push decl) + return .continu - let axHyps ← - withTraceNode `Tactic.sym (fun _ => pure m!"searching for effect hypotheses") <| - lctx.foldlM (init := #[]) fun axHyps decl => do - forallTelescope decl.type <| fun _ type => do - trace[Tactic.sym] "checking {decl.toExpr} with type {type}" - let expectedRead ← expectedRead - let expectedAlign ← expectedAlign - let expectedReadMem ← expectedReadMem - if ← isDefEq type expectedRead then - trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedRead}" - return axHyps.push decl - else if ← isDefEq type expectedAlign then - trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedAlign}" - return axHyps.push decl - else if ← isDefEq type expectedReadMem then - trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedReadMem}" - return axHyps.push decl - else - trace[Tactic.sym] "{Lean.crossEmoji} no match" - return axHyps + -- `r ?field ?state = ?rhs` + searchLCtxFor (whenFound := whenFound) + /- By matching under binders, this also matches for non-effect + hypotheses, which look like: + `∀ f, f ≠ _ → r f ?state = ?rhs` + -/ + (matchUnderBinders := true) + (expectedType := do + let fld ← mkFreshExprMVar (mkConst ``StateField) + let state ← mkFreshExprMVar mkArmState + let rhs ← mkFreshExprMVar none + return mkEqReadField fld state rhs + ) + -- `Memory.read_bytes ?n ?addr ?mem = ?rhs` + searchLCtxFor (whenFound := whenFound) + (matchUnderBinders := true) + (expectedType := do + let n ← mkFreshExprMVar (mkConst ``Nat) + let addr ← mkFreshExprMVar (mkApp (mkConst ``BitVec) (toExpr 64)) + let mem ← mkFreshExprMVar (mkConst ``Memory) + let rhs ← mkFreshExprMVar none + mkEq (mkApp3 (mkConst ``Memory.read_bytes) n addr mem) rhs + ) + -- `CheckSpAlignment ?state` + searchLCtxFor (whenFound := whenFound) + (expectedType := do + let state ← mkFreshExprMVar mkArmState + return mkApp (mkConst ``CheckSPAlignment) state) let loc := (loc?.map expandLocation).getD (.targets #[] true) aggregate axHyps loc simpConfig? diff --git a/Tactics/Attr.lean b/Tactics/Attr.lean index e6b682e8..37026ee9 100644 --- a/Tactics/Attr.lean +++ b/Tactics/Attr.lean @@ -9,8 +9,11 @@ open Lean initialize -- CSE tactic's non-verbose summary logging. registerTraceClass `Tactic.cse.summary + -- enable tracing for `sym_n` tactic and related components registerTraceClass `Tactic.sym + -- enable verbose tracing + registerTraceClass `Tactic.sym.debug -- enable tracing for heartbeat usage of `sym_n` registerTraceClass `Tactic.sym.heartbeats diff --git a/Tactics/Common.lean b/Tactics/Common.lean index ac1432ee..50212ce9 100644 --- a/Tactics/Common.lean +++ b/Tactics/Common.lean @@ -150,7 +150,7 @@ def reflectPFLag (e : Expr) : MetaM PFlag := /-- Reflect a concrete `StateField` -/ def reflectStateField (e : Expr) : MetaM StateField := - match_expr e with + match_expr e.consumeMData with | StateField.GPR x => StateField.GPR <$> reflectBitVecLiteral _ x | StateField.SFP x => StateField.SFP <$> reflectBitVecLiteral _ x | StateField.PC => pure StateField.PC @@ -242,8 +242,7 @@ def findProgramHyp (state : Expr) : MetaM (LocalDecl × Name) := do -- Assert that `program` is a(n application of a) constant, and find its name let program := (← instantiateMVars program).getAppFn let .const program _ := program - | -- withErrorContext h_run h_run_type <| - throwError "Expected a constant, found:\n\t{program}" + | throwError "Expected a constant, found:\n\t{program}" return ⟨h_program, program⟩ @@ -260,6 +259,24 @@ def mkEqArmState (x y : Expr) : Expr := def mkEqReflArmState (x : Expr) : Expr := mkApp2 (.const ``Eq.refl [1]) mkArmState x +/-- Return `x = y` given expressions `x, y : StateField ` -/ +def mkEqStateValue (field x y : Expr) : Expr := + let ty := mkApp (mkConst ``state_value) field + mkApp3 (.const ``Eq [1]) ty x y + +/-- Return `r = ` -/ +def mkEqReadField (field state value : Expr) : Expr := + let r := mkApp2 (mkConst ``r) field state + mkEqStateValue field r value + +/-- If expression `e` is `r ?field ?state = ?value`, return +`some (field, state, value)`, else return `none` -/ +def Lean.Expr.eqReadField? (e : Expr) : Option (Expr × Expr × Expr) := do + let (_ty, lhs, value) ← e.eq? + let_expr r field state := lhs + | none + some (field, state, value) + /-! ## Tracing helpers -/ def traceHeartbeats (cls : Name) (header : Option String := none) : @@ -278,3 +295,6 @@ variable {m} [Monad m] [MonadLiftT TacticM m] [MonadControlT MetaM m] in Unlike the standard `withMainContext`, `x` may live in a generic monad `m`. -/ def withMainContext' (x : m α) : m α := do (← getMainGoal).withContext x + +/-- An emoji to show that a tactic is processing at an intermediate step. -/ +def processingEmoji : String := "⚙️" diff --git a/Tactics/StepThms.lean b/Tactics/StepThms.lean index 53dd2e56..2f034839 100644 --- a/Tactics/StepThms.lean +++ b/Tactics/StepThms.lean @@ -204,6 +204,7 @@ def genStepEqTheorems : StepThmsM Unit := do name, type, value, levelParams := [] } + trace[gen_step.print_names] "[genStepEqTheorems] Theorem added: {name}" trace[gen_step.debug.timing] "[genStepEqTheorems] added to environment in: {(← IO.monoMsNow) - startTime}ms" /-- `#genProgramInfo program` ensures the `ProgramInfo` for `program` diff --git a/Tactics/Sym.lean b/Tactics/Sym.lean index 6c972b6e..a476f0d5 100644 --- a/Tactics/Sym.lean +++ b/Tactics/Sym.lean @@ -11,15 +11,16 @@ import Tactics.Sym.Context import Lean open BitVec -open Lean Meta -open Lean.Elab.Tactic +open Lean +open Lean.Meta Lean.Elab.Tactic open AxEffects SymContext +open Sym (withTraceNode withVerboseTraceNode) /-- A wrapper around `evalTactic` that traces the passed tactic script, executes those tactics, and then traces the new goal state -/ private def evalTacticAndTrace (tactic : TSyntax `tactic) : TacticM Unit := - withTraceNode `Tactic.sym (fun _ => pure m!"running: {tactic}") <| do + withTraceNode m!"running: {tactic}" <| do evalTactic tactic trace[Tactic.sym] "new goal state:\n{← getGoals}" @@ -50,7 +51,8 @@ to add a new local hypothesis in terms of `w` and `write_mem` `h_step : ?s' = w _ _ (w _ _ (... ?s))` -/ def stepiTac (stepiEq : Expr) (hStep : Name) : SymReaderM Unit := fun ctx => - withMainContext' do + withMainContext' <| + withVerboseTraceNode m!"stepiTac: {stepiEq}" <| do let pc := (Nat.toDigits 16 ctx.pc.toNat).asString -- ^^ The PC in hex let stepLemma := Name.str ctx.program s!"stepi_eq_0x{pc}" @@ -94,8 +96,7 @@ Finally, we use this proof to change the type of `hRun` accordingly. -/ def unfoldRun (whileTac : Unit → TacticM Unit) : SymM Expr := do let c ← readThe SymContext - let msg := m!"unfoldRun (runSteps? := {c.runSteps?})" - withTraceNode `Tactic.sym (fun _ => pure msg) <| + Sym.withTraceNode m!"unfoldRun (runSteps? := {c.runSteps?})" (tag := "unfoldRun") <| match c.runSteps? with | some (n + 1) => do trace[Tactic.sym] "runSteps is statically known to be non-zero, \ @@ -125,9 +126,9 @@ def unfoldRun (whileTac : Unit → TacticM Unit) : SymM Expr := do runStepsEq.setType <| -- `?runSteps = ?runStepsPred + 1` mkApp3 (.const ``Eq [1]) (mkConst ``Nat) runSteps subGoalTyRhs - let msg := m!"runSteps is not statically known, so attempt to prove:\ - {runStepsEq}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| runStepsEq.withContext <| do + Sym.withTraceNode m!"runSteps is not statically known, so attempt to prove:\ + {runStepsEq}" <| + runStepsEq.withContext <| do setGoals [runStepsEq] whileTac () -- run `whileTac` to attempt to close `subGoal` @@ -168,7 +169,8 @@ In that order, it also modifies `hRun` to be of type: ` = hRun _ sn` -/ def initNextStep (whileTac : TSyntax `tactic) : SymM (Expr × Expr) := - withMainContext' do + withMainContext' <| + withTraceNode "initNextStep" (tag := "initNextStep") <| do let goal ← getMainGoal -- Add next state to local context @@ -208,9 +210,10 @@ add the relevant hypotheses to the local context, and store an `AxEffects` object with the newly added variables in the monad state -/ def explodeStep (hStep : Expr) : SymM Unit := - withMainContext' do + withMainContext' <| + withTraceNode m!"explodeStep {hStep}" (tag := "explodeStep") <| do let c ← getThe SymContext - let mut eff ← AxEffects.fromEq hStep + let mut eff ← AxEffects.fromEq hStep c.effects.stackAlignmentProof? let stateExpr ← getCurrentState /- Assert that the initial state of the obtained `AxEffects` is equal to @@ -224,47 +227,26 @@ def explodeStep (hStep : Expr) : SymM Unit := eff ← eff.withProgramEq c.effects.programProof eff ← eff.withField (← c.effects.getField .ERR).proof - if let some h_sp := c.h_sp? then - let hSp ← SymContext.findFromUserName h_sp - -- let effWithSp? - eff ← match ← eff.withStackAlignment? hSp.toExpr with - | some newEff => pure newEff - | none => do - trace[Tactic.sym] "failed to show stack alignment" - -- FIXME: in future, we'd like to detect when the `sp_aligned` - -- hypothesis is actually necessary, and add the proof obligation - -- on-demand. For now, however, we over-approximate, and say that - -- if the original state was known to be aligned, and something - -- writes to the SP, then we eagerly add the obligation to proof - -- that the result is aligned as well. - -- If you don't want this obligation, simply remove the hypothesis - -- that the original state is aligned - let spEff ← eff.getField .SP - let subGoal ← mkFreshMVarId - -- subGoal.setTag <| - let hAligned ← do - let name := Name.mkSimple s!"h_{← getNextStateName}_sp_aligned" - mkFreshExprMVarWithId subGoal (userName := name) <| - mkAppN (mkConst ``Aligned) #[toExpr 64, spEff.value, toExpr 4] - - trace[Tactic.sym] "created subgoal to show alignment:\n{subGoal}" - let subGoal? ← do - let (ctx, simprocs) ← - LNSymSimpContext - (config := {failIfUnchanged := false, decide := true}) - (decls := #[hSp]) - LNSymSimp subGoal ctx simprocs - - if let some subGoal := subGoal? then - trace[Tactic.sym] "subgoal got simplified to:\n{subGoal}" - appendGoals [subGoal] - else - trace[Tactic.sym] "subgoal got closed by simplification" - - let stackAlignmentProof? := some <| - mkAppN (mkConst ``CheckSPAlignment_of_r_sp_aligned) - #[eff.currentState, spEff.value, spEff.proof, hAligned] - pure { eff with stackAlignmentProof? } + if let some hSp := c.effects.stackAlignmentProof? then + withVerboseTraceNode m!"discharging side condiitions" <| do + for subGoal in eff.sideConditions do + trace[Tactic.sym] "attempting to discharge side-condition:\n {subGoal}" + let subGoal? ← do + let (ctx, simprocs) ← + LNSymSimpContext + (config := {failIfUnchanged := false, decide := true}) + (exprs := #[hSp]) + LNSymSimp subGoal ctx simprocs + + if let some subGoal := subGoal? then + trace[Tactic.sym] "subgoal got simplified to:\n{subGoal}" + subGoal.setTag (.mkSimple s!"h_{← getNextStateName}_sp_aligned") + appendGoals [subGoal] + else + trace[Tactic.sym] "subgoal got closed by simplification" + else + appendGoals eff.sideConditions + eff := { eff with sideConditions := [] } -- Add new (non-)effect hyps to the context, and to the aggregation simpset withMainContext' <| do @@ -286,21 +268,24 @@ elab "explode_step" h_step:term " at " state:term : tactic => withMainContext do let .fvar stateFVar := state | throwError "Expected fvar, found {state}" let stateDecl := (← getLCtx).get! stateFVar - let c ← SymContext.fromLocalContext (some stateDecl.userName) + let c ← SymContext.fromMainContext (some stateDecl.userName) let _ ← SymM.run c <| explodeStep hStep /-- Symbolically simulate a single step, according the the symbolic simulation context `c`, returning the context for the next step in simulation. -/ def sym1 (whileTac : TSyntax `tactic) : SymM Unit := do + /- `withCurHeartbeats` sets the initial heartbeats to the current heartbeats, + effectively resetting our heartbeat budget back to the maximum. -/ + withCurrHeartbeats <| do + let stateNumber ← getCurrentStateNumber - let msg := m!"(sym1): simulating step {stateNumber}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext' do - withTraceNode `Tactic.sym (fun _ => pure "verbose context") <| do + withTraceNode m!"(sym1): simulating step {stateNumber}" (tag:="sym1") <| + withMainContext' do + withVerboseTraceNode "verbose context" (tag := "infoDump") <| do traceSymContext trace[Tactic.sym] "Goal state:\n {← getMainGoal}" - let (_sn, stepiEq) ← initNextStep whileTac -- Apply relevant pre-generated `stepi` lemma @@ -313,23 +298,18 @@ def sym1 (whileTac : TSyntax `tactic) : SymM Unit := do -- `simp` here withMainContext' <| do let hStep ← SymContext.findFromUserName h_step.getId - let lctx ← getLCtx - let decls := (← getThe SymContext).h_sp?.bind lctx.findFromUserName? - let decls := decls.toArray - -- If we know SP is aligned, `simp` with that fact - if !decls.isEmpty then - trace[Tactic.sym] "simplifying {hStep.toExpr} \ - with {decls.map (·.toExpr)}" - -- If `decls` is empty, we have no more knowledge than before, so - -- everything that could've been `simp`ed, already should have been - let some goal ← do - let (ctx, simprocs) ← LNSymSimpContext - (config := {decide := false}) (decls := decls) - let goal ← getMainGoal - LNSymSimp goal ctx simprocs hStep.fvarId - | throwError "internal error: simp closed goal unexpectedly" - replaceMainGoal [goal] + -- If we know SP is aligned, `simp` with that fact + if let some hSp := (← getThe AxEffects).stackAlignmentProof? then + let msg := m!"simplifying {hStep.toExpr} with {hSp}" + withTraceNode msg (tag := "simplifyHStep") <| do + let some goal ← do + let (ctx, simprocs) ← LNSymSimpContext + (config := {decide := false}) (exprs := #[hSp]) + let goal ← getMainGoal + LNSymSimp goal ctx simprocs hStep.fvarId + | throwError "internal error: simp closed goal unexpectedly" + replaceMainGoal [goal] else trace[Tactic.sym] "we have no relevant local hypotheses, \ skipping simplification step" @@ -352,44 +332,46 @@ def sym1 (whileTac : TSyntax `tactic) : SymM Unit := do - log a warning and return `m`, if `runSteps? = some m` and `m < n`, or - return `n` unchanged, otherwise -/ def ensureAtMostRunSteps (n : Nat) : SymM Nat := do - let ctx ← getThe SymContext - match ctx.runSteps? with - | none => pure n - | some runSteps => - if n ≤ runSteps then - pure n - else - withMainContext <| do - let hRun := ctx.hRun - logWarning m!"Symbolic simulation is limited to at most {runSteps} \ - steps, because {hRun} is of type:\n {← inferType hRun}" - pure runSteps - return n + withVerboseTraceNode "" (tag := "ensureAtMostRunSteps") <| do + let ctx ← getThe SymContext + match ctx.runSteps? with + | none => pure n + | some runSteps => + if n ≤ runSteps then + pure n + else + withMainContext <| do + let hRun := ctx.hRun + logWarning m!"Symbolic simulation is limited to at most {runSteps} \ + steps, because {hRun} is of type:\n {← inferType hRun}" + pure runSteps + return n /-- Check that the step-thoerem corresponding to the current PC value exists, and throw a user-friendly error, pointing to `#genStepEqTheorems`, if it does not. -/ -def assertStepTheoremsGenerated : SymM Unit := do - let c ← getThe SymContext - let pc := c.pc.toHexWithoutLeadingZeroes - if !c.programInfo.instructions.contains c.pc then - let pcEff ← AxEffects.getFieldM .PC - throwError "\ - Program {c.program} has no instruction at address {c.pc}. - - We inferred this address as the program-counter from {pcEff.proof}, \ - which has type: - {← inferType pcEff.proof}" - - let step_thm := Name.str c.program ("stepi_eq_0x" ++ pc) - try - let _ ← getConstInfo step_thm - catch err => - throwErrorAt err.getRef "{err.toMessageData}\n -Did you remember to generate step theorems with: - #genStepEqTheorems {c.program}" --- TODO: can we make this error ^^ into a `Try this:` suggestion that --- automatically adds the right command just before the theorem? +def assertStepTheoremsGenerated : SymM Unit := + withVerboseTraceNode "" (tag := "assertStepTheoremsGenerated") <| do + let c ← getThe SymContext + let pc := c.pc.toHexWithoutLeadingZeroes + if !c.programInfo.instructions.contains c.pc then + let pcEff ← AxEffects.getFieldM .PC + throwError "\ + Program {c.program} has no instruction at address {c.pc}. + + We inferred this address as the program-counter from {pcEff.proof}, \ + which has type: + {← inferType pcEff.proof}" + + let step_thm := Name.str c.program ("stepi_eq_0x" ++ pc) + try + let _ ← getConstInfo step_thm + catch err => + throwErrorAt err.getRef "{err.toMessageData}\n + Did you remember to generate step theorems with: + #genStepEqTheorems {c.program}" + -- TODO: can we make this error ^^ into a `Try this:` suggestion that + -- automatically adds the right command just before the theorem? /- used in `sym_n` tactic to specify an initial state -/ syntax sym_at := "at" ident @@ -437,11 +419,8 @@ elab "sym_n" whileTac?:(sym_while)? n:num s:(sym_at)? : tactic => do omega; )) - let c ← withMainContext <| SymContext.fromLocalContext s - SymM.run' c <| do - -- Context preparation - canonicalizeHypothesisTypes - + let c ← SymContext.fromMainContext s + SymM.run' c <| withMainContext' <| do -- Check pre-conditions assertStepTheoremsGenerated let n ← ensureAtMostRunSteps n.getNat @@ -452,36 +431,32 @@ elab "sym_n" whileTac?:(sym_while)? n:num s:(sym_at)? : tactic => do sym1 whileTac traceHeartbeats "symbolic simulation total" - let c ← getThe SymContext - -- Check if we can substitute the final state - if c.runSteps? = some 0 then - let msg := pure m!"runSteps := 0, substituting along {c.hRun}" - withMainContext' <| withTraceNode `Tactic.sym (fun _ => msg) <| do - let sfEq ← mkEq (← getCurrentState) c.finalState - - let goal ← getMainGoal - trace[Tactic.sym] "original goal:\n{goal}" - let ⟨hEqId, goal⟩ ← do - goal.note `this (← mkEqSymm c.hRun) sfEq - goal.withContext <| do - trace[Tactic.sym] "added {← userNameToMessageData `this} of type \ - {sfEq} in:\n{goal}" - - let goal ← subst goal hEqId - trace[Tactic.sym] "performed subsitutition in:\n{goal}" - - replaceMainGoal [goal] - else -- Replace `h_run` in the local context - let goal ← getMainGoal - let res ← goal.replace c.hRunId c.hRun - replaceMainGoal [res.mvarId] - - - -- Rudimentary aggregation: we feed all the axiomatic effect hypotheses - -- added while symbolically evaluating to `simp` - let msg := m!"aggregating (non-)effects" - withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext' do - let goal? ← LNSymSimp (← getMainGoal) c.aggregateSimpCtx c.aggregateSimprocs - replaceMainGoal goal?.toList - - traceHeartbeats "final usage" + withCurrHeartbeats <| + Sym.withTraceNode "Post processing" (tag := "postProccessing") <| do + let c ← getThe SymContext + -- Check if we can substitute the final state + if c.runSteps? = some 0 then + withMainContext' <| + Sym.withTraceNode m!"runSteps := 0, substituting along {c.hRun}" <| do + let sfEq ← mkEq (← getCurrentState) c.finalState + + let goal ← getMainGoal + trace[Tactic.sym] "original goal:\n{goal}" + let ⟨hEqId, goal⟩ ← do + goal.note `this (← mkEqSymm c.hRun) sfEq + goal.withContext <| do + trace[Tactic.sym] "added {← userNameToMessageData `this} of type \ + {sfEq} in:\n{goal}" + + let goal ← subst goal hEqId + trace[Tactic.sym] "performed subsitutition in:\n{goal}" + replaceMainGoal [goal] + + -- Rudimentary aggregation: we feed all the axiomatic effect hypotheses + -- added while symbolically evaluating to `simp` + withMainContext' <| + withTraceNode m!"aggregating (non-)effects" (tag := "aggregateEffects") <| do + let goal? ← LNSymSimp (← getMainGoal) c.aggregateSimpCtx c.aggregateSimprocs + replaceMainGoal goal?.toList + + traceHeartbeats "aggregation" diff --git a/Tactics/Sym/AxEffects.lean b/Tactics/Sym/AxEffects.lean index 2b9a7dcc..36e71f2a 100644 --- a/Tactics/Sym/AxEffects.lean +++ b/Tactics/Sym/AxEffects.lean @@ -8,10 +8,12 @@ import Arm.State import Tactics.Common import Tactics.Attr import Tactics.Simp +import Tactics.Sym.Common import Std.Data.HashMap open Lean Meta +open Sym (withTraceNode withVerboseTraceNode) /-- A reflected `ArmState` field, see `AxEffects.fields` for more context -/ structure AxEffects.FieldEffect where @@ -78,6 +80,10 @@ structure AxEffects where However, if SP is written to, no effort is made to prove alignment of the new value; the field will be set to `none` -/ stackAlignmentProof? : Option Expr + + /-- `sideContitions` are proof obligations that come up during effect + characterization. -/ + sideConditions : List MVarId deriving Repr namespace AxEffects @@ -141,6 +147,7 @@ def initial (state : Expr) : AxEffects where mkConst ``Program, mkApp (mkConst ``ArmState.program) state] stackAlignmentProof? := none + sideConditions := [] /-! ## ToMessageData -/ @@ -165,9 +172,8 @@ instance : ToMessageData AxEffects where }" private def traceCurrentState (eff : AxEffects) - (header : MessageData := "current state") : - MetaM Unit := - withTraceNode `Tactic.sym (fun _ => pure header) do + (header : MessageData := "current state") : MetaM Unit := + withTraceNode header <| do trace[Tactic.sym] "{eff}" /-! ## Helpers -/ @@ -199,7 +205,7 @@ private def rewriteType (e eq : Expr) : MetaM Expr := do by constructing an application of `eff.nonEffectProof` -/ partial def mkAppNonEffect (eff : AxEffects) (field : Expr) : MetaM Expr := do let msg := m!"constructing application of non-effects proof" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode msg (tag := "mkAppNonEffect") <| do trace[Tactic.sym] "nonEffectProof: {eff.nonEffectProof}" let nonEffectProof := mkApp eff.nonEffectProof field @@ -218,8 +224,7 @@ partial def mkAppNonEffect (eff : AxEffects) (field : Expr) : MetaM Expr := do /-- Get the value for a field, if one is stored in `eff.fields`, or assemble an instantiation of the non-effects proof otherwise -/ def getField (eff : AxEffects) (fld : StateField) : MetaM FieldEffect := - let msg := m!"getField {fld}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode m!"getField {fld}" (tag := "getField") <| do eff.traceCurrentState if let some val := eff.fields.get? fld then @@ -231,11 +236,36 @@ def getField (eff : AxEffects) (fld : StateField) : MetaM FieldEffect := let proof ← eff.mkAppNonEffect (toExpr fld) pure { value, proof } -variable {m} [Monad m] [MonadReaderOf AxEffects m] [MonadLiftT MetaM m] in +section Monad +variable {m} [Monad m] [MonadLiftT MetaM m] + +variable [MonadReaderOf AxEffects m] in @[inherit_doc getField] def getFieldM (field : StateField) : m FieldEffect := do (← read).getField field +variable [MonadStateOf AxEffects m] + +/-- Set the effect of a specific field in the monad state, overwriting any +previous value for that field. + +NOTE: the proof in `effect` is assumed to be valid for the current state, +this is not eagerly checked (but the kernel will of course eventually reject +a proof if it used a malformed field-effect; a mallformed proof does not +compromise soundness, but it will cause obscure errors) -/ +def setFieldEffect (field : StateField) (effect : FieldEffect) : m Unit := + modify fun eff => { eff with + fields := eff.fields.insert field effect } + +/-- Given a proof that `r .ERR = None`, set the effect of the +`ERR` field accordingly. + +This is a specialization of `setFieldEffect`. -/ +def setErrorProof (proof : Expr) : m Unit := + setFieldEffect .ERR { value := mkConst ``StateError.None, proof } + +end Monad + /-! ## Update a Reflected State -/ /-- Execute `write_mem ` against the state stored in `eff` @@ -245,9 +275,8 @@ and all other fields are updated accordingly. Note that no effort is made to preserve `currentStateEq`; it is set to `none`! -/ private def update_write_mem (eff : AxEffects) (n addr val : Expr) : - MetaM AxEffects := do - trace[Tactic.sym] "adding write of {n} bytes of value {val} \ - to memory address {addr}" + MetaM AxEffects := + withTraceNode m!"processing: write_mem {n} {addr} {val} …" (tag := "updateWriteMem") <| do -- Update each field let fields ← eff.fields.toList.mapM fun ⟨fld, {value, proof}⟩ => do @@ -279,6 +308,11 @@ private def update_write_mem (eff : AxEffects) (n addr val : Expr) : #[eff.currentState, n, addr, val]) eff.programProof + -- Update the stack alignment proof + let stackAlignmentProof? := eff.stackAlignmentProof?.map fun proof => + mkAppN (mkConst ``CheckSPAligment_write_mem_bytes_of) + #[eff.currentState, n, addr, val, proof] + -- Assemble the result let addWrite (e : Expr) := -- `@write_mem_bytes ` @@ -290,9 +324,9 @@ private def update_write_mem (eff : AxEffects) (n addr val : Expr) : memoryEffect := addWrite eff.memoryEffect memoryEffectProof programProof + stackAlignmentProof? } - withTraceNode `Tactic.sym (fun _ => pure "new state") <| do - trace[Tactic.sym] "{eff}" + eff.traceCurrentState return eff /-- Execute `w ` against the state stored in `eff` @@ -303,8 +337,8 @@ Note that no effort is made to preserve `currentStateEq`; it is set to `none`! -/ private def update_w (eff : AxEffects) (fld val : Expr) : MetaM AxEffects := do + withTraceNode m!"processing: w {fld} {val} …" (tag := "updateWrite") <| do let rField ← reflectStateField fld - trace[Tactic.sym] "adding write of value {val} to register {rField}" -- Update all other fields let fields ← @@ -373,6 +407,24 @@ private def update_w (eff : AxEffects) (fld val : Expr) : (mkAppN (mkConst ``w_program) #[fld, val, eff.currentState]) eff.programProof + -- Update the stack alignment proof + let mut sideConditions := eff.sideConditions + let mut stackAlignmentProof? := eff.stackAlignmentProof? + if let some proof := stackAlignmentProof? then + if rField ≠ StateField.SP then + let hNeq ← mkDecideProof <| + mkApp3 (.const ``Ne [1]) + (mkConst ``StateField) (toExpr StateField.SP) fld + stackAlignmentProof? := mkAppN (mkConst ``CheckSPAligment_w_of_ne_sp_of) + #[fld, eff.currentState, val, hNeq, proof] + else + let hAligned ← mkFreshExprMVar (some <| + mkApp3 (mkConst ``Aligned) (toExpr 64) val (toExpr 4) + ) + sideConditions := hAligned.mvarId! :: sideConditions + stackAlignmentProof? := mkAppN (mkConst ``CheckSPAligment_w_sp_of) + #[val, eff.currentState, hAligned] + -- Assemble the result let eff := { eff with currentState := mkApp3 (mkConst ``w) fld val eff.currentState @@ -381,6 +433,8 @@ private def update_w (eff : AxEffects) (fld val : Expr) : -- memory effects are unchanged memoryEffectProof programProof + stackAlignmentProof? + sideConditions } eff.traceCurrentState "new state" return eff @@ -398,20 +452,27 @@ private def assertIsDefEq (e expected : Expr) : MetaM Unit := do /-- Given an expression `e : ArmState`, which is a sequence of `w`/`write_mem`s to `eff.currentState`, return an `AxEffects` where `e` is the new `currentState`. -/ -partial def updateWithExpr (eff : AxEffects) (e : Expr) : MetaM AxEffects := do - let msg := m!"Updating effects with writes from: {e}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do match_expr e with - | write_mem_bytes n addr val e => - let eff ← eff.updateWithExpr e - eff.update_write_mem n addr val +private partial def updateWithExprAux (eff : AxEffects) (e : Expr) : MetaM AxEffects := do + match_expr e with + | write_mem_bytes n addr val e => + let eff ← eff.updateWithExprAux e + eff.update_write_mem n addr val - | w field value e => - let eff ← eff.updateWithExpr e - eff.update_w field value + | w field value e => + let eff ← eff.updateWithExprAux e + eff.update_w field value - | _ => - assertIsDefEq e eff.currentState - return eff + | _ => + assertIsDefEq e eff.currentState + return eff + +/-- Given an expression `e : ArmState`, +which is a sequence of `w`/`write_mem`s to `eff.currentState`, +return an `AxEffects` where `e` is the new `currentState`. -/ +partial def updateWithExpr (eff : AxEffects) (e : Expr) : MetaM AxEffects := do + let msg := m!"Updating effects with writes from: {e}" + withTraceNode msg (tag := "updateWithExpr") <| + updateWithExprAux eff e /-- Given an expression `e : ArmState`, which is a sequence of `w`/`write_mem`s to the some state `s`, @@ -426,62 +487,69 @@ def fromExpr (e : Expr) : MetaM AxEffects := do let eff ← eff.updateWithExpr e return { eff with initialState := ← instantiateMVars eff.initialState} - /-- Given a proof `eq : s = `, set `s` to be the new `currentState`, and update all proofs accordingly -/ def adjustCurrentStateWithEq (eff : AxEffects) (s eq : Expr) : MetaM AxEffects := do - withTraceNode `Tactic.sym (fun _ => pure "adjusting `currenstState`") do - eff.traceCurrentState + withTraceNode m!"adjustCurrentStateWithEq" (tag := "adjustCurrentStateWithEq") do trace[Tactic.sym] "rewriting along {eq}" + eff.traceCurrentState + assertHasType eq <| mkEqArmState s eff.currentState let eq ← mkEqSymm eq let currentState := s let fields ← eff.fields.toList.mapM fun (field, fieldEff) => do - withTraceNode `Tactic.sym (fun _ => pure m!"rewriting field {field}") do + withTraceNode m!"rewriting field {field}" (tag := "rewriteField") do trace[Tactic.sym] "original proof: {fieldEff.proof}" let proof ← rewriteType fieldEff.proof eq trace[Tactic.sym] "new proof: {proof}" pure (field, {fieldEff with proof}) let fields := .ofList fields - let nonEffectProof ← rewriteType eff.nonEffectProof eq - let memoryEffectProof ← rewriteType eff.memoryEffectProof eq - -- ^^ TODO: what happens if `memoryEffect` is the same as `currentState`? - -- Presumably, we would *not* want to encapsulate `memoryEffect` here - let programProof ← rewriteType eff.programProof eq + withTraceNode m!"rewriting other proofs" (tag := "rewriteMisc") <| do + let nonEffectProof ← rewriteType eff.nonEffectProof eq + let memoryEffectProof ← rewriteType eff.memoryEffectProof eq + -- ^^ TODO: what happens if `memoryEffect` is the same as `currentState`? + -- Presumably, we would *not* want to encapsulate `memoryEffect` here + let programProof ← rewriteType eff.programProof eq + let stackAlignmentProof? ← eff.stackAlignmentProof?.mapM + (rewriteType · eq) - return { eff with - currentState, fields, nonEffectProof, memoryEffectProof, programProof - } + return { eff with + currentState, fields, nonEffectProof, memoryEffectProof, programProof, + stackAlignmentProof? + } /-- Given a proof `eq : ?s = `, where `?s` and `?s0` are arbitrary `ArmState`s, return an `AxEffect` with the rhs of the equality as the current state, and the (non-)effects updated accordingly -/ def updateWithEq (eff : AxEffects) (eq : Expr) : MetaM AxEffects := - let msg := m!"Building effects with equality: {eq}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode m!"Building effects with equality: {eq}" + (tag := "updateWithEq") <| do let s ← mkFreshExprMVar mkArmState let rhs ← mkFreshExprMVar mkArmState assertHasType eq <| mkEqArmState s rhs let eff ← eff.updateWithExpr (← instantiateMVars rhs) let eff ← eff.adjustCurrentStateWithEq s eq - withTraceNode `Tactic.sym (fun _ => pure "new state") do - trace[Tactic.sym] "{eff}" + eff.traceCurrentState "new state" return eff /-- Given a proof `eq : ?s = `, where `?s` and `?s0` are arbitrary `ArmState`s, return an `AxEffect` with `?s0` as the initial state, the rhs of the equality as the current state, -and the (non-)effects updated accordingly -/ -def fromEq (eq : Expr) : MetaM AxEffects := do +and the (non-)effects updated accordingly + +One can optionally pass in a proof that `?s0` has a well-aligned stack pointer. +-/ +def fromEq (eq : Expr) (stackAlignmentProof? : Option Expr := none) : + MetaM AxEffects := do let s0 ← mkFreshExprMVar mkArmState - let eff := initial s0 + let eff := { initial s0 with stackAlignmentProof? } let eff ← eff.updateWithEq eq return { eff with initialState := ← instantiateMVars eff.initialState} @@ -509,8 +577,7 @@ Note: throws an error when `initialState = currentState` *and* the field already has a value stored, as the rewrite might produce expressions of unexpected types. -/ def withField (eff : AxEffects) (eq : Expr) : MetaM AxEffects := do - let msg := m!"withField {eq}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode m!"withField {eq}" (tag := "withField") <| do eff.traceCurrentState let fieldE ← mkFreshExprMVar (mkConst ``StateField) let value ← mkFreshExprMVar none @@ -545,34 +612,6 @@ def withField (eff : AxEffects) (eq : Expr) : MetaM AxEffects := do let fields := eff.fields.insert field { value, proof } return { eff with fields } -/-- Given a proof of `CheckSPAlignment `, -attempt to transport it to a proof of `CheckSPAlignment ` -and store that proof in `stackAlignmentProof?`. - -Returns `none` if the proof failed to be transported, -i.e., if SP was written to. -/ -def withStackAlignment? (eff : AxEffects) (spAlignment : Expr) : - MetaM (Option AxEffects) := do - let msg := m!"withInitialStackAlignment? {spAlignment}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do - eff.traceCurrentState - - let { value, proof } ← eff.getField StateField.SP - let expected := - mkApp2 (mkConst ``r) (toExpr <| StateField.SP) eff.initialState - trace[Tactic.sym] "checking whether value:\n {value}\n\ - is syntactically equal to expected value\n {expected}" - if value != expected then - trace[Tactic.sym] "failed to transport proof: - expected value to be {expected}, but found {value}" - return none - - let stackAlignmentProof? := some <| - mkAppN (mkConst ``CheckSPAlignment_of_r_sp_eq) - #[eff.initialState, eff.currentState, proof, spAlignment] - trace[Tactic.sym] "constructed stackAlignmentProof: {stackAlignmentProof?}" - return some { eff with stackAlignmentProof? } - /-! ## Composition -/ /- TODO: write a function that combines two effects `left` and `right`, @@ -596,8 +635,8 @@ NOTE: does not necessarily validate *which* type an expression has, validation will still pass if types are different to those we claim in the docstrings -/ def validate (eff : AxEffects) : MetaM Unit := do - let msg := "validating that the axiomatic effects are well-formed" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode "validating that the axiomatic effects are well-formed" + (tag := "validate") <| do eff.traceCurrentState assertHasType eff.initialState mkArmState @@ -632,8 +671,8 @@ that was just added to the local context -/ def addHypothesesToLContext (eff : AxEffects) (hypPrefix : String := "h_") (mvar : Option MVarId := none) : TacticM AxEffects := - let msg := m!"adding hypotheses to local context" - withTraceNode `Tactic.sym (fun _ => pure msg) do + withTraceNode m!"adding hypotheses to local context" + (tag := "addHypothesesToLContext") do eff.traceCurrentState let mut goal ← mvar.getDM getMainGoal @@ -696,8 +735,8 @@ where replaceOrNote (goal : MVarId) (h : Name) (v : Expr) (t? : Option Expr := none) : MetaM (FVarId × MVarId) := - let msg := m!"adding {h} to the local context" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode m!"adding {h} to the local context" + (tag := "replaceOrNote") <| do trace[Tactic.sym] "with value {v} and type {t?}" if let some decl := (← getLCtx).findFromUserName? h then let ⟨fvar, goal, _⟩ ← goal.replace decl.fvarId v t? @@ -709,8 +748,8 @@ where /-- Return an array of `SimpTheorem`s of the proofs contained in the given `AxEffects` -/ def toSimpTheorems (eff : AxEffects) : MetaM (Array SimpTheorem) := do - let msg := m!"computing SimpTheorems for (non-)effect hypotheses" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode m!"computing SimpTheorems for (non-)effect hypotheses" + (tag := "toSimpTheorems") <| do let lctx ← getLCtx let baseName? := if eff.currentState.isFVar then @@ -722,8 +761,7 @@ def toSimpTheorems (eff : AxEffects) : MetaM (Array SimpTheorem) := do let add (thms : Array SimpTheorem) (e : Expr) (name : String) (prio : Nat := 1000) := - let msg := m!"adding {e} with name {name}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do + withTraceNode m!"adding {e} with name {name}" <| do let origin : Origin := if e.isFVar then .fvar e.fvarId! diff --git a/Tactics/Sym/Common.lean b/Tactics/Sym/Common.lean new file mode 100644 index 00000000..78e7b823 --- /dev/null +++ b/Tactics/Sym/Common.lean @@ -0,0 +1,30 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Lean + +open Lean + +namespace Sym + +/-! ## Trace Nodes -/ +section Tracing +variable {α : Type} {m : Type → Type} [Monad m] [MonadTrace m] [MonadLiftT IO m] + [MonadRef m] [AddMessageContext m] [MonadOptions m] {ε : Type} + [MonadAlwaysExcept ε m] [MonadLiftT BaseIO m] + +def withTraceNode (msg : MessageData) (k : m α) + (collapsed : Bool := true) + (tag : String := "") + : m α := do + Lean.withTraceNode `Tactic.sym (fun _ => pure msg) k collapsed tag + +def withVerboseTraceNode (msg : MessageData) (k : m α) + (collapsed : Bool := true) + (tag : String := "") + : m α := do + Lean.withTraceNode `Tactic.sym.verbose (fun _ => pure msg) k collapsed tag + +end Tracing diff --git a/Tactics/Sym/Context.lean b/Tactics/Sym/Context.lean index 9eecc325..37d844cf 100644 --- a/Tactics/Sym/Context.lean +++ b/Tactics/Sym/Context.lean @@ -9,10 +9,13 @@ import Lean.Meta import Arm.Exec import Tactics.Common import Tactics.Attr +import Tactics.Sym.Common import Tactics.Sym.ProgramInfo import Tactics.Sym.AxEffects +import Tactics.Sym.LCtxSearch import Tactics.Simp + /-! This files defines the `SymContext` structure, which collects the names of various @@ -31,6 +34,7 @@ and is likely to be deprecated and removed in the near future. -/ open Lean Meta Elab.Tactic open BitVec +open Sym (withTraceNode withVerboseTraceNode) /-- A `SymContext` collects the names of various variables/hypotheses in the local context required for symbolic evaluation -/ @@ -78,9 +82,6 @@ structure SymContext where and we assume that no overflow happens (i.e., `base - x` can never be equal to `base + y`) -/ pc : BitVec 64 - /-- `h_sp?`, if present, is a local hypothesis of the form - `CheckSPAlignment state` -/ - h_sp? : Option Name /-- The `simp` context used for effect aggregation. This collects references to all (non-)effect hypotheses of the intermediate @@ -160,10 +161,11 @@ def program : Name := c.programInfo.name /-- Find the local declaration that corresponds to a given name, or throw an error if no local variable of that name exists -/ -def findFromUserName (name : Name) : MetaM LocalDecl := do - let some decl := (← getLCtx).findFromUserName? name - | throwError "Unknown local variable `{name}`" - return decl +def findFromUserName (name : Name) : MetaM LocalDecl := + withVerboseTraceNode m!"[findFromUserName] {name}" <| do + let some decl := (← getLCtx).findFromUserName? name + | throwError "Unknown local variable `{name}`" + return decl section Monad variable {m} [Monad m] [MonadReaderOf SymContext m] @@ -213,29 +215,36 @@ end /-- Convert a `SymContext` to `MessageData` for tracing. This is not a `ToMessageData` instance because we need access to `MetaM` -/ def toMessageData (c : SymContext) : MetaM MessageData := do - let h_sp? ← c.h_sp?.mapM userNameToMessageData - return m!"\{ finalState := {c.finalState}, runSteps? := {c.runSteps?}, hRun := {c.hRun}, program := {c.program}, pc := {c.pc}, - h_sp? := {h_sp?}, state_prefix := {c.state_prefix}, curr_state_number := {c.currentStateNumber}, effects := {c.effects} }" -variable {α : Type} {m : Type → Type} [Monad m] [MonadTrace m] [MonadLiftT IO m] - [MonadRef m] [AddMessageContext m] [MonadOptions m] {ε : Type} - [MonadAlwaysExcept ε m] [MonadLiftT BaseIO m] in -def withSymTraceNode (msg : MessageData) (k : m α) : m α := do - withTraceNode `Tactic.sym (fun _ => pure msg) k - def traceSymContext : SymM Unit := - withTraceNode `Tactic.sym (fun _ => pure m!"SymContext: ") <| do + withTraceNode m!"SymContext: " <| do let m ← (← getThe SymContext).toMessageData trace[Tactic.sym] m +/-! ## Adding new simp theorems for aggregation -/ + +/-- Add a set of new simp-theorems to the simp-theorems used +for effect aggregation -/ +def addSimpTheorems (c : SymContext) (simpThms : Array SimpTheorem) : SymContext := + let addSimpThms := simpThms.foldl addSimpTheoremEntry + + let oldSimpTheorems := c.aggregateSimpCtx.simpTheorems + let simpTheorems := + if oldSimpTheorems.isEmpty then + oldSimpTheorems.push <| addSimpThms {} + else + oldSimpTheorems.modify (oldSimpTheorems.size - 1) addSimpThms + + { c with aggregateSimpCtx.simpTheorems := simpTheorems } + /-! ## Creating initial contexts -/ /-- Modify a `SymContext` with a monadic action `k : SymM Unit` -/ @@ -243,6 +252,34 @@ def modify (ctxt : SymContext) (k : SymM Unit) : TacticM SymContext := do let ((), ctxt) ← SymM.run ctxt k return ctxt +private def initial (state : Expr) : MetaM SymContext := do + /- Create an mvar for the final state -/ + let finalState ← mkFreshExprMVar mkArmState + /- Get the default simp lemmas & simprocs for aggregation -/ + let (aggregateSimpCtx, aggregateSimprocs) ← + LNSymSimpContext (config := {decide := true, failIfUnchanged := false}) + let aggregateSimpCtx := { aggregateSimpCtx with + -- Create a new discrtree for effect hypotheses to be added to. + -- TODO(@alexkeizer): I put this here, since the previous version kept + -- a seperate discrtree for lemmas. I should run benchmarks to see what + -- happens if we keep everything in one simpset. + simpTheorems := aggregateSimpCtx.simpTheorems.push {} + } + return { + finalState + runSteps? := none + hRun :=← mkFreshExprMVar none + hRunId := ⟨`dummyValue⟩ + programInfo := { + name := `dummyValue + instructions := ∅ + } + pc := 0 + aggregateSimpCtx, + aggregateSimprocs, + effects := AxEffects.initial state + } + /-- Infer `state_prefix` and `curr_state_number` from the `state` name as follows: if `state` is `s{i}` for some number `i` and a single character `s`, then `s` is the prefix and `i` the number, @@ -277,16 +314,140 @@ private def withErrorContext (name : Name) (type? : Option Expr) (k : MetaM α) | none => m!"" throwErrorAt e.getRef "{e.toMessageData}\n\nIn {h}{type}" -/-- Build a `SymContext` by searching the local context for hypotheses of the -required types (up-to defeq) . The local context is modified to unfold the types -to be syntactically equal to the expected type. + +/-- Build the lazy search structure (for use with `searchLCtx`) +to populate the `SymContext` state from the local context. + +NOTE: some search might be performed eagerly. The resulting search structure +is tied to the specific `SymM` state and local context, it's expected that +neither of these change between construction and execution of the search. -/ +protected def searchFor : SearchLCtxForM SymM Unit := do + let c ← getThe SymContext + let currentState ← AxEffects.getCurrentState + + /- We start by doing an eager search for `h_run`, outside the `SearchLCxtForM` + monad. This is needed to instantiate the initial state -/ + let runSteps ← mkFreshExprMVar (Expr.const ``Nat []) + let hRunType := h_run_type c.finalState runSteps currentState + let some hRun ← findLocalDeclOfType? hRunType + | throwNotFound hRunType + + let runSteps? ← reflectNatLiteral? runSteps + if runSteps?.isNone then + trace[Tactic.sym] "failed to reflect {runSteps} \ + (from {hRun.toExpr} : {hRun.type})" + + modifyThe SymContext ({ · with + hRun := hRun.toExpr + hRunId := hRun.fvarId + finalState := ←instantiateMVars c.finalState + runSteps? + }) + + /- + From here on out, we're building the lazy search patterns + -/ + -- Find `h_program : .program = ` + let program ← mkFreshExprMVar none + searchLCtxForOnce (h_program_type currentState program) + (whenNotFound := throwNotFound) + (whenFound := fun decl _ => do + -- Register the program proof + modifyThe AxEffects ({· with + programProof := decl.toExpr + }) + -- Assert that `program` is a(n application of a) constant + let program := (← instantiateMVars program).getAppFn + let .const program _ := program + | throwError "Expected a constant, found:\n\t{program}" + -- Retrieve the programInfo from the environment + let some programInfo ← ProgramInfo.lookup? program + | throwError "Could not find program info for `{program}`.\n\ + Did you remember to generate step theorems with:\n \ + #generateStepEqTheorems {program}" + modifyThe SymContext ({· with + programInfo + }) + ) + + -- Find `h_pc : r .PC = ` + let pc ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64]) + searchLCtxForOnce (h_pc_type currentState pc) + (changeType := true) + (whenNotFound := throwNotFound) + (whenFound := fun decl _ => do + let pc ← instantiateMVars pc + -- Set the field effect + AxEffects.setFieldEffect .PC { + value := pc + proof := decl.toExpr + } + -- Then, reflect the value + let pc ← withErrorContext decl.userName decl.type <| + reflectBitVecLiteral 64 pc + modifyThe SymContext ({ · with pc }) + ) + + -- Find `h_err : r .ERR = .None`, or add a new subgoal if it isn't found + searchLCtxForOnce (h_err_type currentState) + (changeType := true) + (whenFound := fun decl _ => + AxEffects.setErrorProof decl.toExpr + ) + (whenNotFound := fun _ => do + let errHyp ← mkFreshExprMVar (h_err_type currentState) + replaceMainGoal [← getMainGoal, errHyp.mvarId!] + AxEffects.setErrorProof errHyp + ) + + -- Find `h_sp : CheckSPAlignment `. + searchLCtxForOnce (h_sp_type currentState) + (changeType := true) + (whenNotFound := traceNotFound `Tactic.sym) + -- ^^ Note that `h_sp` is optional, so there's no need to throw an error, + -- we merely add a message to the trace and move on + (whenFound := fun decl _ => do + modifyThe AxEffects ({ · with + stackAlignmentProof? := some decl.toExpr + }) + ) + + -- Find `r ?field currentState = ?value` + -- NOTE: this HAS to come after the search for specific fields, like `h_pc`, + -- or `h_err`, to ensure those take priority and the special handling + -- of those fields gets applied. + searchLCtxFor + (expectedType := do + let field ← mkFreshExprMVar (mkConst ``StateField) + let value ← mkFreshExprMVar none + return mkEqReadField field currentState value + ) + (whenFound := fun decl ty => do + let some (field, _state, value) := ty.eqReadField? + | throwError "internal error: unexpected type:\n {ty}" + + let field ← reflectStateField (← instantiateMVars field) + AxEffects.setFieldEffect field { + value := ←instantiateMVars value, + proof := decl.toExpr + } + return .continu + ) + /- TODO(@alexkeizer): Should we search for memory as well? + Probably we can only do so after the memoryProof refactor -/ + return () + +/-- Build a `SymContext` by searching the local context of the main goal for +hypotheses of the required types (up-to defeq). +The local context is modified to unfold the types to be syntactically equal to +the expected types. If an hypothesis `h_err : r .ERR = None` is not found, -we create a new subgoal of this type +we create a new subgoal of this type. -/ -def fromLocalContext (state? : Option Name) : TacticM SymContext := do +def fromMainContext (state? : Option Name) : TacticM SymContext := do let msg := m!"Building a `SymContext` from the local context" - withTraceNode `Tactic.sym (fun _ => pure msg) do + withTraceNode msg (tag := "fromMainContext") <| withMainContext' do trace[Tactic.Sym] "state? := {state?}" let lctx ← getLCtx @@ -299,134 +460,16 @@ def fromLocalContext (state? : Option Name) : TacticM SymContext := do pure (Expr.fvar decl.fvarId) | none => mkFreshExprMVar (Expr.const ``ArmState []) - -- Find `h_run` - let finalState ← mkFreshExprMVar none - let runSteps ← mkFreshExprMVar (Expr.const ``Nat []) - let h_run ← - findLocalDeclOfTypeOrError <| h_run_type finalState runSteps stateExpr - - -- Unwrap and reflect `runSteps` - let runSteps? ← do - let msg := m!"Reflecting: {runSteps}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do - let runSteps? ← reflectNatLiteral? runSteps - trace[Tactic.sym] "got: {runSteps?}" - pure runSteps? - let finalState ← instantiateMVars finalState - - -- At this point, `stateExpr` should have been assigned (if it was an mvar), - -- so we can unwrap it to get the underlying name - let stateExpr ← instantiateMVars stateExpr - - -- Try to find `h_program`, and infer `program` from it - let ⟨h_program, program⟩ ← withErrorContext h_run.userName h_run.type <| - findProgramHyp stateExpr - - -- Then, try to find `h_pc` - let pcE ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64]) - let h_pc ← findLocalDeclOfTypeOrError <| h_pc_type stateExpr pcE - - -- Unwrap and reflect `pc` - let pcE ← instantiateMVars pcE - let pc ← withErrorContext h_pc.userName h_pc.type <| - reflectBitVecLiteral 64 pcE - - -- Attempt to find `h_err`, adding a new subgoal if it couldn't be found - let errHyp ← do - let h_err? ← findLocalDeclOfType? (h_err_type stateExpr) - match h_err? with - | some d => pure d.toExpr - | none => do - let errHyp ← mkFreshExprMVar (h_err_type stateExpr) - replaceMainGoal [← getMainGoal, errHyp.mvarId!] - pure errHyp - - let h_sp? ← findLocalDeclOfType? (h_sp_type stateExpr) - if h_sp?.isNone then - trace[Sym] "Could not find local hypothesis of type {h_sp_type stateExpr}" - - -- Finally, retrieve the programInfo from the environment - let some programInfo ← ProgramInfo.lookup? program - | throwError "Could not find program info for `{program}`. - Did you remember to generate step theorems with: - #generateStepEqTheorems {program}" - - -- Initialize the axiomatic hypotheses with hypotheses for the initial state - let axHyps := #[h_program, h_pc] ++ h_sp?.toArray - let (aggregateSimpCtx, aggregateSimprocs) ← - LNSymSimpContext - (config := {decide := true, failIfUnchanged := false}) - (decls := axHyps) - (exprs := #[errHyp]) - (noIndexAtArgs := false) - - -- Build an initial AxEffects - let effects := AxEffects.initial stateExpr - let effects := { effects with - fields := effects.fields - |>.insert .PC { value := pcE, proof := h_pc.toExpr} - |>.insert .ERR { value := mkConst ``StateError.None, proof := errHyp} - programProof := h_program.toExpr - stackAlignmentProof? := h_sp?.map (·.toExpr) - } - let c : SymContext := { - finalState, runSteps?, pc, - hRun := h_run.toExpr, - hRunId := h_run.fvarId, - h_sp? := (·.userName) <$> h_sp?, - programInfo, - effects, - aggregateSimpCtx, aggregateSimprocs - } - c.modify <| - inferStatePrefixAndNumber -where - findLocalDeclOfType? (expectedType : Expr) : MetaM (Option LocalDecl) := do - let msg := m!"Searching for hypothesis of type: {expectedType}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do - let decl? ← _root_.findLocalDeclOfType? expectedType - trace[Tactic.sym] "Found: {(·.toExpr) <$> decl?}" - return decl? - findLocalDeclOfTypeOrError (expectedType : Expr) : MetaM LocalDecl := do - let msg := m!"Searching for hypothesis of type: {expectedType}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do - let decl ← _root_.findLocalDeclOfTypeOrError expectedType - trace[Tactic.sym] "Found: {decl.toExpr}" - return decl - -/-! ## Massaging the local context -/ - -/-- change the type (in the local context of the main goal) -of the hypotheses tracked by the given `SymContext` to be *exactly* of the shape -described in the relevant docstrings. - -That is, (un)fold types which were definitionally, but not syntactically, -equal to the expected shape. -/ -def canonicalizeHypothesisTypes : SymReaderM Unit := withMainContext' do - let c ← readThe SymContext - let lctx ← getLCtx - let mut goal ← getMainGoal - let state := c.effects.currentState - - let mut hyps := #[] - if let some h_sp := c.h_sp? then - hyps := hyps.push (h_sp, h_sp_type state) - - let mut hypIds ← hyps.mapM fun ⟨name, type⟩ => do - let some decl := lctx.findFromUserName? name - | throwError "Unknown local hypothesis `{name}`" - pure (decl.fvarId, type) + -- We create a bogus initial context + let c ← SymContext.initial stateExpr + c.modify <| do + searchLCtx SymContext.searchFor + withMainContext' <| do + let thms ← (← readThe AxEffects).toSimpTheorems + modifyThe SymContext (·.addSimpTheorems thms) - if let some runSteps := c.runSteps? then - hypIds := hypIds.push - (c.hRun.fvarId!, h_run_type c.finalState (toExpr runSteps) state) - let errHyp ← AxEffects.getFieldM .ERR - if let Expr.fvar id := errHyp.proof then - hypIds := hypIds.push (id, h_err_type state) - for ⟨fvarId, type⟩ in hypIds do - goal ← goal.replaceLocalDeclDefEq fvarId type - replaceMainGoal [goal] + inferStatePrefixAndNumber /-! ## Incrementing the context to the next state -/ @@ -437,32 +480,17 @@ evaluation: * the `currentStateNumber` is incremented -/ def prepareForNextStep : SymM Unit := do - let s ← getNextStateName - let pc ← do - let { value, ..} ← AxEffects.getFieldM .PC - try - reflectBitVecLiteral 64 value - catch err => - trace[Tactic.sym] "failed to reflect PC: {err.toMessageData}" - pure <| (← getThe SymContext).pc + 4 - - modifyThe SymContext (fun c => { c with - pc - h_sp? := c.h_sp?.map (fun _ => .mkSimple s!"h_{s}_sp_aligned") - runSteps? := (· - 1) <$> c.runSteps? - currentStateNumber := c.currentStateNumber + 1 - }) - -/-- Add a set of new simp-theorems to the simp-theorems used -for effect aggregation -/ -def addSimpTheorems (c : SymContext) (simpThms : Array SimpTheorem) : SymContext := - let addSimpThms := simpThms.foldl addSimpTheoremEntry - - let oldSimpTheorems := c.aggregateSimpCtx.simpTheorems - let simpTheorems := - if oldSimpTheorems.isEmpty then - oldSimpTheorems.push <| addSimpThms {} - else - oldSimpTheorems.modify (oldSimpTheorems.size - 1) addSimpThms - - { c with aggregateSimpCtx.simpTheorems := simpTheorems } + withVerboseTraceNode "prepareForNextStep" (tag := "prepareForNextStep") <| do + let pc ← do + let { value, ..} ← AxEffects.getFieldM .PC + try + reflectBitVecLiteral 64 value + catch err => + trace[Tactic.sym] "failed to reflect PC: {err.toMessageData}" + pure <| (← getThe SymContext).pc + 4 + + modifyThe SymContext (fun c => { c with + pc + runSteps? := (· - 1) <$> c.runSteps? + currentStateNumber := c.currentStateNumber + 1 + }) diff --git a/Tactics/Sym/LCtxSearch.lean b/Tactics/Sym/LCtxSearch.lean new file mode 100644 index 00000000..bc59eb5d --- /dev/null +++ b/Tactics/Sym/LCtxSearch.lean @@ -0,0 +1,235 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer +-/ +import Lean + +open Lean + +/-! +## Local Context Search + +In this module we build an abstraction around searching the local context for +multiple local variables or hypotheses at the same time. + +The main entry point to search is `searchLocalContext`. +`searchFor` is the main way to register which patterns to search for, +and what actions to perform if the variable is found (or not found). +-/ + +variable (m) [Monad m] + +inductive LCtxSearchResult + /-- This occurence of the pattern should be ignored -/ + | skip + /-- This should be counted as a successful occurence, + and we should *continue* matching for more variables -/ + | continu + /-- This should be counted as a successful occurence, + and we can *stop* matching against this particular pattern -/ + | done + deriving DecidableEq + +structure LCtxSearchState.Pattern where + /-- The type to search for (up to def-eq!). + Notice that `expectedType` is stored as a monadic value, + so that we can create fresh metavariables for each search -/ + expectedType : m Expr + /-- A cached result of `expectedType`; + this should be regenerated after every match! -/ + cachedExpectedType : Expr + /-- `whenFound` will be called whenever a match for `pattern` is found, + with the instantiated pattern as an argument. + The returned `LCtxSearchResult` determines if we count this as a successful + occurence of the pattern, which is relevant for if `whenNotFound` is called. + + NOTE: We give the (instantiated) pattern as an arg, *not* the expression that + we matched against. This way, implementors can recover information through + syntactic destructuring. + + An alternative design would have `pattern : MetaM (List Expr × Expr)`, + where the list is intended to be a list of meta-variables, and + `whenFound : List Expr → Expr → m Unit`, where we would call + `whenFound` with the list returned by `pattern` (which has the metavariables + that should now have been instantiated with subexpressions of interest). + -/ + whenFound : LocalDecl → Expr → m LCtxSearchResult + /-- `whenNotFound` will be called if no successful occurence of the pattern + (as determined by the return value of `whenFound`) + could be found in the local context -/ + whenNotFound : Expr → m Unit + /-- Whether to change the type of successful matches -/ + changeType : Bool + /-- Whether to match under binders -/ + matchUnderBinders : Bool + /-- How many times have we (successfully) found the pattern -/ + occurences : Nat := 0 + /-- Whether the pattern is active; is `isActive = false`, + then no further matches are attempted -/ + isActive : Bool := true + +structure LCtxSearchState where + patterns : Array (LCtxSearchState.Pattern m) + +abbrev SearchLCtxForM := StateT (LCtxSearchState m) m + +variable {m} + +/-- register a new expression pattern to search for: +- `expectedType` should give an expression, with meta-variables, which is the + type to search for. + + Note that, once a match is found, any meta-variables in `expectedType` will be + assigned, and thus further matches will now need to match those same concrete + values. That is why `expectedType` is a monadic value, which is re-evaluated + after each successful match. + If multiple matches, with distinct instantiations of a meta-variable, are + desired, it's important that meta-variable is created *inside* the + `expectedType` action. + If, on the other hand, a single instantiation accross all variables is + desired, the meta-variable should be created *outside*. +- `whenFound` will be called whenever a local variable whose type is def-eq + to `expectedType` is found, with as argument the instantiated `expectedType`. + The return value of `whenFound` is used to determine if a match is considered + succesful. +- `whenNotFound` will be called if no local variable could be found with a type + def-eq to the pattern, or if `whenFound` returned `skip` for all variables + that were found. For convenience, we pass in the `expectedType` here as well. + See `throwNotFound` for a convenient way to throw an error here. +- If `changeType` (default `false`) is set to true, then we change the type of + every successful match (i.e., `whenFound` returns `continu` or `done`) + to be exactly the `expectedType` +- If `matchUnderBinders` (default `false`) is set to true, we will introduce + metavariable for any binders in a variable's type before matching. + For example, with `matchUnderBinders` set to true, we consider a variable + `h : ∀ f, r f s0 = r f s1` + as a match for expected type `r ?f s0 = r ?f s1`. + +WARNING: Once a pattern is found for which `whenFound` returns `done`, that +particular variable will not be matched for any other patterns. +In case of overlapping patterns, the pattern which was added first will be +tried first +-/ +def searchLCtxFor + (expectedType : m Expr) + (whenFound : LocalDecl → Expr → m LCtxSearchResult) + (whenNotFound : Expr → m Unit := fun _ => pure ()) + (changeType : Bool := false) + (matchUnderBinders : Bool := false) + : SearchLCtxForM m Unit := do + let pattern := { + -- Placeholder value, since we can't evaluate `m` inside of `LCtxSearchM` + cachedExpectedType :=← expectedType + expectedType, whenFound, whenNotFound, changeType, matchUnderBinders + } + modify fun state => { state with + patterns := state.patterns.push pattern + } + +/-- A wrapper around `searchLCtxFor`, which is simplified for matching at most +one occurence of `expectedType`. + +See `searchLCtxFor` for an explanation of the arguments -/ +def searchLCtxForOnce + (expectedType : Expr) + (whenFound : LocalDecl → Expr → m Unit) + (whenNotFound : Expr → m Unit := fun _ => pure ()) + (changeType : Bool := false) + (matchUnderBinders : Bool := false) + : SearchLCtxForM m Unit := do + searchLCtxFor (pure expectedType) + (fun d e => do whenFound d e; return .done) + whenNotFound changeType matchUnderBinders + +section Run +open Elab.Tactic +open Meta (isDefEq) +variable [MonadLCtx m] [MonadLiftT MetaM m] [MonadLiftT TacticM m] + +namespace LCtxSearchState + +/-- Return `true` if `e` matches the pattern -/ +def Pattern.matches (pat : Pattern m) (e : Expr) : m Bool := do + let mut e := e + if pat.matchUnderBinders then + let ⟨_, _, e'⟩ ← Meta.forallMetaTelescope e + e := e' + isDefEq e pat.cachedExpectedType + +/-- +Attempt to match `e` against the given pattern: +- if `e` is def-eq to `pat.cachedExpectedType`, then return + the updated pattern state (with a fresh `cachedExpectedType`), and + the result of `whenFound` +- Otherwise, if `e` is not def-eq, return `none` +-/ +def Pattern.match? (pat : Pattern m) (decl : LocalDecl) : + m (Option (Pattern m × LCtxSearchResult)) := do + if !pat.isActive then + return none + else if !(← pat.matches decl.type) then + return none + else + let cachedExpectedType ← pat.expectedType + let res ← pat.whenFound decl pat.cachedExpectedType + let mut occurences := pat.occurences + if res != .skip then + occurences := occurences + 1 + if pat.changeType = true then + let goal ← getMainGoal + let goal ← goal.replaceLocalDeclDefEq decl.fvarId pat.cachedExpectedType + replaceMainGoal [goal] + return some ({pat with cachedExpectedType, occurences}, res) + +end LCtxSearchState + +/-- Search the local context for variables of certain types, in a single pass. +`k` is a monadic continuation that determines the patterns to search for, +see `searchLCtxFor` to see how to register those patterns +-/ +def searchLCtx (k : SearchLCtxForM m Unit) : m Unit := do + let ((), { patterns }) ← StateT.run k ⟨#[]⟩ + -- We have to keep `patterns` in a Subtype to be able to prove our indexes + -- are valid even after mutation + -- TODO(@alexkeizer): consider using `Batteries.Data.Vector`, if we can + -- justify a batteries dependency + let n := patterns.size + let mut patterns : { as : Array _ // as.size = n} := + ⟨patterns, rfl⟩ + + -- The main search + for decl in ← getLCtx do + for hi : i in [0:n] do + have hi : i < patterns.val.size := by + rw [patterns.property]; get_elem_tactic + let pat := patterns.val[i] + if let some (pat, res) ← pat.match? decl then + patterns := ⟨ + patterns.val.set ⟨i, hi⟩ pat, + by simp[patterns.property] + ⟩ + if res = .done || res = .continu then + break -- break out of the inner loop + + -- Finally, check each pattern and call `whenNotFound` if appropriate + for pat in patterns.val do + if pat.occurences = 0 then + pat.whenNotFound pat.cachedExpectedType + return () + +variable [MonadError m] in +/-- Throw an error complaining that no variable of `expectedType` could +be found -/ +def throwNotFound (expectedType : Expr) : m Unit := + throwError "Expected a local variable of type:\n {expectedType}\n\ + but no such variable was found in the local context" + +/-- Add a message to the trace that we searched for, but couldn't find, +a variable of `expectedType`, and continue execution. -/ +def traceNotFound (cls : Name) (expectedType : Expr) : m Unit := + trace (m:=MetaM) cls fun _ => + m!"Unable to find a variable of type {expectedType} in the local context" + + +end Run diff --git a/Tests/AES-GCM/AESV8ProgramTests.lean b/Tests/AES-GCM/AESV8ProgramTests.lean index 4389d415..085fe7a2 100644 --- a/Tests/AES-GCM/AESV8ProgramTests.lean +++ b/Tests/AES-GCM/AESV8ProgramTests.lean @@ -360,28 +360,28 @@ def final_state1 : ArmState := aes_hw_ctr32_encrypt_blocks_test 88 1 in_block rounds key_schedule ivec def final_buf1 : BitVec 640 := read_mem_bytes 80 out_address final_state1 example : read_err final_state1 = StateError.None := by native_decide -example: final_buf1 = (BitVec.zero 512) ++ (extractLsb 127 0 (revflat buf_res_128)) := by native_decide +example: final_buf1 = (BitVec.zero 512) ++ (extractLsb' 0 128 (revflat buf_res_128)) := by native_decide -- -- len = 2 def final_state2 : ArmState := aes_hw_ctr32_encrypt_blocks_test 89 2 in_block rounds key_schedule ivec def final_buf2 : BitVec 640 := read_mem_bytes 80 out_address final_state2 example : read_err final_state2 = StateError.None := by native_decide -example: final_buf2 = (BitVec.zero 384) ++ (extractLsb 255 0 (revflat buf_res_128)) := by native_decide +example: final_buf2 = (BitVec.zero 384) ++ (extractLsb' 0 256 (revflat buf_res_128)) := by native_decide -- len = 3 def final_state3 : ArmState := aes_hw_ctr32_encrypt_blocks_test 128 3 in_block rounds key_schedule ivec def final_buf3 : BitVec 640 := read_mem_bytes 80 out_address final_state3 example : read_err final_state3 = StateError.None := by native_decide -example: final_buf3 = (BitVec.zero 256) ++ (extractLsb 383 0 (revflat buf_res_128)) := by native_decide +example: final_buf3 = (BitVec.zero 256) ++ (extractLsb' 0 384 (revflat buf_res_128)) := by native_decide -- len = 4 def final_state4 : ArmState := aes_hw_ctr32_encrypt_blocks_test 190 4 in_block rounds key_schedule ivec def final_buf4 : BitVec 640 := read_mem_bytes 80 out_address final_state4 example : read_err final_state4 = StateError.None := by native_decide -example: final_buf4 = (BitVec.zero 127) ++ (extractLsb 512 0 (revflat buf_res_128)) := by native_decide +example: final_buf4 = (BitVec.zero 128) ++ (extractLsb' 0 512 (revflat buf_res_128)) := by native_decide -- len = 5 def final_state5 : ArmState := diff --git a/Tests/AES-GCM/GCMSpecTests.lean b/Tests/AES-GCM/GCMSpecTests.lean index cf255810..2bd8598e 100644 --- a/Tests/AES-GCM/GCMSpecTests.lean +++ b/Tests/AES-GCM/GCMSpecTests.lean @@ -14,8 +14,8 @@ namespace GCMInitV8SpecTest def flatten_H := BitVec.flatten GCMProgramTestParams.H def spec_table := GCMV8.GCMInitV8 flatten_H -example : extractLsb (12 * 128) 0 (revflat spec_table) - = extractLsb (12 * 128) 0 (revflat GCMProgramTestParams.Htable) +example : extractLsb' 0 (12 * 128) (revflat spec_table) + = extractLsb' 0 (12 * 128) (revflat GCMProgramTestParams.Htable) := by native_decide end GCMInitV8SpecTest diff --git a/Tests/Tactics/AddressNormalization.lean b/Tests/Tactics/AddressNormalization.lean new file mode 100644 index 00000000..709275d8 --- /dev/null +++ b/Tests/Tactics/AddressNormalization.lean @@ -0,0 +1,82 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Siddharth Bhat, Tobias Grosser +-/ + +import Arm.Memory.AddressNormalization + +/-! ## Examples -/ + +set_option trace.Tactic.address_normalization true in +/-- +info: [Tactic.address_normalization] ⚙️ reduceModOfLt 'x.toNat % 2 ^ w' +[Tactic.address_normalization] ✅️ reduceModOfLt 'x.toNat % 2 ^ w' +-/ +#guard_msgs in theorem eg₁ {w} (x : BitVec w) : x.toNat % 2 ^ w = x.toNat + 0 := by + simp only [address_normalization] + rfl + +/-- info: 'eg₁' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms eg₁ + +theorem eg₂ {w} (x y : BitVec w) (h : x.toNat + y.toNat < 2 ^ w) : + (x + y).toNat = x.toNat + y.toNat := by + simp [address_normalization] + +/-- info: 'eg₂' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms eg₂ + +set_option trace.Tactic.address_normalization true in +/-- +info: [Tactic.address_normalization] ⚙️ canonicalizeBinConst '(HAdd.hAdd x y)' +[Tactic.address_normalization] ⚙️ reduceModOfLt 'x.toNat + y.toNat % 2 ^ w' +[Tactic.address_normalization] ⚙️ reduceModSub 'x.toNat + y.toNat % 2 ^ w' +-/ +#guard_msgs in theorem eg₃ {w} (x y : BitVec w) : + (x + y).toNat = (x.toNat + y.toNat) % 2 ^ w := by + simp [address_normalization] + +/-- info: 'eg₂' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms eg₂ + +theorem eg₄ {w} (x y z : BitVec w) + (h₂ : y.toNat + z.toNat < 2 ^ w) + (h : x.toNat * (y.toNat + z.toNat) < 2 ^ w) : + (x * (y + z)).toNat = x.toNat * (y.toNat + z.toNat) := by + simp [address_normalization] + +/-- info: 'eg₄' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms eg₄ + +theorem eg₅ {w} (x y : BitVec w) (h : x.toNat + y.toNat ≥ 2 ^ w) (h' : (x.toNat + y.toNat) - 2 ^ w < 2 ^ w) : + (x + y).toNat = x.toNat + y.toNat - 2 ^ w := by + simp [address_normalization] + +/-- info: 'eg₅' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms eg₅ + +set_option trace.Tactic.address_normalization true in +/-- +info: [Tactic.address_normalization] ⚙️ canonicalizeBinConst '(HAdd.hAdd x 100#w)' +[Tactic.address_normalization] ✅️ canonicalizeBinConst '(HAdd.hAdd x 100#w)' +[Tactic.address_normalization] ⚙️ canonicalizeBinConst '(HAdd.hAdd 100#w x)' +-/ +#guard_msgs in theorem eg₆ {w} (x : BitVec w) : x + 100#w = 100#w + x := by + simp only [address_normalization] + +/-- info: 'eg₆' depends on axioms: [propext] -/ +#guard_msgs in #print axioms eg₆ + + +theorem eg₇ {w} (x : BitVec w) : 100#w + (200#w + x) = 300#w + x := by + simp only [address_normalization] + +/-- info: 'eg₇' depends on axioms: [propext] -/ +#guard_msgs in #print axioms eg₇ + +theorem eg₈ {w} : 100#w + 200#w = 300#w := by + simp only [address_normalization] + +/-- info: 'eg₈' depends on axioms: [propext] -/ +#guard_msgs in #print axioms eg₈ diff --git a/Tests/Tactics/CSE.lean b/Tests/Tactics/CSE.lean index ce4222d1..bc74a9b2 100644 --- a/Tests/Tactics/CSE.lean +++ b/Tests/Tactics/CSE.lean @@ -234,19 +234,20 @@ warning: declaration uses 'sorry' info: a b c d : BitVec 64 x1 x2 : BitVec 128 hx1 : x2 <<< 64 = x1 -x3 : BitVec (127 - 64 + 1) -hx3 : BitVec.extractLsb 127 64 c = x3 +x3 : BitVec 64 +hx3 : BitVec.extractLsb' 64 64 c = x3 hx2 : BitVec.zeroExtend 128 x3 = x2 -⊢ BitVec.zeroExtend 128 (BitVec.extractLsb 63 0 x1) ||| BitVec.zeroExtend 128 (BitVec.extractLsb 127 64 x1) = +⊢ BitVec.zeroExtend 128 (BitVec.extractLsb' 0 64 x1) ||| BitVec.zeroExtend 128 (BitVec.extractLsb' 64 64 x1) = sorryAx (BitVec 128) -/ + #guard_msgs in theorem bitvec_subexpr (a b c d : BitVec 64) : (zeroExtend 128 - (extractLsb 63 0 + (extractLsb' 0 64 ( - zeroExtend 128 (extractLsb 127 64 c) <<< 64)) ||| + zeroExtend 128 (extractLsb' 64 64 c) <<< 64)) ||| zeroExtend 128 - (extractLsb 127 64 - (zeroExtend 128 (extractLsb 127 64 c) <<< 64))) = sorry := by + (extractLsb' 64 64 + (zeroExtend 128 (extractLsb' 64 64 c) <<< 64))) = sorry := by cse trace_state all_goals sorry @@ -260,107 +261,104 @@ warning: declaration uses 'sorry' --- info: H : 64 > 0 a b c d e : BitVec 128 -x1 x2 : BitVec 64 -x3 : BitVec (127 - 64 + 1) +x1 x2 x3 : BitVec 64 x4 : BitVec 128 -hx3 : BitVec.extractLsb 127 64 x4 = x3 +hx3 : BitVec.extractLsb' 64 64 x4 = x3 x5 x6 x7 x8 x9 : BitVec 64 hx2 : x9 + x3 = x2 x10 x11 : BitVec 128 hx4 : x10 ||| x11 = x4 x12 : BitVec 128 -hx10 : x12 &&& 18446744073709551615#128 = x10 +hx11 : x12 <<< 64 = x11 x13 : BitVec 128 -hx11 : x13 <<< 64 = x11 +hx10 : x13 &&& 18446744073709551615#128 = x10 x14 : BitVec 64 hx12 : BitVec.zeroExtend 128 x14 = x12 x15 : BitVec 64 hx13 : BitVec.zeroExtend 128 x15 = x13 -x16 : BitVec (63 - 0 + 1) -x17 : BitVec (127 - 64 + 1) +x16 x17 : BitVec 64 x18 : BitVec 128 -hx16 : BitVec.extractLsb 63 0 x18 = x16 -hx17 : BitVec.extractLsb 127 64 x18 = x17 +hx16 : BitVec.extractLsb' 0 64 x18 = x16 +hx17 : BitVec.extractLsb' 64 64 x18 = x17 x19 x20 : BitVec 64 hx8 : x19 + x20 = x8 x21 : BitVec 64 hx9 : x20 + x21 = x9 -x22 : BitVec 128 -x23 : BitVec 64 -x24 : BitVec 128 -hx18 : x22 ||| x24 = x18 -x25 : BitVec 64 +x22 x23 : BitVec 128 +hx18 : x22 ||| x23 = x18 +x24 : BitVec 64 +x25 : BitVec 128 +hx23 : x25 <<< 64 = x23 x26 : BitVec 128 hx22 : x26 &&& 18446744073709551615#128 = x22 -x27 : BitVec 128 -hx24 : x27 <<< 64 = x24 -x28 : BitVec 64 -hx27 : BitVec.zeroExtend 128 x28 = x27 +x27 x28 : BitVec 64 +hx25 : BitVec.zeroExtend 128 x28 = x25 x29 : BitVec 64 -hx20 : x29 ^^^ x25 = x20 +hx20 : x29 ^^^ x27 = x20 x30 : BitVec 64 hx26 : BitVec.zeroExtend 128 x30 = x26 -x31 x32 : BitVec 64 -hx23 : x31 ^^^ x32 = x23 -x33 x34 : BitVec 64 -hx21 : x23 ^^^ x34 = x21 -x35 : BitVec (127 - 64 + 1) -hx35 : BitVec.extractLsb 127 64 d = x35 -hx7 : x8 + x35 = x7 -x36 : BitVec (63 - 0 + 1) -hx36 : BitVec.extractLsb 63 0 a = x36 -x37 : BitVec (63 - 0 + 1) -hx37 : BitVec.extractLsb 63 0 c = x37 -x38 : BitVec (127 - 64 + 1) -hx38 : BitVec.extractLsb 127 64 c = x38 -hx19 : x38 + x21 = x19 -hx28 : x38 + x35 = x28 -x39 : BitVec (127 - 64 + 1) -hx39 : BitVec.extractLsb 127 64 e = x39 -hx6 : x7 + x39 = x6 -hx15 : x17 + x39 = x15 -x40 : BitVec (63 - 0 + 1) -hx40 : BitVec.extractLsb 63 0 b = x40 -hx1 : x2 + x40 = x1 -hx5 : x40 + x6 = x5 -x41 : BitVec (63 - 0 + 1) -hx41 : BitVec.extractLsb 63 0 d = x41 -hx30 : x37 + x41 = x30 -x42 : BitVec (127 - 64 + 1) -hx42 : BitVec.extractLsb 127 64 a = x42 -hx25 : x33 &&& x42 = x25 -x43 : BitVec (127 - 64 + 1) -hx43 : BitVec.extractLsb 127 64 b = x43 -hx31 : x43.rotateRight 14 = x31 -hx32 : x43.rotateRight 18 = x32 -hx33 : ~~~x43 = x33 -hx34 : x43.rotateRight 41 = x34 -hx29 : x43 &&& x36 = x29 -x44 : BitVec (63 - 0 + 1) -hx44 : BitVec.extractLsb 63 0 e = x44 -hx14 : x16 + x44 = x14 +x31 : BitVec 64 +hx21 : x24 ^^^ x31 = x21 +x32 x33 : BitVec 64 +hx24 : x33 ^^^ x32 = x24 +x34 x35 : BitVec 64 +hx35 : BitVec.extractLsb' 0 64 a = x35 +x36 : BitVec 64 +hx36 : BitVec.extractLsb' 64 64 a = x36 +hx27 : x34 &&& x36 = x27 +x37 : BitVec 64 +hx37 : BitVec.extractLsb' 0 64 b = x37 +hx1 : x2 + x37 = x1 +hx5 : x37 + x6 = x5 +x38 : BitVec 64 +hx38 : BitVec.extractLsb' 0 64 e = x38 +hx15 : x16 + x38 = x15 +x39 : BitVec 64 +hx39 : BitVec.extractLsb' 64 64 b = x39 +hx31 : x39.rotateRight 41 = x31 +hx32 : x39.rotateRight 18 = x32 +hx33 : x39.rotateRight 14 = x33 +hx34 : ~~~x39 = x34 +hx29 : x39 &&& x35 = x29 +x40 : BitVec 64 +hx40 : BitVec.extractLsb' 64 64 c = x40 +hx19 : x40 + x21 = x19 +x41 : BitVec 64 +hx41 : BitVec.extractLsb' 64 64 d = x41 +hx7 : x8 + x41 = x7 +hx28 : x40 + x41 = x28 +x42 : BitVec 64 +hx42 : BitVec.extractLsb' 64 64 e = x42 +hx6 : x7 + x42 = x6 +hx14 : x17 + x42 = x14 +x43 : BitVec 64 +hx43 : BitVec.extractLsb' 0 64 d = x43 +x44 : BitVec 64 +hx44 : BitVec.extractLsb' 0 64 c = x44 +hx30 : x44 + x43 = x30 ⊢ x2 ++ - ((x1 &&& x43 ^^^ ~~~x1 &&& x36) + (x1.rotateRight 14 ^^^ x1.rotateRight 18 ^^^ x1.rotateRight 41) + - BitVec.extractLsb 63 0 x4) = + ((x1 &&& x39 ^^^ ~~~x1 &&& x35) + (x1.rotateRight 14 ^^^ x1.rotateRight 18 ^^^ x1.rotateRight 41) + + BitVec.extractLsb' 0 64 x4) = x6 ++ - (x37 + (x5.rotateRight 14 ^^^ x5.rotateRight 18 ^^^ x5.rotateRight 41) + (x5 &&& x43 ^^^ ~~~x5 &&& x36) + x41 + - x44) + (x44 + (x5.rotateRight 14 ^^^ x5.rotateRight 18 ^^^ x5.rotateRight 41) + (x5 &&& x39 ^^^ ~~~x5 &&& x35) + x43 + + x38) -/ + #guard_msgs in theorem sha512h_rule_1 (a b c d e : BitVec 128) : let elements := 2 let esize := 64 let inner_sum := (binary_vector_op_aux 0 elements esize BitVec.add c d (BitVec.zero 128) H) let outer_sum := (binary_vector_op_aux 0 elements esize BitVec.add inner_sum e (BitVec.zero 128) H) - let a0 := extractLsb 63 0 a - let a1 := extractLsb 127 64 a - let b0 := extractLsb 63 0 b - let b1 := extractLsb 127 64 b - let c0 := extractLsb 63 0 c - let c1 := extractLsb 127 64 c - let d0 := extractLsb 63 0 d - let d1 := extractLsb 127 64 d - let e0 := extractLsb 63 0 e - let e1 := extractLsb 127 64 e + let a0 := extractLsb' 0 64 a + let a1 := extractLsb' 64 64 a + let b0 := extractLsb' 0 64 b + let b1 := extractLsb' 64 64 b + let c0 := extractLsb' 0 64 c + let c1 := extractLsb' 64 64 c + let d0 := extractLsb' 0 64 d + let d1 := extractLsb' 64 64 d + let e0 := extractLsb' 0 64 e + let e1 := extractLsb' 64 64 e let hi64_spec := compression_update_t1 b1 a0 a1 c1 d1 e1 let lo64_spec := compression_update_t1 (b0 + hi64_spec) b1 a0 c0 d0 e0 sha512h a b outer_sum = hi64_spec ++ lo64_spec := by @@ -385,28 +383,24 @@ warning: declaration uses 'sorry' --- info: h1 h2 : 64 > 0 a b c d e : BitVec 128 -x1 x2 : BitVec 64 -x3 : BitVec (127 - 64 + 1) +x1 x2 x3 : BitVec 64 x4 : BitVec 128 -hx3 : BitVec.extractLsb 127 64 x4 = x3 +hx3 : BitVec.extractLsb' 64 64 x4 = x3 x5 : BitVec 64 x6 x7 : BitVec 128 hx4 : x7 ||| x6 = x4 x8 : BitVec 128 hx6 : x8 <<< 64 = x6 x9 : BitVec 64 -hx7 : BitVec.zeroExtend 128 x9 = x7 +hx8 : BitVec.zeroExtend 128 x9 = x8 x10 : BitVec 64 -hx8 : BitVec.zeroExtend 128 x10 = x8 -x11 : BitVec 64 -x12 : BitVec (127 - 64 + 1) -x13 : BitVec (63 - 0 + 1) -x14 : BitVec 64 -x15 : BitVec (191 - 64 + 1) -hx12 : BitVec.extractLsb 127 64 x15 = x12 -hx13 : BitVec.extractLsb 63 0 x15 = x13 +hx7 : BitVec.zeroExtend 128 x10 = x7 +x11 x12 x13 x14 : BitVec 64 +x15 : BitVec 128 +hx12 : BitVec.extractLsb' 64 64 x15 = x12 +hx13 : BitVec.extractLsb' 0 64 x15 = x13 x16 : BitVec 256 -hx15 : BitVec.extractLsb 191 64 x16 = x15 +hx15 : BitVec.extractLsb' 64 128 x16 = x15 x17 x18 : BitVec 64 hx2 : x18 + x3 = x2 x19 : BitVec 128 @@ -415,78 +409,80 @@ x20 x21 : BitVec 64 hx17 : x20 + x21 = x17 x22 : BitVec 64 hx18 : x21 + x22 = x18 -x23 : BitVec 128 -x24 : BitVec 64 -x25 : BitVec 128 -hx19 : x25 ||| x23 = x19 +x23 : BitVec 64 +x24 : BitVec 128 +x25 : BitVec 64 x26 : BitVec 128 -hx23 : x26 <<< 64 = x23 -x27 x28 : BitVec 64 -hx25 : BitVec.zeroExtend 128 x28 = x25 +hx19 : x26 ||| x24 = x19 +x27 : BitVec 128 +hx24 : x27 <<< 64 = x24 +x28 : BitVec 64 +hx26 : BitVec.zeroExtend 128 x28 = x26 x29 : BitVec 64 -hx26 : BitVec.zeroExtend 128 x29 = x26 +hx21 : x29 ^^^ x25 = x21 x30 : BitVec 64 -hx21 : x30 ^^^ x27 = x21 -x31 : BitVec 64 -hx22 : x24 ^^^ x31 = x22 -x32 x33 x34 : BitVec 64 -hx24 : x33 ^^^ x34 = x24 -x35 : BitVec (63 - 0 + 1) -hx35 : BitVec.extractLsb 63 0 b = x35 -hx1 : x2 + x35 = x1 -hx5 : x35 + x11 = x5 -x36 : BitVec (63 - 0 + 1) -hx36 : BitVec.extractLsb 63 0 a = x36 -x37 : BitVec (127 - 64 + 1) -hx37 : BitVec.extractLsb 127 64 b = x37 -hx31 : x37.rotateRight 41 = x31 -hx32 : ~~~x37 = x32 -hx33 : x37.rotateRight 14 = x33 -hx34 : x37.rotateRight 18 = x34 -hx30 : x37 &&& x36 = x30 -x38 : BitVec (127 - 64 + 1) -hx38 : BitVec.extractLsb 127 64 a = x38 -hx27 : x32 &&& x38 = x27 -x39 : BitVec (127 - 64 + 1) -hx39 : BitVec.extractLsb 127 64 e = x39 -x40 : BitVec (63 - 0 + 1) -hx40 : BitVec.extractLsb 63 0 c = x40 -hx9 : x40 + x13 = x9 -x41 : BitVec (127 - 64 + 1) -hx41 : BitVec.extractLsb 127 64 d = x41 -hx29 : x41 + x39 = x29 -x42 : BitVec (63 - 0 + 1) -hx42 : BitVec.extractLsb 63 0 e = x42 -hx11 : x14 + x42 = x11 -x43 : BitVec (127 - 64 + 1) -hx43 : BitVec.extractLsb 127 64 c = x43 -hx10 : x43 + x12 = x10 -hx20 : x43 + x22 = x20 -x44 : BitVec (63 - 0 + 1) -hx44 : BitVec.extractLsb 63 0 d = x44 -hx14 : x17 + x44 = x14 -hx28 : x44 + x42 = x28 +hx27 : BitVec.zeroExtend 128 x30 = x27 +x31 x32 x33 : BitVec 64 +hx22 : x23 ^^^ x33 = x22 +x34 : BitVec 64 +hx23 : x34 ^^^ x31 = x23 +x35 : BitVec 64 +hx35 : BitVec.extractLsb' 64 64 e = x35 +x36 : BitVec 64 +hx36 : BitVec.extractLsb' 64 64 b = x36 +hx31 : x36.rotateRight 18 = x31 +hx32 : ~~~x36 = x32 +hx33 : x36.rotateRight 41 = x33 +hx34 : x36.rotateRight 14 = x34 +x37 : BitVec 64 +hx37 : BitVec.extractLsb' 0 64 c = x37 +hx10 : x37 + x13 = x10 +x38 : BitVec 64 +hx38 : BitVec.extractLsb' 0 64 d = x38 +hx14 : x17 + x38 = x14 +x39 : BitVec 64 +hx39 : BitVec.extractLsb' 64 64 c = x39 +hx9 : x39 + x12 = x9 +hx20 : x39 + x22 = x20 +x40 : BitVec 64 +hx40 : BitVec.extractLsb' 0 64 e = x40 +hx11 : x14 + x40 = x11 +hx28 : x38 + x40 = x28 +x41 : BitVec 64 +hx41 : BitVec.extractLsb' 0 64 b = x41 +hx1 : x2 + x41 = x1 +hx5 : x41 + x11 = x5 +x42 : BitVec 64 +hx42 : BitVec.extractLsb' 64 64 d = x42 +hx30 : x42 + x35 = x30 +x43 : BitVec 64 +hx43 : BitVec.extractLsb' 0 64 a = x43 +hx29 : x36 &&& x43 = x29 +x44 : BitVec 64 +hx44 : BitVec.extractLsb' 64 64 a = x44 +hx25 : x32 &&& x44 = x25 ⊢ x2 ++ - ((x1 &&& x37 ^^^ ~~~x1 &&& x36) + (x1.rotateRight 14 ^^^ x1.rotateRight 18 ^^^ x1.rotateRight 41) + - BitVec.extractLsb 63 0 x4) = + ((x1 &&& x36 ^^^ ~~~x1 &&& x43) + (x1.rotateRight 14 ^^^ x1.rotateRight 18 ^^^ x1.rotateRight 41) + + BitVec.extractLsb' 0 64 x4) = x11 ++ - (x40 + (x5.rotateRight 14 ^^^ x5.rotateRight 18 ^^^ x5.rotateRight 41) + (x5 &&& x37 ^^^ ~~~x5 &&& x36) + x41 + - x39) + (x37 + (x5.rotateRight 14 ^^^ x5.rotateRight 18 ^^^ x5.rotateRight 41) + (x5 &&& x36 ^^^ ~~~x5 &&& x43) + x42 + + x35) -/ + #guard_msgs in theorem sha512h_rule_2 (a b c d e : BitVec 128) : - let a0 := extractLsb 63 0 a - let a1 := extractLsb 127 64 a - let b0 := extractLsb 63 0 b - let b1 := extractLsb 127 64 b - let c0 := extractLsb 63 0 c - let c1 := extractLsb 127 64 c - let d0 := extractLsb 63 0 d - let d1 := extractLsb 127 64 d - let e0 := extractLsb 63 0 e - let e1 := extractLsb 127 64 e + let a0 := extractLsb' 0 64 a + let a1 := extractLsb' 64 64 a + let b0 := extractLsb' 0 64 b + let b1 := extractLsb' 64 64 b + let c0 := extractLsb' 0 64 c + let c1 := extractLsb' 64 64 c + let d0 := extractLsb' 0 64 d + let d1 := extractLsb' 64 64 d + let e0 := extractLsb' 0 64 e + let e1 := extractLsb' 64 64 e let inner_sum := binary_vector_op_aux 0 2 64 BitVec.add d e (BitVec.zero 128) h1 let concat := inner_sum ++ inner_sum - let operand := extractLsb 191 64 concat + let operand := extractLsb' 64 128 concat let hi64_spec := compression_update_t1 b1 a0 a1 c1 d0 e0 let lo64_spec := compression_update_t1 (b0 + hi64_spec) b1 a0 c0 d1 e1 sha512h a b (binary_vector_op_aux 0 2 64 BitVec.add c operand (BitVec.zero 128) h2) = diff --git a/Tests/Tests.lean b/Tests/Tests.lean index 7eeb81ad..64a33a77 100644 --- a/Tests/Tests.lean +++ b/Tests/Tests.lean @@ -20,3 +20,4 @@ import «Tests».«ELFParser».MiscTests import «Tests».Tactics.CSE import «Tests».Tactics.Sym import «Tests».Tactics.ReduceFetchInst +import «Tests».Tactics.AddressNormalization diff --git a/lakefile.lean b/lakefile.lean index 7175ad63..66000295 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -33,6 +33,7 @@ lean_lib «Doc» where -- add library configuration options here lean_lib «Benchmarks» where + leanOptions := #[⟨`weak.benchmark.runs, (0 : Nat)⟩] -- add library configuration options here @[default_target] diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh new file mode 100755 index 00000000..028b2bf4 --- /dev/null +++ b/scripts/benchmark.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +LAKE=lake +BENCH="$LAKE env lean -Dweak.benchmark.runs=5" +OUT="data/benchmarks" + +timestamp=$(date +"%Y-%m-%d_%H%M%S") +rev=$(git rev-parse --short HEAD) +echo "HEAD is on $rev" +out="$OUT/${timestamp}_${rev}" +mkdir -p "$out" + +$LAKE build Benchmarks +for file in "$@"; do + echo + echo + $file + echo + base="$(basename "$file" ".lean")" + $BENCH $file | tee "$out/$base" +done diff --git a/scripts/profile.sh b/scripts/profile.sh new file mode 100755 index 00000000..ccc1433d --- /dev/null +++ b/scripts/profile.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +LAKE=lake +PROF="$LAKE env lean -Dprofiler=true" +OUT="data/profiles" + +timestamp=$(date +"%Y-%m-%d_%H%M%S") +rev=$(git rev-parse --short HEAD) +echo "HEAD is on $rev" +out="$OUT/${timestamp}_${rev}" +mkdir -p "$out" + +$LAKE build Benchmarks +for file in "$@"; do + echo + echo + $file + echo + base="$(basename "$file" ".lean")" + $PROF -Dtrace.profiler.output="$out/$base.json" "$file" | tee "$base.log" +done