diff --git a/backends/lean/Aeneas.lean b/backends/lean/Aeneas.lean index 60f984c3..a54ea445 100644 --- a/backends/lean/Aeneas.lean +++ b/backends/lean/Aeneas.lean @@ -1,11 +1,17 @@ import Aeneas.Arith +import Aeneas.Bvify +import Aeneas.BvTac import Aeneas.Diverge +import Aeneas.FSimp import Aeneas.List -import Aeneas.Std +import Aeneas.Natify import Aeneas.Progress -import Aeneas.SimpLemmas -import Aeneas.Utils +import Aeneas.Range import Aeneas.Saturate import Aeneas.ScalarNF import Aeneas.ScalarTac +import Aeneas.SimpLemmas +import Aeneas.Std +import Aeneas.Utils import Aeneas.Termination +import Aeneas.Zify diff --git a/backends/lean/Aeneas/Arith/Lemmas.lean b/backends/lean/Aeneas/Arith/Lemmas.lean index 15d02d54..09324862 100644 --- a/backends/lean/Aeneas/Arith/Lemmas.lean +++ b/backends/lean/Aeneas/Arith/Lemmas.lean @@ -1,10 +1,40 @@ -import Aeneas.ScalarTac.IntTac import Aeneas.ScalarTac.ScalarTac +import Mathlib.Algebra.Algebra.ZMod +import Mathlib.RingTheory.Int.Basic +import Init.Data.Int.DivModLemmas -namespace Aeneas.ScalarTac +namespace Aeneas.Arith @[nonlin_scalar_tac n % m] -theorem Int.emod_of_pos_disj (n m : Int) : m ≤ 0 ∨ (0 ≤ n % m ∧ n % m < m) := by +theorem Nat.mod_zero_or_lt (n m : Nat) : m = 0 ∨ (n % m < m) := by + dcases h: m = 0 + . simp [h] + . right + apply Nat.mod_lt; omega + +@[nonlin_scalar_tac n / m] +theorem Nat.div_zero_or_le (n m : Nat) : m = 0 ∨ (n / m ≤ n) := by + dcases h: m = 0 <;> simp [*] + apply Nat.div_le_self + +theorem Int.self_le_ediv {x y : ℤ} (hx : x ≤ 0) (hy : 0 ≤ y) : + x ≤ x / y := by + dcases x <;> dcases y + . simp_all + . simp_all + . rename_i x y + rw [HDiv.hDiv, instHDiv] + simp only [Div.div] + rw [Int.ediv.eq_def] + dcases y <;> simp only + . omega + . rename_i y + have := @Nat.div_le_self x y.succ + omega + . simp_all + +@[nonlin_scalar_tac n % m] +theorem Int.emod_neg_or_pos_lt (n m : Int) : m ≤ 0 ∨ (0 ≤ n % m ∧ n % m < m) := by if h: 0 < m then right; constructor . apply Int.emod_nonneg; omega @@ -12,7 +42,7 @@ theorem Int.emod_of_pos_disj (n m : Int) : m ≤ 0 ∨ (0 ≤ n % m ∧ n % m < else left; omega @[nonlin_scalar_tac n / m] -theorem Int.div_of_pos_disj (n m : Int) : n < 0 ∨ m < 0 ∨ (0 ≤ n / m ∧ n / m ≤ n) := by +theorem Int.div_neg_or_pos_le (n m : Int) : n < 0 ∨ m < 0 ∨ (0 ≤ n / m ∧ n / m ≤ n) := by dcases hn: 0 ≤ n <;> dcases hm: 0 ≤ m <;> try simp_all right; right; constructor . apply Int.ediv_nonneg <;> omega @@ -29,15 +59,579 @@ theorem Int.pos_mul_pos_is_pos_disj (n m : Int) : m < 0 ∨ n < 0 ∨ 0 ≤ m * cases h: (n < 0 : Bool) <;> simp_all right; right; apply pos_mul_pos_is_pos <;> tauto +@[scalar_tac b.toNat] +theorem Bool.toNat_eq (b : Bool) : + (b = true ∧ b.toNat = 1) ∨ (b = false ∧ b.toNat = 0) := by + cases b <;> simp + -- Some tests section + example (x y : Int) (h : 0 ≤ x ∧ 0 ≤ y) : 0 ≤ x * y := by scalar_tac +nonLin + example (x y : Int) (h : 0 ≤ x ∧ 0 ≤ y) : 0 ≤ x / y := by scalar_tac +nonLin - -- Activate the rule set for non linear arithmetic - set_option scalarTac.nonLin true +end - example (x y : Int) (h : 0 ≤ x ∧ 0 ≤ y) : 0 ≤ x * y := by scalar_tac - example (x y : Int) (h : 0 ≤ x ∧ 0 ≤ y) : 0 ≤ x / y := by scalar_tac +theorem Int.le_div_eq_bound_imp_eq {x y bound : Int} + (hx : 0 < x) (hBound : x ≤ bound) (hy : 0 < y) (hEq : x / y = bound) : + x = bound ∧ y = 1 := by + have hx : x = bound := by + by_contra + have : x < bound := by omega + have := @Int.ediv_le_self x y (by omega) + omega + have hy : y = 1 := by + by_contra + have hLe := @Nat.div_le_div x.toNat bound.toNat y.toNat 2 (by omega) (by omega) (by simp) + zify at hLe + have : x.toNat = x := by omega + rw [this] at hLe + have : y.toNat = y := by omega + rw [this] at hLe + have : bound.toNat = bound := by omega + rw [this] at hLe + have : bound / 2 < bound := by + rw [Int.ediv_lt_iff_lt_mul] <;> omega + omega + simp [hx, hy] -end +/-! +We list here a few arithmetic facts that are not in Mathlib. + +TODO: PR for Mathlib? +-/ + +-- TODO: this should be in mathlib +theorem Int.gcd_add_mul_self (a b k : ℤ) : + Int.gcd a (b + k * a) = Int.gcd a b := by + apply Eq.symm + have h := @Int.gcd_greatest a (b + k * a) (Int.gcd a b) (by simp) + simp only [Nat.cast_inj] at h + apply h + . apply Int.gcd_dvd_left + . apply dvd_add + . apply Int.gcd_dvd_right + . rw [dvd_mul] + exists 1, gcd a b + simp only [isUnit_one, IsUnit.dvd, one_mul, true_and] + split_conjs + apply Int.gcd_dvd_left + rfl + . intro e div_a div_bk + have div_ka : e ∣ k * a := by + rw [dvd_mul] + exists 1, e + simp [*] + have div_b : e ∣ b := by + have h : e ∣ (b + k * a) + (- k * a) := by + apply dvd_add <;> simp [*] + simp only [neg_mul, add_neg_cancel_right] at h + apply h + apply Int.dvd_gcd <;> assumption + +-- TODO: this should be in mathlib +theorem Int.gcd_mod_same {a b : ℤ} : + Int.gcd (a % b) b = Int.gcd a b := by + have h1 : a % b = a - b * (a / b) := by + have heq := Int.ediv_add_emod a b + linarith + have h2 := Int.gcd_add_mul_self b a (- (a / b)) + rw [h1] + simp only [neg_mul] at h2 + conv at h2 => lhs; rw [Int.gcd_comm] + conv => rhs; rw [Int.gcd_comm] + convert h2 using 2 + ring_nf + +theorem cancel_right_div_gcd_pos {m a b : ℤ} + (c : Int) (hm : 0 < m) (hgcd : Int.gcd m c = 1) + (h : (a * c) % m = (b * c) % m) : + a % m = b % m := by + have heq := Int.ModEq.cancel_right_div_gcd hm h + simp only [hgcd, Nat.cast_one, EuclideanDomain.div_one] at heq + apply heq + +theorem cancel_right_div_gcd {m : ℤ} (a b c : Int) (hgcd : Int.gcd c m = 1) + (h : (a * c) % m = (b * c) % m) : + a % m = b % m := by + rw [Int.gcd_comm] at hgcd + if hm : m = 0 then + simp_all only [Int.gcd_zero_left, EuclideanDomain.mod_zero, mul_eq_mul_right_iff] + dcases hc : c = 0 <;> simp_all + else + if m ≤ 0 then + have hm' : 0 < -m := by scalar_tac + have hgcd' : Int.gcd (-m) c = 1 := by simp [hgcd] + have hf := @cancel_right_div_gcd_pos (-m) a b c hm' hgcd' + simp only [Int.emod_neg] at hf + apply hf + assumption + else + have hm : 0 < m := by simp_all + have heq := Int.ModEq.cancel_right_div_gcd hm h + simp only [hgcd, Nat.cast_one, EuclideanDomain.div_one] at heq + apply heq + +theorem cancel_left_div_gcd {m : ℤ} (a b c : Int) (hgcd : Int.gcd c m = 1) + (h : (c * a) % m = (c * b) % m) : + a % m = b % m := by + have heq := cancel_right_div_gcd a b c hgcd + apply heq + ring_nf at * + apply h + +theorem times_mod_imp_div_mod (n : ℕ) (a b c : ℤ) + (hdiv : a % b = 0) + (hgcd : Int.gcd b n = 1) + (heq : a % n = (c * b) % n) : + (a / b) % n = c % n := by + -- step 1: multiply by b on both sides + apply (cancel_right_div_gcd (a / b) c b (by assumption)) + -- step 2: simplify (... / b) * b + rw [Int.ediv_mul_cancel_of_emod_eq_zero hdiv] + -- End of the proof + apply heq + +/-! +Some theorems to reason about equalities of the shape: `a % n = b % n`. + +When encoutering such an equality a good proof strategy is to simply cast the integers into +`ZMod` which, being a ring, is convenient to work in. Below, we list a few simp theorems which +are necessary for this to work. +-/ +theorem ZMod_int_cast_eq_int_cast_iff (n : ℕ) (a b : ℤ) : + ((a : ZMod n) = (b : ZMod n)) ↔ (a % n = b % n) := + ZMod.intCast_eq_intCast_iff a b n + +theorem eq_mod_iff_eq_ZMod (n : ℕ) (a b : ℤ) : + (a % n = b % n) ↔ ((a : ZMod n) = (b : ZMod n)) := by + rw [ZMod.intCast_eq_intCast_iff a b n] + tauto + +/-- The important theorem to convert a goal about equality modulo into a goal about the equalit of two terms in `ZMod` -/ +theorem ZMod_eq_imp_mod_eq {n : ℕ} {a b : ℤ} + (h : (a : ZMod n) = (b : ZMod n)) : + a % n = b % n := + (@ZMod_int_cast_eq_int_cast_iff n a b).mp h + +-- TODO: restrict the set of theorems used by `simp`, do more things, etc. +macro "zmodify" : tactic => + `(tactic | + (apply ZMod_eq_imp_mod_eq; simp)) + +theorem mod_eq_imp_ZMod_eq {n : ℕ} {a b : ℤ} + (h : a % n = b % n) : + (a : ZMod n) = (b : ZMod n) := + (@ZMod_int_cast_eq_int_cast_iff n a b).mpr h + +theorem ZMod_val_injective {n : ℕ} [NeZero n] {a b : ZMod n} (h : a.val = b.val) : + a = b := + ZMod.val_injective n h + +theorem ZMod.mul_inv_eq_int_gcd {n : ℕ} (a : ℤ) : + (a : ZMod n) * (a : ZMod n)⁻¹ = Int.gcd a n := by + if hn : n = 0 then + simp only [hn, CharP.cast_eq_zero, Int.gcd_zero_right] + rw [ZMod.mul_inv_eq_gcd] + simp only [hn, Nat.gcd_zero_right] + + have h := @ZMod.intCast_eq_intCast_iff' (ZMod.val (a : ZMod n)) (Int.natAbs a) n + simp only [Int.cast_natCast, Int.natCast_natAbs, hn, CharP.cast_eq_zero, + EuclideanDomain.mod_zero] at h + unfold ZMod.val + rw [hn] + simp only [Nat.cast_inj] + rfl + else + have hn : 0 < n := by cases n <;> simp_all only [AddLeftCancelMonoid.add_eq_zero, one_ne_zero, + and_false, not_false_eq_true, lt_add_iff_pos_left, add_pos_iff, + zero_lt_one, or_true, not_true_eq_false] + rw [ZMod.mul_inv_eq_gcd] + rw [← Int.gcd_natCast_natCast] + have hnz : NeZero n := by simp [neZero_iff, *] + rw [← ZMod.cast_eq_val] + -- Simplify `↑↑a` + rw [ZMod.coe_intCast] + rw [Int.gcd_mod_same] + +/-- A theorem to work with division when converting integers to elements of `ZMod` -/ +theorem div_to_ZMod {n : ℕ} {a b : ℤ} [NeZero n] (hDiv : b ∣ a) (hgcd : Int.gcd b n = 1) : + ((a / b) : ZMod n) = (a : ZMod n) * (b : ZMod n)⁻¹ := by + have h : (a / b) % (n : Int) = ((a % (n : Int)) * (b : ZMod n)⁻¹.cast) % (n : Int) := by + apply times_mod_imp_div_mod + . rw [← Int.dvd_iff_emod_eq_zero] + assumption + . assumption + . apply ZMod_eq_imp_mod_eq + simp only [Int.cast_mul, ZMod.intCast_mod, ZMod.intCast_cast, ZMod.cast_id', id_eq] + rw [mul_assoc] + have := @ZMod.mul_inv_eq_int_gcd n b + rw [mul_comm] at this + rw [this] + rw [hgcd] + simp + have h1 := mod_eq_imp_ZMod_eq h + rw [h1] + simp + +theorem bmod_eq_emod_eq_iff (n: ℕ) (a b: ℤ) : + (a % n = b % n) ↔ (Int.bmod a n = Int.bmod b n) := by + simp only [Int.bmod] + apply Iff.intro <;> intro h + . rw [h] + . if h_a: a % n < (n + 1) / 2 then + if h_b: b % n < (n + 1) / 2 then + simp only [h_a, ↓reduceIte, h_b] at h + exact h + else + simp only [h_a, ↓reduceIte, h_b] at h + have ha' : 0 ≤ a % n := by apply Int.emod_nonneg; linarith + have hb' : b % n - n < 0 := by + have h : b % n < n := by apply Int.emod_lt_of_pos; linarith + linarith + linarith + else + if h_b: b % n < (n + 1) / 2 then + simp only [h_a, ↓reduceIte, h_b] at h + have ha' : 0 ≤ b % n := by apply Int.emod_nonneg; linarith + have hb' : a % n - n < 0 := by + have h : a % n < n := by apply Int.emod_lt_of_pos; linarith + linarith + linarith + else + simp only [h_a, ↓reduceIte, h_b, sub_left_inj] at h + exact h + +theorem ZMod_int_cast_eq_int_cast_bmod_iff (n : ℕ) (a b : ℤ) : + ((a : ZMod n) = (b : ZMod n)) ↔ (Int.bmod a n = Int.bmod b n) := by + apply Iff.trans + apply ZMod_int_cast_eq_int_cast_iff + apply bmod_eq_emod_eq_iff + +theorem ZMod_eq_imp_bmod_eq {n : ℕ} {a b : ℤ} + (h : (a : ZMod n) = (b : ZMod n)) : + Int.bmod a n = Int.bmod b n := + (@ZMod_int_cast_eq_int_cast_bmod_iff n a b).mp h + + +theorem ZMod.castInt_val_sub {n : ℕ} [inst: NeZero n] {a b : ZMod n} : + (a - b).val = (a.val - (b.val : Int)) % n := by + have : 0 ≤ ((a.val : Int) - (b.val : Int)) % n := by + apply Int.emod_nonneg + cases inst + omega + have : ((a.val : Int) - (b.val : Int)) % n < n := by + apply Int.emod_lt_of_pos + cases inst + omega + have := ZMod.val_add a (-b) + ring_nf at this + rw [this] + push_cast + rw [eq_mod_iff_eq_ZMod] + simp + ring_nf + +theorem ZMod.eq_iff_mod (p : ℕ) [NeZero p] (x y : ZMod p) : + x = y ↔ x.val = y.val := by + constructor + . simp +contextual + . apply ZMod_val_injective + +theorem BitVec.toNat_neq {n : ℕ} {x y : BitVec n} : x ≠ y ↔ x.toNat ≠ y.toNat := by + simp [BitVec.toNat_eq] + +/-! +Below we mark some theorems as `zify_simps` so that `zify` can convert (in-)equalities about +`BitVec` and `ZMod` to (in-)equalities about ℤ. +-/ +-- TODO: those theorems should rather be "natify" (introduce a tactic) +attribute [zify_simps] BitVec.toNat_eq BitVec.toNat_neq BitVec.lt_def BitVec.le_def + BitVec.toNat_umod BitVec.toNat_add BitVec.toNat_sub BitVec.toNat_ofNat + BitVec.toNat_and BitVec.toNat_or BitVec.toNat_xor +attribute [zify_simps] ZMod.eq_iff_mod ZMod.val_intCast ZMod.val_add ZMod.val_sub ZMod.val_mul + ZMod.castInt_val_sub + +theorem Int.bmod_pow2_eq_of_inBounds (n : ℕ) (x : Int) + (h0 : - 2 ^ n ≤ x) + (h1 : x < 2 ^ n) : + Int.bmod x (2 ^ (n + 1)) = x := by + rw [Int.bmod] + have hPowEq : (2^(n+1) : Int) = 2* (2^n) := by rw [Int.pow_succ'] + have : (2^(n + 1) + 1) / 2 = (2^n : Int) := by + have := @Int.add_ediv_of_dvd_left (2^(n+1)) 1 2 (by simp [hPowEq]) + simp [this] + rw [hPowEq] + simp + dcases hpos : 0 ≤ x + . have : x % (2^(n + 1) : Int) = x := by + apply Int.emod_eq_of_lt <;> omega + simp [this] + omega + . simp at hpos + have : x % (2^(n + 1) : Int) = x + 2^(n + 1) := by + have : 0 ≤ x + 2^(n+1) := by omega + have : x + 2^(n+1) < 2^(n+1) := by omega + have := @Int.emod_eq_of_lt (x + 2^(n + 1)) (2^(n+1)) (by omega) (by omega) + rw [← this] + have := Arith.eq_mod_iff_eq_ZMod (2^(n+1)) + simp at this + rw [this] + simp + norm_cast + simp + simp [this] + omega + +theorem Int.bmod_pow2_eq_of_inBounds' (n : ℕ) (x : Int) + (hn : n ≠ 0) + (h0 : - 2 ^ (n - 1) ≤ x) + (h1 : x < 2 ^ (n - 1)) : + Int.bmod x (2 ^ n) = x := by + have h := Int.bmod_pow2_eq_of_inBounds (n - 1) x + have : n - 1 + 1 = n := by omega + simp [this] at h + apply h <;> omega + +theorem Int.bmod_pow2_bounds (n : ℕ) (x : Int) : + - 2^(n-1) ≤ Int.bmod x (2^n) ∧ Int.bmod x (2^n) < 2^(n-1) := by + have h0 : 0 < 2^n := by simp + + have := @Int.le_bmod x (2^n) h0 + have := @Int.bmod_lt x (2^n) h0 + + have : -2^(n-1) ≤ -(((2^n) : Nat) : Int)/2 := by + dcases hn : n = 0 + . simp [hn] + . have : n - 1 + 1 = n := by omega + conv => rhs; rw [← this] + rw [Nat.pow_succ] + simp + + have : 2^(n-1) ≤ (2^n + 1) / 2 := by + dcases hn : n = 0 + . simp [hn] + . have : n - 1 + 1 = n := by omega + conv => rhs; rw [← this] + rw [Nat.pow_succ'] + have := @Nat.add_div_of_dvd_right (2 * 2 ^ (n - 1)) 1 2 (by simp) + rw [this] + simp + + omega + +theorem BitVec.bounds (n : ℕ) (x : BitVec n) : + - 2^(n-1) ≤ x.toInt ∧ x.toInt < 2^(n-1) := by + rw [BitVec.toInt_eq_toNat_bmod] + apply Int.bmod_pow2_bounds + + +theorem BitVec.toInt_neg_of_neg_eq_neg + {n} (x : BitVec n) (h : n ≠ 0) (h0 : - 2^(n - 1) < x.toInt) (h1 : x.toInt < 0) : + (-x).toInt = - x.toInt := by + simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod] + + have hmsb := @BitVec.msb_eq_toInt _ x + simp [h1] at hmsb + + have hmsb' := @BitVec.msb_eq_toNat _ x + simp [hmsb] at hmsb' + + have hx := @BitVec.toInt_eq_msb_cond _ x + simp [hmsb] at hx + + simp only [Int.bmod] + + have : x.toNat < 2^n := by cases x; simp + + have hPow : (2 ^ n + 1) / 2 = 2^(n - 1) := by + have : n = n - 1 + 1 := by omega + conv => lhs; rw [this] + rw [Nat.pow_succ'] + rw [Nat.add_div_of_dvd_right] <;> simp + + have : (2^n : Nat) = (2^n: Int) := by simp -- TODO: this is annoying! + + have hxToNatModPow : (x.toNat : Int) % 2^n = x.toNat := by + apply Int.emod_eq_of_lt <;> omega + + have hPowMinusXMod : ↑(2 ^ n - x.toNat : Nat) % (2 ^ n : Int) = + (2 ^ n - x.toNat : Nat) := by + apply Int.emod_eq_of_lt <;> omega + + have : (((-x).toNat : Int) % (2 ^ n : Nat) < ((2 ^ n : Nat) + 1 : Int) / 2) := by + simp + zify at hPow + simp [hPow] + rw [hPowMinusXMod] + omega + simp only [this]; simp + + have : ¬ ((↑(x.toNat : Nat) % 2 ^ n : Int) < (2 ^ n + 1) / 2) := by + simp + zify at hPow + simp [hPow] + omega + simp only [this]; simp + + rw [hPowMinusXMod] + + rw [hxToNatModPow] + + omega + +theorem BitVec.toInt_neg_of_pos_eq_neg + {n} (x : BitVec n) (h : n ≠ 0) (h0 : 0 ≤ x.toInt) : + (-x).toInt = - x.toInt := by + simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod] + + have : -2^(n-1) ≤ x.toInt ∧ x.toInt < 2^(n-1) := by + apply BitVec.bounds + + have hNotNeg : ¬ x.toInt < 0 := by omega + + have hmsb := @BitVec.msb_eq_toInt _ x + simp [hNotNeg] at hmsb + + have hmsb' := @BitVec.msb_eq_toNat _ x + simp [hmsb] at hmsb' + + have hx := @BitVec.toInt_eq_msb_cond _ x + simp [hmsb] at hx + + simp + have h2n : (2^n: Int) = (2^n:Nat) := by simp -- TODO: this is annoying + rw [h2n] + rw [Int.emod_bmod] + + have : (2^n - x.toNat : Nat) = (2^n - x.toNat : Int) := by omega + rw [this]; clear this + + have hn : n - 1 + 1 = n := by omega + + have : (2^n - x.toNat : Int).bmod (2^n) = (-(x.toNat : Int)).bmod (2^n) := by + rw [h2n] + have : (2^n : Nat) - (x.toNat : Int) = -(x.toNat : Int) + (2^n : Nat) := by ring_nf + rw [this] + simp only [Int.bmod_add_cancel] + rw [this]; clear this + + have : (-(x.toNat : Int)).bmod (2^n) = -(x.toNat) := by + have := Int.bmod_pow2_eq_of_inBounds (n - 1) (-x.toNat) (by omega) (by omega) + rw [hn] at this + omega + rw [this]; clear this + + have : (x.toNat : Int).bmod (2^n) = x.toNat := by + have := Int.bmod_pow2_eq_of_inBounds (n - 1) x.toNat (by omega) (by omega) + rw [hn] at this + omega + rw [this] + +@[simp] +theorem Int.neg_tmod (x y : Int) : (- x).tmod y = - x.tmod y := by + unfold Int.tmod + dcases hx' : (-x) <;> dcases hx : x <;> dcases y <;> rename_i xn xn' yn <;> simp only + . dcases xn <;> simp_all + omega + . dcases xn <;> simp_all + omega + . simp + have : xn = xn' + 1 := by + have : x = - Int.ofNat xn := by omega + simp at this + omega + simp [this] + . simp + have : xn = xn' + 1 := by + have : x = - Int.ofNat xn := by omega + simp at this + omega + simp [this] + . simp + have : (xn + 1 : Int) = xn' := by + have : x = - Int.negSucc xn := by omega + simp at this + have : x = xn + 1 := by omega + simp at hx + omega + simp [this] + . simp + have : (xn + 1 : Int) = xn' := by + have : x = - Int.negSucc xn := by omega + simp at this + have : x = xn + 1 := by omega + simp at hx + omega + simp [this] + . dcases xn <;> simp_all + omega + . dcases xn <;> simp_all + omega + +theorem Int.tmod_ge_of_neg (x y : Int) (hx : x < 0) : + x ≤ x.tmod y := by + have : (-x).tmod y = (-x) % y := by + dcases hy : 0 ≤ y + . rw [Int.tmod_eq_emod] <;> try omega + . have : (-x).tmod y = (-x).tmod (-y) := by simp + rw [this] + rw [Int.tmod_eq_emod] <;> try omega + simp + have : x.tmod y = - ((-x) % y) := by + have h : x.tmod y = - (-x).tmod y := by simp + rw [h, this] + have : (-x) % y ≤ (-x) := by + dcases hy : 0 ≤ y + . have hIneq := Nat.mod_le (-x).toNat y.toNat + zify at hIneq + have : (-x).toNat = -x := by omega + rw [this] at hIneq + have : y.toNat = y := by omega + rw [this] at hIneq + omega + . have hIneq := Nat.mod_le (-x).toNat (-y).toNat + zify at hIneq + have : (-x).toNat = -x := by omega + rw [this] at hIneq + have : (-y).toNat = -y := by omega + rw [this] at hIneq + simp at hIneq + omega + omega + +-- TODO: move to Mathlib +theorem Int.bmod_bmod_of_dvd (n : Int) {m k : Nat} (hDiv : m ∣ k) : + (n.bmod k).bmod m = n.bmod m := by + conv => lhs; arg 1; simp only [Int.bmod] + rw [← bmod_eq_emod_eq_iff] + have h : n % (k : Int) % (m : Int) = n % (m : Int) := by + rw [Int.emod_emod_of_dvd] + simp [Int.ofNat_dvd_left] + apply hDiv + split_ifs + . apply h + . rw [Int.sub_emod] + have : (k : Int) % m = 0 := by + apply Int.emod_eq_zero_of_dvd + simp [Int.ofNat_dvd_left] + apply hDiv + simp [this, h] + +@[simp] +theorem Int.mod_toNat_val (n m : Int) (h : m ≠ 0) : + (n % m).toNat = n % m := by + simp only [Int.ofNat_toNat, ne_eq, h, not_false_eq_true, Int.emod_nonneg, sup_of_le_left] + +theorem Nat.lt_iff_BitVec_ofNat_lt (n : Nat) (x y : Nat) (hx : x < 2^n) (hy : y < 2^n) : + x < y ↔ BitVec.ofNat n x < BitVec.ofNat n y := by + have := Nat.mod_eq_of_lt hx + have := Nat.mod_eq_of_lt hy + simp [*] + +theorem Nat.le_iff_BitVec_ofNat_le (n : Nat) (x y : Nat) (hx : x < 2^n) (hy : y < 2^n) : + x ≤ y ↔ BitVec.ofNat n x ≤ BitVec.ofNat n y := by + have := Nat.mod_eq_of_lt hx + have := Nat.mod_eq_of_lt hy + simp [*] -end Aeneas.ScalarTac +end Aeneas.Arith diff --git a/backends/lean/Aeneas/BvTac.lean b/backends/lean/Aeneas/BvTac.lean new file mode 100644 index 00000000..69dae1da --- /dev/null +++ b/backends/lean/Aeneas/BvTac.lean @@ -0,0 +1 @@ +import Aeneas.BvTac.BvTac diff --git a/backends/lean/Aeneas/BvTac/BvTac.lean b/backends/lean/Aeneas/BvTac/BvTac.lean new file mode 100644 index 00000000..b0577357 --- /dev/null +++ b/backends/lean/Aeneas/BvTac/BvTac.lean @@ -0,0 +1,75 @@ +import Aeneas.Bvify + +namespace Aeneas.BvTac + +open Lean Lean.Meta Lean.Parser.Tactic Lean.Elab.Tactic +open Bvify Utils + +def disjConj : Std.HashSet Name := Std.HashSet.ofList [ + ``And, ``Or +] + +def arithConsts : Std.HashSet Name := Std.HashSet.ofList [ + ``BEq.beq, ``LT.lt, ``LE.le, ``GT.gt, ``GE.ge +] + +partial def getn : TacticM Expr := do + let mgoal ← getMainGoal + let goalTy ← mgoal.getType + let raiseError : TacticM Expr := + throwError "The goal doesn't have the proper shape: expected a proposition only involving (in-)equality between bitvectors" + let fromBitVecTy (ty : Expr) : TacticM Expr := + ty.consumeMData.withApp fun _ args => do + if args.size == 1 then + pure args[0]! + else + raiseError + let rec aux (ty : Expr) : TacticM Expr := do + ty.consumeMData.withApp fun f args => do + if f.isConst then + let f := f.constName! + if f == ``Eq ∧ args.size == 3 then + fromBitVecTy args[0]! + else if f ∈ disjConj ∧ args.size == 2 then + aux args[0]! + else if f ∈ arithConsts ∧ args.size == 4 then + fromBitVecTy args[0]! + else + raiseError + else + raiseError + aux goalTy + +partial def bvTacPreprocess : TacticM Unit := do + /- First figure out the bisize, by looking at the goal -/ + let n ← getn + /- Then apply bvify -/ + bvifyTac n Utils.Location.wildcard + +elab "bv_tac_preprocess" : tactic => + bvTacPreprocess + +open Lean.Elab.Tactic.BVDecide.Frontend in +elab "bv_tac" cfg:Parser.Tactic.optConfig : tactic => + withMainContext do + -- Preprocess + bvTacPreprocess + -- Call bv_decide + let cfg ← elabBVDecideConfig cfg + IO.FS.withTempFile fun _ lratFile => do + let cfg ← BVDecide.Frontend.TacticContext.new lratFile cfg + liftMetaFinishingTactic fun g => do + discard <| bvDecide g cfg + +/-! +# Tests +-/ +open Std + +example (x y : U8) (h : x.val ≤ y.val) : x.bv ≤ y.bv := by + bv_tac + +example (x : U32) (h : x.val < 3329) : x.bv % 3329#32 = x.bv := by + bv_tac + +end Aeneas.BvTac diff --git a/backends/lean/Aeneas/Bvify.lean b/backends/lean/Aeneas/Bvify.lean new file mode 100644 index 00000000..c5c2fbae --- /dev/null +++ b/backends/lean/Aeneas/Bvify.lean @@ -0,0 +1 @@ +import Aeneas.Bvify.Bvify diff --git a/backends/lean/Aeneas/Bvify/Bvify.lean b/backends/lean/Aeneas/Bvify/Bvify.lean new file mode 100644 index 00000000..7a7ab9c1 --- /dev/null +++ b/backends/lean/Aeneas/Bvify/Bvify.lean @@ -0,0 +1,97 @@ +import Mathlib.Tactic.Basic +import Mathlib.Tactic.Attr.Register +import Mathlib.Data.Int.Cast.Basic +import Mathlib.Order.Basic +import Aeneas.Bvify.Init +import Aeneas.Arith.Lemmas +import Aeneas.Std.Scalar + +/-! +# `bvify` tactic + +The `bvify` tactic is used to shift propositions about, e.g., `Nat`, to `BitVec`. +This tactic is adapted from `zify`. +-/ + +namespace Aeneas.Bvify + +open Lean Lean.Meta Lean.Parser.Tactic Lean.Elab.Tactic +open Arith Std + +attribute [bvify_simps] ge_iff_le gt_iff_lt UScalar.BitVec_ofNat_val +attribute [bvify_simps] UScalar.BitVec_ofNat_val_eq + U8.BitVec_ofNat_val_eq U16.BitVec_ofNat_val_eq U32.BitVec_ofNat_val_eq + U64.BitVec_ofNat_val_eq U128.BitVec_ofNat_val_eq Usize.BitVec_ofNat_val_eq + U8.lt_succ_max U16.lt_succ_max U32.lt_succ_max U64.lt_succ_max U128.lt_succ_max + U8.le_max U16.le_max U32.le_max U64.le_max U128.le_max + +syntax (name := bvify) "bvify" num (simpArgs)? (location)? : tactic + +macro_rules +| `(tactic| bvify $n $[[$simpArgs,*]]? $[at $location]?) => + let args := simpArgs.map (·.getElems) |>.getD #[] + `(tactic| + simp -decide (maxDischargeDepth := 1) only [ + Nat.reducePow, Nat.reduceLT, + Nat.lt_iff_BitVec_ofNat_lt $n, Nat.le_iff_BitVec_ofNat_le $n, + bvify_simps, push_cast, $args,*] $[at $location]?) + +def bvifyTac (n : Expr) (loc : Utils.Location) : TacticM Unit := do + let simpTheorems ← bvifySimpExt.getTheorems + let simprocs := [``Nat.reducePow, ``Nat.reduceLT] + let addThm (thName : Name) : TacticM FVarId := do + let thm ← mkAppM thName #[n] + let thm ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) thm (← inferType thm) (asLet := false) + pure thm.fvarId! + let lt_iff ← addThm ``Nat.lt_iff_BitVec_ofNat_lt + let le_iff ← addThm ``Nat.le_iff_BitVec_ofNat_le + withMainContext do + Utils.simpAt true {maxDischargeDepth := 1} simprocs [simpTheorems] [] [] [lt_iff, le_iff] loc + Utils.clearFvarIds #[lt_iff, le_iff] + +elab "bvify'" n:term : tactic => do + bvifyTac (← Elab.Term.elabTerm n (Expr.const ``Nat [])) Utils.Location.wildcard + +/-- The `Simp.Context` generated by `bvify`. -/ +def mkBvifyContext (simpArgs : Option (Syntax.TSepArray `Lean.Parser.Tactic.simpStar ",")) : + TacticM MkSimpContextResult := do + let args := simpArgs.map (·.getElems) |>.getD #[] + mkSimpContext + (← `(tactic| simp -decide (maxDischargeDepth := 1) only [bvify_simps, push_cast, $args,*])) false + +/-- A variant of `applySimpResultToProp` that cannot close the goal, but does not need a meta +variable and returns a tuple of a proof and the corresponding simplified proposition. -/ +def applySimpResultToProp' (proof : Expr) (prop : Expr) (r : Simp.Result) : MetaM (Expr × Expr) := + do + match r.proof? with + | some eqProof => return (← mkExpectedTypeHint (← mkEqMP eqProof proof) r.expr, r.expr) + | none => + if r.expr != prop then + return (← mkExpectedTypeHint proof r.expr, r.expr) + else + return (proof, r.expr) + +/-- Translate a proof and the proposition into a natified form. -/ +def bvifyProof (simpArgs : Option (Syntax.TSepArray `Lean.Parser.Tactic.simpStar ",")) + (proof : Expr) (prop : Expr) : TacticM (Expr × Expr) := do + let ctx_result ← mkBvifyContext simpArgs + let (r, _) ← simp prop ctx_result.ctx + applySimpResultToProp' proof prop r + +example (x y : U8) (h : x.val < y.val) : x.bv < y.bv := by + bvify 8 at h + apply h + +example (x y : U8) (h : x.val < y.val) : x.bv < y.bv := by + bvify' 8 + apply h + +example (x : U8) (h : x.val < 32) : x.bv < 32#8 := by + bvify 8 at h + apply h + +example (x : U8) (h : x.val < 32) : x.bv < 32#8 := by + bvify' 8 + apply h + +end Aeneas.Bvify diff --git a/backends/lean/Aeneas/Bvify/Init.lean b/backends/lean/Aeneas/Bvify/Init.lean new file mode 100644 index 00000000..75477990 --- /dev/null +++ b/backends/lean/Aeneas/Bvify/Init.lean @@ -0,0 +1,11 @@ +import Aeneas.Extensions +open Lean Meta + +namespace Aeneas.Bvify + +/-- The `bvify_simps` simp attribute. -/ +initialize bvifySimpExt : SimpExtension ← + registerSimpAttr `bvify_simps "\ + The `bvify_simps` attribute registers simp lemmas to be used by `bvify`." + +end Aeneas.Bvify diff --git a/backends/lean/Aeneas/Diverge/Elab.lean b/backends/lean/Aeneas/Diverge/Elab.lean index 372e8602..8a9417ec 100644 --- a/backends/lean/Aeneas/Diverge/Elab.lean +++ b/backends/lean/Aeneas/Diverge/Elab.lean @@ -29,12 +29,6 @@ def appendToName (n : Name) (s : String) : Name := def UnitType := Expr.const ``PUnit [Level.succ .zero] def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero] -def mkProdType (x y : Expr) : MetaM Expr := - mkAppM ``Prod #[x, y] - -def mkProd (x y : Expr) : MetaM Expr := - mkAppM ``Prod.mk #[x, y] - def mkInOutTy (x y z : Expr) : MetaM Expr := do mkAppM ``FixII.mk_in_out_ty #[x, y, z] @@ -46,318 +40,6 @@ def getResultTy (ty : Expr) : MetaM Expr := else pure (args.get! 0) -/- Deconstruct a sigma type. - - For instance, deconstructs `(a : Type) × List a` into - `Type` and `λ a => List a`. - -/ -def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do - ty.withApp fun f args => do - if ¬ f.isConstOf ``Sigma ∨ args.size ≠ 2 then - throwError "Invalid argument to getSigmaTypes: {ty}" - else - pure (args.get! 0, args.get! 1) - -/- Make a sigma type. - - `x` should be a variable, and `ty` and type which (might) uses `x` - -/ -def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do - trace[Diverge.def.sigmas] "mkSigmaType: {x} {sty}" - let alpha ← inferType x - let beta ← mkLambdaFVars #[x] sty - trace[Diverge.def.sigmas] "mkSigmaType: ({alpha}) ({beta})" - mkAppOptM ``Sigma #[some alpha, some beta] - -/- Generate a Sigma type from a list of *variables* (all the expressions - must be variables). - - Example: - - xl = [(a:Type), (ls:List a), (i:Int)] - - Generates: - `(a:Type) × (ls:List a) × (i:Int)` - - -/ -def mkSigmasType (xl : List Expr) : MetaM Expr := - match xl with - | [] => do - trace[Diverge.def.sigmas] "mkSigmasType: []" - pure (Expr.const ``PUnit [Level.succ .zero]) - | [x] => do - trace[Diverge.def.sigmas] "mkSigmasType: [{x}]" - let ty ← inferType x - pure ty - | x :: xl => do - trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]" - let sty ← mkSigmasType xl - mkSigmaType x sty - -/- Generate a product type from a list of *variables* (this is similar to `mkSigmas`). - - Example: - - xl = [(ls:List a), (i:Int)] - - Generates: - `List a × Int` - -/ -def mkProdsType (xl : List Expr) : MetaM Expr := - match xl with - | [] => do - trace[Diverge.def.prods] "mkProdsType: []" - pure (Expr.const ``PUnit [Level.succ .zero]) - | [x] => do - trace[Diverge.def.prods] "mkProdsType: [{x}]" - let ty ← inferType x - pure ty - | x :: xl => do - trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]" - let ty ← inferType x - let xl_ty ← mkProdsType xl - mkAppM ``Prod #[ty, xl_ty] - -/- Split the input arguments between the types and the "regular" arguments. - - We do something simple: we treat an input argument as an - input type iff it appears in the type of the following arguments. - - Note that what really matters is that we find the arguments which appear - in the output type. - - Also, we stop at the first input that we treat as an - input type. - -/ -def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × Array Expr) := do - -- Look for the first parameter which appears in the subsequent parameters - let rec splitAux (in_tys : List Expr) : MetaM (Std.HashSet FVarId × List Expr × List Expr) := - match in_tys with - | [] => do - let fvars ← getFVarIds (← inferType out_ty) - pure (fvars, [], []) - | ty :: in_tys => do - let (fvars, in_tys, in_args) ← splitAux in_tys - -- Have we already found where to split between type variables/regular - -- variables? - if ¬ in_tys.isEmpty then - -- The fvars set is now useless: no need to update it anymore - pure (fvars, ty :: in_tys, in_args) - else - -- Check if ty appears in the set of free variables: - let ty_id := ty.fvarId! - if fvars.contains ty_id then - -- We must split here. Note that we don't need to update the fvars - -- set: it is not useful anymore - pure (fvars, [ty], in_args) - else - -- We must split later: update the fvars set - let fvars := fvars.insertMany (← getFVarIds (← inferType ty)) - pure (fvars, [], ty :: in_args) - let (_, in_tys, in_args) ← splitAux in_tys.toList - pure (Array.mk in_tys, Array.mk in_args) - -/- Apply a lambda expression to some arguments, simplifying the lambdas -/ -def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do - lambdaTelescopeN e xs.size fun vars body => - -- Create the substitution - let s : Std.HashMap FVarId Expr := Std.HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList) - -- Substitute in the body - pure (body.replace fun e => - match e with - | Expr.fvar fvarId => match s.get? fvarId with - | none => e - | some v => v - | _ => none) - -/- Group a list of expressions into a dependent tuple. - - Example: - xl = [`a : Type`, `ls : List a`] - returns: - `⟨ (a:Type), (ls: List a) ⟩` - - We need the type argument because as the elements in the tuple are - "concrete", we can't in all generality figure out the type of the tuple. - - Example: - `⟨ True, 3 ⟩ : (x : Bool) × (if x then Int else Unit)` - -/ -def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := - match xl with - | [] => do - trace[Diverge.def.sigmas] "mkSigmasVal: []" - pure (Expr.const ``PUnit.unit [Level.succ .zero]) - | [x] => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" - pure x - | fst :: xl => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" - -- Deconstruct the type - let (alpha, beta) ← getSigmaTypes ty - -- Compute the "second" field - -- Specialize beta for fst - let nty ← applyLambdaToArgs beta #[fst] - -- Recursive call - let snd ← mkSigmasVal nty xl - -- Put everything together - trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" - mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] - -/- Group a list of expressions into a (non-dependent) tuple -/ -def mkProdsVal (xl : List Expr) : MetaM Expr := - match xl with - | [] => - pure (Expr.const ``PUnit.unit [Level.succ .zero]) - | [x] => do - pure x - | x :: xl => do - let xl ← mkProdsVal xl - mkAppM ``Prod.mk #[x, xl] - -def mkAnonymous (s : String) (i : Nat) : Name := - .num (.str .anonymous s) i - -/- Given a list of values `[x0:ty0, ..., xn:ty1]`, where every `xi` might use the previous - `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following - expression: - ``` - fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple - match x with - | (x0, ..., xn) => out - ``` - - The `index` parameter is used for naming purposes: we use it to numerotate the - bound variables that we introduce. - - We use this function to currify functions (the function bodies given to the - fixed-point operator must be unary functions). - - Example: - ======== - - xl = `[a:Type, ls:List a, i:Int]` - - out = `a` - - index = 0 - - generates (getting rid of most of the syntactic sugar): - ``` - λ scrut0 => match scrut0 with - | Sigma.mk x scrut1 => - match scrut1 with - | Sigma.mk ls i => - a - ``` --/ -partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := - match xl with - | [] => do - -- This would be unexpected - throwError "mkSigmasMatch: empty list of input parameters" - | [x] => do - -- In the example given for the explanations: this is the inner match case - trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" - mkLambdaFVars #[x] out - | fst :: xl => do - /- In the example given for the explanations: this is the outer match case - Remark: for the naming purposes, we use the same convention as for the - fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at - those definitions might help) - - We want to build the match expression: - ``` - λ scrut => - match scrut with - | Sigma.mk x ... -- the hole is given by a recursive call on the tail - ``` -/ - trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" - let alpha ← inferType fst - let snd_ty ← mkSigmasType xl - let beta ← mkLambdaFVars #[fst] snd_ty - let snd ← mkSigmasMatch xl out (index + 1) - let mk ← mkLambdaFVars #[fst] snd - -- Introduce the "scrut" variable - let scrut_ty ← mkSigmaType fst snd_ty - withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do - trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" - -- TODO: make the computation of the motive more efficient - let motive ← do - let out_ty ← inferType out - match out_ty with - | .sort _ | .lit _ | .const .. => - -- The type of the motive doesn't depend on the scrutinee - mkLambdaFVars #[scrut] out_ty - | _ => - -- The type of the motive *may* depend on the scrutinee - -- TODO: make this more efficient (we could change the output type of - -- mkSigmasMatch - mkSigmasMatch (fst :: xl) out_ty - -- The final expression: putting everything together - trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" - let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] - -- Abstracting the "scrut" variable - let sm ← mkLambdaFVars #[scrut] sm - trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" - pure sm - -/- This is similar to `mkSigmasMatch`, but with non-dependent tuples - - Remark: factor out with `mkSigmasMatch`? This is extremely similar. --/ -partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := - match xl with - | [] => do - -- This would be unexpected - throwError "mkProdsMatch: empty list of input parameters" - | [x] => do - -- In the example given for the explanations: this is the inner match case - trace[Diverge.def.prods] "mkProdsMatch: [{x}]" - mkLambdaFVars #[x] out - | fst :: xl => do - trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]" - let alpha ← inferType fst - let beta ← mkProdsType xl - let snd ← mkProdsMatch xl out (index + 1) - let mk ← mkLambdaFVars #[fst] snd - -- Introduce the "scrut" variable - let scrut_ty ← mkProdType alpha beta - withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do - trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})" - -- TODO: make the computation of the motive more efficient - let motive ← do - let out_ty ← inferType out - mkLambdaFVars #[scrut] out_ty - -- The final expression: putting everything together - trace[Diverge.def.prods] "mkProdsMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" - let sm ← mkAppOptM ``Prod.casesOn #[some alpha, some beta, some motive, some scrut, some mk] - -- Abstracting the "scrut" variable - let sm ← mkLambdaFVars #[scrut] sm - trace[Diverge.def.prods] "mkProdsMatch: sm: {sm}" - pure sm - -/- Same as `mkSigmasMatch` but also accepts an empty list of inputs, in which case - it generates the expression: - ``` - λ () => e - ``` -/ -def mkSigmasMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := - if xl.isEmpty then do - let scrut_ty := Expr.const ``PUnit [Level.succ .zero] - withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do - mkLambdaFVars #[scrut] out - else - mkSigmasMatch xl out - -/- Same as `mkProdsMatch` but also accepts an empty list of inputs, in which case - it generates the expression: - ``` - λ () => e - ``` -/ -def mkProdsMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := - if xl.isEmpty then do - let scrut_ty := Expr.const ``PUnit [Level.succ .zero] - withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do - mkLambdaFVars #[scrut] out - else - mkProdsMatch xl out - /- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) @@ -1598,6 +1280,18 @@ namespace Tests --set_option trace.Diverge false + /-- + info: Aeneas.Diverge.Tests.list_nth.unfold.{u} {a : Type u} (ls : List a) (i : ℤ) : + list_nth ls i = + match ls with + | [] => Result.fail Error.panic + | x :: ls => + if i = 0 then pure x + else do + let __do_lift ← list_nth ls (i - 1) + pure __do_lift + -/ + #guard_msgs in #check list_nth.unfold example {a: Type} (ls : List a) : @@ -1633,6 +1327,23 @@ namespace Tests let ls ← back ret return (x :: ls))) + /-- + info: Aeneas.Diverge.Tests.list_nth_with_back.unfold {a : Type} (ls : List a) (i : ℤ) : + list_nth_with_back ls i = + match ls with + | [] => Result.fail Error.panic + | x :: ls => + if i = 0 then pure (x, fun ret => pure (ret :: ls)) + else do + let __discr ← list_nth_with_back ls (i - 1) + match __discr with + | (x, back) => + pure + (x, fun ret => do + let ls ← back ret + pure (x :: ls)) + -/ + #guard_msgs in #check list_nth_with_back.unfold mutual @@ -1643,7 +1354,26 @@ namespace Tests if i = 0 then return false else return (← is_even (i - 1)) end + /-- + info: Aeneas.Diverge.Tests.is_even.unfold (i : ℤ) : + is_even i = + if i = 0 then pure true + else do + let __do_lift ← is_odd (i - 1) + pure __do_lift + -/ + #guard_msgs in #check is_even.unfold + + /-- + info: Aeneas.Diverge.Tests.is_odd.unfold (i : ℤ) : + is_odd i = + if i = 0 then pure false + else do + let __do_lift ← is_even (i - 1) + pure __do_lift + -/ + #guard_msgs in #check is_odd.unfold mutual @@ -1654,7 +1384,22 @@ namespace Tests if i > 20 then foo (i / 20) else .ok 42 end + /-- + info: Aeneas.Diverge.Tests.foo.unfold (i : ℤ) : + foo i = + if i > 10 then do + let __do_lift ← foo (i / 10) + let __do_lift_1 ← bar i + pure (__do_lift + __do_lift_1) + else bar 10 + -/ + #guard_msgs in #check foo.unfold + + /-- + info: Aeneas.Diverge.Tests.bar.unfold (i : ℤ) : bar i = if i > 20 then foo (i / 20) else Result.ok 42 + -/ + #guard_msgs in #check bar.unfold -- Testing dependent branching and let-bindings @@ -1664,6 +1409,15 @@ namespace Tests let b := true return b + /-- + info: Aeneas.Diverge.Tests.isNonZero.unfold (i : ℤ) : + isNonZero i = + if _h : i = 0 then pure false + else + let b := true; + pure b + -/ + #guard_msgs in #check isNonZero.unfold -- Testing let-bindings @@ -1673,6 +1427,13 @@ namespace Tests then Result.ok True else Result.ok False + /-- + info: Aeneas.Diverge.Tests.iInBounds.unfold {a : Type} (ls : List a) (i : ℤ) : + iInBounds ls i = + let i0 := ls.length; + if i < ↑i0 then Result.ok (decide True) else Result.ok (decide False) + -/ + #guard_msgs in #check iInBounds.unfold divergent def isCons @@ -1682,6 +1443,15 @@ namespace Tests | [] => Result.ok False | _ :: _ => Result.ok True + /-- + info: Aeneas.Diverge.Tests.isCons.unfold {a : Type} (ls : List a) : + isCons ls = + let ls1 := ls; + match ls1 with + | [] => Result.ok (decide False) + | head :: tail => Result.ok (decide True) + -/ + #guard_msgs in #check isCons.unfold -- Testing what happens when we use concrete arguments in dependent tuples @@ -1691,6 +1461,10 @@ namespace Tests := test1 Option.none () + /-- + info: Aeneas.Diverge.Tests.test1.unfold (x✝ : Option Bool) (x✝¹ : Unit) : test1 x✝ x✝¹ = test1 none () + -/ + #guard_msgs in #check test1.unfold -- Testing a degenerate case @@ -1699,6 +1473,14 @@ namespace Tests let _ ← infinite_loop Result.ok () + /-- + info: Aeneas.Diverge.Tests.infinite_loop.unfold : + infinite_loop = do + let __discr ← infinite_loop + let x : Unit := __discr + Result.ok () + -/ + #guard_msgs in #check infinite_loop.unfold -- Another degenerate case @@ -1708,6 +1490,13 @@ namespace Tests infinite_loop1_call infinite_loop1 + /-- + info: Aeneas.Diverge.Tests.infinite_loop1.unfold : + infinite_loop1 = do + infinite_loop1_call + infinite_loop1 + -/ + #guard_msgs in #check infinite_loop1.unfold /- Tests with higher-order functions -/ @@ -1726,6 +1515,16 @@ namespace Tests let tl ← map id tl .ok (.node tl) + /-- + info: Aeneas.Diverge.Tests.id.unfold.{u} {a : Type u} (t : Tree a) : + id t = + match t with + | Tree.leaf x => Result.ok (Tree.leaf x) + | Tree.node tl => do + let tl ← map id tl + Result.ok (Tree.node tl) + -/ + #guard_msgs in #check id.unfold divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := @@ -1736,6 +1535,16 @@ namespace Tests let tl ← map (fun x => id1 x) tl .ok (.node tl) + /-- + info: Aeneas.Diverge.Tests.id1.unfold.{u} {a : Type u} (t : Tree a) : + id1 t = + match t with + | Tree.leaf x => Result.ok (Tree.leaf x) + | Tree.node tl => do + let tl ← map (fun x => id1 x) tl + Result.ok (Tree.node tl) + -/ + #guard_msgs in #check id1.unfold divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := @@ -1746,6 +1555,22 @@ namespace Tests let tl ← map (fun x => do let _ ← id2 x; id2 x) tl .ok (.node tl) + /-- + info: Aeneas.Diverge.Tests.id2.unfold.{u} {a : Type u} (t : Tree a) : + id2 t = + match t with + | Tree.leaf x => Result.ok (Tree.leaf x) + | Tree.node tl => do + let tl ← + map + (fun x => do + let __discr ← id2 x + let x_1 : Tree a := __discr + id2 x) + tl + Result.ok (Tree.node tl) + -/ + #guard_msgs in #check id2.unfold divergent def incr (t : Tree Nat) : Result (Tree Nat) := @@ -1766,6 +1591,18 @@ namespace Tests let tl ← map f tl .ok (.node tl) + /-- + info: Aeneas.Diverge.Tests.id3.unfold (t : Tree ℕ) : + id3 t = + match t with + | Tree.leaf x => Result.ok (Tree.leaf (x + 1)) + | Tree.node tl => + let f := id3; + do + let tl ← map f tl + Result.ok (Tree.node tl) + -/ + #guard_msgs in #check id3.unfold /- diff --git a/backends/lean/Aeneas/FSimp.lean b/backends/lean/Aeneas/FSimp.lean new file mode 100644 index 00000000..5d2def0e --- /dev/null +++ b/backends/lean/Aeneas/FSimp.lean @@ -0,0 +1,12 @@ +import Lean + +/-! "Fast" simp. + +A version of simp with `maxDischargeDepth` set by default to 1: this is dramatically faster. +-/ + +declare_simp_like_tactic fsimp "fsimp " (maxDischargeDepth := 1) +declare_simp_like_tactic (all := true) fsimp_all "fsimp_all " (maxDischargeDepth := 1) + +example : True := by fsimp +example (x y z : Nat) (p : Nat → Prop) (h0 : p x) (h1 : p x → p y) (h2 : p y → p z) : p z := by fsimp_all diff --git a/backends/lean/Aeneas/List/List.lean b/backends/lean/Aeneas/List/List.lean index 46e874ce..56eae709 100644 --- a/backends/lean/Aeneas/List/List.lean +++ b/backends/lean/Aeneas/List/List.lean @@ -1,5 +1,6 @@ /- Complementary functions and lemmas for the `List` type -/ +import Mathlib.Data.List.GetD import Aeneas.ScalarTac import Aeneas.Utils import Aeneas.SimpLemmas @@ -10,54 +11,17 @@ open Aeneas open Aeneas.ScalarTac open Aeneas.Simp -def indexOpt (ls : List α) (i : Nat) : Option α := - match ls with - | [] => none - | hd :: tl => if i = 0 then some hd else indexOpt tl (i - 1) - -@[simp] theorem indexOpt_nil : indexOpt ([] : List α) i = none := by simp [indexOpt] -@[simp] theorem indexOpt_zero_cons : indexOpt ((x :: tl) : List α) 0 = some x := by simp [indexOpt] -@[simp] theorem indexOpt_nzero_cons (hne : Nat.not_eq i 0) : indexOpt ((x :: tl) : List α) i = indexOpt tl (i - 1) := by simp [indexOpt]; intro; simp_all +attribute [scalar_tac_simp] List.length_nil List.length_cons List.length_append List.length_reverse + List.get!_eq_getElem! List.get?_eq_getElem? -def index [Inhabited α] (ls : List α) (i : Nat) : α := - match ls with - | [] => Inhabited.default - | x :: tl => - if i = 0 then x else index tl (i - 1) +def set_opt (l : List α) (i : Nat) (x : Option α) : List α := + match l with + | [] => l + | hd :: tl => if i = 0 then Option.getD x hd :: tl else hd :: set_opt tl (i-1) x -@[simp] theorem index_nil [Inhabited α] : @index α _ [] i = Inhabited.default := by simp [index] -@[simp] theorem index_zero_cons (x : α) (tl : List α) [Inhabited α] : index ((x :: tl) : List α) 0 = x := by simp [index] -@[simp] theorem index_nzero_cons (x : α) (tl : List α) (i : Nat) [Inhabited α] (hne : Nat.not_eq i 0) : index ((x :: tl) : List α) i = index tl (i - 1) := by simp [index]; intro; simp_all - -theorem indexOpt_bounds (ls : List α) (i : Nat) : - ls.indexOpt i = none ↔ i < 0 ∨ ls.length ≤ i := by - match ls with - | [] => simp - | _ :: tl => - have := indexOpt_bounds tl (i - 1) - if h: i = 0 then simp [*] - else - simp [*] - omega - -theorem indexOpt_eq_index [Inhabited α] (ls : List α) (i : Nat) : - i < ls.length → - ls.indexOpt i = some (ls.index i) := - match ls with - | [] => by simp - | hd :: tl => - if h: i = 0 then - by simp [*] - else by - have hi := indexOpt_eq_index tl (i - 1) - simp [*]; intros - apply hi; int_tac - --- Remark: the list is unchanged if the index is not in bounds -def update (ls : List α) (i : Nat) (y : α) : List α := - match ls with - | [] => [] - | x :: tl => if i = 0 then y :: tl else x :: update tl (i - 1) y +attribute [simp] getElem?_cons_zero getElem!_cons_zero +@[simp] theorem getElem?_cons_nzero (hne : Nat.not_eq i 0) : getElem? ((x :: tl) : List α) i = getElem? tl (i - 1) := by cases i <;> simp_all +@[simp] theorem getElem!_cons_nzero (x : α) (tl : List α) (i : Nat) [Inhabited α] (hne : Nat.not_eq i 0) : getElem! ((x :: tl) : List α) i = getElem! tl (i - 1) := by cases i <;> simp_all def slice (start end_ : Nat) (ls : List α) : List α := (ls.drop start).take (end_ - start) @@ -85,36 +49,51 @@ def resize (l : List α) (new_len : Nat) (x : α) : List α := l.take new_len ++ replicate (new_len - l.length) x else [] -@[simp] theorem update_nil i y : update ([] : List α) i y = [] := by simp [update] -@[simp] theorem update_zero_cons x tl y : update ((x :: tl) : List α) 0 y = y :: tl := by simp [update] -@[simp] theorem update_nzero_cons x tl i y (hne : Nat.not_eq i 0) : update ((x :: tl) : List α) i y = x :: update tl (i - 1) y := by simp [update]; intro; simp_all +@[simp] theorem set_cons_nzero x tl i (hne : Nat.not_eq i 0) y : set ((x :: tl) : List α) i y = x :: set tl (i - 1) y := by cases i <;> simp_all -@[simp] theorem drop_nzero_cons i x tl (hne : Nat.not_eq i 0) : drop i ((x :: tl) : List α) = drop (i - 1) tl := by cases i <;> simp_all [drop] +@[simp] theorem drop_cons_nzero i x tl (hne : Nat.not_eq i 0) : drop i ((x :: tl) : List α) = drop (i - 1) tl := by cases i <;> simp_all [drop] -@[simp] theorem take_nzero_cons i x tl (hne : Nat.not_eq i 0) : take i ((x :: tl) : List α) = x :: take (i - 1) tl := by cases i <;> simp_all +@[simp] theorem take_cons_nzero i x tl (hne : Nat.not_eq i 0) : take i ((x :: tl) : List α) = x :: take (i - 1) tl := by cases i <;> simp_all @[simp] theorem slice_nil i j : slice i j ([] : List α) = [] := by simp [slice] @[simp] theorem slice_zero ls : slice 0 0 (ls : List α) = [] := by cases ls <;> simp [slice] -@[simp] theorem replicate_nzero_cons i (x : List α) (hne : Nat.not_eq i 0) : replicate i x = x :: replicate (i - 1) x := by +@[simp] theorem replicate_cons_nzero i (x : List α) (hne : Nat.not_eq i 0) : replicate i x = x :: replicate (i - 1) x := by cases i <;> simp_all [replicate] +@[simp] theorem set_opt_nil i y : set_opt ([] : List α) i y = [] := by simp [set_opt] +@[simp] theorem set_opt_zero_cons x tl y : set_opt ((x :: tl) : List α) 0 y = y.getD x :: tl := by simp [set_opt] +@[simp] theorem set_opt_cons_nzero x tl i y (hne : Nat.not_eq i 0) : set_opt ((x :: tl) : List α) i y = x :: set_opt tl (i - 1) y := by simp [set_opt]; intro; simp_all + @[simp] -theorem slice_nzero_cons (i j : Nat) (x : α) (tl : List α) (hne : Nat.not_eq i 0) : +theorem slice_cons_nzero (i j : Nat) (x : α) (tl : List α) (hne : Nat.not_eq i 0) : slice i j ((x :: tl) : List α) = slice (i - 1) (j - 1) tl := by apply Nat.not_eq_imp_not_eq at hne induction i <;> cases j <;> simp_all [slice] -@[simp, scalar_tac replicate l x] +@[simp, scalar_tac_simp] theorem replicate_length {α : Type u} (l : Nat) (x : α) : (replicate l x).length = l := by induction l <;> simp_all -@[simp, scalar_tac ls.update i x] -theorem length_update (ls : List α) (i : Nat) (x : α) : (ls.update i x).length = ls.length := by +@[simp] +theorem set_getElem! {α} [Inhabited α] (l : List α) (i : Nat) : + l.set i l[i]! = l := by + revert i; induction l <;> simp_all + rename_i hd tail hi + intro i + cases i <;> simp_all + +attribute [scalar_tac_simp] length_set + +@[simp, scalar_tac_simp] +theorem length_set_opt {α} (l : List α) (i : Nat) (x : Option α): + (l.set_opt i x).length = l.length := by revert i - induction ls <;> simp_all [length, update] - intro; split <;> simp [*] + induction l <;> simp_all + rename_i hd tl hi + intro i + cases i <;> simp_all theorem left_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.length = l1'.length) : l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by @@ -131,35 +110,24 @@ theorem right_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.length = l have : (l1 ++ l2).length = (l1' ++ l2').length := by simp [*] simp only [length_append] at this apply this - . simp [heq] at this + . simp only [heq, add_left_inj] at this tauto . tauto @[simp] -theorem index_append_beg [Inhabited α] (i : Nat) (l0 l1 : List α) (_ : i < l0.length) : - (l0 ++ l1).index i = l0.index i := by - match l0 with - | [] => simp_all - | hd :: tl => - if hi : i = 0 then simp_all - else - have := index_append_beg (i - 1) tl l1 (by simp_all; int_tac) - simp_all +theorem getElem!_append_left [Inhabited α] (l0 l1 : List α) (i : Nat) (h : i < l0.length) : + (l0 ++ l1)[i]! = l0[i]! := by + have := @getElem?_append_left _ l0 l1 i h + simp_all @[simp] -theorem index_append_end [Inhabited α] (i : Nat) (l0 l1 : List α) - (_ : l0.length ≤ i) : - (l0 ++ l1).index i = l1.index (i - l0.length) := - match l0 with - | [] => by simp_all - | hd :: tl => - have : ¬ i = 0 := by simp_all; int_tac - have := index_append_end (i - 1) tl l1 (by simp_all; int_tac) - -- TODO: canonize arith expressions - have : i - 1 - length tl = i - (1 + length tl) := by int_tac - by simp_all; ring_nf +theorem getElem!_append_right [Inhabited α] (l0 l1 : List α) (i : Nat) + (h : l0.length ≤ i) : + (l0 ++ l1)[i]! = l1[i - l0.length]! := by + have := @getElem?_append_right _ l0 l1 i h + simp_all -@[scalar_tac ls.drop i] +@[scalar_tac_simp] theorem drop_length_is_le (i : Nat) (ls : List α) : (ls.drop i).length ≤ ls.length := match ls with | [] => by simp @@ -167,187 +135,199 @@ theorem drop_length_is_le (i : Nat) (ls : List α) : (ls.drop i).length ≤ ls.l if h: i = 0 then by simp [*] else have := drop_length_is_le (i - 1) tl - by simp [*]; omega + by simp only [Nat.not_eq, ne_eq, not_false_eq_true, neq_imp, not_lt_zero', false_or, true_or, + or_self, drop_cons_nzero, length_drop, length_cons, tsub_le_iff_right, h]; omega -@[simp, scalar_tac ls.drop i] -theorem length_drop_eq (i : Nat) (ls : List α) : - (ls.drop i).length = ls.length - i := by - induction ls <;> simp_all +attribute [scalar_tac_simp] length_drop -@[scalar_tac ls.take i] +@[scalar_tac_simp] theorem take_length_is_le (i : Nat) (ls : List α) : (ls.take i).length ≤ ls.length := by induction ls <;> simp_all -attribute [scalar_tac l.take i] length_take +attribute [scalar_tac_simp] length_take -@[simp, scalar_tac l.resize new_len x] +@[simp, scalar_tac_simp] theorem resize_length (l : List α) (new_len : Nat) (x : α) : (l.resize new_len x).length = new_len := by induction l <;> simp_all [resize] - int_tac + scalar_tac @[simp] theorem slice_zero_j (l : List α) : l.slice 0 j = l.take j := by simp [slice] theorem slice_length_le (i j : Nat) (ls : List α) : (ls.slice i j).length ≤ ls.length := by simp [slice] +@[scalar_tac_simp] theorem slice_length (i j : Nat) (ls : List α) : (ls.slice i j).length = min (ls.length - i) (j - i) := by - simp [slice]; int_tac - -@[simp] -theorem index_drop [Inhabited α] (i : Nat) (j : Nat) (ls : List α) : - (ls.drop i).index j = ls.index (i + j) := by - revert i - induction ls - . intro i; simp_all - . intro i; cases i <;> simp_all - -@[simp] -theorem index_take_same [Inhabited α] (i : Nat) (j : Nat) (ls : List α) - (_ : j < i) (_ : j < ls.length) : - (ls.take i).index j = ls.index j := by - revert i j - induction ls - . intro i j; simp_all - . intro i j h0 h1; cases i <;> simp_all - cases j <;> simp_all + simp [slice]; scalar_tac @[simp] -theorem index_slice [Inhabited α] (i j k : Nat) (ls : List α) +theorem getElem?_slice (i j k : Nat) (ls : List α) (_ : j ≤ ls.length) (_ : i + k < j) : - (ls.slice i j).index k = ls.index (i + k) := by + (ls.slice i j)[k]? = ls[i + k]? := by revert i j induction ls . intro i j; simp_all . intro i j h0 h1 simp_all [slice] - rw [index_take_same] <;> first | simp_all | int_tac - int_tac + have : k < j - i := by scalar_tac + simp [*] @[simp] -theorem index_take_append_beg [Inhabited α] (i j : Nat) (l0 l1 : List α) +theorem getElem!_slice [Inhabited α] (i j k : Nat) (ls : List α) + (_ : j ≤ ls.length) (_ : i + k < j) : + (ls.slice i j)[k]! = ls[i + k]! := by + have := getElem?_slice i j k ls + simp_all + +@[simp] +theorem getElem?_take_append_beg (i j : Nat) (l0 l1 : List α) (_ : j < i) (_ : i ≤ l0.length) : - ((l0 ++ l1).take i).index j = l0.index j := by + getElem? ((l0 ++ l1).take i) j = getElem? l0 j := by revert i j l1 induction l0 <;> simp_all intros i j l1 cases i <;> simp_all cases j <;> simp_all + rename_i tail hi n n1 + intros + have : n1 < tail.length := by scalar_tac + rw [hi n] <;> simp [*] @[simp] -theorem index_update_neq - {α : Type u} [Inhabited α] (l: List α) (i: Nat) (j: Nat) (x: α) : - Nat.not_eq i j → (l.update i x).index j = l.index j - := - λ _ => match l with - | [] => by simp at * - | hd :: tl => - if h: i = 0 then - have : j ≠ 0 := by scalar_tac - by simp [*] - else if h : j = 0 then - have : i ≠ 0 := by scalar_tac - by simp [*] - else - by - simp_all - apply index_update_neq; scalar_tac +theorem getElem!_take_append_beg [Inhabited α] (i j : Nat) (l0 l1 : List α) + (_ : j < i) (_ : i ≤ l0.length) : + getElem! ((l0 ++ l1).take i) j = getElem! l0 j := by + have := getElem?_take_append_beg i j l0 l1 + simp_all + +@[simp] +theorem getElem!_drop [Inhabited α] (i : Nat) (j : Nat) (ls : List α) : + getElem! (ls.drop i) j = getElem! ls (i + j) := by + have := @getElem?_drop _ ls i j + simp_all + +@[simp] +theorem getElem?_take_same (i : Nat) (j : Nat) (ls : List α) + (_ : j < i) (_ : j < ls.length) : + getElem? (ls.take i) j = getElem? ls j := by + simp [getElem?_take, *] + +@[simp] +theorem getElem!_take_same [Inhabited α] (i : Nat) (j : Nat) (ls : List α) + (_ : j < i) (_ : j < ls.length) : + getElem! (ls.take i) j = getElem! ls j := by + simp [getElem?_take_same, *] @[simp] -theorem index_update_eq - {α : Type u} [Inhabited α] (l: List α) (i: Nat) (x: α) : - i < l.length → (l.update i x).index i = x +theorem getElem?_set_neq + {α : Type u} (l: List α) (i: Nat) (j: Nat) (x: α) + (h : Nat.not_eq i j) : getElem? (l.set i x) j = getElem? l j := by - revert i - induction l <;> simp_all - intro i h - cases i <;> simp_all + simp [getElem?_set] + intro + simp_all @[simp] -theorem map_update_eq {α : Type u} {β : Type v} (ls : List α) (i : Nat) (x : α) (f : α → β) : - (ls.update i x).map f = (ls.map f).update i (f x) := - match ls with - | [] => by simp - | hd :: tl => - if h : i = 0 then by simp [*] - else - have hi := map_update_eq tl (i - 1) x f - by simp [*] +theorem getElem!_set_neq + {α : Type u} [Inhabited α] (l: List α) (i: Nat) (j: Nat) (x: α) + (h : Nat.not_eq i j) : getElem! (l.set i x) j = getElem! l j + := by + have := getElem?_set_neq l i j x h + simp_all + +@[simp] +theorem getElem!_set_self + {α : Type u} [Inhabited α] (l: List α) (i: Nat) (x: α) + (h : i < l.length) : getElem! (l.set i x) i = x + := by + simp [*] -- TODO: we need "composite" patterns for scalar_tac here -theorem length_index_le_length_flatten (ls : List (List α)) : - forall (i : Nat), (ls.index i).length ≤ ls.flatten.length := by +theorem length_getElem!_le_length_flatten (ls : List (List α)) : + forall (i : Nat), (getElem! ls i).length ≤ ls.flatten.length := by induction ls <;> intro i <;> simp_all [default] cases i <;> simp_all rename ∀ _, _ => ih rename Nat => i replace ih := ih i - int_tac + scalar_tac -theorem length_flatten_update_eq {α : Type u} (ls : List (List α)) (i : Nat) (x : List α) +theorem length_flatten_set_eq {α : Type u} (ls : List (List α)) (i : Nat) (x : List α) (h1 : i < ls.length) : - (ls.update i x).flatten.length + (ls.index i).length = ls.flatten.length + x.length := by + (ls.set i x).flatten.length + (ls[i]!).length = ls.flatten.length + x.length := by revert i induction ls <;> intro i <;> simp_all [default] cases i <;> simp_all - . int_tac + . scalar_tac . rename Nat => i rename ∀ _, _ => ih - replace ih := ih i - int_tac - -@[scalar_tac (ls.update i x).flatten] -theorem length_flatten_update_eq_disj {α : Type u} (ls : List (List α)) (i : Nat) (x : List α) : - i < 0 ∨ ls.length ≤ i ∨ - (ls.update i x).flatten.length + (ls.index i).length = ls.flatten.length + x.length := by - cases h: (i < 0 : Bool) <;> simp_all only [not_lt_zero', decide_false, Bool.false_eq_true, not_false_eq_true, neq_imp] + intro hi + replace ih := ih i hi + scalar_tac + +@[scalar_tac (ls.set i x).flatten] +theorem length_flatten_set_eq_disj {α : Type u} (ls : List (List α)) (i : Nat) (x : List α) : + ls.length ≤ i ∨ + (ls.set i x).flatten.length + (ls[i]!).length = ls.flatten.length + x.length := by cases h: (ls.length ≤ i : Bool) <;> simp_all only [decide_eq_false_iff_not, not_le, false_or, decide_eq_true_eq, true_or] - rw [length_flatten_update_eq] <;> simp [*] + rw [length_flatten_set_eq] <;> simp [*] -theorem length_flatten_update_as_int_eq {α : Type u} (ls : List (List α)) (i : Nat) (x : List α) +theorem length_flatten_set_as_int_eq {α : Type u} (ls : List (List α)) (i : Nat) (x : List α) (h1 : i < ls.length) : - ((ls.update i x).flatten.length : Nat) = ls.flatten.length + x.length - (ls.index i).length := by - int_tac + ((ls.set i x).flatten.length : Nat) = ls.flatten.length + x.length - (ls[i]!).length := by + scalar_tac @[simp] -theorem index_map_eq {α : Type u} {β : Type v} [Inhabited α] [Inhabited β] +theorem getElem!_map_eq {α : Type u} {β : Type v} [Inhabited α] [Inhabited β] (ls : List α) (i : Nat) (f : α → β) (h1 : i < ls.length) : -- We need the bound because otherwise we have to prove that: `(default : β) = f (default : α)` - (ls.map f).index i = f (ls.index i) := by - revert i; induction ls <;> simp_all - intro i h - cases i <;> simp_all + (ls.map f)[i]! = f (ls[i]!) := by + simp [*] -theorem replace_slice_index [Inhabited α] (start end_ : Nat) (l nl : List α) +theorem replace_slice_getElem? (start end_ : Nat) (l nl : List α) (_ : start < end_) (_ : end_ ≤ l.length) (_ : nl.length = end_ - start) : let l1 := l.replace_slice start end_ nl - (∀ i, i < start → l1.index i = l.index i) ∧ - (∀ i, start ≤ i → i < end_ → l1.index i = nl.index (i - start)) ∧ - (∀ i, end_ ≤ i → i < l.length → l1.index i = l.index i) + (∀ i, i < start → getElem? l1 i = getElem? l i) ∧ + (∀ i, start ≤ i → i < end_ → getElem? l1 i = getElem? nl (i - start)) ∧ + (∀ i, end_ ≤ i → i < l.length → getElem? l1 i = getElem? l i) := by -- We need those assumptions everywhere - have : start ≤ l.length := by int_tac + have : start ≤ l.length := by scalar_tac simp only [replace_slice] split_conjs . intro i _ -- Introducing exactly the assumptions we need to make the rewriting work - have : i < l.length := by int_tac - simp_all + have : i < l.length := by scalar_tac + simp_all only [append_assoc, length_take, inf_of_le_left, getElem?_append_left] + simp [*] . intro i _ _ have : (List.take start l).length ≤ i := by simp_all have : i < (List.take start l).length + (nl ++ List.drop end_ l).length := by - simp_all; int_tac - simp_all - have : i - start < nl.length := by int_tac - simp_all + simp_all; scalar_tac + simp_all only [length_take, inf_of_le_left, length_append, length_drop, append_assoc, + getElem?_append_right] + have : i - start < nl.length := by scalar_tac + simp_all only [getElem?_append_left] . intro i _ _ have : 0 ≤ end_ := by scalar_tac - have : end_ ≤ l.length := by int_tac - have : (List.take start l).length ≤ i := by int_tac - have := index_append_end i (take start l ++ nl) (drop end_ l) (by simp; int_tac) - simp_all + have : end_ ≤ l.length := by scalar_tac + have : (List.take start l).length ≤ i := by scalar_tac + have := @getElem?_append_right _ (take start l ++ nl) (drop end_ l) i (by simp; scalar_tac) + simp_all only [zero_le, length_take, inf_of_le_left, append_assoc, getElem?_append_right, + tsub_le_iff_right, Nat.sub_add_cancel, getElem?_drop, length_append] congr - int_tac + scalar_tac + +theorem replace_slice_getElem! [Inhabited α] (start end_ : Nat) (l nl : List α) + (_ : start < end_) (_ : end_ ≤ l.length) (_ : nl.length = end_ - start) : + let l1 := l.replace_slice start end_ nl + (∀ i, i < start → getElem! l1 i = getElem! l i) ∧ + (∀ i, start ≤ i → i < end_ → getElem! l1 i = getElem! nl (i - start)) ∧ + (∀ i, end_ ≤ i → i < l.length → getElem! l1 i = getElem! l i) + := by + have := replace_slice_getElem? start end_ l nl (by assumption) (by assumption) (by assumption) + simp_all @[simp] theorem allP_nil {α : Type u} (p: α → Prop) : allP [] p := @@ -375,13 +355,4 @@ theorem lookup_not_none_imp_length_pos [BEq α] (l : List (α × β)) (key : α) end -@[simp] -theorem list_update_index_eq α [Inhabited α] (x : List α) (i : ℕ) : - x.update i (x.index i) = x := by - revert i - induction x - . simp - . intro i - dcases hi: 0 < i <;> simp_all - end List diff --git a/backends/lean/Aeneas/Natify.lean b/backends/lean/Aeneas/Natify.lean new file mode 100644 index 00000000..3cf531fa --- /dev/null +++ b/backends/lean/Aeneas/Natify.lean @@ -0,0 +1 @@ +import Aeneas.Natify.Natify diff --git a/backends/lean/Aeneas/Natify/Init.lean b/backends/lean/Aeneas/Natify/Init.lean new file mode 100644 index 00000000..ca506631 --- /dev/null +++ b/backends/lean/Aeneas/Natify/Init.lean @@ -0,0 +1,11 @@ +import Aeneas.Extensions +open Lean Meta + +namespace Aeneas.Natify + +/-- The `natify_simps` simp attribute. -/ +initialize natifySimpExt : SimpExtension ← + registerSimpAttr `natify_simps "\ + The `natify_simps` attribute registers simp lemmas to be used by `natify`." + +end Aeneas.Natify diff --git a/backends/lean/Aeneas/Natify/Natify.lean b/backends/lean/Aeneas/Natify/Natify.lean new file mode 100644 index 00000000..2e6339d7 --- /dev/null +++ b/backends/lean/Aeneas/Natify/Natify.lean @@ -0,0 +1,66 @@ +import Mathlib.Tactic.Basic +import Mathlib.Tactic.Attr.Register +import Mathlib.Data.Int.Cast.Basic +import Mathlib.Order.Basic +import Aeneas.Natify.Init +import Aeneas.Arith.Lemmas +import Aeneas.Std.Scalar + +/-! +# `natify` tactic + +The `natify` tactic is used to shift propositions about, e.g., `ZMod` or `BitVec`, to `Nat`. +This tactic is adapted from `zify`. +-/ + +namespace Aeneas.Natify + +open Lean +open Lean.Meta +open Lean.Parser.Tactic +open Lean.Elab.Tactic +open Arith Std + +syntax (name := natify) "natify" (simpArgs)? (location)? : tactic + +macro_rules +| `(tactic| natify $[[$simpArgs,*]]? $[at $location]?) => + let args := simpArgs.map (·.getElems) |>.getD #[] + `(tactic| + simp -decide (maxDischargeDepth := 1) only [natify_simps, push_cast, $args,*] $[at $location]?) + +/-- The `Simp.Context` generated by `natify`. -/ +def mkNatifyContext (simpArgs : Option (Syntax.TSepArray `Lean.Parser.Tactic.simpStar ",")) : + TacticM MkSimpContextResult := do + let args := simpArgs.map (·.getElems) |>.getD #[] + mkSimpContext + (← `(tactic| simp -decide (maxDischargeDepth := 1) only [natify_simps, push_cast, $args,*])) false + +/-- A variant of `applySimpResultToProp` that cannot close the goal, but does not need a meta +variable and returns a tuple of a proof and the corresponding simplified proposition. -/ +def applySimpResultToProp' (proof : Expr) (prop : Expr) (r : Simp.Result) : MetaM (Expr × Expr) := + do + match r.proof? with + | some eqProof => return (← mkExpectedTypeHint (← mkEqMP eqProof proof) r.expr, r.expr) + | none => + if r.expr != prop then + return (← mkExpectedTypeHint proof r.expr, r.expr) + else + return (proof, r.expr) + +/-- Translate a proof and the proposition into a natified form. -/ +def natifyProof (simpArgs : Option (Syntax.TSepArray `Lean.Parser.Tactic.simpStar ",")) + (proof : Expr) (prop : Expr) : TacticM (Expr × Expr) := do + let ctx_result ← mkNatifyContext simpArgs + let (r, _) ← simp prop ctx_result.ctx + applySimpResultToProp' proof prop r + +attribute [natify_simps] BitVec.toNat_eq BitVec.lt_def BitVec.le_def + BitVec.toNat_umod BitVec.toNat_add BitVec.toNat_sub BitVec.toNat_ofNat + BitVec.toNat_and BitVec.toNat_or BitVec.toNat_xor +attribute [natify_simps] ZMod.eq_iff_mod ZMod.val_add ZMod.val_sub ZMod.val_mul +attribute [natify_simps] U8.bv_toNat_eq U16.bv_toNat_eq U32.bv_toNat_eq U64.bv_toNat_eq U128.bv_toNat_eq Usize.bv_toNat_eq + +example (x y : BitVec 32) (h : x.toNat = y.toNat) : x = y := by natify [h] + +end Aeneas.Natify diff --git a/backends/lean/Aeneas/Progress/Core.lean b/backends/lean/Aeneas/Progress/Core.lean index 9af29ee2..207eed45 100644 --- a/backends/lean/Aeneas/Progress/Core.lean +++ b/backends/lean/Aeneas/Progress/Core.lean @@ -2,6 +2,7 @@ import Lean import Aeneas.Utils import Aeneas.Std.Core import Aeneas.Extensions +import Aeneas.Progress.Trace namespace Aeneas @@ -10,10 +11,17 @@ namespace Progress open Lean Elab Term Meta open Utils Extensions --- We can't define and use trace classes in the same file -initialize registerTraceClass `Progress +/-! +# Attribute: `progress_simp` +-/ -/- # Progress tactic -/ +/-- The `progress_simp` simp attribute. -/ +initialize progressSimpExt : SimpExtension ← + registerSimpAttr `progress_simp "\ + The `progress_simp` attribute registers simp lemmas to be used by `progress` + when solving preconditions by means of the simplifier." + +/-! # Attribute: `progress` -/ structure PSpecDesc where -- The universally quantified variables @@ -55,8 +63,6 @@ section Methods - function arguments - return - postconditions - - TODO: generalize for when we do inductive proofs -/ partial def withPSpec [Inhabited (m a)] [Nonempty (m a)] @@ -104,7 +110,7 @@ section Methods -- If we are registering a theorem, then the function must be a constant if ¬ f.isConst then if isGoal then pure [] - else throwError "Not a constant: {f}" + else throwError "{f} should be a constant" else pure f.constLevels! -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB) -- Check if some existentially quantified variables @@ -132,64 +138,611 @@ section Methods } k thDesc + /- Auxiliary helper. + + Given type `α₀ × ... × αₙ`, introduce fresh variables + `x₀ : α₀, ..., xₙ : αₙ` and call the continuation with those. + -/ + def withFreshTupleFieldFVars [Inhabited a] (basename : Name) (ty : Expr) (k : Array Expr → m a) : m a := do + let tys := destProdsType ty + let tys := List.map (fun (i, ty) => (Name.num basename i, fun _ => pure ty)) (List.enum tys) + withLocalDeclsD ⟨ tys ⟩ k end Methods def getPSpecFunArgsExpr (isGoal : Bool) (th : Expr) : MetaM Expr := withPSpec isGoal th (fun d => do pure d.fArgsExpr) --- pspec attribute +structure Rules where + rules : DiscrTree Name + /- We can't remove keys from a discrimination tree, so to support + local rules we keep a set of deactivated rules (rules which have + come out of scope) on the side -/ + deactivated : Std.HashSet Name +deriving Inhabited + +def Rules.empty : Rules := ⟨ DiscrTree.empty, Std.HashSet.empty ⟩ + +def Extension := SimpleScopedEnvExtension (DiscrTreeKey × Name) Rules +deriving Inhabited + +def Rules.insert (r : Rules) (kv : Array DiscrTree.Key × Name) : Rules := + { r with rules := r.rules.insertCore kv.fst kv.snd } + +def Rules.erase (r : Rules) (k : Name) : Rules := + { r with deactivated := r.deactivated.insert k } + +def mkExtension (name : Name := by exact decl_name%) : + IO Extension := + registerSimpleScopedEnvExtension { + name := name, + initial := Rules.empty, + addEntry := Rules.insert, + } + +/-- The progress attribute -/ structure PSpecAttr where attr : AttributeImpl - ext : DiscrTreeExtension Name + ext : Extension deriving Inhabited -/- The persistent map from expressions to pspec theorems. -/ +private def saveProgressSpecFromThm (ext : Extension) (attrKind : AttributeKind) (thName : Name) : + AttrM Unit := do + -- Lookup the theorem + let env ← getEnv + -- Ignore some auxiliary definitions (see the comments for attrIgnoreMutRec) + attrIgnoreAuxDef thName (pure ()) do + trace[Progress] "Registering `progress` theorem for {thName}" + let thDecl := env.constants.find! thName + let fKey ← MetaM.run' (do + trace[Progress] "Theorem: {thDecl.type}" + -- Normalize to eliminate the let-bindings + let ty ← normalizeLetBindings thDecl.type + trace[Progress] "Theorem after normalization (to eliminate the let bindings): {ty}" + let fExpr ← getPSpecFunArgsExpr false ty + trace[Progress] "Registering spec theorem for expr: {fExpr}" + -- Convert the function expression to a discrimination tree key + DiscrTree.mkPath fExpr) + -- Save the entry + ScopedEnvExtension.add ext (fKey, thName) attrKind + trace[Progress] "Saved the entry" + pure () + +/- Initiliaze the `progress` attribute. -/ initialize pspecAttr : PSpecAttr ← do - let ext ← mkDiscrTreeExtension `pspecMap + let ext ← mkExtension `pspecMap let attrImpl : AttributeImpl := { - name := `pspec - descr := "Marks theorems to use with the `progress` tactic" + name := `progress + descr := "Adds theorems to the `progress` database" add := fun thName stx attrKind => do Attribute.Builtin.ensureNoArgs stx - -- TODO: use the attribute kind - unless attrKind == AttributeKind.global do - throwError "invalid attribute 'pspec', must be global" - -- Lookup the theorem - let env ← getEnv - -- Ignore some auxiliary definitions (see the comments for attrIgnoreMutRec) - attrIgnoreAuxDef thName (pure ()) do - trace[Progress] "Registering spec theorem for {thName}" - let thDecl := env.constants.find! thName - let fKey ← MetaM.run' (do - trace[Progress] "Theorem: {thDecl.type}" - -- Normalize to eliminate the let-bindings - let ty ← normalizeLetBindings thDecl.type - trace[Progress] "Theorem after normalization (to eliminate the let bindings): {ty}" - let fExpr ← getPSpecFunArgsExpr false ty - trace[Progress] "Registering spec theorem for expr: {fExpr}" - -- Convert the function expression to a discrimination tree key - DiscrTree.mkPath fExpr) - let env := ext.addEntry env (fKey, thName) - setEnv env - trace[Progress] "Saved the environment" - pure () + saveProgressSpecFromThm ext attrKind thName + erase := fun thName => do + let s := ext.getState (← getEnv) + let s := s.erase thName + modifyEnv fun env => ext.modifyState env fun _ => s } registerBuiltinAttribute attrImpl pure { attr := attrImpl, ext := ext } def PSpecAttr.find? (s : PSpecAttr) (e : Expr) : MetaM (Array Name) := do - (s.ext.getState (← getEnv)).getMatch e + let state := s.ext.getState (← getEnv) + let rules ← state.rules.getMatch e + pure (rules.filter (fun th => th ∉ state.deactivated)) -def PSpecAttr.getState (s : PSpecAttr) : MetaM (DiscrTree Name) := do +def PSpecAttr.getState (s : PSpecAttr) : MetaM Rules := do pure (s.ext.getState (← getEnv)) def showStoredPSpec : MetaM Unit := do let st ← pspecAttr.getState -- TODO: how can we iterate over (at least) the values stored in the tree? --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" - let s := f!"{st}" + let s := f!"{st.rules}, {st.deactivated.toArray}" IO.println s +/-! # Attribute: `progress_pure` -/ + +namespace Test + /-! + Making some tests here as models to guide the automation generation of proof terms when lifting theorems in `progress_pure` + -/ + open Std Result + def pos_pair : Int × Int := (0, 1) + + theorem pos_pair_is_pos : + let (x, y) := pos_pair + x ≥ 0 ∧ y ≥ 0 := by simp [pos_pair] + + theorem lifted_is_pos : + ∃ x y, toResult pos_pair = ok (x, y) ∧ + x ≥ 0 ∧ y ≥ 0 := + (match pos_pair with + | (x, y) => + fun (h : match (x, y) with | (x, y) => x ≥ 0 ∧ y ≥ 0) => + Exists.intro x (Exists.intro y (And.intro (Eq.refl (ok (x, y))) h)) + : ∀ (_ : match pos_pair with | (x, y) => x ≥ 0 ∧ y ≥ 0), + ∃ x y, toResult pos_pair = ok (x, y) ∧ + x ≥ 0 ∧ y ≥ 0) pos_pair_is_pos + + /- Same as `lifted_is_pos` but making the implicit parameters of the `Exists.intro` explicit: + this is the important part. -/ + theorem lifted_is_pos' : + ∃ x y, toResult pos_pair = ok (x, y) ∧ + x ≥ 0 ∧ y ≥ 0 := + (match pos_pair with + | (x, y) => + fun (h : match (x, y) with | (x, y) => x ≥ 0 ∧ y ≥ 0) => + @Exists.intro Int (fun x_1 => ∃ y_1, ok (x, y) = ok (x_1, y_1) ∧ x_1 ≥ 0 ∧ y_1 ≥ 0) + x (@Exists.intro Int (fun y_1 => ok (x, y) = ok (x, y_1) ∧ x ≥ 0 ∧ y_1 ≥ 0) + y (@And.intro (ok (x, y) = ok (x, y)) _ (Eq.refl (ok (x, y))) h)) + : ∀ (_ : match pos_pair with | (x, y) => x ≥ 0 ∧ y ≥ 0), + ∃ x y, toResult pos_pair = ok (x, y) ∧ + x ≥ 0 ∧ y ≥ 0) pos_pair_is_pos + + def pos_triple : Int × Int × Int := (0, 1, 2) + + theorem pos_triple_is_pos : + let (x, y, z) := pos_triple + x ≥ 0 ∧ y ≥ 0 ∧ z ≥ 0 := by simp [pos_triple] + + structure U8 where + val : Nat + + def overflowing_add (x y : U8) : U8 × Bool := (⟨ x.val + y.val ⟩, x.val + y.val > 255) + + theorem overflowing_add_eq (x y : U8) : + let z := overflowing_add x y + if x.val + y.val > 255 then z.snd = true + else z.snd = false + := + by simp [overflowing_add] + +end Test + +def reduceProdProjs (e : Expr) : MetaM Expr := do + let pre (e : Expr) : MetaM TransformStep := do + trace[Utils] "Attempting to reduce: {e}" + match ← reduceProj? e with + | none => + e.withApp fun fn args => + if fn.isConst ∧ (fn.constName! = ``Prod.fst ∨ fn.constName! = ``Prod.snd) ∧ args.size = 3 then + let pair := args[2]! + pair.withApp fun fn' args => + if fn'.isConst ∧ fn'.constName! = ``Prod.mk ∧ args.size = 4 then + if fn.constName! = ``Prod.fst then pure (.continue args[2]!) + else pure (.continue args[3]!) + else pure (.continue e) + else pure (.continue e) + | some e => + trace[Utils] "reduced: {e}" + pure (.continue e) + transform e (pre := pre) + +/-- Given a theorem of type `P x` and a pattern of the shape `∃ y₀ ... yₙ, x = (y₀, ..., yₙ)`, + introduce a lifted version of the theorem of the shape: + ``` + ∃ y₀ ... yₙ, toResult x = ok (y₀, ..., yₙ) ∧ P (y₀, ..., yₙ) + ``` + + The output of the function is the name of the new theorem. + + Note that if the pattern is simply `x` (not an existentially quantified equality), this function + decomposes the type of `x` for as long as it finds a tuple, and introduces one variable per field + in the tuple. + + For instance, given pattern `some_pair : Int × Int`, the following theorem: + ``` + P some_pair.fst ∧ Q some_pair.snd + ``` + gets lifted to: + ``` + ∃ x y, toResult some_pair = ok (x, y) ∧ P x ∧ Q y + ``` +-/ +def liftThm (pat : Syntax) (n : Name) (suffix : String := "progress_spec") : MetaM Name := do + trace[Progress] "Name: {n}" + let env ← getEnv + let decl := env.constants.find! n + /- Strip the quantifiers before elaborating the pattern -/ + forallTelescope decl.type.consumeMData fun fvars thm0 => do + let (pat, _) ← Elab.Term.elabTerm pat none |>.run + trace[Progress] "Elaborated pattern: {pat}" + /- -/ + existsTelescope pat.consumeMData fun _ eqMatch => do + existsTelescope pat.consumeMData fun _ eqExists => do + /- Destruct the equality. Note that there may not be a tuple, in which case + we see the type as a tuple and introduce one variable per field of the tuple + (and a single variable if it is actually not a tuple). -/ + let tryDestEq basename (eq : Expr) (k : Expr → Expr → MetaM Name) : MetaM Name := do + match ← destEqOpt eq with + | some (x, y) => k x y + | none => + withFreshTupleFieldFVars (.str .anonymous basename) (← inferType pat) fun fvars => do + k pat (← mkProdsVal fvars.toList) + /- We need to introduce two sets of variables: + - one for variables which will be introduced by the outer match + - another for variables which will be bound by the ∃ quantifiers -/ + tryDestEq "x" eqMatch fun pat decompPatMatch => do + tryDestEq "y" eqExists fun _ decompPatExists => do + trace[Progress] "Decomposed equality: {pat}, {decompPatMatch}, {decompPatExists}" + /- The decomposed patterns should be tuple expressions: decompose them further into lists of variables -/ + let fvarsMatch : Array Expr := ⟨ destProdsVal decompPatMatch ⟩ + let fvarsExists : Array Expr := ⟨ destProdsVal decompPatExists ⟩ + trace[Progress] "Fvars: {fvarsMatch}, {fvarsExists}" + /- Small helper that we use to substitute the pattern in the original theorem -/ + let mkPureThmType (npat : Expr) : MetaM Expr := do + let thm ← mapVisit (fun _ e => do if e == pat then pure npat else pure e) thm0 + /- Reduce a bit the expression, but in a controlled manner, to make it cleaner -/ + let thm ← normalizeLetBindings thm + reduceProdProjs thm + /- Introduce the binder for the pure theorem - it will be bound outside of the ∃ but we need to use it + right now to generate an expression of type: + ``` + toResult ... = ok x ∧ + P x -- HERE + ``` + -/ + let pureThmType ← mkPureThmType decompPatMatch + let pureThmName ← mkFreshUserName (.str .anonymous "pureThm") + withLocalDeclD pureThmName pureThmType fun pureThm => do + /- Introduce the equality -/ + let okDecompPat ← mkAppM ``Std.Result.ok #[decompPatMatch] + let eqExpr ← mkEqRefl okDecompPat + let thm ← mkAppM ``And.intro #[eqExpr, pureThm] + trace[Progress] "Theorem after introducing the lifted equality: {thm}\n :\n{← inferType thm}" + /- Auxiliary helper which computes the type of the (intermediate) theorems when adding the existentials. + + Given `toResultArg`, `xl0` and `xl1`, generates: + ``` + ∃ $xl1, + toResult $toResultArg = ($xl0 ++ $xl1) ∧ + P ($xl0 ++ $xl1) + ``` + -/ + let mkThmType (toResultArg : Expr) (xl0 : List Expr) (xl1 : List Expr) : MetaM Expr := do + trace[Progress] "mkThmType:\n- {toResultArg}\n- {xl0}\n- {xl1}" + let npatExists ← mkProdsVal (xl0 ++ xl1) + /- Update the theorem statement to replace the pattern with the decomposed pattern -/ + let thmType ← mkPureThmType npatExists + trace[Progress] "mkThmType: without lifted equality: {thmType}" + let toResultPat ← mkAppM ``Std.toResult #[toResultArg] + let okDecompPat ← mkAppM ``Std.Result.ok #[npatExists] + let eqType ← mkEq toResultPat okDecompPat + let thmType := mkAnd eqType thmType + trace[Progress] "mkThmType: after lifting equality: {thmType}" + /- Introduce the existentials, only for the suffix of the list of variables -/ + let thmType ← List.foldrM (fun fvar thmType => do + let p ← mkLambdaFVars #[fvar] thmType + mkAppM ``Exists #[p] + ) thmType xl1 + trace[Progress] "mkThmType: after adding the existentials: {thmType}" + pure thmType + /- Introduce the existentials -/ + let rec introExists (xl0 xl1 : List (Expr × Expr)) : MetaM Expr := do + match xl1 with + | [] => pure thm + | fvarPair :: xl1 => + let thm ← introExists (fvarPair :: xl0) xl1 + let (fvarMatch, fvarExists) := fvarPair + let α ← inferType fvarMatch + let thmType ← mkThmType decompPatMatch (fvarExists :: (List.unzip xl0).fst).reverse (List.unzip xl1).snd + let p ← mkLambdaFVars #[fvarExists] thmType + let x := fvarMatch + let h := thm + trace[Progress] "introExists: about to insert existential:\n- α: {α}\n- p: {p}\n- x: {x}\n- h: {h}" + let thm ← mkAppOptM ``Exists.intro #[α, p, x, h] + trace[Progress] "introExists: resulting theorem:\n{thm}\n :\n{← inferType thm}" + pure thm + let thm ← introExists [] (List.zip fvarsMatch.toList fvarsExists.toList) + trace[Progress] "Theorem after introducing the existentials: {thm} :\n{← inferType thm}" + /- Introduce the λ which binds the pure theorem -/ + let thm ← mkLambdaFVars #[pureThm] thm + trace[Progress] "Theorem after introducing the lambda: {thm} :\n{← inferType thm}" + /- Introduce the matches -/ + let thm ← mkProdsMatch fvarsMatch.toList thm + trace[Progress] "Theorem after introducing the matches: {thm} :\n{← inferType thm}" + /- Apply to the scrutinee (which is the pattern provided by the user): `mkProdsMatch` generates + a lambda expression, where the bound value is the scrutinee we should match over. -/ + let thm := mkApp thm pat + trace[Progress] "Theorem after introducing the scrutinee: {thm} :\n{← inferType thm}" + /- Apply to the pure theorem (the expression inside the match is a function which expects to receive this theorem) -/ + let pureThm := mkAppN (.const decl.name (List.map Level.param decl.levelParams)) fvars + let thm := mkAppN thm #[pureThm] + trace[Progress] "Theorem after introducing the matches and the app: {thm} :\n{← inferType thm}" + let thm ← mkLambdaFVars fvars thm + /- Prepare the theorem type -/ + let thmType ← do + let thmType ← mkThmType pat [] fvarsExists.toList + let thmType ← mkForallFVars fvars thmType + pure thmType + trace[Progress] "Final theorem: {thm}\n :\n{thmType}" + /- Save the auxiliary theorem -/ + let name := Name.str decl.name suffix + let auxDecl : TheoremVal := { + name + levelParams := decl.levelParams + type := thmType + value := thm + } + addDecl (.thmDecl auxDecl) + /- -/ + pure name + +local elab "#progress_pure_lift_thm" id:ident pat:term : command => do + Lean.Elab.Command.runTermElabM (fun _ => do + let some cs ← Term.resolveId? id | throwError m!"Unknown id: {id}" + let name := cs.constName! + let _ ← liftThm pat name) + +namespace Test + #progress_pure_lift_thm pos_pair_is_pos (∃ x y, pos_pair = (x, y)) + + #progress_pure_lift_thm pos_triple_is_pos pos_triple + + def pos_triple_is_pos' := pos_triple_is_pos + #progress_pure_lift_thm pos_triple_is_pos' (∃ z, pos_triple = z) + + #progress_pure_lift_thm overflowing_add_eq (overflowing_add x y) +end Test + + +/- The ident is the name of the saturation set, the term is the pattern. -/ +syntax (name := progress_pure) "progress_pure" term : attr + +def elabProgressPureAttribute (stx : Syntax) : AttrM (TSyntax `term) := + withRef stx do + match stx with + | `(attr| progress_pure $pat) => do + pure pat + | _ => throwUnsupportedSyntax + +/-- The progress pure attribute -/ +structure ProgressPureSpecAttr where + attr : AttributeImpl + deriving Inhabited + +/- Initialize the `progress_pure` attribute, which lifts lemmas about pure functions to + `progress` lemmas. + + For instance, if we annotate the following theorem with `progress_pure`: + ``` + @[progress_pure wrapping x y] + theorem U32.wrapping_add_eq (x y : U32) : + (wrapping_add x y).bv = x.bv + y.bv + ``` + `progress_pure` performs operations which are equivalent to introducing the following lemma: + ``` + @[progress] + theorem U32.wrapping_add_eq.progress_spec (x y : U32) : + ∃ z, ↑(wrapping_add x y) = ok z ∧ + z.bv = x.bv + y.bv + ``` + + Note that it is possible to control how existential variables are introduced in the generated lemma + by writing an equality in the pattern we want to abstract over. + For instance if we write: + ``` + @[progress_pure ∃ x y, pos_pair = (x, y)] + theorem pos_pair_is_pos : pos_pair.fst ≥ 0 ∧ pos_pair.snd ≥ 0 + ``` + we get: + ``` + @[progress] + theorem pos_pair_is_pos.progress_spec : + ∃ x y, ↑pos_pair = ok (x, y) ∧ + x ≥ 0 ∧ y ≥ 0 + ``` + + Similarly if we write: + ``` + @[progress_pure ∃ x, pos_pair = x] + theorem pos_pair_is_pos : pos_pair.fst ≥ 0 ∧ pos_pair.snd ≥ 0 + ``` + we get: + ``` + @[progress] + theorem pos_pair_is_pos.progress_spec : + ∃ x, ↑pos_pair = ok x ∧ + x.fst ≥ 0 ∧ y.fst ≥ 0 + ``` + + If we don't put an equality in the pattern, `progress_pure` will introduce one variable + per field in the type of the pattern, if it is a tuple. + -/ +initialize pspecPureAttribute : ProgressPureSpecAttr ← do + let attrImpl : AttributeImpl := { + name := `progress_pure + descr := "Adds lifted version of pure theorems to the `progress_pure` database" + add := fun thName stx attrKind => do + -- Lookup the theorem + let env ← getEnv + -- Ignore some auxiliary definitions (see the comments for attrIgnoreMutRec) + attrIgnoreAuxDef thName (pure ()) do + -- Elaborate the pattern + let pat ← elabProgressPureAttribute stx + -- Introduce the lifted theorem + let liftedThmName ← MetaM.run' (liftThm pat thName) + -- Save the lifted theorem to the `progress` database + saveProgressSpecFromThm pspecAttr.ext attrKind liftedThmName + } + registerBuiltinAttribute attrImpl + pure { attr := attrImpl } + +/-! # Attribute: `progress_pure_def` -/ + +/- The ident is the name of the saturation set, the term is the pattern. -/ +syntax (name := progress_pure_def) "progress_pure_def" (term)? : attr + +def elabProgressPureDefAttribute (stx : Syntax) : AttrM (Option Syntax) := do + if stx[1].isNone then pure none + else pure (stx[1]) + +/-- The progress pure def attribute -/ +structure ProgressPureDefSpecAttr where + attr : AttributeImpl + deriving Inhabited + +def mkProgressPureDefThm (pat : Option Syntax) (n : Name) (suffix : String := "progress_spec") : MetaM Name := do + trace[Progress] "Name: {n}" + let env ← getEnv + let decl := env.constants.find! n + /- Strip the quantifiers before elaborating the pattern -/ + forallTelescope decl.type.consumeMData fun fvars _ => do + let declTerm := mkAppN (.const decl.name (List.map Level.param decl.levelParams)) fvars + /- Elaborate the pattern, if there is -/ + let elabDecomposePat (basename : String) (k : Expr → Array Expr → MetaM Name) : MetaM Name := do + match pat with + | none => + withFreshTupleFieldFVars (.str .anonymous basename) (← inferType declTerm) fun fvars => do + let pat ← mkProdsVal fvars.toList + k pat fvars + | some pat => + /- Elaborate the pattern -/ + let (pat, _) ← Elab.Term.elabTerm pat none |>.run + trace[Progress] "Elaborated pattern: {pat}" + /- Introduce the existentials -/ + existsTelescope pat.consumeMData fun fvarsExists eq => do + /- Destruct the equality -/ + let (lhs, rhs) ← destEq eq + /- Sanity check: the lhs should be equal to the definition -/ + assert! (lhs == declTerm) + /- -/ + k rhs fvarsExists + /- We need to introduce two sets of variables: + - one for the variables bound at the external case + - one for the variables bound in the existential quantifiers + -/ + elabDecomposePat "x" fun decompPatMatch fvarsMatch => do + elabDecomposePat "y" fun _ fvarsExists => do + /- Introduce the lifted and pure equalities -/ + let liftedEq ← do + let okDecompPat ← mkAppM ``Std.Result.ok #[decompPatMatch] + mkEqRefl okDecompPat + let pureEq ← mkEqRefl decompPatMatch + let thm ← mkAppM ``And.intro #[liftedEq, pureEq] + trace[Progress] "Theorem after introducing the lifted and pure equalities: {thm}\n :\n{← inferType thm}" + /- Auxiliary helper which computes the type of the (intermediate) theorems when adding the existentials. + + Given `toResultArg`, `xl0` and `xl1`, generates: + ``` + ∃ $xl1, + toResult $toResultArg = ($xl0 ++ $xl1) ∧ + ($xl0 ++ $xl1) = $toResultArg + ``` + -/ + let mkThmType (toResultArg : Expr) (xl0 : List Expr) (xl1 : List Expr) : MetaM Expr := do + trace[Progress] "mkThmType:\n- {toResultArg}\n- {xl0}\n- {xl1}" + let npatExists ← mkProdsVal (xl0 ++ xl1) + let liftedEqTy ← + mkAppM ``Eq #[← mkAppM ``Std.toResult #[toResultArg], ← mkAppM ``Std.Result.ok #[npatExists]] + let pureEqTy ← mkAppM ``Eq #[npatExists, toResultArg] + let thmType := mkAnd liftedEqTy pureEqTy + trace[Progress] "mkThmType: conjunction: {thmType}" + /- Introduce the existentials, only for the suffix of the list of variables -/ + let thmType ← List.foldrM (fun fvar thmType => do + let p ← mkLambdaFVars #[fvar] thmType + mkAppM ``Exists #[p] + ) thmType xl1 + trace[Progress] "mkThmType: after adding the existentials: {thmType}" + pure thmType + /- Introduce the existentials -/ + let rec introExists (xl0 xl1 : List (Expr × Expr)) : MetaM Expr := do + match xl1 with + | [] => pure thm + | fvarPair :: xl1 => + let thm ← introExists (fvarPair :: xl0) xl1 + let (fvarMatch, fvarExists) := fvarPair + let α ← inferType fvarMatch + let thmType ← mkThmType decompPatMatch (fvarExists :: (List.unzip xl0).fst).reverse (List.unzip xl1).snd + let p ← mkLambdaFVars #[fvarExists] thmType + let x := fvarMatch + let h := thm + trace[Progress] "introExists: about to insert existential:\n- α: {α}\n- p: {p}\n- x: {x}\n- h: {h}" + let thm ← mkAppOptM ``Exists.intro #[α, p, x, h] + trace[Progress] "introExists: resulting theorem:\n{thm}\n :\n{← inferType thm}" + pure thm + let thm ← introExists [] (List.zip fvarsMatch.toList fvarsExists.toList) + trace[Progress] "Theorem after introducing the existentials: {thm} :\n{← inferType thm}" + /- Introduce the matches -/ + let thm ← mkProdsMatch fvarsMatch.toList thm + trace[Progress] "Theorem after introducing the matches: {thm} :\n{← inferType thm}" + /- Apply to the scrutinee (which is the pattern provided by the user): `mkProdsMatch` generates + a lambda expression, where the bound value is the scrutinee we should match over. -/ + let thm := mkApp thm declTerm + trace[Progress] "Theorem after introducing the scrutinee: {thm} :\n{← inferType thm}" + let thm ← mkLambdaFVars fvars thm + /- Prepare the theorem type -/ + let thmType ← do + let thmType ← mkThmType declTerm [] fvarsExists.toList + let thmType ← mkForallFVars fvars thmType + pure thmType + trace[Progress] "Final theorem: {thm}\n :\n{thmType}" + /- Save the auxiliary theorem -/ + let name := Name.str decl.name suffix + let auxDecl : TheoremVal := { + name + levelParams := decl.levelParams + type := thmType + value := thm + } + addDecl (.thmDecl auxDecl) + /- -/ + pure name + +local elab "#progress_pure_def" id:ident pat:(term)? : command => do + Lean.Elab.Command.runTermElabM (fun _ => do + let some cs ← Term.resolveId? id | throwError m!"Unknown id: {id}" + let name := cs.constName! + let _ ← mkProgressPureDefThm pat name) + +namespace Test + def wrapping_add (x y : U8) : U8 := ⟨ x.val + y.val ⟩ + + #progress_pure_def overflowing_add (∃ z, overflowing_add x y = z) + #elab overflowing_add.progress_spec + + #progress_pure_def wrapping_add + + #elab wrapping_add.progress_spec +end Test + +/- Initialize the `progress_lift_def` attribute, which automatically generates + progress lemams for pure definitions. + + For instance, if we annotate the following definition with `progress_pure_def`: + ``` + @[progress_pure_def] + def wrapping_add (x y : U32) : U32 := ... + ``` + `progress_pure_def` performs operations which are equivalent to introducing the following lemma: + ``` + @[progress] + theorem wrapping_add.progress_spec (x y : U32) : + ∃ z, ↑(wrapping_add x y) = ok z ∧ + z = wrapping_add x y + ``` + + Note that `progress_pure_def` takes a,n + -/ +initialize pspecPureDefAttribute : ProgressPureDefSpecAttr ← do + let attrImpl : AttributeImpl := { + name := `progress_pure_def + descr := "Automatically generate `progress` theorems for pure definitions" + add := fun declName stx attrKind => do + -- Lookup the theorem + let env ← getEnv + -- Ignore some auxiliary definitions (see the comments for attrIgnoreMutRec) + attrIgnoreAuxDef declName (pure ()) do + -- Elaborate the pattern + trace[Saturate.attribute] "Syntax: {stx}" + let pat ← elabProgressPureDefAttribute stx + -- Introduce the lifted theorem + let thmName ← MetaM.run' (mkProgressPureDefThm pat declName) + -- Save the lifted theorem to the `progress` database + saveProgressSpecFromThm pspecAttr.ext attrKind thmName + } + registerBuiltinAttribute attrImpl + pure { attr := attrImpl } + end Progress end Aeneas diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index f1effe1c..f09d0a53 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -2,6 +2,7 @@ import Lean import Aeneas.ScalarTac import Aeneas.Progress.Core import Aeneas.Std -- TODO: remove? +import Aeneas.FSimp namespace Aeneas @@ -13,22 +14,23 @@ open Utils -- TODO: the scalar types annoyingly often get reduced when we use the progress -- tactic. We should find a way of controling reduction. For now we use rewriting -- lemmas to make sure the goal remains clean, but this complexifies proof terms. --- It seems there used to be a `fold` tactic. -theorem scalar_isize_eq : Std.Scalar .Isize = Std.Isize := by rfl -theorem scalar_i8_eq : Std.Scalar .I8 = Std.I8 := by rfl -theorem scalar_i16_eq : Std.Scalar .I16 = Std.I16 := by rfl -theorem scalar_i32_eq : Std.Scalar .I32 = Std.I32 := by rfl -theorem scalar_i64_eq : Std.Scalar .I64 = Std.I64 := by rfl -theorem scalar_i128_eq : Std.Scalar .I128 = Std.I128 := by rfl -theorem scalar_usize_eq : Std.Scalar .Usize = Std.Usize := by rfl -theorem scalar_u8_eq : Std.Scalar .U8 = Std.U8 := by rfl -theorem scalar_u16_eq : Std.Scalar .U16 = Std.U16 := by rfl -theorem scalar_u32_eq : Std.Scalar .U32 = Std.U32 := by rfl -theorem scalar_u64_eq : Std.Scalar .U64 = Std.U64 := by rfl -theorem scalar_u128_eq : Std.Scalar .U128 = Std.U128 := by rfl +-- It seems there used to be a `fold` tactic. Update: there is a `refold_let` in Mathlib +theorem uscalar_u8_eq : Std.UScalar .U8 = Std.U8 := by rfl +theorem uscalar_u16_eq : Std.UScalar .U16 = Std.U16 := by rfl +theorem uscalar_u32_eq : Std.UScalar .U32 = Std.U32 := by rfl +theorem uscalar_u64_eq : Std.UScalar .U64 = Std.U64 := by rfl +theorem uscalar_u128_eq : Std.UScalar .U128 = Std.U128 := by rfl +theorem uscalar_usize_eq : Std.UScalar .Usize = Std.Usize := by rfl + +theorem iscalar_i8_eq : Std.IScalar .I8 = Std.I8 := by rfl +theorem iscalar_i16_eq : Std.IScalar .I16 = Std.I16 := by rfl +theorem iscalar_i32_eq : Std.IScalar .I32 = Std.I32 := by rfl +theorem iscalar_i64_eq : Std.IScalar .I64 = Std.I64 := by rfl +theorem iscalar_i128_eq : Std.IScalar .I128 = Std.I128 := by rfl +theorem iscalar_isize_eq : Std.IScalar .Isize = Std.Isize := by rfl def scalar_eqs := [ - ``scalar_isize_eq, ``scalar_i8_eq, ``scalar_i16_eq, ``scalar_i32_eq, ``scalar_i64_eq, ``scalar_i128_eq, - ``scalar_usize_eq, ``scalar_u8_eq, ``scalar_u16_eq, ``scalar_u32_eq, ``scalar_u64_eq, ``scalar_u128_eq + ``uscalar_usize_eq, ``uscalar_u8_eq, ``uscalar_u16_eq, ``uscalar_u32_eq, ``uscalar_u64_eq, ``uscalar_u128_eq, + ``iscalar_isize_eq, ``iscalar_i8_eq, ``iscalar_i16_eq, ``iscalar_i32_eq, ``iscalar_i64_eq, ``iscalar_i128_eq ] inductive TheoremOrLocal where @@ -36,7 +38,7 @@ inductive TheoremOrLocal where | Local (asm : LocalDecl) structure Stats where - usedTheorem : TheoremOrLocal + usedTheorem : Syntax instance : ToMessageData TheoremOrLocal where toMessageData := λ x => match x with | .Theorem thName => m!"{thName}" | .Local asm => m!"{asm.userName}" @@ -53,7 +55,7 @@ inductive ProgressError | Error (msg : MessageData) deriving Inhabited -def progressWith (fExpr : Expr) (th : TheoremOrLocal) +def progressWith (fExpr : Expr) (th : Expr) (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM ProgressError := do /- Apply the theorem @@ -66,21 +68,16 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) We also make sure that all the meta variables which appear in the function arguments have been instantiated -/ - let thTy ← do - match th with - | .Theorem thName => - -- Lookup the theorem and introduce fresh meta-variables for the universes - let th ← mkConstWithFreshMVarLevels thName - -- Retrieve the type - inferType th - | .Local asmDecl => pure asmDecl.type + /- There might be meta-variables in the type if the theorem comes from a local declaration, + especially if this declaration was introduced by a tactic -/ + let thTy ← instantiateMVars (← inferType th) trace[Progress] "Looked up theorem/assumption type: {thTy}" -- Normalize to inline the let-bindings let thTy ← normalizeLetBindings thTy trace[Progress] "After normalizing the let-bindings: {thTy}" -- TODO: the tactic fails if we uncomment withNewMCtxDepth -- withNewMCtxDepth do - let (mvars, binders, thExBody) ← forallMetaTelescope thTy + let (mvars, binders, thExBody) ← forallMetaTelescope thTy.consumeMData trace[Progress] "After stripping foralls: {thExBody}" -- Introduce the existentially quantified variables and the post-condition -- in the context @@ -103,16 +100,14 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) let thBody ← instantiateMVars thBody trace[Progress] "thBody (after instantiation): {thBody}" -- Add the instantiated theorem to the assumptions (we apply it on the metavariables). - let th ← do - match th with - | .Theorem thName => mkAppOptM thName (mvars.map some) - | .Local decl => mkAppOptM' (mkFVar decl.fvarId) (mvars.map some) + let th ← mkAppOptM' th (mvars.map some) + trace[Progress] "Instantiated theorem reusing the metavariables: {th}" let asmName ← do match keep with | none => mkFreshAnonPropUserName | some n => do pure n let thTy ← inferType th trace[Progress] "thTy (after application): {thTy}" - -- Normalize the let-bindings (note that we already inlined the let bindings once above when analizing - -- the theorem, now we do it again on the instantiated theorem - there is probably a smarter way to do, - -- but it doesn't really matter). + /- Normalize the let-bindings (note that we already inlined the let bindings once above when analizing + the theorem, now we do it again on the instantiated theorem - there is probably a smarter way to do, + but it doesn't really matter). -/ -- TODO: actually we might want to let the user insert them in the context let thTy ← normalizeLetBindings thTy trace[Progress] "thTy (after normalizing let-bindings): {thTy}" @@ -149,13 +144,14 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) Tactic.focus do let _ ← tryTac - (simpAt true {} [] [] + (simpAt true {} [] [] [] [``Std.bind_tc_ok, ``Std.bind_tc_fail, ``Std.bind_tc_div, -- Those ones are quite useful to simplify the goal further by eliminating -- existential quantifiers, for instance. ``and_assoc, ``Std.Result.ok.injEq, ``exists_eq_left, ``exists_eq_left', ``exists_eq_right, ``exists_eq_right', - ``exists_eq, ``exists_eq', ``true_and, ``and_true] + ``exists_eq, ``exists_eq', ``true_and, ``and_true, + ``Prod.mk.injEq] [hEq.fvarId!] (.targets #[] true)) -- It may happen that at this point the goal is already solved (though this is rare) -- TODO: not sure this is the best way of checking it @@ -163,7 +159,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) else trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}" -- TODO: remove this (some types get unfolded too much: we "fold" them back) - let _ ← tryTac (simpAt true {} [] [] scalar_eqs [] .wildcard_dep) + let _ ← tryTac (simpAt true {} [] [] [] scalar_eqs [] .wildcard_dep) trace[Progress] "goal after folding back scalar types: {← getMainGoal}" -- Clear the equality, unless the user requests not to do so let mgoal ← do @@ -244,11 +240,12 @@ def getFirstArg (args : Array Expr) : Option Expr := do if args.size = 0 then none else some (args.get! 0) -/- Helper: try to lookup a theorem and apply it. - Return true if it succeeded. -/ -def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) +/-- Helper: try to apply a theorem. + + Return true if it succeeded. -/ +def tryApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) (fExpr : Expr) - (kind : String) (th : Option TheoremOrLocal) : TacticM Bool := do + (kind : String) (th : Option Expr) : TacticM Bool := do let res ← do match th with | none => @@ -268,8 +265,8 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : | none => pure false -- The array of ids are identifiers to use when introducing fresh variables -def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) - (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM TheoremOrLocal := do +def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option Expr) + (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Syntax := do withMainContext do -- Retrieve the goal let mgoal ← Tactic.getMainGoal @@ -297,7 +294,9 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL match withTh with | some th => do match ← progressWith fExpr th keep ids splitPost asmTac with - | .Ok => return th + | .Ok => + -- Remark: exprToSyntax doesn't give the expected result + return ← Lean.Meta.Tactic.TryThis.delabToRefinableSyntax th | .Error msg => throwError msg | none => -- Try all the assumptions one by one and if it fails try to lookup a theorem. @@ -305,9 +304,9 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL let decls ← ctx.getDecls for decl in decls.reverse do trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue + let res ← do try progressWith fExpr decl.toExpr keep ids splitPost asmTac catch _ => continue match res with - | .Ok => return (.Local decl) + | .Ok => return (mkIdent decl.userName) | .Error msg => throwError msg -- It failed: lookup the pspec theorems which match the expression *only -- if the function is a constant* @@ -317,7 +316,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL if ¬ fIsConst then throwError "Progress failed" else do trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem" - let pspecs : Array TheoremOrLocal ← do + let pspecs : Array Name ← do let thNames ← pspecAttr.find? fExpr -- TODO: because of reduction, there may be several valid theorems (for -- instance for the scalars). We need to sort them from most specific to @@ -325,10 +324,12 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL -- the end. let thNames := thNames.reverse trace[Progress] "Looked up pspec theorems: {thNames}" - pure (thNames.map fun th => TheoremOrLocal.Theorem th) + pure thNames -- Try the theorems one by one for pspec in pspecs do - if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return pspec + let pspecExpr ← Term.mkConst pspec + if ← tryApply keep ids splitPost asmTac fExpr "pspec theorem" pspecExpr + then return (mkIdent pspec) else pure () -- It failed: try to use the recursive assumptions trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" @@ -339,16 +340,17 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL | .default | .implDetail => false | .auxDecl => true) for decl in decls.reverse do trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue + let res ← do try progressWith fExpr decl.toExpr keep ids splitPost asmTac catch _ => continue match res with - | .Ok => return (.Local decl) + | .Ok => return (mkIdent decl.userName) | .Error msg => throwError msg -- Nothing worked: failed throwError "Progress failed" -syntax progressArgs := ("keep" (ident <|> "_"))? ("with" ident)? ("as" " ⟨ " (ident <|> "_"),* " ⟩")? +syntax progressArgs := ("keep" (ident <|> "_"))? ("with" term)? ("as" " ⟨ " (ident <|> "_"),* " ⟩")? def evalProgress (args : TSyntax `Aeneas.Progress.progressArgs) : TacticM Stats := do + withMainContext do let args := args.raw -- Process the arguments to retrieve the identifiers to use trace[Progress] "Progress arguments: {args}" @@ -370,19 +372,32 @@ def evalProgress (args : TSyntax `Aeneas.Progress.progressArgs) : TacticM Stats let withArg ← do let withArg := withArg.getArgs if withArg.size > 0 then - let id := withArg.get! 1 - trace[Progress] "With arg: {id}" - -- Attempt to lookup a local declaration - match (← getLCtx).findFromUserName? id.getId with - | some decl => do - trace[Progress] "With arg: local decl" - pure (some (.Local decl)) - | none => do - -- Not a local declaration: should be a theorem - trace[Progress] "With arg: theorem" - addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none - let some (.const name _) ← Term.resolveId? id | throwError m!"Could not find theorem: {id}" - pure (some (.Theorem name)) + let pspec := withArg.get! 1 + trace[Progress] "With arg: {pspec}" + /- The theorem with which to make progress is either: + - the identifier of a local declaration or a theroem + - a term + We have to make a case disjunction, because if we treat identifiers like + terms, then Lean will not succeed in infering their implicit parameters + (`progress` does that by matching against the goal). + -/ + if pspec.isIdent then + -- Attempt to lookup a local declaration + match (← getLCtx).findFromUserName? pspec.getId with + | some decl => do + trace[Progress] "With arg: local decl" + pure (some decl.toExpr) + | none => do + -- Not a local declaration: should be a theorem + trace[Progress] "With arg: theorem" + addCompletionInfo <| CompletionInfo.id pspec pspec.getId (danglingDot := false) {} none + let some e ← Term.resolveId? pspec (withInfo := true) | throwError m!"Could not find theorem: {pspec}" + pure (some e) + else + trace[Progress] "With arg: is term" + let pspec ← Tactic.elabTerm pspec none + trace[Progress] m!"With arg: elaborated expression {pspec}" + pure (some pspec) else pure none let ids := let args := asArgs.getArgs @@ -399,12 +414,13 @@ def evalProgress (args : TSyntax `Aeneas.Progress.progressArgs) : TacticM Stats if ← ScalarTac.goalIsLinearInt then -- Also: we don't try to split the goal if it is a conjunction -- (it shouldn't be), but we split the disjunctions. - ScalarTac.scalarTac true false + ScalarTac.scalarTac { split := false, fastSaturate := true } else throwError "Not a linear arithmetic goal" + let simpLemmas ← Aeneas.ScalarTac.scalarTacSimpExt.getTheorems let simpTac : TacticM Unit := do -- Simplify the goal - Utils.simpAt false {} [] [] [] [] (.targets #[] true) + Utils.simpAt false {} [] [simpLemmas] [] [] [] (.targets #[] true) -- Raise an error if the goal is not proved allGoalsNoRecover (throwError "Goal not proved") -- We use our custom assumption tactic, which instantiates meta-variables only if there is a single @@ -427,7 +443,7 @@ elab tk:"progress?" args:progressArgs : tactic => do let stats ← evalProgress args let mut stxArgs := args.raw if stxArgs[1].isNone then - let withArg := mkNullNode #[mkAtom "with", mkIdent stats.usedTheorem] + let withArg := mkNullNode #[mkAtom "with", stats.usedTheorem] stxArgs := stxArgs.setArg 1 withArg let tac := mkNode `Aeneas.Progress.progress #[mkAtom "progress", stxArgs] Meta.Tactic.TryThis.addSuggestion tk tac (origSpan? := ← getRef) @@ -435,42 +451,60 @@ elab tk:"progress?" args:progressArgs : tactic => do namespace Test open Std Result - -- Show the traces + -- Show the traces: -- set_option trace.Progress true -- set_option pp.rawOnError true set_option says.verify true - -- The following commands display the databases of theorems + -- The following command displays the database of theorems: -- #eval showStoredPSpec open alloc.vec - example {ty} {x y : Scalar ty} - (hmin : Scalar.min ty ≤ x.val + y.val) - (hmax : x.val + y.val ≤ Scalar.max ty) : + example {ty} {x y : UScalar ty} + (hmax : x.val + y.val ≤ UScalar.max ty) : + ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by + progress keep _ as ⟨ z, h1 ⟩ + simp [*, h1] + + example {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ x.val + y.val) + (hmax : x.val + y.val ≤ IScalar.max ty) : ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by progress keep _ as ⟨ z, h1 ⟩ simp [*, h1] - example {ty} {x y : Scalar ty} - (hmin : Scalar.min ty ≤ x.val + y.val) - (hmax : x.val + y.val ≤ Scalar.max ty) : + example {ty} {x y : UScalar ty} + (hmax : x.val + y.val ≤ UScalar.max ty) : + ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by + progress? keep _ as ⟨ z, h1 ⟩ says progress keep _ with Aeneas.Std.UScalar.add_spec as ⟨ z, h1 ⟩ + simp [*, h1] + + example {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ x.val + y.val) + (hmax : x.val + y.val ≤ IScalar.max ty) : ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by - progress? keep _ as ⟨ z, h1 ⟩ says progress keep _ with Aeneas.Std.Scalar.add_spec as ⟨ z, h1 ⟩ + progress? keep _ as ⟨ z, h1 ⟩ says progress keep _ with Aeneas.Std.IScalar.add_spec as ⟨ z, h1 ⟩ simp [*, h1] - example {ty} {x y : Scalar ty} - (hmin : Scalar.min ty ≤ x.val + y.val) - (hmax : x.val + y.val ≤ Scalar.max ty) : + example {ty} {x y : UScalar ty} + (hmax : x.val + y.val ≤ UScalar.max ty) : + ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by + progress keep h with UScalar.add_spec as ⟨ z ⟩ + simp [*, h] + + example {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ x.val + y.val) + (hmax : x.val + y.val ≤ IScalar.max ty) : ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by - progress keep h with Scalar.add_spec as ⟨ z ⟩ + progress keep h with IScalar.add_spec as ⟨ z ⟩ simp [*, h] example {x y : U32} (hmax : x.val + y.val ≤ U32.max) : ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by -- This spec theorem is suboptimal, but it is good to check that it works - progress with Scalar.add_spec as ⟨ z, h1 ⟩ + progress with UScalar.add_spec as ⟨ z, h1 ⟩ simp [*, h1] example {x y : U32} @@ -489,16 +523,16 @@ namespace Test `α : Type u` where u is quantified, while here we use `α : Type 0` -/ example {α : Type} (v: Vec α) (i: Usize) (x : α) (hbounds : i.val < v.length) : - ∃ nv, v.update_usize i x = ok nv ∧ - nv.val = v.val.update i.toNat x := by + ∃ nv, v.update i x = ok nv ∧ + nv.val = v.val.set i.val x := by progress simp [*] example {α : Type} (v: Vec α) (i: Usize) (x : α) (hbounds : i.val < v.length) : - ∃ nv, v.update_usize i x = ok nv ∧ - nv.val = v.val.update i.toNat x := by - progress? says progress with Aeneas.Std.alloc.vec.Vec.update_usize_spec + ∃ nv, v.update i x = ok nv ∧ + nv.val = v.val.set i.val x := by + progress? says progress with Aeneas.Std.alloc.vec.Vec.update_spec simp [*] /- Checking that progress can handle nested blocks -/ @@ -507,7 +541,7 @@ namespace Test ∃ nv, (do (do - let _ ← v.update_usize i x + let _ ← v.update i x .ok ()) .ok ()) = ok nv := by @@ -531,14 +565,21 @@ namespace Test /- The use of `right` introduces a meta-variable in the goal, that we need to instantiate (otherwise `progress` gets stuck) -/ - example {ty} {x y : Scalar ty} - (hmin : Scalar.min ty ≤ x.val + y.val) - (hmax : x.val + y.val ≤ Scalar.max ty) : + example {ty} {x y : UScalar ty} + (hmax : x.val + y.val ≤ UScalar.max ty) : False ∨ (∃ z, x + y = ok z ∧ z.val = x.val + y.val) := by right progress keep _ as ⟨ z, h1 ⟩ simp [*, h1] + example {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ x.val + y.val) + (hmax : x.val + y.val ≤ IScalar.max ty) : + False ∨ (∃ z, x + y = ok z ∧ z.val = x.val + y.val) := by + right + progress? keep _ as ⟨ z, h1 ⟩ says progress keep _ with Aeneas.Std.IScalar.add_spec as ⟨ z, h1 ⟩ + simp [*, h1] + -- Testing with mutually recursive definitions mutual inductive Tree @@ -563,45 +604,56 @@ namespace Test ok (s + s') end - mutual - @[pspec] - theorem Tree.size_spec (t : Tree) : - ∃ i, t.size = ok i ∧ i ≥ 0 := by - cases t - simp [Tree.size] - progress - omega - - @[pspec] - theorem Trees.size_spec (t : Trees) : - ∃ i, t.size = ok i ∧ i ≥ 0 := by - cases t <;> simp [Trees.size] - progress - progress - omega + section + mutual + @[local progress] + theorem Tree.size_spec (t : Tree) : + ∃ i, t.size = ok i ∧ i ≥ 0 := by + cases t + simp [Tree.size] + progress + omega + + @[local progress] + theorem Trees.size_spec (t : Trees) : + ∃ i, t.size = ok i ∧ i ≥ 0 := by + cases t <;> simp [Trees.size] + progress + progress? says progress with Trees.size_spec + omega + end end -- Testing progress on theorems containing local let-bindings def add (x y : U32) : Result U32 := x + y - @[pspec] -- TODO: give the possibility of using pspec as a local attribute - theorem add_spec (x y : U32) (h : x.val + y.val ≤ U32.max) : - let tot := x.val + y.val - ∃ z, add x y = ok z ∧ z.val = tot := by - rw [add] - intro tot - progress - simp [*, tot] - - def add1 (x y : U32) : Result U32 := do - let z ← add x y - add z z + section + /- Testing progress on theorems containing local let-bindings as well as + the `local` attribute kind -/ + @[local progress] theorem add_spec' (x y : U32) (h : x.val + y.val ≤ U32.max) : + let tot := x.val + y.val + ∃ z, x + y = ok z ∧ z.val = tot := by + simp + progress with U32.add_spec + scalar_tac + + def add1 (x y : U32) : Result U32 := do + let z ← x + y + z + z + + example (x y : U32) (h : 2 * x.val + 2 * y.val ≤ U32.max) : + ∃ z, add1 x y = ok z := by + rw [add1] + progress? as ⟨ z1, h ⟩ says progress with Aeneas.Progress.Test.add_spec' as ⟨ z1, h ⟩ + progress? as ⟨ z2, h ⟩ says progress with Aeneas.Progress.Test.add_spec' as ⟨ z2, h ⟩ + end + /- Checking that `add_spec'` went out of scope -/ example (x y : U32) (h : 2 * x.val + 2 * y.val ≤ U32.max) : ∃ z, add1 x y = ok z := by rw [add1] - progress as ⟨ z1, h ⟩ - progress as ⟨ z2, h ⟩ + progress? as ⟨ z1, h ⟩ says progress with Aeneas.Std.U32.add_spec as ⟨ z1, h ⟩ + progress? as ⟨ z2, h ⟩ says progress with Aeneas.Std.U32.add_spec as ⟨ z2, h ⟩ variable (P : ℕ → List α → Prop) variable (f : List α → Result Bool) @@ -609,7 +661,205 @@ namespace Test example (l : List α) (h : P i l) : ∃ b, f l = ok b := by - progress as ⟨ b ⟩ + progress? as ⟨ b ⟩ says progress with f_spec as ⟨ b ⟩ + + /- Progress using a term -/ + example {x: U32} + (f : U32 → Result Unit) + (h : ∀ x, f x = .ok ()): + f x = ok () := by + progress? with (show ∀ x, f x = .ok () by exact h) says progress with(show ∀ x, f x = .ok () by exact h) + + /- Progress using a term -/ + example (x y : U32) (h : 2 * x.val + 2 * y.val ≤ U32.max) : + ∃ z, add1 x y = ok z := by + rw [add1] + have h1 := add_spec' + progress with h1 as ⟨ z1, h ⟩ + progress with add_spec' z1 as ⟨ z2, h ⟩ + + namespace Ntt + def wfArray (_ : Array U16 256#usize) : Prop := True + + def nttLayer (a : Array U16 256#usize) (_k : Usize) (_len : Usize) : Result (Array U16 256#usize) := ok a + + def toPoly (a : Array U16 256#usize) : List U16 := a.val + + def Spec.nttLayer (a : List U16) (_ : Nat) (len : Nat) (_ : Nat) (_ : 0 < len) : List U16 := a + + @[local progress] + theorem nttLayer_spec + (peSrc : Array U16 256#usize) + (k : Usize) (len : Usize) + (_ : wfArray peSrc) + (_ : k.val = 2^(k.val.log2) ∧ k.val.log2 < 7) + (_ : len.val = 128 / k.val) + (hLenPos : 0 < len.val) : + ∃ peSrc', nttLayer peSrc k len = ok peSrc' ∧ + toPoly peSrc' = Spec.nttLayer (toPoly peSrc) k.val len.val 0 hLenPos ∧ + wfArray peSrc' := by + simp [wfArray, nttLayer, toPoly, Spec.nttLayer] + + def ntt (x : Array U16 256#usize) : Result (Array U16 256#usize) := do + let x ← nttLayer x 1#usize 128#usize + let x ← nttLayer x 2#usize 64#usize + let x ← nttLayer x 4#usize 32#usize + let x ← nttLayer x 8#usize 16#usize + let x ← nttLayer x 16#usize 8#usize + let x ← nttLayer x 32#usize 4#usize + let x ← nttLayer x 64#usize 2#usize + let x ← nttLayer x 64#usize 2#usize + let x ← nttLayer x 64#usize 2#usize + let x ← nttLayer x 64#usize 2#usize + let x ← nttLayer x 64#usize 2#usize + let x ← nttLayer x 64#usize 2#usize + let x ← nttLayer x 64#usize 2#usize + ok x + + set_option maxHeartbeats 800000 + + /- + simp took 24.6ms + simp took 18.3ms + tactic execution of Aeneas.Progress.progress took 43.1ms + simp took 13.8ms + simp took 21.1ms + simp took 17ms + tactic execution of Aeneas.Progress.progress took 115ms + simp took 18.2ms + simp took 20.7ms + simp took 17.4ms + tactic execution of Aeneas.Progress.progress took 189ms + simp took 22.8ms + simp took 21.8ms + simp took 17.1ms + tactic execution of Aeneas.Progress.progress took 259ms + simp took 28.9ms + simp took 21.4ms + simp took 17.7ms + tactic execution of Aeneas.Progress.progress took 324ms + simp took 33.9ms + simp took 21.7ms + simp took 17.7ms + tactic execution of Aeneas.Progress.progress took 407ms + simp took 39.1ms + simp took 21.5ms + simp took 17.8ms + tactic execution of Aeneas.Progress.progress took 483ms + simp took 44ms + simp took 21ms + simp took 17.7ms + tactic execution of Aeneas.Progress.progress took 563ms + simp took 44.6ms + simp took 21.7ms + simp took 17.7ms + tactic execution of Aeneas.Progress.progress took 631ms + simp took 45.1ms + simp took 21.7ms + simp took 17.5ms + tactic execution of Aeneas.Progress.progress took 706ms + simp took 44.6ms + simp took 21.9ms + simp took 18.2ms + tactic execution of Aeneas.Progress.progress took 789ms + simp took 45.5ms + simp took 21.1ms + simp took 18.7ms + tactic execution of Aeneas.Progress.progress took 864ms + simp took 45.4ms + simp took 22.6ms + dsimp took 11.3ms + simp took 18.5ms + tactic execution of Aeneas.Progress.progress took 951ms + simp took 46.5ms + tactic execution of Lean.Parser.Tactic.tacticSeq1Indented took 19ms + type checking took 81.3ms + + After using `saturateFast` in `scalar_tac`: + simp took 26.2ms + simp took 20.6ms + simp took 10ms + tactic execution of Aeneas.Progress.progress took 20.9ms + simp took 21.9ms + simp took 18.5ms + tactic execution of Aeneas.Progress.progress took 23.1ms + simp took 18.1ms + simp took 21.6ms + simp took 18ms + tactic execution of Aeneas.Progress.progress took 38.8ms + simp took 18ms + simp took 23.1ms + simp took 17.6ms + simp took 10.3ms + tactic execution of Aeneas.Progress.progress took 31.8ms + simp took 19.8ms + simp took 22.1ms + simp took 18ms + tactic execution of Aeneas.Progress.progress took 34.9ms + simp took 22.9ms + simp took 22.8ms + simp took 17.8ms + tactic execution of Aeneas.Progress.progress took 40.9ms + simp took 26.5ms + simp took 23.1ms + simp took 20.1ms + simp took 19.4ms + simp took 10.1ms + tactic execution of Aeneas.Progress.progress took 48.2ms + simp took 29ms + simp took 22.9ms + simp took 18.6ms + tactic execution of Aeneas.Progress.progress took 51.1ms + simp took 29.1ms + simp took 22.5ms + simp took 19.2ms + tactic execution of Aeneas.Progress.progress took 56.1ms + simp took 29.2ms + simp took 22.5ms + simp took 19.5ms + simp took 10.3ms + tactic execution of Aeneas.Progress.progress took 60ms + simp took 29.7ms + simp took 23.4ms + simp took 19.6ms + simp took 10.5ms + tactic execution of Aeneas.Progress.progress took 67.4ms + simp took 29.1ms + simp took 23.6ms + simp took 18.6ms + simp took 10.7ms + tactic execution of Aeneas.Progress.progress took 70.1ms + simp took 30ms + simp took 24ms + simp took 20.1ms + simp took 10.5ms + tactic execution of Aeneas.Progress.progress took 76.5ms + simp took 28.7ms + tactic execution of Lean.Parser.Tactic.tacticSeq1Indented took 17.4ms + type checking took 86.5ms + -/ + theorem ntt_spec (peSrc : Std.Array U16 256#usize) + (hWf : wfArray peSrc) : + ∃ peSrc1, ntt peSrc = ok peSrc1 ∧ + wfArray peSrc1 + := by + unfold ntt + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + progress; fsimp [Nat.log2] + assumption + + end Ntt end Test diff --git a/backends/lean/Aeneas/Progress/Trace.lean b/backends/lean/Aeneas/Progress/Trace.lean new file mode 100644 index 00000000..246fccd8 --- /dev/null +++ b/backends/lean/Aeneas/Progress/Trace.lean @@ -0,0 +1,9 @@ +import Lean +open Lean Elab Term Meta + +namespace Aeneas.Progress + +-- We can't define and use trace classes in the same file +initialize registerTraceClass `Progress + +end Aeneas.Progress diff --git a/backends/lean/Aeneas/Range.lean b/backends/lean/Aeneas/Range.lean new file mode 100644 index 00000000..ea97e5bc --- /dev/null +++ b/backends/lean/Aeneas/Range.lean @@ -0,0 +1,4 @@ +import Aeneas.Range.DivRange +import Aeneas.Range.MulRange +import Aeneas.Range.Notations +import Aeneas.Range.SRRange diff --git a/backends/lean/Aeneas/Range/DivRange.lean b/backends/lean/Aeneas/Range/DivRange.lean new file mode 100644 index 00000000..e1d6e3e1 --- /dev/null +++ b/backends/lean/Aeneas/Range/DivRange.lean @@ -0,0 +1,3 @@ +import Aeneas.Range.DivRange.Basic +import Aeneas.Range.DivRange.Lemmas +import Aeneas.Range.DivRange.Notations diff --git a/backends/lean/Aeneas/Range/DivRange/Basic.lean b/backends/lean/Aeneas/Range/DivRange/Basic.lean new file mode 100644 index 00000000..90d7c581 --- /dev/null +++ b/backends/lean/Aeneas/Range/DivRange/Basic.lean @@ -0,0 +1,107 @@ +import Mathlib.Data.Nat.Defs +import Mathlib.Algebra.Group.Basic +import Aeneas.Utils + +namespace Aeneas + +-- TODO: move +/-- A "structural recursion" range type, that we use to implement for + loops with structural induction. + + This is the same as `Std.Range`, but with a slighly different implementation + of the loop inside the `forIn'` function, for which we introduce a fuel parameter. + + We do this because of issues with the kernel reducing definitions eagerly, leading + to explosions in the presence of well-founded recursion. This this: + https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/simp.20taking.20a.20long.20time.20on.20a.20small.20definition/near/495050322 + -/ +structure DivRange where + start : Nat := 0 + stop : Nat + divisor : Nat + divisor_pos : 1 < divisor + +instance : Membership Nat DivRange where + mem r i := r.stop < i ∧ i ≤ r.start ∧ ∃ k, i = r.start / r.divisor ^ k + +namespace DivRange +universe u v + +@[inline] protected def forIn' [Monad m] (range : DivRange) (init : β) + (f : (i : Nat) → i ∈ range → β → m (ForInStep β)) : m β := + let rec @[specialize] loop (maxSteps : Nat) (b : β) (i : Nat) + (hs : ∃ k, i = range.start / range.divisor ^ k) + (hl : i ≤ range.start) : m β := do + -- Introduce structural induction + match maxSteps with + | 0 => pure b + | maxSteps+1 => + if h : range.stop < i then + match (← f i ⟨h, hl, hs⟩ b) with + | .done b => pure b + | .yield b => + have := range.divisor_pos + loop maxSteps b (i / range.divisor) + (by + have ⟨ k, hk ⟩ := hs + exists (k + 1) + simp only [hk] + simp [Nat.div_div_eq_div_mul, Nat.pow_add_one]) + (by + have := @Nat.div_le_self i range.divisor + omega) + else + pure b + loop (range.start + 1) init range.start (by exists 0; simp) (by omega) + +instance : ForIn' m DivRange Nat inferInstance where + forIn' := DivRange.forIn' + +-- No separate `ForIn` instance is required because it can be derived from `ForIn'`. + +end DivRange + +/-! +We now introduce a convenient `DivRange` definition +-/ + +-- TODO: don't use a fuel +def divRange (start stop div : Nat) : List Nat := + let rec loop (fuel i : Nat) := + match fuel with + | 0 => [] + | fuel + 1 => + if i > stop then + i :: loop fuel (i / div) + else [] + loop (start + 1) start + +namespace DivRange + +/-- A convenient utility for the proofs -/ +def foldWhile' {α : Type u} (r : DivRange) (f : α → (a : Nat) → (a ∈ r) → α) (i : Nat) (init : α) + (hi : i ≤ r.start ∧ ∃ k, i = r.start / r.divisor ^ k) : α := + if h: r.stop < i then + foldWhile' r f (i / r.divisor) + (f init i (by simp [Membership.mem]; split_conjs <;> simp [*])) + (by split_conjs + . have := Nat.div_le_self i r.divisor; omega + . have ⟨ k, hk ⟩ := hi.right + exists k + 1 + simp [hk, Nat.div_div_eq_div_mul, ← Nat.pow_add_one]) + else init +termination_by i +decreasing_by apply Nat.div_lt_self; omega; apply r.divisor_pos + +/-- A convenient utility for the proofs -/ +def foldWhile {α : Type u} (stop divisor : Nat) (hDiv : 1 < divisor) + (f : α → (a : Nat) → α) (i : Nat) (init : α) : α := +if stop < i then + foldWhile stop divisor hDiv f (i / divisor) (f init i) + else init +termination_by i +decreasing_by apply Nat.div_lt_self; omega; apply hDiv + +end DivRange + +end Aeneas diff --git a/backends/lean/Aeneas/Range/DivRange/Lemmas.lean b/backends/lean/Aeneas/Range/DivRange/Lemmas.lean new file mode 100644 index 00000000..260aa635 --- /dev/null +++ b/backends/lean/Aeneas/Range/DivRange/Lemmas.lean @@ -0,0 +1,237 @@ +import Mathlib.Data.Nat.Log +import Mathlib.Algebra.Order.Sub.Defs +import Aeneas.Range.DivRange.Basic + +namespace Aeneas + +-- Auxiliary lemma - TODO: move? +theorem pow_ineq (start divisor : Nat) (hDiv : 1 < divisor) : + start ≤ divisor ^ (start + 1) := by + have h0 := Nat.log_le_self divisor start + have h1 : start < divisor ^ (Nat.log divisor start + 1) := + Nat.lt_pow_succ_log_self hDiv start + have h3 : Nat.log divisor start + 1 ≤ start + 1 := by omega + have := @Nat.pow_le_pow_of_le_right divisor (by omega) _ _ h3 + omega + +namespace DivRange + +/-! +# Lemmas about `DivRange` + +We provide lemmas rewriting for loops over `DivRange` in terms of `List.range'`. +-/ + +@[simp] +private theorem divRange_loop_zero (stop divisor fuel : Nat) : + divRange.loop stop divisor fuel 0 = [] := by + cases fuel <;> simp [divRange.loop] + +private theorem mem_of_mem_divRange_loop_aux + (fuel : Nat) : + ∀ (start stop divisor a : Nat), + 1 < divisor → + start ≤ divisor ^ fuel → + a ∈ divRange.loop stop divisor fuel start → + stop < a ∧ a ≤ start ∧ ∃ k, a = start / divisor ^ k + := by + induction fuel <;> intros start stop divisor a hDiv hStartLe hMem + . simp only [Nat.pow_zero] at hStartLe + unfold divRange.loop at hMem + cases hMem + . rename_i fuel hInd + simp only [divRange.loop, gt_iff_lt, List.mem_ite_nil_right, List.mem_cons] at hMem + replace ⟨ hIneq, hMem ⟩ := hMem + cases hMem + . simp_all only [le_refl, true_and] + exists 0 + simp + . rename_i hMem + have hPowIneq : start / divisor ≤ divisor ^ fuel := by + have h := @Nat.div_le_div_right start (divisor ^ (fuel + 1)) divisor hStartLe + simp only [Nat.pow_add_one'] at h + have := @Nat.mul_div_cancel_left (divisor ^ fuel) divisor (by omega) + simp_all + replace hInd := hInd (start / divisor) stop divisor a (by omega) hPowIneq hMem + have : a ≤ start := by + have := Nat.div_le_self start divisor + omega + simp only [true_and, hInd, this] + have ⟨ k, hkEq ⟩ := hInd.right.right + exists (k + 1) + simp [hkEq, Nat.div_div_eq_div_mul, Nat.pow_add_one'] + +private theorem mem_of_mem_divRange (r : DivRange) (a : Nat) + (h : a ∈ divRange r.start r.stop r.divisor) : a ∈ r := by + have hDiv := r.divisor_pos + have h0 := Nat.log_le_self r.divisor r.start + have h1 : r.start < r.divisor ^ (Nat.log r.divisor r.start + 1) := + Nat.lt_pow_succ_log_self hDiv r.start + have h2 : r.start + 1 ≤ r.divisor ^ (Nat.log r.divisor r.start + 1) := by omega + have h3 : Nat.log r.divisor r.start + 1 ≤ r.start + 1 := by omega + have := @Nat.pow_le_pow_of_le_right r.divisor (by omega) _ _ h3 + have hStartLe : r.start ≤ r.divisor ^ (r.start + 1) := by omega + + have := mem_of_mem_divRange_loop_aux (r.start + 1) r.start r.stop r.divisor a hDiv + hStartLe (by simp_all [divRange]) + simp [Membership.mem, this] + +private theorem mem_of_mem_divRange_loop (r : DivRange) (i : Nat) (fuel a : Nat) + (hStart : i ≤ r.start) + (hFuel : i ≤ r.divisor ^ fuel) + (hᵢ : ∃ k, i = r.start / r.divisor ^ k) + (hMem : a ∈ divRange.loop r.stop r.divisor fuel i) : + r.stop < a ∧ a ≤ r.start ∧ ∃ k, a = r.start / r.divisor ^ k + := by + have h := mem_of_mem_divRange_loop_aux fuel i r.stop r.divisor a r.divisor_pos hFuel hMem + split_conjs + . omega + . omega + . have ⟨ k, hk ⟩ := hᵢ + have ⟨ k', hk' ⟩ := h.right.right + exists (k + k') + simp [*, Nat.div_div_eq_div_mul, Nat.pow_add] + +private theorem forIn'_loop_eq_forIn'_divRange [Monad m] (r : DivRange) + (fuel : Nat) (init : β) (f : (a : Nat) → a ∈ r → β → m (ForInStep β)) (i) (hk : ∃ k, i = r.start / r.divisor ^ k) + (hStart : i ≤ r.start) (hFuel : i ≤ r.divisor ^ fuel) : + forIn'.loop r f fuel init i hk hStart = + forIn' (divRange.loop r.stop r.divisor fuel i) init + fun a h => + f a (mem_of_mem_divRange_loop r i fuel a hStart hFuel hk h) := by + cases fuel + . rw [forIn'.loop] + simp [divRange.loop] + . rename_i fuel + simp only [forIn'.loop, divRange.loop, gt_iff_lt] + dcases hStop : r.stop < i <;> simp only [hStop, ↓reduceDIte, ↓reduceIte, List.not_mem_nil, + IsEmpty.forall_iff, implies_true, List.forIn'_nil, List.forIn'_cons] + apply letFun_val_congr + apply funext + intro x + cases x + . simp + . rename_i x + simp only + replace ⟨ k, hk ⟩ := hk + have hiDiv : ∃ k, i / r.divisor = r.start / r.divisor ^ k := by + exists (k + 1) + simp [hk, Nat.div_div_eq_div_mul, Nat.pow_add] + have hiLe : i / r.divisor ≤ r.start := by + have := Nat.div_le_self i r.divisor + omega + have hiDivLe : i / r.divisor ≤ r.divisor ^ fuel := by + have h1 := @Nat.div_le_div_right _ _ r.divisor hFuel + have h2 : r.divisor = r.divisor ^ 1 := by simp + conv at h1 => rhs; rhs; rw [h2] + rw [Nat.pow_div] at h1 <;> try omega + . simp only [Nat.add_one_sub_one] at h1 + apply h1 + . have := r.divisor_pos + omega + have hEq := forIn'_loop_eq_forIn'_divRange r fuel x f (i / r.divisor) hiDiv hiLe hiDivLe + simp [hEq] + +-- Auxiliary lemma +private theorem pow_ineq (r: DivRange) : + r.start ≤ r.divisor ^ (r.start + 1) := by + cases r; apply Aeneas.pow_ineq; simp; assumption + +@[simp] theorem forIn_eq_forIn_divRange [Monad m] (r : DivRange) + (init : β) (f : Nat → β → m (ForInStep β)) : + forIn r init f = forIn (divRange r.start r.stop r.divisor ) init f := by + simp only [forIn, forIn', divRange, DivRange.forIn'] + rw [forIn'_loop_eq_forIn'_divRange] + . simp + . apply pow_ineq + +@[simp] theorem forIn'_eq_forIn_divRange [Monad m] (r : DivRange) + (init : β) (f : (a:Nat) → (a ∈ r) → β → m (ForInStep β)) : + forIn' r init f = + forIn' (divRange r.start r.stop r.divisor ) init + (fun a h => f a (mem_of_mem_divRange r a h)) := by + simp only [forIn, forIn', divRange, DivRange.forIn'] + rw [forIn'_loop_eq_forIn'_divRange] + . simp + . apply pow_ineq + +@[simp] +def foldWhile'_step {α : Type u} (r : DivRange) (f : α → (a : Nat) → a ∈ r → α) (i : Nat) (init : α) + (hi : i ≤ r.start ∧ ∃ k, i = r.start / r.divisor ^ k) + (h : r.stop < i) : + foldWhile' r f i init hi = + foldWhile' r f (i / r.divisor) + (f init i (by simp only [Membership.mem]; split_conjs <;> simp [*])) + (by split_conjs + . have := Nat.div_le_self i r.divisor; omega + . have ⟨ k, hk ⟩ := hi.right + exists k + 1 + simp [hk, Nat.div_div_eq_div_mul, ← Nat.pow_add_one]) + := by + conv => lhs; unfold foldWhile' + simp [*] + +@[simp] +def foldWhile'_id {α : Type u} (r : DivRange) (f : α → (a : Nat) → a ∈ r → α) (i : Nat) (init : α) + (hi : i ≤ r.start ∧ ∃ k, i = r.start / r.divisor ^ k) + (h : ¬ r.stop < i) : + foldWhile' r f i init hi = init + := by + conv => lhs; unfold foldWhile' + simp [*] + +@[simp] +def foldWhile_step {α : Type u} (stop divisor : Nat) (hDiv : 1 < divisor) + (f : α → Nat → α) (i : Nat) (init : α) (h : stop < i) : + foldWhile stop divisor hDiv f i init = foldWhile stop divisor hDiv f (i / divisor) (f init i) := by + conv => lhs; unfold foldWhile + simp [*] + +@[simp] +def foldWhile_id {α : Type u} (stop divisor : Nat) (hDiv : 1 < divisor) + (f : α → Nat → α) (i : Nat) (init : α) (h : ¬ stop < i) : + foldWhile stop divisor hDiv f i init = init := by + conv => lhs; unfold foldWhile + simp [*] + +private theorem divRange.loop_le_maxSteps_eq (stop div maxSteps start : Nat) (hDiv : 1 < div) (hMaxSteps : start + 1 ≤ maxSteps) : + divRange.loop stop div maxSteps start = divRange.loop stop div (start + 1) start := by + dcases maxSteps + . omega + . rename_i maxSteps + unfold divRange.loop + dcases h: stop < start + . simp only [gt_iff_lt, h, ↓reduceIte, List.cons.injEq, true_and] + have : start / div < start := by apply Nat.div_lt_self <;> omega + have h1 : start / div + 1 ≤ maxSteps := by omega + have := divRange.loop_le_maxSteps_eq stop div maxSteps (start / div) hDiv h1 + rw [this] + have h2 : start / div + 1 ≤ start := by omega + have := divRange.loop_le_maxSteps_eq stop div start (start / div) hDiv h2 + rw [this] + .simp [h] + +private theorem foldl_divRange_loop_foldWhile (start stop div maxSteps : Nat) (hMaxSteps : start + 1 ≤ maxSteps) + (hDiv : 1 < div) (f : α → Nat → α) (init : α) : + List.foldl f init (divRange.loop stop div maxSteps start) = foldWhile stop div hDiv f start init := by + dcases maxSteps + . omega + . rename_i maxSteps + unfold divRange.loop foldWhile + dcases h: stop < start + . simp only [gt_iff_lt, h, ↓reduceIte, List.foldl_cons] + rw [foldl_divRange_loop_foldWhile] + have : start / div < start := by apply Nat.div_lt_self <;> omega + omega + . simp [h] + +@[simp] +theorem foldl_divRange_foldWhile (start stop div : Nat) (hDiv : 1 < div) (f : α → Nat → α) (init : α) : + List.foldl f init (divRange start stop div) = foldWhile stop div hDiv f start init := by + unfold divRange + rw [foldl_divRange_loop_foldWhile] + simp + +end DivRange + +end Aeneas diff --git a/backends/lean/Aeneas/Range/DivRange/Notations.lean b/backends/lean/Aeneas/Range/DivRange/Notations.lean new file mode 100644 index 00000000..a1e5196e --- /dev/null +++ b/backends/lean/Aeneas/Range/DivRange/Notations.lean @@ -0,0 +1,21 @@ +import Aeneas.Range.Notations +import Aeneas.Range.DivRange.Basic + +namespace Aeneas.Notations + +namespace DivRange + +open Range -- activates the aeneas_range_tactic notation + + scoped syntax:max "[" withoutPosition(term ":" ">" term ":" "/=" term) "]" : term + + scoped macro_rules + | `([ $start : > $stop : /= $step ]) => + `({ start := $start, stop := $stop, divisor := $step, + divisor_pos := by aeneas_range_tactic : DivRange }) + + example : DivRange := [256:>1:/= 2] + +end DivRange + +end Aeneas.Notations diff --git a/backends/lean/Aeneas/Range/MulRange.lean b/backends/lean/Aeneas/Range/MulRange.lean new file mode 100644 index 00000000..a7426e41 --- /dev/null +++ b/backends/lean/Aeneas/Range/MulRange.lean @@ -0,0 +1,3 @@ +import Aeneas.Range.MulRange.Basic +import Aeneas.Range.MulRange.Lemmas +import Aeneas.Range.MulRange.Notations diff --git a/backends/lean/Aeneas/Range/MulRange/Basic.lean b/backends/lean/Aeneas/Range/MulRange/Basic.lean new file mode 100644 index 00000000..68a4fe7f --- /dev/null +++ b/backends/lean/Aeneas/Range/MulRange/Basic.lean @@ -0,0 +1,115 @@ +import Mathlib.Data.Nat.Defs +import Mathlib.Algebra.Group.Basic +import Aeneas.Utils + +namespace Aeneas + +-- TODO: move +/-- A "structural recursion" range type, that we use to implement for + loops with structural induction. + + This is the same as `Std.Range`, but with a slighly different implementation + of the loop inside the `forIn'` function, for which we introduce a fuel parameter. + + We do this because of issues with the kernel reducing definitions eagerly, leading + to explosions in the presence of well-founded recursion. This this: + https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/simp.20taking.20a.20long.20time.20on.20a.20small.20definition/near/495050322 + -/ +structure MulRange where + start : Nat := 0 + start_pos : 0 < start + stop : Nat + mul : Nat + mul_pos : 1 < mul + +instance : Membership Nat MulRange where + mem r i := r.start ≤ i ∧ i < r.stop ∧ ∃ k, i = r.start * r.mul ^ k + +namespace MulRange +universe u v + +@[inline] protected def forIn' [Monad m] (range : MulRange) (init : β) + (f : (i : Nat) → i ∈ range → β → m (ForInStep β)) : m β := + let rec @[specialize] loop (maxSteps : Nat) (b : β) (i : Nat) + (hs : ∃ k, i = range.start * range.mul ^ k) + (hl : range.start ≤ i) : m β := do + -- Introduce structural induction + match maxSteps with + | 0 => pure b + | maxSteps+1 => + if h : i < range.stop then + match (← f i ⟨hl, h, hs⟩ b) with + | .done b => pure b + | .yield b => + have := range.mul_pos + loop maxSteps b (i * range.mul) + (by + have ⟨ k, hk ⟩ := hs + exists (k + 1) + simp only [hk] + simp only [Nat.mul_assoc, Nat.pow_add_one]) + (by + have := @Nat.le_mul_of_pos_right range.mul i (by omega) + omega) + else + pure b + loop (range.stop + 1) init range.start (by exists 0; simp) (by omega) + +instance : ForIn' m MulRange Nat inferInstance where + forIn' := MulRange.forIn' + +-- No separate `ForIn` instance is required because it can be derived from `ForIn'`. + +end MulRange + +/-! +We now introduce a convenient `mulRange` definition +-/ + +def mulRange (stop mul : Nat) (hMul : 1 < mul) (i : Nat) (hi : 0 < i) : List Nat := + if i < stop then + i :: mulRange stop mul hMul (i * mul) (by rw [Nat.mul_pos_iff_of_pos_left] <;> omega) + else [] +termination_by stop - i +decreasing_by + have : i < i * mul := by rw [Nat.lt_mul_iff_one_lt_right] <;> assumption + omega + +namespace MulRange + +/-- A convenient utility for the proofs -/ +def foldWhile' {α : Type u} (r : MulRange) (f : α → (a : Nat) → (a ∈ r) → α) (i : Nat) (init : α) + (hi : r.start ≤ i ∧ ∃ k, i = r.start * r.mul ^ k) : α := + if h: i < r.stop then + foldWhile' r f (i * r.mul) + (f init i (by simp only [Membership.mem]; split_conjs <;> simp [*])) + (by + have := r.mul_pos + split_conjs + . have := @Nat.le_mul_of_pos_right r.mul i (by omega) + omega + . have ⟨ k, hk ⟩ := hi.right + exists k + 1 + simp only [hk, Nat.mul_assoc, ← Nat.pow_add_one]) + else init +termination_by r.stop - i +decreasing_by + have := r.mul_pos + have := r.start_pos + have : i < i * r.mul := by rw [Nat.lt_mul_iff_one_lt_right] <;> omega + omega + +/-- A convenient utility for the proofs -/ +def foldWhile {α : Type u} (stop mul : Nat) (hMul : 1 < mul) + (f : α → (a : Nat) → α) (i : Nat) (hi : 0 < i) (init : α) : α := +if i < stop then + foldWhile stop mul hMul f (i * mul) (by simp only [hi, Nat.mul_pos_iff_of_pos_left]; omega) (f init i) + else init +termination_by stop - i +decreasing_by + have : i < i * mul := by rw [Nat.lt_mul_iff_one_lt_right] <;> omega + omega + +end MulRange + +end Aeneas diff --git a/backends/lean/Aeneas/Range/MulRange/Lemmas.lean b/backends/lean/Aeneas/Range/MulRange/Lemmas.lean new file mode 100644 index 00000000..aacb542a --- /dev/null +++ b/backends/lean/Aeneas/Range/MulRange/Lemmas.lean @@ -0,0 +1,216 @@ +import Mathlib.Data.Nat.Log +import Mathlib.Algebra.Order.Ring.Canonical +import Mathlib.Tactic.Ring.RingNF +import Aeneas.Range.MulRange.Basic +import Aeneas.Range.DivRange.Lemmas + +namespace Aeneas + +-- Auxiliary lemma +theorem pow_ineq' (stop mul start : Nat) (hMul : 1 < mul) (hStart : 0 < start) : + stop ≤ start * mul ^ (stop + 1):= by + have := pow_ineq stop mul hMul + have := @Nat.le_mul_of_pos_right start (mul ^ (stop + 1)) hStart + rw [Nat.mul_comm] at this + omega + +namespace MulRange + +/-! +# Lemmas about `MulRange` + +We provide lemmas rewriting for loops over `MulRange` in terms of `List.range'`. +-/ + +private theorem mem_of_mem_MulRange_aux + (stop mul start : Nat) (hMul : 1 < mul) (i : Nat) (hi : 0 < i ∧ ∃ k, i = start * mul ^ k) + (a : Nat) : + a ∈ mulRange stop mul hMul i hi.left → + start ≤ a ∧ a < stop ∧ ∃ k, a = start * mul ^ k + := by + unfold mulRange + dcases h: i < stop + . simp only [h, ↓reduceIte, List.mem_cons] + intro hMem + cases hMem + . simp_all only [and_self, and_true] + have ⟨ k, hk ⟩ := hi.right + have := Nat.one_le_pow k mul (by omega) + have := @Nat.le_mul_of_pos_right (mul ^ k) start (by omega) + omega + . rename_i hMem + apply mem_of_mem_MulRange_aux stop mul start hMul (i * mul) + (by + split_conjs + . apply Nat.mul_pos <;> omega + . have ⟨ k, hk ⟩ := hi.right + exists k + 1 + simp [hk, mul_assoc, ← Nat.pow_add_one]) + a hMem + . simp_all +termination_by stop - i +decreasing_by + have : i < i * mul := by rw [Nat.lt_mul_iff_one_lt_right] <;> omega + simp only [not_lt, decide_eq_false_iff_not, not_le] at h + omega + +private theorem mem_of_mem_MulRange (r : MulRange) (a : Nat) + (h : a ∈ mulRange r.stop r.mul r.mul_pos r.start r.start_pos) : a ∈ r := by + apply mem_of_mem_MulRange_aux r.stop r.mul r.start r.mul_pos r.start + (by + simp [r.start_pos] + exists 0; simp) + a h + +@[simp] +private theorem i_of_MulRange_start_pos (r : MulRange) (i : Nat) (hi : ∃ k, i = r.start * r.mul ^ k): + 0 < i := by + have ⟨ k, hk ⟩ := hi + simp only [hk] + apply Nat.mul_pos + . apply r.start_pos + . apply Nat.pos_pow_of_pos + have := r.mul_pos + omega + +private theorem mem_of_mem_MulRange_i (r : MulRange) + (i : Nat) (h1 : ∃ k, i = r.start * r.mul ^ k) (a : Nat) + (h : a ∈ mulRange r.stop r.mul r.mul_pos i + (by apply i_of_MulRange_start_pos; assumption)) : a ∈ r := by + apply mem_of_mem_MulRange_aux r.stop r.mul r.start r.mul_pos i + (by + split_conjs + . apply i_of_MulRange_start_pos; assumption + . apply h1) + a h + +private theorem forIn'_loop_eq_forIn'_MulRange [Monad m] (r : MulRange) + (fuel : Nat) (init : β) (f : (a : Nat) → a ∈ r → β → m (ForInStep β)) (i) + (hk : ∃ k, i = r.start * r.mul ^ k) + (hStart : r.start ≤ i) + (hFuel : r.stop ≤ i * r.mul ^ fuel) : + forIn'.loop r f fuel init i hk hStart = + forIn' (mulRange r.stop r.mul r.mul_pos i (by have := r.start_pos; omega)) init + fun a h => + f a (mem_of_mem_MulRange_i r i hk a h) := by + cases fuel + . rw [forIn'.loop] + simp only [pow_zero, mul_one] at hFuel + unfold mulRange + have : ¬ i < r.stop := by omega + simp [*] + . rename_i fuel + simp only [forIn'.loop, gt_iff_lt] + unfold mulRange + dcases hStop : i < r.stop <;> simp only [hStop, ↓reduceDIte, ↓reduceIte, List.forIn'_cons, + id_eq, Int.reduceNeg, Int.Nat.cast_ofNat_Int, Int.reduceAdd, List.not_mem_nil, + IsEmpty.forall_iff, implies_true, List.forIn'_nil] + apply letFun_val_congr + apply funext + intro x + cases x + . simp + . rename_i x + simp only + replace ⟨ k, hk ⟩ := hk + have := r.mul_pos + have h0 : ∃ k, i * r.mul = r.start * r.mul ^ k := by + exists (k + 1) + simp [hk, Nat.mul_assoc, Nat.pow_add] + have h1 : r.start ≤ i * r.mul := by + have := @Nat.le_mul_of_pos_right r.mul i (by omega) + omega + have h2 : r.stop ≤ i * r.mul * r.mul ^ fuel := by + simp only [Nat.pow_add_one] at hFuel + ring_nf at hFuel + apply hFuel + have hEq := forIn'_loop_eq_forIn'_MulRange r fuel x f (i * r.mul) h0 h1 h2 + simp [hEq] + +-- Auxiliary lemma +private theorem pow_ineq (r: MulRange) : + r.stop ≤ r.start * r.mul ^ (r.stop + 1) := by + apply pow_ineq' r.stop r.mul r.start r.mul_pos r.start_pos + +@[simp] theorem forIn_eq_forIn_MulRange [Monad m] (r : MulRange) + (init : β) (f : Nat → β → m (ForInStep β)) : + forIn r init f = forIn (mulRange r.stop r.mul r.mul_pos r.start r.start_pos) init f := by + simp only [forIn, forIn', MulRange, MulRange.forIn'] + rw [forIn'_loop_eq_forIn'_MulRange] + . simp + . apply pow_ineq + +@[simp] theorem forIn'_eq_forIn_MulRange [Monad m] (r : MulRange) + (init : β) (f : (a:Nat) → (a ∈ r) → β → m (ForInStep β)) : + forIn' r init f = + forIn' (mulRange r.stop r.mul r.mul_pos r.start r.start_pos) init + (fun a h => f a (mem_of_mem_MulRange r a h)) := by + simp only [forIn, forIn', MulRange, MulRange.forIn'] + rw [forIn'_loop_eq_forIn'_MulRange] + . simp + . apply pow_ineq + +private theorem MulRange_imp_pred (r : MulRange) (i : Nat) + (h0: r.start ≤ i ∧ ∃ k, i = r.start * r.mul ^ k) : + r.start ≤ i * r.mul ∧ ∃ k, i * r.mul = r.start * r.mul ^ k := by + have := r.mul_pos + split_conjs + . have := @Nat.le_mul_of_pos_right r.mul i (by omega) + omega + . have ⟨ k, hk ⟩ := h0.right + exists k + 1 + simp [hk, Nat.mul_assoc, ← Nat.pow_add_one] + +@[simp] +def foldWhile'_step {α : Type u} (r : MulRange) (f : α → (a : Nat) → a ∈ r → α) (i : Nat) (init : α) + (hi : r.start ≤ i ∧ ∃ k, i = r.start * r.mul ^ k) + (h : i < r.stop) : + foldWhile' r f i init hi = + foldWhile' r f (i * r.mul) + (f init i (by simp only [Membership.mem]; split_conjs <;> simp [*])) + (by apply MulRange_imp_pred; assumption) + := by + conv => lhs; unfold foldWhile' + simp [*] + +@[simp] +def foldWhile'_id {α : Type u} (r : MulRange) (f : α → (a : Nat) → a ∈ r → α) (i : Nat) (init : α) + (hi : r.start ≤ i ∧ ∃ k, i = r.start * r.mul ^ k) (h : ¬ i < r.stop) : + foldWhile' r f i init hi = init + := by + conv => lhs; unfold foldWhile' + simp [*] + +@[simp] +def foldWhile_step {α : Type u} (stop mul : Nat) (f : α → Nat → α) (i : Nat) + (init : α) (hMul) (hi : 0 < i) (h : i < stop) : + foldWhile stop mul hMul f i hi init = + foldWhile stop mul hMul f (i * mul) + (by apply Nat.mul_pos <;> omega) (f init i) + := by + conv => lhs; unfold foldWhile + simp [*] + +@[simp] +def foldWhile_id {α : Type u} (stop mul : Nat) (f : α → Nat → α) (i : Nat) + (init : α) (hMul) (hi : 0 < i) (h : ¬ i < stop) : + foldWhile stop mul hMul f i hi init = init := by + conv => lhs; unfold foldWhile + simp [*] + +@[simp] +theorem foldl_MulRange_foldWhile (stop mul i : Nat) (hMul) (hi) + (f : α → Nat → α) (init : α) : + List.foldl f init (mulRange stop mul hMul i hi) = foldWhile stop mul hMul f i hi init := by + unfold mulRange foldWhile + dcases h: i < stop <;> simp only [h, ↓reduceIte, List.foldl_cons, List.foldl_nil] + rw [foldl_MulRange_foldWhile] +termination_by stop - i +decreasing_by + have : i < i * mul := by rw [Nat.lt_mul_iff_one_lt_right] <;> assumption + simp only [not_lt, decide_eq_false_iff_not, not_le] at h + omega + +end MulRange + +end Aeneas diff --git a/backends/lean/Aeneas/Range/MulRange/Notations.lean b/backends/lean/Aeneas/Range/MulRange/Notations.lean new file mode 100644 index 00000000..f873f8aa --- /dev/null +++ b/backends/lean/Aeneas/Range/MulRange/Notations.lean @@ -0,0 +1,22 @@ +import Aeneas.Range.Notations +import Aeneas.Range.MulRange.Basic + +namespace Aeneas.Notations + +namespace MulRange + +open Range -- activates the aeneas_range_tactic notation + + scoped syntax:max "[" withoutPosition(term ":" "<" term ":" "*=" term) "]" : term + + scoped macro_rules + | `([ $start : < $stop : *= $step ]) => + `({ start := $start, start_pos := by aeneas_range_tactic + stop := $stop, + mul := $step, mul_pos := by aeneas_range_tactic : MulRange }) + + example : MulRange := [1:<256:*=2] + +end MulRange + +end Aeneas.Notations diff --git a/backends/lean/Aeneas/Range/Notations.lean b/backends/lean/Aeneas/Range/Notations.lean new file mode 100644 index 00000000..9fd66e57 --- /dev/null +++ b/backends/lean/Aeneas/Range/Notations.lean @@ -0,0 +1,9 @@ +namespace Aeneas.Notations.Range + +scoped syntax "aeneas_range_tactic" : tactic + +-- The default tactic to discharge proof obligations related to ranges +macro_rules +| `(tactic| aeneas_range_tactic) => `(tactic| decide) + +end Aeneas.Notations.Range diff --git a/backends/lean/Aeneas/Range/SRRange.lean b/backends/lean/Aeneas/Range/SRRange.lean new file mode 100644 index 00000000..b8621ea0 --- /dev/null +++ b/backends/lean/Aeneas/Range/SRRange.lean @@ -0,0 +1,3 @@ +import Aeneas.Range.SRRange.Basic +import Aeneas.Range.SRRange.Lemmas +import Aeneas.Range.SRRange.Notations diff --git a/backends/lean/Aeneas/Range/SRRange/Basic.lean b/backends/lean/Aeneas/Range/SRRange/Basic.lean new file mode 100644 index 00000000..37403d77 --- /dev/null +++ b/backends/lean/Aeneas/Range/SRRange/Basic.lean @@ -0,0 +1,83 @@ +import Mathlib.Data.Nat.Defs +import Aeneas.Utils + +namespace Aeneas + +/-- A "structural recursion" range type, that we use to implement for + loops with structural induction. + + This is the same as `Std.Range`, but with a slighly different implementation + of the loop inside the `forIn'` function, for which we introduce a fuel parameter. + + We do this because of issues with the kernel reducing definitions eagerly, leading + to explosions in the presence of well-founded recursion. See this: + https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/simp.20taking.20a.20long.20time.20on.20a.20small.20definition/near/495050322 + Also, we don't need to extract this code, meaning we are not concerned with its efficiency. + -/ +structure SRRange where + start : Nat := 0 + stop : Nat + step : Nat := 1 + step_pos : 0 < step + +instance : Membership Nat SRRange where + mem r i := r.start ≤ i ∧ i < r.stop ∧ (i - r.start) % r.step = 0 + +namespace SRRange +universe u v + +/-- The number of elements in the range. -/ +@[simp] def size (r : SRRange) : Nat := (r.stop - r.start + r.step - 1) / r.step + +/-- A bound of the number of elements in the range -/ +@[simp] def sizeBound (r : SRRange) : Nat := r.stop - r.start + +@[inline] protected def forIn' [Monad m] (range : SRRange) (init : β) + (f : (i : Nat) → i ∈ range → β → m (ForInStep β)) : m β := + let rec @[specialize] loop (maxSteps : Nat) (b : β) (i : Nat) + (hs : (i - range.start) % range.step = 0) (hl : range.start ≤ i := by omega) : m β := do + -- Introduce structural induction + match maxSteps with + | 0 => pure b + | maxSteps+1 => + if h : i < range.stop then + match (← f i ⟨hl, by omega, hs⟩ b) with + | .done b => pure b + | .yield b => + have := range.step_pos + loop maxSteps b (i + range.step) (by rwa [Nat.add_comm, Nat.add_sub_assoc hl, Nat.add_mod_left]) + else + pure b + have := range.step_pos + loop range.sizeBound init range.start (by simp) + +instance : ForIn' m SRRange Nat inferInstance where + forIn' := SRRange.forIn' + +-- No separate `ForIn` instance is required because it can be derived from `ForIn'`. + +/-- A convenient utility for the proofs, which uses well-founded recursion -/ +def foldWhile' {α : Type u} (r : SRRange) (f : α → (a : Nat) → (a ∈ r) → α) (i : Nat) (init : α) + (hi : r.start ≤ i ∧ (i - r.start) % r.step = 0) : α := + if h: i < r.stop then + foldWhile' r f (i + r.step) + (f init i (by simp [Membership.mem]; split_conjs <;> simp [*])) + (by split_conjs + . omega + . have := @Nat.add_mod_left r.step (i - r.start) + have : r.step + (i - r.start) = i + r.step - r.start := by omega + simp_all only) + else init +termination_by r.stop - i +decreasing_by have:= r.step_pos; omega + +/-- A convenient utility for the proofs, which uses well-founded recursion -/ +def foldWhile {α : Type u} (max step : Nat) (hStep : 0 < step) (f : α → Nat → α) (i : Nat) (init : α) : α := + if i < max then foldWhile max step hStep f (i + step) (f init i) + else init +termination_by max - i +decreasing_by omega + +end SRRange + +end Aeneas diff --git a/backends/lean/Aeneas/Range/SRRange/Lemmas.lean b/backends/lean/Aeneas/Range/SRRange/Lemmas.lean new file mode 100644 index 00000000..4e150731 --- /dev/null +++ b/backends/lean/Aeneas/Range/SRRange/Lemmas.lean @@ -0,0 +1,150 @@ +import Mathlib.Tactic.Ring.RingNF +import Aeneas.Range.SRRange.Basic + +namespace Aeneas + +namespace SRRange + +/-! +# Lemmas about `SRRange` + +We provide lemmas rewriting for loops over `SRRange` in terms of `List.range'`. +Remark: the lemmas below are adapted from `Std.Range`. +-/ + +/-- Generalization of `mem_of_mem_range'` used in `forIn'_loop_eq_forIn'_range'` below. -/ +private theorem mem_of_mem_range'_aux {r : SRRange} {a : Nat} (w₁ : (i - r.start) % r.step = 0) + (w₂ : r.start ≤ i) + (h : a ∈ List.range' i ((r.stop - i + r.step - 1) / r.step) r.step) : a ∈ r := by + obtain ⟨j, h', rfl⟩ := List.mem_range'.1 h + refine ⟨by omega, ?_⟩ + rw [Nat.lt_div_iff_mul_lt r.step_pos, Nat.mul_comm] at h' + constructor + · omega + · rwa [Nat.add_comm, Nat.add_sub_assoc w₂, Nat.mul_add_mod_self_left] + +theorem mem_of_mem_range' {r : SRRange} (h : x ∈ List.range' r.start r.size r.step) : x ∈ r := by + unfold size at h + apply mem_of_mem_range'_aux (by simp) (by simp) h + +private theorem size_eq (r : SRRange) (h : i < r.stop) : + (r.stop - i + r.step - 1) / r.step = + (r.stop - (i + r.step) + r.step - 1) / r.step + 1 := by + have w := r.step_pos + if i + r.step < r.stop then -- Not sure this case split is strictly necessary. + rw [Nat.div_eq_iff w, Nat.add_one_mul] + have : (r.stop - (i + r.step) + r.step - 1) / r.step * r.step ≤ + (r.stop - (i + r.step) + r.step - 1) := Nat.div_mul_le_self _ _ + have : r.stop - (i + r.step) + r.step - 1 - r.step < + (r.stop - (i + r.step) + r.step - 1) / r.step * r.step := + Nat.lt_div_mul_self w (by omega) + omega + else + have : (r.stop - i + r.step - 1) / r.step = 1 := by + rw [Nat.div_eq_iff w, Nat.one_mul] + omega + have : (r.stop - (i + r.step) + r.step - 1) / r.step = 0 := by + rw [Nat.div_eq_iff] <;> omega + omega + +private theorem forIn'_loop_eq_forIn'_range' [Monad m] (r : SRRange) + (maxSteps : Nat) (init : β) (f : (a : Nat) → a ∈ r → β → m (ForInStep β)) (i) (w₁) (w₂) + (hMaxSteps : r.stop - i ≤ maxSteps) : + forIn'.loop r f maxSteps init i w₁ w₂ = + forIn' (List.range' i ((r.stop - i + r.step - 1) / r.step) r.step) init + fun a h => f a (mem_of_mem_range'_aux w₁ w₂ h) := by + have w := r.step_pos + revert init i + induction maxSteps <;> intros init i w₁ w₂ hMaxSteps + . rw [forIn'.loop] + simp only [forIn'] + have hEq : (r.stop - i + r.step - 1) / r.step = 0 := by + have : r.stop - i + r.step - 1 < r.step := by omega + simp [this] + simp [hEq] + . rename_i maxSteps hInd + rw [forIn'.loop] + split <;> rename_i h + · simp only [size_eq r h, List.range'_succ, List.forIn'_cons] + congr 1 + funext step + split + · simp + · rw [hInd] + omega + · have : (r.stop - i + r.step - 1) / r.step = 0 := by + rw [Nat.div_eq_iff] <;> omega + simp [this] + +@[simp] +theorem forIn'_eq_forIn'_range' [Monad m] (r : SRRange) + (init : β) (f : (a : Nat) → a ∈ r → β → m (ForInStep β)) : + forIn' r init f = + forIn' (List.range' r.start r.size r.step) init (fun a h => f a (mem_of_mem_range' h)) := by + conv => lhs; simp only [forIn', SRRange.forIn'] + simp only [size] + rw [forIn'_loop_eq_forIn'_range'] + simp [SRRange.sizeBound] + +@[simp] +theorem forIn_eq_forIn_range' [Monad m] (r : SRRange) + (init : β) (f : Nat → β → m (ForInStep β)) : + forIn r init f = forIn (List.range' r.start r.size r.step) init f := by + simp only [forIn, forIn'_eq_forIn'_range'] + +@[simp] +def foldWhile'_step {α : Type u} (r : SRRange) (f : α → (a : Nat) → a ∈ r → α) (i : Nat) (init : α) + (hi : r.start ≤ i ∧ (i - r.start) % r.step = 0) + (h : i < r.stop) : + foldWhile' r f i init hi = + foldWhile' r f (i + r.step) + (f init i (by simp [Membership.mem]; split_conjs <;> tauto)) + (by split_conjs + . omega + . have := @Nat.add_mod_left r.step (i - r.start) + have : r.step + (i - r.start) = i + r.step - r.start := by omega + simp_all only) + := by + conv => lhs; unfold foldWhile' + simp [*] + +@[simp] +def foldWhile'_id {α : Type u} (r : SRRange) (f : α → (a : Nat) → a ∈ r → α) (i : Nat) (init : α) + (hi : r.start ≤ i ∧ (i - r.start) % r.step = 0) + (h : ¬ i < r.stop) : + foldWhile' r f i init hi = init + := by + conv => lhs; unfold foldWhile' + simp [*] + +@[simp] +def foldWhile_step {α : Type u} (max step : Nat) (hStep : 0 < step) (f : α → Nat → α) (i : Nat) (init : α) + (h : i < max) : foldWhile max step hStep f i init = foldWhile max step hStep f (i + step) (f init i) := by + conv => lhs; unfold foldWhile + simp [*] + +@[simp] +def foldWhile_id {α : Type u} (max step : Nat) (hStep : 0 < step) (f : α → Nat → α) (i : Nat) (init : α) + (h : ¬ i < max) : foldWhile max step hStep f i init = init := by + conv => lhs; unfold foldWhile + simp [*] + +@[simp] +theorem foldl_range' (start len step : Nat) (hStep : 0 < step) (f : α → Nat → α) (init : α) : + List.foldl f init (List.range' start len step) = foldWhile (start + len * step) step hStep f start init := by + cases len + . simp only [List.range'_zero, List.foldl_nil, Nat.zero_mul, Nat.add_zero] + unfold foldWhile + simp + . rename_i len + simp only [List.range', List.foldl_cons] + have := foldl_range' (start + step) len step hStep f (f init start) + simp only [this] + conv => rhs; unfold foldWhile + have : start < start + (len + 1) * step := by simp [*] + simp only [this, ↓reduceIte] + ring_nf + +end SRRange + +end Aeneas diff --git a/backends/lean/Aeneas/Range/SRRange/Notations.lean b/backends/lean/Aeneas/Range/SRRange/Notations.lean new file mode 100644 index 00000000..d56e3416 --- /dev/null +++ b/backends/lean/Aeneas/Range/SRRange/Notations.lean @@ -0,0 +1,20 @@ +import Aeneas.Range.Notations +import Aeneas.Range.SRRange.Basic + +namespace Aeneas.Notations + +namespace SRRange + +open Range -- activates the aeneas_range_tactic notation + + scoped macro_rules + | `([ $start : $stop ]) => + `({ start := $start, stop := $stop, step := 1, + step_pos := by aeneas_range_tactic : SRRange }) + | `([ $start : $stop : $step ]) => + `({ start := $start, stop := $stop, step := $step, + step_pos := by aeneas_range_tactic : SRRange }) + +end SRRange + +end Aeneas.Notations diff --git a/backends/lean/Aeneas/Saturate/Attribute.lean b/backends/lean/Aeneas/Saturate/Attribute.lean index fb881546..9fa5a242 100644 --- a/backends/lean/Aeneas/Saturate/Attribute.lean +++ b/backends/lean/Aeneas/Saturate/Attribute.lean @@ -107,6 +107,9 @@ private def Rules.insert (s : Rules) (kv : Key × Rule) : Rules := private def Rules.erase (s : Rules) (thName : Name) : Rules := let ⟨ nameToRule, rules ⟩ := s + /- Note that we can't remove a key from a discrimination tree, so we + remove the rule from the `nameToRule` map instead: when instantiating rules + we check that they are still active (i.e., they are still in `nameToRule`) -/ let nameToRule := nameToRule.erase thName ⟨ nameToRule, rules ⟩ @@ -127,7 +130,7 @@ structure SaturateAttribute where deriving Inhabited -- The ident is the name of the saturation set, the term is the pattern. -syntax (name := aeneas_saturate) "aeneas_saturate" " (" &"set" " := " ident ")" " (" &"pattern" " := " term ")" : attr +syntax (name := aeneas_saturate) "aeneas_saturate" "(" &"set" " := " ident ")" " (" &"pattern" " := " term ")" : attr def elabSaturateAttribute (stx : Syntax) : MetaM (Name × Syntax) := withRef stx do @@ -155,7 +158,10 @@ initialize saturateAttr : SaturateAttribute ← do -- Analyze the theorem let (key, rule) ← MetaM.run' do let ty := thDecl.type - -- Strip the quantifiers + /- Strip the quantifiers. + We do this before elaborating the pattern because we need the universally quantified variables + to be in the context. + -/ forallTelescope ty.consumeMData fun fvars _ => do let numFVars := fvars.size -- Elaborate the pattern diff --git a/backends/lean/Aeneas/Saturate/Tactic.lean b/backends/lean/Aeneas/Saturate/Tactic.lean index 1b13a7c3..6467a59f 100644 --- a/backends/lean/Aeneas/Saturate/Tactic.lean +++ b/backends/lean/Aeneas/Saturate/Tactic.lean @@ -29,6 +29,7 @@ def matchExpr (nameToRule : NameMap Rule) (dtrees : Array (DiscrTree Rule)) trace[Saturate] "Potential matches: {exprs}" -- Check each expression (exprs.foldlM fun matched rule => do + trace[Saturate] "Checking potential match: {rule}" -- Check if the theorem is still active if let some activeRule := nameToRule.find? rule.thName then do -- Check that the patterns are the same @@ -39,25 +40,48 @@ def matchExpr (nameToRule : NameMap Rule) (dtrees : Array (DiscrTree Rule)) let pat := rule.pattern.instantiateLevelParams info.levelParams mvarLevels -- Strip the binders, introduce meta-variables at the same time, and match let (mvars, _, pat) ← lambdaMetaTelescope pat (some rule.numBinders) - if ← isDefEq pat e then - -- It matched! Check the variables which appear in the arguments - let (args, allFVars) ← mvars.foldrM (fun arg (args, hs) => do - let arg ← instantiateMVars arg - let hs ← getFVarIds arg hs - pure (arg :: args, hs) - ) ([], Std.HashSet.empty) - if boundVars.all (fun fvar => ¬ allFVars.contains fvar) then - -- Ok: save the theorem - trace[Saturate] "Matched with: {rule.thName} {args}" - pure (matched.insert (rule.thName, args)) + trace[Saturate] "Checking if defEq:\n- pat: {pat}\n- expression: {e}" + let pat_ty ← inferType pat + let e_ty ← inferType e + /- Small issue here: we use big integer constants and we have several patterns which + are just a variable (for instance: `UScalar.bounds`). Because `isDefEq` first + starts by unifying the expressions themselves (without looking at their type) we + often end up attempting to unify every expression in the context with variables + of type, e.g., `UScalar _`. The issue is that, if we attempt to unify an expression + like `1000` with `?x : UScalar ?ty`, Lean will lanch a "max recursion depth" exception + when attempting to reduce `1000` to `succ succ ...`. The current workaround is to + first check whether the types are definitionally equal, then compare the expressions + themselves. This way, in the case above we would not even compare `1000` with `?x` + because `ℕ` wouldn't match `UScalar ?ty`. + + TODO: it would probably be more efficient to have a specific treatment of degenerate + patterns, for instance by using the types as the keys in the discrimination trees. + -/ + if ← isDefEq pat_ty e_ty then + if ← isDefEq pat e then + trace[Saturate] "defEq" + -- It matched! Check the variables which appear in the arguments + let (args, allFVars) ← mvars.foldrM (fun arg (args, hs) => do + let arg ← instantiateMVars arg + let hs ← getFVarIds arg hs + pure (arg :: args, hs) + ) ([], Std.HashSet.empty) + if boundVars.all (fun fvar => ¬ allFVars.contains fvar) then + -- Ok: save the theorem + trace[Saturate] "Matched with: {rule.thName} {args}" + pure (matched.insert (rule.thName, args)) + else + -- Ignore + trace[Saturate] "Didn't match" + pure matched else - -- Ignore + -- Didn't match, leave the set of matches unchanged trace[Saturate] "Didn't match" pure matched else -- Didn't match, leave the set of matches unchanged - trace[Saturate] "Didn't match" - pure matched + trace[Saturate] "Types didn't match" + pure matched else -- The rule is not active trace[Saturate] "The rule is not active" @@ -132,8 +156,74 @@ private partial def visit (depth : Nat) (nameToRule : NameMap Rule) trace[Saturate] ".proj" visit (depth + 1) nameToRule dtrees boundVars matched b +def binaryConsts : Std.HashSet Name := Std.HashSet.ofList [ + ``And, ``Or +] + +def arithConsts : Std.HashSet Name := Std.HashSet.ofList [ + ``LT.lt, ``LE.le, ``GT.gt, ``GE.ge +] + +/- Fast version of `visit`: we do not explore everything. -/ +private partial def fastVisit (depth : Nat) (nameToRule : NameMap Rule) + (dtrees : Array (DiscrTree Rule)) + (matched : Std.HashSet (Name × List Expr)) + (e : Expr) : MetaM (Std.HashSet (Name × List Expr)) := do + trace[Saturate] "Visiting {e}" + -- Match + let matched ← matchExpr nameToRule dtrees Std.HashSet.empty matched e + -- Recurse + let e := e.consumeMData + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => + trace[Saturate] "Stop: bvar, fvar, etc." + pure matched + | .app .. => do e.withApp fun f args => do + trace[Saturate] ".app" + let visitRec := fastVisit (depth + 1) nameToRule dtrees + if f.isConst then + -- + let constName := f.constName! + if constName == ``Eq ∧ args.size == 3 then + trace[Saturate] "Found `=`" + let matched ← visitRec matched args[1]! + let matched ← visitRec matched args[2]! + pure matched + else if constName ∈ binaryConsts ∧ args.size == 2 then + trace[Saturate] "Found binary const: {f}" + let matched ← visitRec matched args[0]! + let matched ← visitRec matched args[1]! + pure matched + else if constName ∈ arithConsts ∧ args.size == 4 then + trace[Saturate] "Found arith const: {f}" + let matched ← visitRec matched args[2]! + let matched ← visitRec matched args[3]! + pure matched + else + -- Stop there + pure matched + else + -- Stop there + pure matched + | .lam .. + | .forallE .. + | .letE .. => do + -- Do not go inside the foralls, the lambdas and the let expressions + pure matched + | .mdata _ b => do + trace[Saturate] ".mdata" + fastVisit (depth + 1) nameToRule dtrees matched b + | .proj _ _ _ => do + trace[Saturate] ".proj" + pure matched + /- The saturation tactic itself -/ -def evalSaturate (sets : List Name) : TacticM Unit := do +def evalSaturate (fast : Bool) (sets : List Name) : TacticM Unit := do Tactic.withMainContext do trace[Saturate] "sets: {sets}" -- Retrieve the rule sets @@ -143,17 +233,19 @@ def evalSaturate (sets : List Name) : TacticM Unit := do let ctx ← Lean.MonadLCtx.getLCtx -- Explore the declarations let decls ← ctx.getDecls + let visit := if fast then fastVisit 0 s.nameToRule dtrees else visit 0 s.nameToRule dtrees Std.HashSet.empty + let matched ← decls.foldlM (fun matched (decl : LocalDecl) => do trace[Saturate] "Exploring local decl: {decl.userName}" /- We explore both the type, the expresion and the body (if there is) -/ - let matched ← visit 0 s.nameToRule dtrees Std.HashSet.empty matched decl.type - let matched ← visit 0 s.nameToRule dtrees Std.HashSet.empty matched decl.toExpr + let matched ← visit matched decl.type + let matched ← visit matched decl.toExpr match decl.value? with | none => pure matched - | some value => visit 0 s.nameToRule dtrees Std.HashSet.empty matched value) Std.HashSet.empty + | some value => visit matched value) Std.HashSet.empty -- Explore the goal trace[Saturate] "Exploring the goal" - let matched ← visit 0 s.nameToRule dtrees Std.HashSet.empty matched (← Tactic.getMainTarget) + let matched ← visit matched (← Tactic.getMainTarget) -- Introduce the theorems in the context for (thName, args) in matched do let th ← mkAppOptM thName (args.map some).toArray @@ -161,11 +253,11 @@ def evalSaturate (sets : List Name) : TacticM Unit := do let _ ← Utils.addDeclTac (.str .anonymous "_") th thTy (asLet := false) elab "aeneas_saturate" : tactic => - evalSaturate [`Aeneas.ScalarTac] + evalSaturate false [`Aeneas.ScalarTac] section Test local elab "aeneas_saturate_test" : tactic => - evalSaturate [`Aeneas.Test] + evalSaturate false [`Aeneas.Test] set_option trace.Saturate.attribute false @[aeneas_saturate (set := Aeneas.Test) (pattern := l.length)] diff --git a/backends/lean/Aeneas/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac.lean index a27abb62..46f29596 100644 --- a/backends/lean/Aeneas/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac.lean @@ -1,3 +1,3 @@ -import Aeneas.ScalarTac.IntTac import Aeneas.ScalarTac.ScalarTac +import Aeneas.ScalarTac.Lemmas import Aeneas.Arith.Lemmas diff --git a/backends/lean/Aeneas/ScalarTac/Core.lean b/backends/lean/Aeneas/ScalarTac/Core.lean index b5e426bd..f2403c71 100644 --- a/backends/lean/Aeneas/ScalarTac/Core.lean +++ b/backends/lean/Aeneas/ScalarTac/Core.lean @@ -7,9 +7,6 @@ namespace ScalarTac open Lean Elab Term Meta --- We can't define and use trace classes in the same file -initialize registerTraceClass `ScalarTac - -- TODO: move? theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by cases h: x <;> simp_all diff --git a/backends/lean/Aeneas/ScalarTac/Init.lean b/backends/lean/Aeneas/ScalarTac/Init.lean index a8f8f5ed..4a4e0044 100644 --- a/backends/lean/Aeneas/ScalarTac/Init.lean +++ b/backends/lean/Aeneas/ScalarTac/Init.lean @@ -1,27 +1,32 @@ import Aeneas.Extensions import Aesop -open Lean +open Lean Meta + +namespace Aeneas.ScalarTac /-! -# Scalar tac rules sets +# Tracing +-/ -This module defines several Aesop rule sets and options which are used by the -`scalar_tac` tactic. Aesop rule sets only become visible once the file in which -they're declared is imported, so we must put this declaration into its own file. +-- We can't define and use trace classes in the same file +initialize registerTraceClass `ScalarTac + +/-! +# Simp Sets -/ -namespace Aeneas +/-- The `scalar_tac_simp` simp attribute. -/ +initialize scalarTacSimpExt : SimpExtension ← + registerSimpAttr `scalar_tac_simp "\ + The `scalar_tac_simp` attribute registers simp lemmas to be used by `scalar_tac` + during its preprocessing phase." -namespace ScalarTac +/-! +# Saturation Rules Sets +-/ declare_aesop_rule_sets [Aeneas.ScalarTac, Aeneas.ScalarTacNonLin] -register_option scalarTac.nonLin : Bool := { - defValue := false - group := "" - descr := "Activate the use of a set of lemmas to reason about non-linear arithmetic by `scalar_tac`" -} - -- The sets of rules that `scalar_tac` should use open Extensions in initialize scalarTacRuleSets : ListDeclarationExtension Name ← do @@ -38,6 +43,4 @@ def scalarTacRuleSets.set (names : List Name) : MetaM Unit := do def scalarTacRuleSets.add (name : Name) : MetaM Unit := do let _ := scalarTacRuleSets.modifyState (← getEnv) (fun ls => name :: ls) -end ScalarTac - -end Aeneas +end Aeneas.ScalarTac diff --git a/backends/lean/Aeneas/ScalarTac/IntTac.lean b/backends/lean/Aeneas/ScalarTac/IntTac.lean deleted file mode 100644 index b02a0fc5..00000000 --- a/backends/lean/Aeneas/ScalarTac/IntTac.lean +++ /dev/null @@ -1,274 +0,0 @@ -/- This file contains tactics to solve arithmetic goals -/ - -import Lean -import Lean.Meta.Tactic.Simp -import Init.Data.List.Basic -import Mathlib.Tactic.Ring.RingNF -import Aeneas.Utils -import Aeneas.ScalarTac.Core -import Aeneas.ScalarTac.Init -import Aeneas.Saturate - -namespace Aeneas - -namespace ScalarTac - -open Utils -open Lean Lean.Elab Lean.Meta Lean.Elab.Tactic - -/- Defining a custom attribute for Aesop - we use Aesop tactic in the arithmetic tactics -/ - -attribute [aesop (rule_sets := [Aeneas.ScalarTac]) unfold norm] Function.comp - -/- --- DEPRECATED: `int_tac` and `scalar_tac` used to rely on `aesop`. As there are performance issues --- with the saturation tactic for now we use our own tactic. We will revert once the performance --- is improved. -/-- The `int_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/ -macro "int_tac" pat:term : attr => - `(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTac):ident]) (pattern := $pat)) - -/-- The `scalar_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/ -macro "scalar_tac" pat:term : attr => - `(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTac):ident]) (pattern := $pat)) - -/-- The `nonlin_scalar_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/ -macro "nonlin_scalar_tac" pat:term : attr => - `(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTacNonLin):ident]) (pattern := $pat)) --/ - -/-- The `int_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/ -macro "int_tac" pat:term : attr => - `(attr|aeneas_saturate (set := $(Lean.mkIdent `Aeneas.ScalarTac)) (pattern := $pat)) - -/-- The `scalar_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/ -macro "scalar_tac" pat:term : attr => - `(attr|aeneas_saturate (set := $(Lean.mkIdent `Aeneas.ScalarTac)) (pattern := $pat)) - -/-- The `nonlin_scalar_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/ -macro "nonlin_scalar_tac" pat:term : attr => - `(attr|aeneas_saturate (set := $(Lean.mkIdent `Aeneas.ScalarTacNonLin)) (pattern := $pat)) - --- This is useful especially in the termination proofs -attribute [scalar_tac a.toNat] Int.toNat_eq_max - -/- Check if a proposition is a linear integer proposition. - We notably use this to check the goals: this is useful to filter goals that - are unlikely to be solvable with arithmetic tactics. -/ -class IsLinearIntProp (x : Prop) where - -instance (x y : Int) : IsLinearIntProp (x < y) where -instance (x y : Int) : IsLinearIntProp (x > y) where -instance (x y : Int) : IsLinearIntProp (x ≤ y) where -instance (x y : Int) : IsLinearIntProp (x ≥ y) where -instance (x y : Int) : IsLinearIntProp (x ≥ y) where -instance (x y : Int) : IsLinearIntProp (x = y) where - -instance (x y : Nat) : IsLinearIntProp (x < y) where -instance (x y : Nat) : IsLinearIntProp (x > y) where -instance (x y : Nat) : IsLinearIntProp (x ≤ y) where -instance (x y : Nat) : IsLinearIntProp (x ≥ y) where -instance (x y : Nat) : IsLinearIntProp (x ≥ y) where -instance (x y : Nat) : IsLinearIntProp (x = y) where - -instance : IsLinearIntProp False where -instance (p : Prop) [IsLinearIntProp p] : IsLinearIntProp (¬ p) where -instance (p q : Prop) [IsLinearIntProp p] [IsLinearIntProp q] : IsLinearIntProp (p ∨ q) where -instance (p q : Prop) [IsLinearIntProp p] [IsLinearIntProp q] : IsLinearIntProp (p ∧ q) where --- We use the one below for goals -instance (p q : Prop) [IsLinearIntProp p] [IsLinearIntProp q] : IsLinearIntProp (p → q) where - --- Check if the goal is a linear arithmetic goal -def goalIsLinearInt : Tactic.TacticM Bool := do - Tactic.withMainContext do - let gty ← Tactic.getMainTarget - match ← trySynthInstance (← mkAppM ``IsLinearIntProp #[gty]) with - | .some _ => pure true - | _ => pure false - -example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by - omega - -def intTacSimpRocs : List Name := [``Int.reduceNegSucc, ``Int.reduceNeg] - -/-- Apply the scalar_tac forward rules -/ -def intTacSaturateForward : Tactic.TacticM Unit := do - /- - let options : Aesop.Options := {} - -- Use a forward max depth of 0 to prevent recursively applying forward rules on the assumptions - -- introduced by the forward rules themselves. - let options ← options.toOptions' (some 0)-/ - -- We always use the rule set `Aeneas.ScalarTac`, but also need to add other rule sets locally - -- activated by the user. The `Aeneas.ScalarTacNonLin` rule set has a special treatment as - -- it is activated through an option. - let ruleSets := - let ruleSets := `Aeneas.ScalarTac :: (← scalarTacRuleSets.get) - if scalarTac.nonLin.get (← getOptions) then `Aeneas.ScalarTacNonLin :: ruleSets - else ruleSets - -- TODO - -- evalAesopSaturate options ruleSets.toArray - Saturate.evalSaturate ruleSets - --- For debugging -elab "int_tac_saturate" : tactic => - intTacSaturateForward - -/- Boosting a bit the `omega` tac. - - - `extraPrePreprocess`: extra-preprocessing to be done *before* this preprocessing - - `extraPreprocess`: extra-preprocessing to be done *after* this preprocessing - -/ -def intTacPreprocess (extraPrePreprocess extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do - Tactic.withMainContext do - -- Pre-preprocessing - extraPrePreprocess - -- Apply the forward rules - allGoalsNoRecover intTacSaturateForward - -- Extra preprocessing - allGoalsNoRecover extraPreprocess - -- Reduce all the terms in the goal - note that the extra preprocessing step - -- might have proven the goal, hence the `allGoals` - let dsimp := - allGoalsNoRecover do tryTac ( - -- We set `simpOnly` at false on purpose. - -- Also, we need `zetaDelta` to inline the let-bindings (otherwise, omega doesn't always manages - -- to deal with them) - dsimpAt false {zetaDelta := true} intTacSimpRocs - -- Declarations to unfold - [] - -- Theorems - [] - [] Tactic.Location.wildcard) - dsimp - -- More preprocessing: apply norm_cast to the whole context - allGoalsNoRecover (Utils.tryTac (Utils.normCastAtAll)) - -- norm_cast does weird things with negative numbers so we reapply simp - dsimp - allGoalsNoRecover do Utils.tryTac ( - Utils.simpAt true {} - -- Simprocs - [] - -- Unfoldings - [] - -- Simp lemmas - [-- Int.subNatNat is very annoying - TODO: there is probably something more general thing to do - ``Int.subNatNat_eq_coe, - -- We also need this, in case the goal is: ¬ False - ``not_false_eq_true] - -- Hypotheses - [] .wildcard) - -elab "int_tac_preprocess" : tactic => - intTacPreprocess (do pure ()) (do pure ()) - -/-- - `splitAllDisjs`: if true, also split all the matches/if then else in the context (note that - `omega` splits the *disjunctions*) - - `splitGoalConjs`: if true, split the goal if it is a conjunction so as to introduce one - subgoal per conjunction. --/ -def intTac (tacName : String) (splitAllDisjs splitGoalConjs : Bool) - (extraPrePreprocess extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do - Tactic.withMainContext do - Tactic.focus do - let g ← Tactic.getMainGoal - trace[ScalarTac] "Original goal: {g}" - -- Introduce all the universally quantified variables (includes the assumptions) - let (_, g) ← g.intros - Tactic.setGoals [g] - -- Preprocess - wondering if we should do this before or after splitting - -- the goal. I think before leads to a smaller proof term? - allGoalsNoRecover (intTacPreprocess extraPrePreprocess extraPreprocess) - -- Split the conjunctions in the goal - if splitGoalConjs then allGoalsNoRecover (Utils.repeatTac Utils.splitConjTarget) - /- If we split the disjunctions, split then use simp_all. Otherwise only use simp_all. - Note that simp_all is very useful here as a "congruence" procedure. Note however that we - only activate a very restricted set of simp lemmas (otherwise it can be very expensive, - and have surprising behaviors). -/ - allGoalsNoRecover do - try do - let simpThenOmega := do - /- IMPORTANT: we put a quite low number for `maxSteps`. - There are two reasons. - First, `simp_all` seems to loop at times, so by controling the maximum number of steps - we make sure it doesn't exceed the maximum number of heart beats (or worse, overflows the - stack). - Second, this makes the tactic very snappy. - -/ - Utils.tryTac ( - -- TODO: is there a simproc to simplify propositional logic? - Utils.simpAll {failIfUnchanged := false, maxSteps := 1000} true [``reduceIte] [] - [``and_self, ``false_implies, ``true_implies, ``Prod.mk.injEq, - ``not_false_eq_true, ``not_true_eq_false, - ``true_and, ``and_true, ``false_and, ``and_false, - ``true_or, ``or_true,``false_or, ``or_false, - ``Bool.true_eq_false, ``Bool.false_eq_true] []) - allGoalsNoRecover (do - trace[ScalarTac] "Goal after simplification: {← getMainGoal}" - trace[ScalarTac] "Calling omega" - Tactic.Omega.omegaTactic {} - trace[ScalarTac] "Omega solved the goal") - if splitAllDisjs then do - /- In order to improve performance, we first try to prove the goal without splitting. If it - fails, we split. -/ - try - trace[ScalarTac] "First trying to solve the goal without splitting" - simpThenOmega - catch _ => - trace[ScalarTac] "First attempt failed: splitting the goal and retrying" - splitAll (allGoalsNoRecover simpThenOmega) - else - simpThenOmega - catch _ => - let g ← Tactic.getMainGoal - throwError "{tacName} failed to prove the goal below.\n\nNote that {tacName} is almost equivalent to:\n {tacName}_preprocess; split_all <;> simp_all only <;> omega\n\nGoal: \n{g}" - -elab "int_tac" args:(" split_goal"?): tactic => - let splitConjs := args.raw.getArgs.size > 0 - intTac "int_tac" true splitConjs (do pure ()) (do pure ()) - --- For termination proofs -syntax "int_decr_tac" : tactic -macro_rules - | `(tactic| int_decr_tac) => - `(tactic| - simp_wf; - -- TODO: don't use a macro (namespace problems) - (first | apply ScalarTac.to_int_to_nat_lt - | apply ScalarTac.to_int_sub_to_nat_lt) <;> - simp_all <;> int_tac) - --- Checking that things happen correctly when there are several disjunctions -example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by - int_tac split_goal - ---example (x y : Int) : x + y ≥ 2 := by --- int_tac split_goal - --- Checking that things happen correctly when there are several disjunctions -example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by - int_tac split_goal - --- Checking that we can prove exfalso -example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by - int_tac - --- Intermediate cast through natural numbers -example (a : Prop) (x : Int) (h0: (0 : Nat) < x) (h1: x < 0) : a := by - int_tac - -example (x : Int) (h : x ≤ -3) : x ≤ -2 := by - int_tac - -example (x y : Int) (h : x + y = 3) : - let z := x + y - z = 3 := by - intro z - omega - --- Checking that we manage to split the cases/if then else -example (x : Int) (b : Bool) (h : if b then x ≤ 0 else x ≤ 0) : x ≤ 0 := by - int_tac - -end ScalarTac - -end Aeneas diff --git a/backends/lean/Aeneas/ScalarTac/Lemmas.lean b/backends/lean/Aeneas/ScalarTac/Lemmas.lean new file mode 100644 index 00000000..309dc034 --- /dev/null +++ b/backends/lean/Aeneas/ScalarTac/Lemmas.lean @@ -0,0 +1,377 @@ +import Aeneas.ScalarTac.ScalarTac +import Aeneas.Std.ScalarCore + +namespace Aeneas + +namespace Std + +set_option maxRecDepth 1024 + +attribute [scalar_tac_simp] + and_self false_implies true_implies Prod.mk.injEq + not_false_eq_true not_true_eq_false + true_and and_true false_and and_false + true_or or_true false_or or_false + Bool.true_eq_false Bool.false_eq_true + decide_eq_true_eq Bool.or_eq_true Bool.and_eq_true + +attribute [scalar_tac_simp] zero_add + +local syntax "simp_scalar_consts" : tactic +local macro_rules +| `(tactic|simp_scalar_consts) => + `(tactic| + simp [ + UScalar.rMax, UScalar.max, + Usize.rMax, Usize.rMax, Usize.max, + U8.rMax, U8.max, U16.rMax, U16.max, U32.rMax, U32.max, + U64.rMax, U64.max, U128.rMax, U128.max, + U8.numBits, U16.numBits, U32.numBits, U64.numBits, U128.numBits, Usize.numBits, + U8.size, U16.size, U32.size, U64.size, U128.size, Usize.size, + IScalar.rMax, IScalar.max, + IScalar.rMin, IScalar.min, + Isize.rMax, Isize.rMax, Isize.max, + I8.rMax, I8.max, I16.rMax, I16.max, I32.rMax, I32.max, + I64.rMax, I64.max, I128.rMax, I128.max, + Isize.rMin, Isize.rMin, Isize.min, + I8.rMin, I8.min, I16.rMin, I16.min, I32.rMin, I32.min, + I64.rMin, I64.min, I128.rMin, I128.min, + I8.numBits, I16.numBits, I32.numBits, I64.numBits, I128.numBits, Isize.numBits, + I8.size, I16.size, I32.size, I64.size, I128.size, Isize.size, + UScalar.size, IScalar.size, + UScalar.cMax, IScalar.cMin, IScalar.cMax]) + +@[scalar_tac_simp] theorem UScalar.max_USize_eq : UScalar.max .Usize = Usize.max := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.min_ISize_eq : IScalar.min .Isize = Isize.min := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.max_ISize_eq : IScalar.max .Isize = Isize.max := by simp_scalar_consts + +theorem Usize.max_succ_eq_pow : Usize.max + 1 = 2^System.Platform.numBits := by + simp [Usize.max, Usize.numBits] + have : 0 < 2^System.Platform.numBits := by simp + omega + +@[scalar_tac Usize.max] +theorem Usize.cMax_bound : UScalar.cMax .Usize ≤ Usize.max ∧ Usize.max + 1 = 2^System.Platform.numBits := by + simp [Usize.max, UScalar.cMax, UScalar.rMax, U32.rMax, Usize.numBits] + have := System.Platform.numBits_eq; cases this <;> simp [*] + +@[scalar_tac Usize.size] +theorem Usize.size_scalarTac_eq : Usize.size = Usize.max + 1 ∧ Usize.size = 2^System.Platform.numBits := by + simp [Usize.max, UScalar.cMax, UScalar.rMax, U32.rMax, Usize.numBits, Usize.size] + have := System.Platform.numBits_eq; cases this <;> simp [*] + +abbrev Usize.maxAbbrevPow := 2^System.Platform.numBits +@[scalar_tac Usize.maxAbbrevPow] +theorem Usize.cMax_bound' : UScalar.cMax .Usize ≤ Usize.max ∧ Usize.max + 1 = 2^System.Platform.numBits := Usize.cMax_bound + +@[scalar_tac Isize.min] +theorem Isize.cMin_bound : Isize.min ≤ IScalar.cMin .Isize ∧ Isize.min = - 2^(System.Platform.numBits - 1) := by + simp [Isize.min, IScalar.cMin, IScalar.rMin, I32.rMin, Isize.numBits, + Isize.max, IScalar.cMax, IScalar.rMax, I32.rMax] + have := System.Platform.numBits_eq; cases this <;> simp [*] + +abbrev Isize.minAbbrevPow :Int := -2^(System.Platform.numBits-1) +@[scalar_tac Isize.minAbbrevPow] +theorem Isize.cMin_bound' : Isize.min ≤ IScalar.cMin .Isize ∧ Isize.min = - 2^(System.Platform.numBits - 1) := Isize.cMin_bound + +@[scalar_tac Isize.max] +theorem Isize.cMax_bound : IScalar.cMax .Isize ≤ Isize.max ∧ Isize.max + 1 = 2^(System.Platform.numBits - 1) := by + simp [Isize.min, IScalar.cMin, IScalar.rMin, I32.rMin, Isize.numBits, + Isize.max, IScalar.cMax, IScalar.rMax, I32.rMax] + have := System.Platform.numBits_eq; cases this <;> simp [*] + +@[scalar_tac Usize.size] +theorem Isize.size_scalarTac_eq : Isize.size = 2^System.Platform.numBits := by + simp [Isize.max, Isize.numBits, Isize.size] + +abbrev Isize.maxAbbrevPow : Int := 2^(System.Platform.numBits-1) +@[scalar_tac Isize.maxAbbrevPow] +theorem Isize.cMax_bound' : IScalar.cMax .Isize ≤ Isize.max ∧ Isize.max + 1 = 2^(System.Platform.numBits - 1) := Isize.cMax_bound + +@[scalar_tac_simp] theorem U8.numBits_eq : U8.numBits = 8 := by simp_scalar_consts +@[scalar_tac_simp] theorem U16.numBits_eq : U16.numBits = 16 := by simp_scalar_consts +@[scalar_tac_simp] theorem U32.numBits_eq : U32.numBits = 32 := by simp_scalar_consts +@[scalar_tac_simp] theorem U64.numBits_eq : U64.numBits = 64 := by simp_scalar_consts +@[scalar_tac_simp] theorem U128.numBits_eq : U128.numBits = 128 := by simp_scalar_consts +@[scalar_tac_simp] theorem Usize.numBits_eq : Usize.numBits = System.Platform.numBits := by simp_scalar_consts + +@[scalar_tac_simp] theorem I8.numBits_eq : I8.numBits = 8 := by simp_scalar_consts +@[scalar_tac_simp] theorem I16.numBits_eq : I16.numBits = 16 := by simp_scalar_consts +@[scalar_tac_simp] theorem I32.numBits_eq : I32.numBits = 32 := by simp_scalar_consts +@[scalar_tac_simp] theorem I64.numBits_eq : I64.numBits = 64 := by simp_scalar_consts +@[scalar_tac_simp] theorem I128.numBits_eq : I128.numBits = 128 := by simp_scalar_consts +@[scalar_tac_simp] theorem Isize.numBits_eq : Isize.numBits = System.Platform.numBits := by simp_scalar_consts + +@[scalar_tac_simp] theorem U8.max_eq : U8.max = 255 := by simp_scalar_consts +@[scalar_tac_simp] theorem U16.max_eq : U16.max = 65535 := by simp_scalar_consts +@[scalar_tac_simp] theorem U32.max_eq : U32.max = 4294967295 := by simp_scalar_consts +@[scalar_tac_simp] theorem U64.max_eq : U64.max = 18446744073709551615 := by simp_scalar_consts +@[scalar_tac_simp] theorem U128.max_eq : U128.max = 340282366920938463463374607431768211455 := by simp_scalar_consts + +@[scalar_tac_simp] theorem UScalar.max_U8_eq : UScalar.max .U8 = 255 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.max_U16_eq : UScalar.max .U16 = 65535 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.max_U32_eq : UScalar.max .U32 = 4294967295 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.max_U64_eq : UScalar.max .U64 = 18446744073709551615 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.max_U128_eq : UScalar.max .U128 = 340282366920938463463374607431768211455 := by simp_scalar_consts + +@[scalar_tac_simp] theorem I8.min_eq : I8.min = -128 := by simp_scalar_consts +@[scalar_tac_simp] theorem I8.max_eq : I8.max = 127 := by simp_scalar_consts +@[scalar_tac_simp] theorem I16.min_eq : I16.min = -32768 := by simp_scalar_consts +@[scalar_tac_simp] theorem I16.max_eq : I16.max = 32767 := by simp_scalar_consts +@[scalar_tac_simp] theorem I32.min_eq : I32.min = -2147483648 := by simp_scalar_consts +@[scalar_tac_simp] theorem I32.max_eq : I32.max = 2147483647 := by simp_scalar_consts +@[scalar_tac_simp] theorem I64.min_eq : I64.min = -9223372036854775808 := by simp_scalar_consts +@[scalar_tac_simp] theorem I64.max_eq : I64.max = 9223372036854775807 := by simp_scalar_consts +@[scalar_tac_simp] theorem I128.min_eq : I128.min = -170141183460469231731687303715884105728 := by simp_scalar_consts +@[scalar_tac_simp] theorem I128.max_eq : I128.max = 170141183460469231731687303715884105727 := by simp_scalar_consts + +@[scalar_tac_simp] theorem IScalar.min_I8_eq : IScalar.min .I8 = -128 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.max_I8_eq : IScalar.max .I8 = 127 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.min_I16_eq : IScalar.min .I16 = -32768 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.max_I16_eq : IScalar.max .I16 = 32767 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.min_I32_eq : IScalar.min .I32 = -2147483648 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.max_I32_eq : IScalar.max .I32 = 2147483647 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.min_I64_eq : IScalar.min .I64 = -9223372036854775808 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.max_I64_eq : IScalar.max .I64 = 9223372036854775807 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.min_I128_eq : IScalar.min .I128 = -170141183460469231731687303715884105728 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.max_I128_eq : IScalar.max .I128 = 170141183460469231731687303715884105727 := by simp_scalar_consts + +@[scalar_tac_simp] theorem U8.size_eq : U8.size = 256 := by simp_scalar_consts +@[scalar_tac_simp] theorem U16.size_eq : U16.size = 65536 := by simp_scalar_consts +@[scalar_tac_simp] theorem U32.size_eq : U32.size = 4294967296 := by simp_scalar_consts +@[scalar_tac_simp] theorem U64.size_eq : U64.size = 18446744073709551616 := by simp_scalar_consts +@[scalar_tac_simp] theorem U128.size_eq : U128.size = 340282366920938463463374607431768211456 := by simp_scalar_consts + +@[scalar_tac_simp] theorem I8.size_eq : I8.size = 256 := by simp_scalar_consts +@[scalar_tac_simp] theorem I16.size_eq : I16.size = 65536 := by simp_scalar_consts +@[scalar_tac_simp] theorem I32.size_eq : I32.size = 4294967296 := by simp_scalar_consts +@[scalar_tac_simp] theorem I64.size_eq : I64.size = 18446744073709551616 := by simp_scalar_consts +@[scalar_tac_simp] theorem I128.size_eq : I128.size = 340282366920938463463374607431768211456 := by simp_scalar_consts + +@[scalar_tac_simp] theorem UScalar.size_U8_eq : UScalar.size .U8 = 256 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.size_U16_eq : U16.size = 65536 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.size_U32_eq : UScalar.size .U32 = 4294967296 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.size_U64_eq : UScalar.size .U64 = 18446744073709551616 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.size_U128_eq : UScalar.size .U128 = 340282366920938463463374607431768211456 := by simp_scalar_consts + +@[scalar_tac_simp] theorem IScalar.size_I8_eq : IScalar.size .I8 = 256 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.size_I16_eq : IScalar.size .I16 = 65536 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.size_I32_eq : IScalar.size .I32 = 4294967296 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.size_I64_eq : IScalar.size .I64 = 18446744073709551616 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.size_I128_eq : IScalar.size .I128 = 340282366920938463463374607431768211456 := by simp_scalar_consts + +@[scalar_tac_simp] theorem UScalar.cMax_U8_eq : UScalar.cMax .U8 = 255 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.cMax_U16_eq : UScalar.cMax .U16 = 65535 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.cMax_U32_eq : UScalar.cMax .U32 = 4294967295 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.cMax_U64_eq : UScalar.cMax .U64 = 18446744073709551615 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.cMax_U128_eq : UScalar.cMax .U128 = 340282366920938463463374607431768211455 := by simp_scalar_consts +@[scalar_tac_simp] theorem UScalar.cMax_Usize_eq : UScalar.cMax .Usize = 4294967295 := by simp_scalar_consts + +@[scalar_tac_simp] theorem IScalar.cMin_I8_eq : IScalar.cMin .I8 = -128 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMax_I8_eq : IScalar.cMax .I8 = 127 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMin_I16_eq : IScalar.cMin .I16 = -32768 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMax_I16_eq : IScalar.cMax .I16 = 32767 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMin_I32_eq : IScalar.cMin .I32 = -2147483648 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMax_I32_eq : IScalar.cMax .I32 = 2147483647 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMin_I64_eq : IScalar.cMin .I64 = -9223372036854775808 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMax_I64_eq : IScalar.cMax .I64 = 9223372036854775807 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMin_I128_eq : IScalar.cMin .I128 = -170141183460469231731687303715884105728 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMax_I128_eq : IScalar.cMax .I128 = 170141183460469231731687303715884105727 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMin_Isize_eq : IScalar.cMin .Isize = -2147483648 := by simp_scalar_consts +@[scalar_tac_simp] theorem IScalar.cMax_Isize_eq : IScalar.cMax .Isize = 2147483647 := by simp_scalar_consts + + +@[scalar_tac_simp] +theorem UScalarTy.USize.numBits_eq : UScalarTy.Usize.numBits = System.Platform.numBits := by simp_scalar_consts + +@[scalar_tac_simp] +theorem IScalarTy.ISize.numBits_eq : IScalarTy.Isize.numBits = System.Platform.numBits := by simp_scalar_consts + +attribute [scalar_tac_simp] Bool.toNat_false Bool.toNat_true + +end Std + +namespace ScalarTac + +open Std + +@[scalar_tac x] +theorem UScalar.bounds {ty : UScalarTy} (x : UScalar ty) : + x.val ≤ UScalar.max ty := by + simp [UScalar.max] + have := x.hBounds + omega + +@[scalar_tac x] +theorem IScalar.bounds {ty : IScalarTy} (x : IScalar ty) : + IScalar.min ty ≤ x.val ∧ x.val ≤ IScalar.max ty := by + simp [IScalar.max, IScalar.min] + have := x.hBounds + omega + +/-! +# Tests and Additional Simp Theorems +-/ + +example (x _y : U32) : x.val ≤ UScalar.max .U32 := by + scalar_tac_preprocess + +example (x _y : U32) : x.val ≤ UScalar.max .U32 := by + scalar_tac + +-- Checking that we explore the goal *and* projectors correctly +example (x : U32 × U32) : 0 ≤ x.fst.val := by + scalar_tac + +-- Checking that we properly handle [ofInt] +example : (U32.ofNat 1).val ≤ U32.max := by + scalar_tac + +example (x : Nat) (h1 : x ≤ U32.max) : + (U32.ofNat x (by scalar_tac)).val ≤ U32.max := by + scalar_tac + +-- Not equal +example (x : U32) (h0 : ¬ x = U32.ofNat 0) : 0 < x.val := by + scalar_tac + +/- See this: https://aeneas-verif.zulipchat.com/#narrow/stream/349819-general/topic/U64.20trouble/near/444049757 + + We solved it by removing the instance `OfNat` for `Scalar`. + Note however that we could also solve it with a simplification lemma. + However, after testing, we noticed we could only apply such a lemma with + the rewriting tactic (not the simplifier), probably because of the use + of typeclasses. -/ +example {u: U64} (h1: (u : Nat) < 2): (u : Nat) = 0 ∨ (u : Nat) = 1 := by + scalar_tac + +example (x : I32) : -100000000000 < x.val := by + scalar_tac + +example : (Usize.ofNat 2).val ≠ 0 := by + scalar_tac + +example (x : U32) : x.val ≤ Usize.max := by scalar_tac +example (x : I32) : x.val ≤ Isize.max := by scalar_tac +example (x : I32) : Isize.min ≤ x.val := by scalar_tac + +example (x y : Nat) (z : Int) (h : Int.subNatNat x y + z = 0) : (x : Int) - (y : Int) + z = 0 := by + scalar_tac_preprocess + omega + +example (x : U32) (h : 16 * x.val ≤ U32.max) : + 4 * (U32.ofNat (4 * x.val) (by scalar_tac)).val ≤ U32.max := by + scalar_tac + +example (b : Bool) (x y : Int) (h : if b then P ∧ x + y < 3 else x + y < 4) : x + y < 5 := by + scalar_tac +split + +open Utils + +/- Some tests which introduce big constants (those sometimes cause issues when reducing expressions). + + See for instance: https://github.com/leanprover/lean4/issues/6955 + -/ +example (x y : Nat) (h : x = y + 2^32) : 0 ≤ x := by + scalar_tac + +example (x y : Nat) (h : x = y - 2^32) : 0 ≤ x := by + scalar_tac + +example + (xi yi : U32) + (c0 : U8) + (hCarryLe : c0.val ≤ 1) + (c0u : U32) + (_ : c0u.val = c0.val) + (s1 : U32) + (c1 : Bool) + (hConv1 : if xi.val + c0u.val > U32.max then s1.val = ↑xi + ↑c0u - U32.max - 1 ∧ c1 = true else s1 = xi.val + c0u ∧ c1 = false) + (s2 : U32) + (c2 : Bool) + (hConv2 : if s1.val + yi.val > U32.max then s2.val = ↑s1 + ↑yi - U32.max - 1 ∧ c2 = true else s2 = s1.val + yi ∧ c2 = false) + (c1u : U8) + (_ : c1u.val = if c1 = true then 1 else 0) + (c2u : U8) + (_ : c2u.val = if c2 = true then 1 else 0) + (c3 : U8) + (_ : c3.val = c1u.val + c2u.val): + c3.val ≤ 1 := by + scalar_tac +split + +example (x y : Nat) (h : x = y - 2^32) : 0 ≤ x := by + scalar_tac + +example (v : { l : List α // l.length ≤ Usize.max }) : + v.val.length < 2 ^ UScalarTy.Usize.numBits := by + scalar_tac + +example (i : I8) : - 2^(Isize.numBits - 1) ≤ i.val ∧ i.val ≤ 2^(Isize.numBits - 1) := by scalar_tac +example (x : I8) : -2 ^ (System.Platform.numBits - 1) ≤ x.val := by scalar_tac + +example + (α : Type u) + (v : { l : List α // l.length ≤ Usize.max }) + (nlen : ℕ) + (h : nlen ≤ U32.max ∨ nlen ≤ 2 ^ Usize.numBits - 1) : + nlen ≤ 2 ^ Usize.numBits - 1 + := by + scalar_tac + +example + (α : Type u) + (v : { l : List α // l.length ≤ Usize.max }) + (nlen : ℕ) + (h : (decide (nlen ≤ U32.max) || decide (nlen ≤ Usize.max)) = true) : + nlen ≤ Usize.max + := by + scalar_tac + +example (x : I8) : x.toNat = x.val.toNat := by scalar_tac + +/- `assumption` triggers a "max recursion depth" error if `U32.max` is reducible -/ +example (x y : U32) + (h : 2 * x.val + 1 + y.val ≤ (U32.max : Int)) : + 2 * x.val + 1 ≤ (U32.max : Int) := by + try assumption + scalar_tac + +/-! +## Min, Max +-/ + +@[scalar_tac_simp] theorem Nat.max_eq_Max_max (x y : Nat) : Nat.max x y = x ⊔ y := by simp +@[scalar_tac_simp] theorem Nat.min_eq_Min_min (x y : Nat) : Nat.min x y = x ⊓ y := by simp + +example (x y : Nat) : x ≤ x ⊔ y := by scalar_tac +example (x y : Nat) : x ≤ Nat.max x y := by scalar_tac +example (x y : Nat) : x ⊓ y ≤ x := by scalar_tac +example (x y : Nat) : Nat.min x y ≤ x := by scalar_tac + +example (x y : Int) : x ≤ x ⊔ y := by scalar_tac +example (x y : Int) : x ≤ max x y := by scalar_tac +example (x y : Int) : x ⊓ y ≤ x := by scalar_tac +example (x y : Int) : min x y ≤ x := by scalar_tac + +/-! +## Abs +-/ + +@[scalar_tac_simp] +theorem Int.natAbs_eq_abs (x : Int) : |x| = ↑x.natAbs := by simp + +example (x y z : Int) (h0 : x.natAbs ≤ y.natAbs) (h1 : y.natAbs ≤ z.natAbs) : x ≤ z.natAbs := by + scalar_tac +example (x y : Int) (h : |x| ≤ |y|) : x ≤ |y| := by scalar_tac +example (x y : Int) (h : |x| ≤ |y|) : x ≤ |y| := by scalar_tac + +/-! +## Fast Saturate +-/ +example : + 128 ≤ Usize.max ∧ 128 ≥ 5 := by + scalar_tac +fastSaturate + +end ScalarTac + +end Aeneas diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index bdf1b67c..955f3671 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -1,151 +1,316 @@ -import Aeneas.ScalarTac.IntTac -import Aeneas.Std.ScalarCore +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.Ring.RingNF +import Aeneas.Utils +import Aeneas.ScalarTac.Core +import Aeneas.ScalarTac.Init +import Aeneas.Saturate namespace Aeneas -/- Automation for scalars - TODO: not sure it is worth having two files (Int.lean and Scalar.lean) -/ namespace ScalarTac -open Lean Lean.Elab Lean.Meta -open Std +open Utils +open Lean Lean.Elab Lean.Meta Lean.Elab.Tactic -def scalarTacSimpLemmas := - [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val, - ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv] +/- +-- DEPRECATED: `scalar_tac` used to rely on `aesop`. As there are performance issues +-- with the saturation tactic for now we use our own tactic. We will revert once the performance +-- is improved. -def scalarTacExtraPrePreprocess : Tactic.TacticM Unit := +/- Defining a custom attribute for Aesop - we use the Aesop tactic in the arithmetic tactics -/ +attribute [aesop (rule_sets := [Aeneas.ScalarTac]) unfold norm] Function.comp + +/-- The `scalar_tac` attribute used to tag forward theorems for the `scalar_tac` tactic. -/ +macro "scalar_tac" pat:term : attr => + `(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTac):ident]) (pattern := $pat)) + +/-- The `nonlin_scalar_tac` attribute used to tag forward theorems for the `scalar_tac` tactics. -/ +macro "nonlin_scalar_tac" pat:term : attr => + `(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTacNonLin):ident]) (pattern := $pat)) +-/ + +/-- The `scalar_tac` attribute used to tag forward theorems for the `scalar_tac` tactics. -/ +macro "scalar_tac" pat:term : attr => + `(attr|aeneas_saturate (set := $(Lean.mkIdent `Aeneas.ScalarTac)) (pattern := $pat)) + +/-- The `nonlin_scalar_tac` attribute used to tag forward theorems for the `scalar_tac` tactics. -/ +macro "nonlin_scalar_tac" pat:term : attr => + `(attr|aeneas_saturate (set := $(Lean.mkIdent `Aeneas.ScalarTacNonLin)) (pattern := $pat)) + +-- This is useful especially in the termination proofs +attribute [scalar_tac a.toNat] Int.toNat_eq_max + +/- Check if a proposition is a linear integer proposition. + We notably use this to check the goals: this is useful to filter goals that + are unlikely to be solvable with arithmetic tactics. -/ +class IsLinearIntProp (x : Prop) where + +instance (x y : Int) : IsLinearIntProp (x < y) where +instance (x y : Int) : IsLinearIntProp (x > y) where +instance (x y : Int) : IsLinearIntProp (x ≤ y) where +instance (x y : Int) : IsLinearIntProp (x ≥ y) where +instance (x y : Int) : IsLinearIntProp (x ≥ y) where +instance (x y : Int) : IsLinearIntProp (x = y) where + +instance (x y : Nat) : IsLinearIntProp (x < y) where +instance (x y : Nat) : IsLinearIntProp (x > y) where +instance (x y : Nat) : IsLinearIntProp (x ≤ y) where +instance (x y : Nat) : IsLinearIntProp (x ≥ y) where +instance (x y : Nat) : IsLinearIntProp (x ≥ y) where +instance (x y : Nat) : IsLinearIntProp (x = y) where + +instance : IsLinearIntProp False where +instance (p : Prop) [IsLinearIntProp p] : IsLinearIntProp (¬ p) where +instance (p q : Prop) [IsLinearIntProp p] [IsLinearIntProp q] : IsLinearIntProp (p ∨ q) where +instance (p q : Prop) [IsLinearIntProp p] [IsLinearIntProp q] : IsLinearIntProp (p ∧ q) where +-- We use the one below for goals +instance (p q : Prop) [IsLinearIntProp p] [IsLinearIntProp q] : IsLinearIntProp (p → q) where + +-- Check if the goal is a linear arithmetic goal +def goalIsLinearInt : Tactic.TacticM Bool := do Tactic.withMainContext do - -- First get rid of [ofInt] (if there are dependent arguments, we may not - -- manage to simplify the context) - Utils.simpAt true {dsimp := false, failIfUnchanged := false} - -- Simprocs - intTacSimpRocs - -- Unfoldings - [] + let gty ← Tactic.getMainTarget + match ← trySynthInstance (← mkAppM ``IsLinearIntProp #[gty]) with + | .some _ => pure true + | _ => pure false + +example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by + omega + +def scalarTacSimpRocs : List Name := [ + ``reduceIte, + ``Nat.reduceLeDiff, + ``Nat.reduceLT, ``Nat.reduceGT, ``Nat.reduceBEq, ``Nat.reduceBNe, + ``Nat.reducePow, ``Nat.reduceAdd, ``Nat.reduceSub, ``Nat.reduceMul, ``Nat.reduceDiv, ``Nat.reduceMod, + ``Int.reduceLT, ``Int.reduceLE, ``Int.reduceGT, ``Int.reduceGE, ``Int.reduceEq, ``Int.reduceNe, ``Int.reduceBEq, ``Int.reduceBNe, + ``Int.reducePow, ``Int.reduceAdd, ``Int.reduceSub, ``Int.reduceMul, ``Int.reduceDiv, ``Int.reduceMod, + ``Int.reduceNegSucc, ``Int.reduceNeg,] + +/- Small trick to prevent `simp_all` from simplifying an assumption `h1 : P v` when we have + `h0 : ∀ x, P x` in the context: we replace the forall quantifiers with our own definition + of `forall`. -/ +def forall' {α : Type u} (p : α → Prop) : Prop := ∀ (x: α), p x + +theorem forall_eq_forall' {α : Type u} (p : α → Prop) : (∀ (x: α), p x) = forall' p := by + simp [forall'] + +@[app_unexpander forall'] +def unexpForall' : Lean.PrettyPrinter.Unexpander | `($_ $_) => `(∀ _, __) | _ => throw () + +structure Config where + /- Should we use non-linear arithmetic reasoning? -/ + nonLin : Bool := false + /- If `true`, split all the matches/if then else in the context (note that `omega` + splits the *disjunctions*) -/ + split: Bool := false + /- Maximum number of steps to take with `simpAll` during the preprocessing phase. + If equal to 0, we do not call `simpAll` at all. + -/ + simpAllMaxSteps : Nat := 2000 + fastSaturate : Bool := false + +declare_config_elab elabConfig Config + +/-- Apply the scalar_tac forward rules -/ +def scalarTacSaturateForward (fast : Bool) (nonLin : Bool): Tactic.TacticM Unit := do + /- + let options : Aesop.Options := {} + -- Use a forward max depth of 0 to prevent recursively applying forward rules on the assumptions + -- introduced by the forward rules themselves. + let options ← options.toOptions' (some 0)-/ + -- We always use the rule set `Aeneas.ScalarTac`, but also need to add other rule sets locally + -- activated by the user. The `Aeneas.ScalarTacNonLin` rule set has a special treatment as + -- it is activated through an option. + let ruleSets := + let ruleSets := `Aeneas.ScalarTac :: (← scalarTacRuleSets.get) + if nonLin then `Aeneas.ScalarTacNonLin :: ruleSets + else ruleSets + -- TODO + -- evalAesopSaturate options ruleSets.toArray + Saturate.evalSaturate fast ruleSets + +-- For debugging +elab "scalar_tac_saturate" config:Parser.Tactic.optConfig : tactic => do + let config ← elabConfig config + scalarTacSaturateForward config.fastSaturate config.nonLin + +/- Propositional logic simp lemmas -/ +attribute [scalar_tac_simp] + and_self false_implies true_implies Prod.mk.injEq + not_false_eq_true not_true_eq_false + true_and and_true false_and and_false + true_or or_true false_or or_false + Bool.true_eq_false Bool.false_eq_true + decide_eq_true_eq Bool.or_eq_true Bool.and_eq_true + +/- Boosting a bit the `omega` tac. + + - `extraPrePreprocess`: extra-preprocessing to be done *before* this preprocessing + - `extraPreprocess`: extra-preprocessing to be done *after* this preprocessing + -/ +def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do + Tactic.withMainContext do + let simpLemmas ← scalarTacSimpExt.getTheorems + -- Pre-preprocessing + /- First get rid of [ofInt] (if there are dependent arguments, we may not + manage to simplify the context) -/ + trace[ScalarTac] "Original goal before preprocessing: {← getMainGoal}" + Utils.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 0} + -- Simprocs + scalarTacSimpRocs + -- Simp theorems + [simpLemmas] + -- Unfoldings + [] -- Simp lemmas - scalarTacSimpLemmas + [] -- Hypotheses [] .wildcard - -def scalarTacExtraPreprocess : Tactic.TacticM Unit := do - Tactic.withMainContext do - -- Inroduce the bounds for the isize/usize types - let add (e : Expr) : Tactic.TacticM Unit := do - let ty ← inferType e - let _ ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) e ty (asLet := false) - add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) - Utils.simpAt true {failIfUnchanged := false} + trace[ScalarTac] "Goal after first simplification: {← getMainGoal}" + -- Apply the forward rules + allGoalsNoRecover (scalarTacSaturateForward config.fastSaturate config.nonLin) + trace[ScalarTac] "Goal after saturation: {← getMainGoal}" + -- Apply `simpAll` + if config.simpAllMaxSteps ≠ 0 then + allGoalsNoRecover + (Utils.simpAll {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true + scalarTacSimpRocs [simpLemmas] [] [] []) + trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" + -- Reduce all the terms in the goal - note that the extra preprocessing step + -- might have proven the goal, hence the `allGoals` + let dsimp := + allGoalsNoRecover do tryTac ( + -- We set `simpOnly` at false on purpose. + -- Also, we need `zetaDelta` to inline the let-bindings (otherwise, omega doesn't always manages + -- to deal with them) + dsimpAt false {zetaDelta := true} scalarTacSimpRocs + -- Simp theorems + [] + -- Declarations to unfold + [] + -- Theorems + [] + [] Tactic.Location.wildcard) + dsimp + trace[ScalarTac] "Goal after first dsimp: {← getMainGoal}" + -- More preprocessing: apply norm_cast to the whole context + allGoalsNoRecover (Utils.tryTac (Utils.normCastAtAll)) + trace[ScalarTac] "Goal after first normCast: {← getMainGoal}" + -- norm_cast does weird things with negative numbers so we reapply simp + dsimp + trace[ScalarTac] "Goal after 2nd dsimp: {← getMainGoal}" + allGoalsNoRecover do Utils.tryTac ( + Utils.simpAt true {} -- Simprocs - intTacSimpRocs + [] + -- Simp theorems + [simpLemmas] -- Unfoldings - [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, - ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, - ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, - ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, - ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max, - ``Usize.min, - ``Scalar.in_bounds, - ``Scalar.toNat, ``Isize.toNat, ``USize.toNat, - ``I8.toNat, ``I16.toNat, ``I32.toNat, ``I64.toNat, ``I128.toNat, - ``U8.toNat, ``U16.toNat, ``U32.toNat, ``U64.toNat, ``U128.toNat, - ``U8.unsigned_ofNat_toNat, ``U16.unsigned_ofNat_toNat, - ``U32.unsigned_ofNat_toNat, ``U64.unsigned_ofNat_toNat, - ``U128.unsigned_ofNat_toNat, ``Usize.unsigned_ofNat_toNat, - ] + [] -- Simp lemmas - scalarTacSimpLemmas + [-- Int.subNatNat is very annoying - TODO: there is probably something more general thing to do + ``Int.subNatNat_eq_coe, + -- We also need this, in case the goal is: ¬ False + ``not_false_eq_true, + -- Remove the forall quantifiers to prepare for the call of `simp_all` (we + -- don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) + ``forall_eq_forall' + ] -- Hypotheses - [] .wildcard - trace[ScalarTac] "scalarTacExtraPreprocess: after simp: {(← Tactic.getMainGoal)}" - -elab "scalar_tac_preprocess" : tactic => - intTacPreprocess scalarTacExtraPrePreprocess scalarTacExtraPreprocess + [] .wildcard) + trace[ScalarTac] "Goal after simpAt following dsimp: {← getMainGoal}" --- A tactic to solve linear arithmetic goals in the presence of scalars -def scalarTac (splitAllDisjs splitGoalConjs : Bool) : Tactic.TacticM Unit := do - intTac "scalar_tac" splitAllDisjs splitGoalConjs scalarTacExtraPrePreprocess scalarTacExtraPreprocess +elab "scalar_tac_preprocess" config:Parser.Tactic.optConfig : tactic => do + let config ← elabConfig config + scalarTacPreprocess config -elab "scalar_tac" : tactic => - scalarTac true false - -@[scalar_tac x] -theorem Scalar.bounds {ty : ScalarTy} (x : Scalar ty) : - Scalar.min ty ≤ x.val ∧ x.val ≤ Scalar.max ty := - And.intro x.hmin x.hmax +/-- - `splitAllDisjs`: if true, also split all the matches/if then else in the context (note that + `omega` splits the *disjunctions*) + - `splitGoalConjs`: if true, split the goal if it is a conjunction so as to introduce one + subgoal per conjunction. +-/ +def scalarTac (config : Config) : Tactic.TacticM Unit := do + Tactic.withMainContext do + Tactic.focus do + let simpLemmas ← scalarTacSimpExt.getTheorems + let g ← Tactic.getMainGoal + trace[ScalarTac] "Original goal: {g}" + -- Introduce all the universally quantified variables (includes the assumptions) + let (_, g) ← g.intros + Tactic.setGoals [g] + -- Preprocess - wondering if we should do this before or after splitting + -- the goal. I think before leads to a smaller proof term? + allGoalsNoRecover (scalarTacPreprocess config) + allGoalsNoRecover do + try do + if config.split then do + trace[ScalarTac] "Splitting the goal" + /- If we split the `if then else` call `simp_all` again -/ + splitAll do + allGoalsNoRecover + (Utils.simpAll {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true + scalarTacSimpRocs [simpLemmas] [] [] []) + trace[ScalarTac] "Calling omega" + allGoalsNoRecover (Tactic.Omega.omegaTactic {}) + else + trace[ScalarTac] "Calling omega" + Tactic.Omega.omegaTactic {} + catch _ => + let g ← Tactic.getMainGoal + throwError "scalar_tac failed to prove the goal below.\n\nNote that scalar_tac is almost equivalent to:\n scalar_tac_preprocess; split_all <;> simp_all only <;> omega\n\nGoal: \n{g}" -example (x _y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by - scalar_tac_preprocess - simp [*] +example : True := by simp -example (x _y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by - scalar_tac - --- Checking that we explore the goal *and* projectors correctly -example (x : U32 × U32) : 0 ≤ x.fst.val := by - scalar_tac +elab "scalar_tac" config:Parser.Tactic.optConfig : tactic => do + let config ← elabConfig config + scalarTac config --- Checking that we properly handle [ofInt] -example : (U32.ofInt 1).val ≤ U32.max := by - scalar_tac +-- For termination proofs +syntax "int_decr_tac" : tactic +macro_rules + | `(tactic| int_decr_tac) => + `(tactic| + simp_wf; + -- TODO: don't use a macro (namespace problems) + (first | apply ScalarTac.to_int_to_nat_lt + | apply ScalarTac.to_int_sub_to_nat_lt) <;> + simp_all <;> scalar_tac) -example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) : - (U32.ofIntCore x (by constructor <;> scalar_tac)).val ≤ U32.max := by - scalar_tac_preprocess +-- Checking that things happen correctly when there are several conjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by scalar_tac --- Not equal -example (x : U32) (h0 : ¬ x = U32.ofInt 0) : 0 < x.val := by +-- Checking that things happen correctly when there are several conjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by scalar_tac -/- See this: https://aeneas-verif.zulipchat.com/#narrow/stream/349819-general/topic/U64.20trouble/near/444049757 - - We solved it by removing the instance `OfNat` for `Scalar`. - Note however that we could also solve it with a simplification lemma. - However, after testing, we noticed we could only apply such a lemma with - the rewriting tactic (not the simplifier), probably because of the use - of typeclasses. -/ -example {u: U64} (h1: (u : Int) < 2): (u : Int) = 0 ∨ (u : Int) = 1 := by +-- Checking that we can prove exfalso +example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by scalar_tac -example (x : I32) : -100000000000 < x.val := by +-- Intermediate cast through natural numbers +example (a : Prop) (x : Int) (h0: (0 : Nat) < x) (h1: x < 0) : a := by scalar_tac -example : (Usize.ofInt 2).val ≠ 0 := by +example (x : Int) (h : x ≤ -3) : x ≤ -2 := by scalar_tac -example (x y : Nat) (z : Int) (h : Int.subNatNat x y + z = 0) : (x : Int) - (y : Int) + z = 0 := by - scalar_tac_preprocess +example (x y : Int) (h : x + y = 3) : + let z := x + y + z = 3 := by + intro z omega -example (x : U32) (h : 16 * x.val ≤ U32.max) : - 4 * (U32.ofInt (4 * x.val) (by scalar_tac)).val ≤ U32.max := by +example (P : Nat → Prop) (z : Nat) (h : ∀ x, P x → x ≤ z) (y : Nat) (hy : P y) : + y + 2 ≤ z + 2 := by + have := h y hy scalar_tac -example (b : Bool) (x y : Int) (h : if b then P ∧ x + y < 3 else x + y < 4) : x + y < 5 := by - scalar_tac - -example - (xi yi : U32) - (c0 : U8) - (hCarryLe : c0.val ≤ 1) - (c0u : U32) - (_ : c0u.val = c0.val) - (s1 : U32) - (c1 : Bool) - (hConv1 : if xi.val + c0u.val > U32.max then s1.val = ↑xi + ↑c0u - U32.max - 1 ∧ c1 = true else s1 = xi.val + c0u ∧ c1 = false) - (s2 : U32) - (c2 : Bool) - (hConv2 : if s1.val + yi.val > U32.max then s2.val = ↑s1 + ↑yi - U32.max - 1 ∧ c2 = true else s2 = s1.val + yi ∧ c2 = false) - (c1u : U8) - (_ : c1u.val = if c1 = true then 1 else 0) - (c2u : U8) - (_ : c2u.val = if c2 = true then 1 else 0) - (c3 : U8) - (_ : c3.val = c1u.val + c2u.val): - c3.val ≤ 1 := by - scalar_tac +-- Checking that we manage to split the cases/if then else +example (x : Int) (b : Bool) (h : if b then x ≤ 0 else x ≤ 0) : x ≤ 0 := by + scalar_tac +split end ScalarTac diff --git a/backends/lean/Aeneas/Std/Alloc.lean b/backends/lean/Aeneas/Std/Alloc.lean index e995d030..d643b3f7 100644 --- a/backends/lean/Aeneas/Std/Alloc.lean +++ b/backends/lean/Aeneas/Std/Alloc.lean @@ -1,6 +1,5 @@ import Lean import Aeneas.Std.Core -import Aeneas.Std.Core namespace Aeneas diff --git a/backends/lean/Aeneas/Std/ArraySlice.lean b/backends/lean/Aeneas/Std/ArraySlice.lean index 2fda553a..6aee754d 100644 --- a/backends/lean/Aeneas/Std/ArraySlice.lean +++ b/backends/lean/Aeneas/Std/ArraySlice.lean @@ -16,6 +16,19 @@ namespace Std open Result Error core.ops.range +/-! +# Notations for `List` +-/ +instance {α : Type u} : GetElem (List α) Usize α (fun l i => i.val < l.length) where + getElem l i h := getElem l i.val h + +instance {α : Type u} : GetElem? (List α) Usize α (fun l i => i < l.length) where + getElem? l i := getElem? l i.val + +/-! +# Array +-/ + def Array (α : Type u) (n : Usize) := { l : List α // l.length = n.val } /-- We need this to coerce arrays to lists without marking `Array` as reducible. @@ -43,44 +56,61 @@ abbrev Array.length {α : Type u} {n : Usize} (v : Array α n) : Nat := v.val.le @[simp] abbrev Array.v {α : Type u} {n : Usize} (v : Array α n) : List α := v.val -example {α: Type u} {n : Usize} (v : Array α n) : v.length ≤ Scalar.max ScalarTy.Usize := by +example {α: Type u} {n : Usize} (v : Array α n) : v.length ≤ Usize.max := by scalar_tac -def Array.make {α : Type u} (n : Usize) (init : List α) (hl : init.length = n.val := by rfl) : +def Array.make {α : Type u} (n : Usize) (init : List α) (hl : init.length = n.val := by simp) : Array α n := ⟨ init, by apply hl ⟩ -example : Array Int (Usize.ofInt 2) := Array.make (Usize.ofInt 2) [0, 1] +example : Array Int (Usize.ofNat 2) := Array.make (Usize.ofNat 2) [0, 1] (by simp) -example : Array Int (Usize.ofInt 2) := +example : Array Int (Usize.ofNat 2) := let x := 0 let y := 1 - Array.make (Usize.ofInt 2) [x, y] + Array.make (Usize.ofNat 2) [x, y] -example : Result (Array Int (Usize.ofInt 2)) := do +example : Result (Array Int (Usize.ofNat 2)) := do let x ← ok 0 let y ← ok 1 - ok (Array.make (Usize.ofInt 2) [x, y]) + ok (Array.make (Usize.ofNat 2) [x, y]) -@[simp] -abbrev Array.index_s {α : Type u} {n : Usize} [Inhabited α] (v : Array α n) (i : Nat) : α := - v.val.index i +@[reducible] instance {α : Type u} {n : Usize} : GetElem (Array α n) Nat α (fun a i => i < a.val.length) where + getElem a i h := getElem a.val i h + +@[reducible] instance {α : Type u} {n : Usize} : GetElem? (Array α n) Nat α (fun a i => i < a.val.length) where + getElem? a i := getElem? a.val i + +@[simp, scalar_tac_simp] theorem Array.getElem?_Nat_eq {α : Type u} {n : Usize} (v : Array α n) (i : Nat) : v[i]? = v.val[i]? := by rfl +@[simp, scalar_tac_simp] theorem Array.getElem!_Nat_eq {α : Type u} [Inhabited α] {n : Usize} (v : Array α n) (i : Nat) : v[i]! = v.val[i]! := by rfl + +@[reducible] instance {α : Type u} {n : Usize} : GetElem (Array α n) Usize α (fun a i => i.val < a.val.length) where + getElem a i h := getElem a.val i.val h + +@[reducible] instance {α : Type u} {n : Usize} : GetElem? (Array α n) Usize α (fun a i => i.val < a.val.length) where + getElem? a i := getElem? a.val i.val + +@[simp, scalar_tac_simp] theorem Array.getElem?_Usize_eq {α : Type u} {n : Usize} (v : Array α n) (i : Usize) : v[i]? = v.val[i.val]? := by rfl +@[simp, scalar_tac_simp] theorem Array.getElem!_Usize_eq {α : Type u} [Inhabited α] {n : Usize} (v : Array α n) (i : Usize) : v[i]! = v.val[i.val]! := by rfl + +@[simp, scalar_tac_simp] abbrev Array.get? {α : Type u} {n : Usize} (v : Array α n) (i : Nat) : Option α := getElem? v i +@[simp, scalar_tac_simp] abbrev Array.get! {α : Type u} {n : Usize} [Inhabited α] (v : Array α n) (i : Nat) : α := getElem! v i @[simp] abbrev Array.slice {α : Type u} {n : Usize} [Inhabited α] (v : Array α n) (i j : Nat) : List α := v.val.slice i j def Array.index_usize {α : Type u} {n : Usize} (v: Array α n) (i: Usize) : Result α := - match v.val.indexOpt i.toNat with + match v[i]? with | none => fail .arrayOutOfBounds | some x => ok x -- For initialization def Array.repeat {α : Type u} (n : Usize) (x : α) : Array α n := - ⟨ List.replicate n.toNat x, by have h := n.hmin; simp_all [Scalar.min] ⟩ + ⟨ List.replicate n.val x, by simp_all ⟩ -@[pspec] +@[progress] theorem Array.repeat_spec {α : Type u} (n : Usize) (x : α) : - ∃ a, Array.repeat n x = a ∧ a.val = List.replicate n.toNat x := by + ∃ a, Array.repeat n x = a ∧ a.val = List.replicate n.val x := by simp [Array.repeat] /- In the theorems below: we don't always need the `∃ ..`, but we use one @@ -88,60 +118,68 @@ theorem Array.repeat_spec {α : Type u} (n : Usize) (x : α) : helps control the context. -/ -@[pspec] +@[progress] theorem Array.index_usize_spec {α : Type u} {n : Usize} [Inhabited α] (v: Array α n) (i: Usize) - (hbound : i.toNat < v.length) : - ∃ x, v.index_usize i = ok x ∧ x = v.val.index i.toNat := by + (hbound : i.val < v.length) : + ∃ x, v.index_usize i = ok x ∧ x = v.val[i.val]! := by simp only [index_usize] - -- TODO: dependent rewrite - have h := List.indexOpt_eq_index v.val i.toNat (by scalar_tac) - simp [*] + simp at * + split <;> simp_all -def Array.update_usize {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : Result (Array α n) := - match v.val.indexOpt i.toNat with - | none => fail .arrayOutOfBounds - | some _ => - ok ⟨ v.val.update i.toNat x, by have := v.property; simp [*] ⟩ +def Array.set {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : Array α n := + ⟨ v.val.set i.val x, by have := v.property; simp [*] ⟩ -def Array.update {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : Array α n := - ⟨ v.val.update i.toNat x, by have := v.property; simp [*] ⟩ +def Array.set_opt {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: Option α) : Array α n := + ⟨ v.val.set_opt i.val x, by have := v.property; simp [*] ⟩ @[simp] -theorem Array.update_val_eq {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : - (v.update i x).val = v.val.update i.toNat x := by - simp [update] - -@[scalar_tac v.update i x] -theorem Array.update_length {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : - (v.update i x).length = v.length := by simp - -@[pspec] -theorem Array.update_usize_spec {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x : α) - (hbound : i.toNat < v.length) : - ∃ nv, v.update_usize i x = ok nv ∧ - nv = v.update i x +theorem Array.set_val_eq {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : + (v.set i x).val = v.val.set i.val x := by + simp [set] + +@[simp] +theorem Array.set_opt_val_eq {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: Option α) : + (v.set_opt i x).val = v.val.set_opt i.val x := by + simp [set_opt] + +@[scalar_tac_simp] +theorem Array.set_length {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : + (v.set i x).length = v.length := by simp + +def Array.update {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : Result (Array α n) := + match v[i]? with + | none => fail .arrayOutOfBounds + | some _ => + ok ⟨ v.val.set i.val x, by have := v.property; simp [*] ⟩ + +@[progress] +theorem Array.update_spec {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x : α) + (hbound : i.val < v.length) : + ∃ nv, v.update i x = ok nv ∧ + nv = v.set i x := by - simp only [update_usize] - have h := List.indexOpt_bounds v.val i.toNat - split - . simp_all [length] - scalar_tac - . simp [Array.update] + simp only [update, set] + simp at * + split <;> simp_all def Array.index_mut_usize {α : Type u} {n : Usize} (v: Array α n) (i: Usize) : Result (α × (α -> Array α n)) := do let x ← index_usize v i - ok (x, update v i) + ok (x, set v i) -@[pspec] +@[progress] theorem Array.index_mut_usize_spec {α : Type u} {n : Usize} [Inhabited α] (v: Array α n) (i: Usize) - (hbound : i.toNat < v.length) : - ∃ x, v.index_mut_usize i = ok (x, update v i) ∧ - x = v.val.index i.toNat := by + (hbound : i.val < v.length) : + ∃ x, v.index_mut_usize i = ok (x, set v i) ∧ + x = v.val.get! i.val := by simp only [index_mut_usize, Bind.bind, bind] have ⟨ x, h ⟩ := index_usize_spec v i hbound simp [h] +/-! +# Slice +-/ + def Slice (α : Type u) := { l : List α // l.length ≤ Usize.max } /-- We need this to coerce slices to lists without marking `Slice` as reducible. @@ -155,8 +193,8 @@ instance [BEq α] : BEq (Slice α) := SubtypeBEq _ instance [BEq α] [LawfulBEq α] : LawfulBEq (Slice α) := SubtypeLawfulBEq _ @[scalar_tac s] -theorem Slice.length_ineq {α : Type u} (s : Slice α) : s.val.length ≤ Scalar.max ScalarTy.Usize := by - cases s; simp[Scalar.max, *] +theorem Slice.length_ineq {α : Type u} (s : Slice α) : s.val.length ≤ Usize.max := by + cases s; simp[*] -- TODO: move/remove? @[scalar_tac s] @@ -168,29 +206,55 @@ abbrev Slice.length {α : Type u} (v : Slice α) : Nat := v.val.length @[simp] abbrev Slice.v {α : Type u} (v : Slice α) : List α := v.val -example {a: Type u} (v : Slice a) : v.length ≤ Scalar.max ScalarTy.Usize := by +example {a: Type u} (v : Slice a) : v.length ≤ Usize.max := by scalar_tac def Slice.new (α : Type u) : Slice α := ⟨ [], by simp ⟩ -- TODO: very annoying that the α is an explicit parameter abbrev Slice.len {α : Type u} (v : Slice α) : Usize := - Usize.ofIntCore v.val.length (by constructor <;> scalar_tac) + Usize.ofNatCore v.val.length (by scalar_tac) -@[simp] +@[simp, scalar_tac_simp] theorem Slice.len_val {α : Type u} (v : Slice α) : (Slice.len v).val = v.length := - by rfl + by simp -@[simp] -abbrev Slice.index_s {α : Type u} [Inhabited α] (v: Slice α) (i: Nat) : α := - v.val.index i +@[reducible] instance {α : Type u} : GetElem (Slice α) Nat α (fun a i => i < a.val.length) where + getElem a i h := getElem a.val i h + +@[reducible] instance {α : Type u} : GetElem? (Slice α) Nat α (fun a i => i < a.val.length) where + getElem? a i := getElem? a.val i + +@[simp, scalar_tac_simp] theorem Slice.getElem?_Nat_eq {α : Type u} (v : Slice α) (i : Nat) : v[i]? = v.val[i]? := by rfl +@[simp, scalar_tac_simp] theorem Slice.getElem!_Nat_eq {α : Type u} [Inhabited α] (v : Slice α) (i : Nat) : v[i]! = v.val[i]! := by rfl + +@[reducible] instance {α : Type u} : GetElem (Slice α) Usize α (fun a i => i.val < a.val.length) where + getElem a i h := getElem a.val i.val h + +@[reducible] instance {α : Type u} : GetElem? (Slice α) Usize α (fun a i => i < a.val.length) where + getElem? a i := getElem? a.val i.val + +@[simp, scalar_tac_simp] theorem Slice.getElem?_Usize_eq {α : Type u} (v : Slice α) (i : Usize) : v[i]? = v.val[i.val]? := by rfl +@[simp, scalar_tac_simp] theorem Slice.getElem!_Usize_eq {α : Type u} [Inhabited α] (v : Slice α) (i : Usize) : v[i]! = v.val[i.val]! := by rfl + +@[simp, scalar_tac_simp] abbrev Slice.get? {α : Type u} (v : Slice α) (i : Nat) : Option α := getElem? v i +@[simp, scalar_tac_simp] abbrev Slice.get! {α : Type u} [Inhabited α] (v : Slice α) (i : Nat) : α := getElem! v i + +def Slice.set {α : Type u} (v: Slice α) (i: Usize) (x: α) : Slice α := + ⟨ v.val.set i.val x, by have := v.property; simp [*] ⟩ + +def Slice.set_opt {α : Type u} (v: Slice α) (i: Usize) (x: Option α) : Slice α := + ⟨ v.val.set_opt i.val x, by have := v.property; simp [*] ⟩ + +def Slice.drop {α} (s : Slice α) (i : Usize) : Slice α := + ⟨ s.val.drop i.val, by scalar_tac ⟩ @[simp] abbrev Slice.slice {α : Type u} [Inhabited α] (s : Slice α) (i j : Nat) : List α := s.val.slice i j def Slice.index_usize {α : Type u} (v: Slice α) (i: Usize) : Result α := - match v.val.indexOpt i.toNat with + match v[i]? with | none => fail .arrayOutOfBounds | some x => ok x @@ -199,55 +263,54 @@ def Slice.index_usize {α : Type u} (v: Slice α) (i: Usize) : Result α := helps control the context. -/ -@[pspec] +@[progress] theorem Slice.index_usize_spec {α : Type u} [Inhabited α] (v: Slice α) (i: Usize) - (hbound : i.toNat < v.length) : - ∃ x, v.index_usize i = ok x ∧ x = v.val.index i.toNat := by + (hbound : i.val < v.length) : + ∃ x, v.index_usize i = ok x ∧ x = v.val[i.val]! := by simp only [index_usize] - -- TODO: dependent rewrite - have h := List.indexOpt_eq_index v.val i.toNat (by scalar_tac) + simp at * simp [*] -def Slice.update_usize {α : Type u} (v: Slice α) (i: Usize) (x: α) : Result (Slice α) := - match v.val.indexOpt i.toNat with +@[simp] +theorem Slice.set_val_eq {α : Type u} (v: Slice α) (i: Usize) (x: α) : + (v.set i x) = v.val.set i.val x := by + simp [set] + +@[simp] +theorem Slice.set_opt_val_eq {α : Type u} (v: Slice α) (i: Usize) (x: Option α) : + (v.set_opt i x) = v.val.set_opt i.val x := by + simp [set_opt] + +@[scalar_tac_simp] +theorem Slice.set_length {α : Type u} (v: Slice α) (i: Usize) (x: α) : + (v.set i x).length = v.length := by simp + +def Slice.update {α : Type u} (v: Slice α) (i: Usize) (x: α) : Result (Slice α) := + match v.val[i.val]? with | none => fail .arrayOutOfBounds | some _ => - ok ⟨ v.val.update i.toNat x, by have := v.property; simp [*] ⟩ + ok ⟨ v.val.set i.val x, by have := v.property; simp [*] ⟩ -def Slice.update {α : Type u} (v: Slice α) (i: Usize) (x: α) : Slice α := - ⟨ v.val.update i.toNat x, by have := v.property; simp [*] ⟩ - -@[simp] -theorem Slice.update_val_eq {α : Type u} (v: Slice α) (i: Usize) (x: α) : - (v.update i x) = v.val.update i.toNat x := by - simp [update] - -@[scalar_tac v.update i x] -theorem Slice.update_length {α : Type u} (v: Slice α) (i: Usize) (x: α) : - (v.update i x).length = v.length := by simp - -@[pspec] -theorem Slice.update_usize_spec {α : Type u} (v: Slice α) (i: Usize) (x : α) - (hbound : i.toNat < v.length) : - ∃ nv, v.update_usize i x = ok nv ∧ - nv = v.update i x +@[progress] +theorem Slice.update_spec {α : Type u} (v: Slice α) (i: Usize) (x : α) + (hbound : i.val < v.length) : + ∃ nv, v.update i x = ok nv ∧ + nv = v.set i x := by - simp only [update_usize] - have h := List.indexOpt_bounds v.val i.toNat - split - . simp_all [length]; scalar_tac - . simp [Slice.update] + simp only [update, set] + simp at * + simp [*] def Slice.index_mut_usize {α : Type u} (v: Slice α) (i: Usize) : Result (α × (α → Slice α)) := do let x ← Slice.index_usize v i - ok (x, Slice.update v i) + ok (x, Slice.set v i) -@[pspec] +@[progress] theorem Slice.index_mut_usize_spec {α : Type u} [Inhabited α] (v: Slice α) (i: Usize) - (hbound : i.toNat < v.length) : - ∃ x, v.index_mut_usize i = ok (x, Slice.update v i) ∧ - x = v.val.index i.toNat := by + (hbound : i.val < v.length) : + ∃ x, v.index_mut_usize i = ok (x, Slice.set v i) ∧ + x = v.val[i.val]! := by simp only [index_mut_usize, Bind.bind, bind] have ⟨ x, h ⟩ := Slice.index_usize_spec v i hbound simp [h] @@ -261,7 +324,7 @@ theorem Slice.index_mut_usize_spec {α : Type u} [Inhabited α] (v: Slice α) (i def Array.to_slice {α : Type u} {n : Usize} (v : Array α n) : Result (Slice α) := ok ⟨ v.val, by scalar_tac ⟩ -@[pspec] +@[progress] theorem Array.to_slice_spec {α : Type u} {n : Usize} (v : Array α n) : ∃ s, to_slice v = ok s ∧ v.val = s.val := by simp [to_slice] @@ -280,7 +343,7 @@ def Array.to_slice_mut {α : Type u} {n : Usize} (a : Array α n) : let s ← Array.to_slice a ok (s, Array.from_slice a) -@[pspec] +@[progress] theorem Array.to_slice_mut_spec {α : Type u} {n : Usize} (v : Array α n) : ∃ s, to_slice_mut v = ok (s, Array.from_slice v) ∧ v.val = s.val @@ -289,40 +352,40 @@ theorem Array.to_slice_mut_spec {α : Type u} {n : Usize} (v : Array α n) : def Array.subslice {α : Type u} {n : Usize} (a : Array α n) (r : Range Usize) : Result (Slice α) := -- TODO: not completely sure here if r.start.val < r.end_.val ∧ r.end_.val ≤ a.val.length then - ok ⟨ a.val.slice r.start.toNat r.end_.toNat, + ok ⟨ a.val.slice r.start.val r.end_.val, by - have := a.val.slice_length_le r.start.toNat r.end_.toNat + have := a.val.slice_length_le r.start.val r.end_.val scalar_tac ⟩ else fail panic -@[pspec] +@[progress] theorem Array.subslice_spec {α : Type u} {n : Usize} [Inhabited α] (a : Array α n) (r : Range Usize) (h0 : r.start.val < r.end_.val) (h1 : r.end_.val ≤ a.val.length) : ∃ s, subslice a r = ok s ∧ - s.val = a.val.slice r.start.toNat r.end_.toNat ∧ - (∀ i, 0 ≤ i → i + r.start.val < r.end_.val → s.val.index i.toNat = a.val.index (r.start.toNat + i.toNat)) + s.val = a.val.slice r.start.val r.end_.val ∧ + (∀ i, i + r.start.val < r.end_.val → s.val[i]! = a.val[r.start.val + i]!) := by - simp [subslice, *] - intro i _ _ - have := List.index_slice r.start.toNat r.end_.toNat i.toNat a.val (by scalar_tac) (by scalar_tac) - simp [*] + simp only [subslice, true_and, h0, h1, ↓reduceIte, ok.injEq, exists_eq_left', true_and] + intro i _ + have := List.getElem!_slice r.start.val r.end_.val i a.val (by scalar_tac) (by scalar_tac) + simp only [this] set_option maxHeartbeats 500000 def Array.update_subslice {α : Type u} {n : Usize} (a : Array α n) (r : Range Usize) (s : Slice α) : Result (Array α n) := -- TODO: not completely sure here - if h: r.start.toNat < r.end_.toNat ∧ r.end_.toNat ≤ a.length ∧ s.val.length = r.end_.toNat - r.start.toNat then - let s_beg := a.val.take r.start.toNat - let s_end := a.val.drop r.end_.toNat - have : s_beg.length = r.start.toNat := by + if h: r.start.val < r.end_.val ∧ r.end_.val ≤ a.length ∧ s.val.length = r.end_.val - r.start.val then + let s_beg := a.val.take r.start.val + let s_end := a.val.drop r.end_.val + have : s_beg.length = r.start.val := by scalar_tac - have : s_end.length = a.val.length - r.end_.toNat := by + have : s_end.length = a.val.length - r.end_.val := by scalar_tac let na := s_beg.append (s.val.append s_end) have : na.length = a.val.length:= by simp [na]; scalar_tac - ok ⟨ na, by simp_all; scalar_tac ⟩ + ok ⟨ na, by simp_all ⟩ else fail panic @@ -331,58 +394,60 @@ def Array.update_subslice {α : Type u} {n : Usize} (a : Array α n) (r : Range -- operations/ -- We should introduce special symbols for the monadic arithmetic operations -- (the user will never write those symbols directly). -@[pspec] +@[progress] theorem Array.update_subslice_spec {α : Type u} {n : Usize} [Inhabited α] (a : Array α n) (r : Range Usize) (s : Slice α) - (_ : r.start.toNat < r.end_.toNat) (_ : r.end_.toNat ≤ a.length) (_ : s.length = r.end_.toNat - r.start.toNat) : + (_ : r.start.val < r.end_.val) (_ : r.end_.val ≤ a.length) (_ : s.length = r.end_.val - r.start.val) : ∃ na, update_subslice a r s = ok na ∧ - (∀ i, 0 ≤ i → i < r.start.toNat → na.index_s i = a.index_s i) ∧ - (∀ i, r.start.toNat ≤ i → i < r.end_.toNat → na.index_s i = s.index_s (i - r.start.toNat)) ∧ - (∀ i, r.end_.toNat ≤ i → i < n.toNat → na.index_s i = a.index_s i) := by - simp [update_subslice, *] - have h := List.replace_slice_index r.start.toNat r.end_.toNat a.val s.val + (∀ i, i < r.start.val → na[i]! = a[i]!) ∧ + (∀ i, r.start.val ≤ i → i < r.end_.val → na[i]! = s[i - r.start.val]!) ∧ + (∀ i, r.end_.val ≤ i → i < n.val → na[i]! = a[i]!) := by + simp only [update_subslice, length, and_true, true_and, List.append_eq, + get!, Slice.get!, ↓reduceDIte, ok.injEq, exists_eq_left', *] + have h := List.replace_slice_getElem! r.start.val r.end_.val a.val s.val (by scalar_tac) (by scalar_tac) (by scalar_tac) simp [List.replace_slice] at h have ⟨ h0, h1, h2 ⟩ := h clear h split_conjs . intro i _ - have := h0 i (by int_tac) + have := h0 i (by scalar_tac) simp_all . intro i _ _ - have := h1 i (by int_tac) (by int_tac) + have := h1 i (by scalar_tac) (by scalar_tac) simp [*] . intro i _ _ - have := h2 i (by int_tac) (by int_tac) + have := h2 i (by scalar_tac) (by scalar_tac) simp [*] def Slice.subslice {α : Type u} (s : Slice α) (r : Range Usize) : Result (Slice α) := -- TODO: not completely sure here if r.start.val < r.end_.val ∧ r.end_.val ≤ s.length then - ok ⟨ s.val.slice r.start.toNat r.end_.toNat, + ok ⟨ s.val.slice r.start.val r.end_.val, by - have := s.val.slice_length_le r.start.toNat r.end_.toNat + have := s.val.slice_length_le r.start.val r.end_.val scalar_tac ⟩ else fail panic -@[pspec] +@[progress] theorem Slice.subslice_spec {α : Type u} [Inhabited α] (s : Slice α) (r : Range Usize) - (h0 : r.start.toNat < r.end_.toNat) (h1 : r.end_.toNat ≤ s.val.length) : + (h0 : r.start.val < r.end_.val) (h1 : r.end_.val ≤ s.val.length) : ∃ ns, subslice s r = ok ns ∧ - ns.val = s.slice r.start.toNat r.end_.toNat ∧ - (∀ i, i + r.start.toNat < r.end_.toNat → ns.index_s i = s.index_s (r.start.toNat + i)) + ns.val = s.slice r.start.val r.end_.val ∧ + (∀ i, i + r.start.val < r.end_.val → ns[i]! = s[r.start.val + i]!) := by - simp_all [subslice] + simp_all only [subslice, length, and_self, ite_true, ok.injEq, slice, get!, exists_eq_left', + true_and] intro i _ - have := List.index_slice r.start.toNat r.end_.toNat i s.val (by scalar_tac) (by scalar_tac) - simp [*] + have := List.getElem!_slice r.start.val r.end_.val i s.val (by scalar_tac) (by scalar_tac) + apply this def Slice.update_subslice {α : Type u} (s : Slice α) (r : Range Usize) (ss : Slice α) : Result (Slice α) := -- TODO: not completely sure here - if h: r.start.toNat < r.end_.toNat ∧ r.end_.toNat ≤ s.length ∧ ss.val.length = r.end_.toNat - r.start.toNat then - let s_beg := s.val.take r.start.toNat - let s_end := s.val.drop r.end_.toNat - have : s_beg.length = r.start.toNat := by scalar_tac + if h: r.start.val < r.end_.val ∧ r.end_.val ≤ s.length ∧ ss.val.length = r.end_.val - r.start.val then + let s_beg := s.val.take r.start.val + let s_end := s.val.drop r.end_.val + have : s_beg.length = r.start.val := by scalar_tac have : s_end.length = s.val.length - r.end_.val := by scalar_tac let ns := s_beg.append (ss.val.append s_end) have : ns.length = s.val.length := by simp [ns, *]; scalar_tac @@ -390,39 +455,40 @@ def Slice.update_subslice {α : Type u} (s : Slice α) (r : Range Usize) (ss : S else fail panic -@[pspec] +@[progress] theorem Slice.update_subslice_spec {α : Type u} [Inhabited α] (a : Slice α) (r : Range Usize) (ss : Slice α) - (_ : r.start.toNat < r.end_.toNat) (_ : r.end_.toNat ≤ a.length) (_ : ss.length = r.end_.toNat - r.start.toNat) : + (_ : r.start.val < r.end_.val) (_ : r.end_.val ≤ a.length) (_ : ss.length = r.end_.val - r.start.val) : ∃ na, update_subslice a r ss = ok na ∧ - (∀ i, i < r.start.toNat → na.index_s i = a.index_s i) ∧ - (∀ i, r.start.toNat ≤ i → i < r.end_.toNat → na.index_s i = ss.index_s (i - r.start.toNat)) ∧ - (∀ i, r.end_.toNat ≤ i → i < a.length → na.index_s i = a.index_s i) := by + (∀ i, i < r.start.val → na[i]! = a[i]!) ∧ + (∀ i, r.start.val ≤ i → i < r.end_.val → na[i]! = ss[i - r.start.val]!) ∧ + (∀ i, r.end_.val ≤ i → i < a.length → na[i]! = a[i]!) := by simp [update_subslice, *] - have h := List.replace_slice_index r.start.toNat r.end_.toNat a.val ss.val + have h := List.replace_slice_getElem! r.start.val r.end_.val a.val ss.val (by scalar_tac) (by scalar_tac) (by scalar_tac) simp [List.replace_slice, *] at h have ⟨ h0, h1, h2 ⟩ := h clear h split_conjs . intro i _ - have := h0 i (by int_tac) + have := h0 i (by scalar_tac) simp [*] . intro i _ _ - have := h1 i (by int_tac) (by int_tac) + have := h1 i (by scalar_tac) (by scalar_tac) simp [*] . intro i _ _ - have := h2 i (by int_tac) (by int_tac) + have := h2 i (by scalar_tac) (by scalar_tac) simp [*] @[simp] -theorem Array.update_index_eq α n [Inhabited α] (x : Array α n) (i : Usize) : - x.update i (x.val.index i.toNat) = x := by - simp [Array, Subtype.eq_iff] +theorem Array.set_getElem!_eq α n [Inhabited α] (x : Array α n) (i : Usize) : + x.set i (x.val[i.val]!) = x := by + have := @List.set_getElem_self _ x.val i.val + simp only [Array, Subtype.eq_iff, set_val_eq, List.set_getElem!] @[simp] theorem Slice.update_index_eq α [Inhabited α] (x : Slice α) (i : Usize) : - x.update i (x.val.index i.toNat) = x := by - simp [Slice, Subtype.eq_iff] + x.set i (x.val[i.val]!) = x := by + simp only [Slice, Subtype.eq_iff, set_val_eq, List.set_getElem!] /- Trait declaration: [core::slice::index::private_slice_index::Sealed] -/ structure core.slice.index.private_slice_index.Sealed (Self : Type) where @@ -441,47 +507,60 @@ structure core.slice.index.SliceIndex (Self T : Type) where /- [core::slice::index::[T]::index]: forward function -/ def core.slice.index.Slice.index {T I : Type} (inst : core.slice.index.SliceIndex I (Slice T)) - (slice : Slice T) (i : I) : Result inst.Output := do - let x ← inst.get i slice - match x with - | none => fail panic - | some x => ok x + (slice : Slice T) (i : I) : Result inst.Output := + inst.index i slice /- [core::slice::index::Range:::get]: forward function -/ -def core.slice.index.RangeUsize.get {T : Type} (i : Range Usize) (slice : Slice T) : +def core.slice.index.RangeUsize.get {T : Type} (r : Range Usize) (s : Slice T) : Result (Option (Slice T)) := - sorry -- TODO + if r.start ≤ r.end_ ∧ r.end_ ≤ s.length then + ok (some ⟨ s.val.slice r.start r.end_, by scalar_tac⟩) + else ok none /- [core::slice::index::Range::get_mut]: forward function -/ def core.slice.index.RangeUsize.get_mut - {T : Type} : Range Usize → Slice T → Result (Option (Slice T) × (Option (Slice T) → Slice T)) := - sorry -- TODO + {T : Type} (r : Range Usize) (s : Slice T) : Result (Option (Slice T) × (Option (Slice T) → Slice T)) := + if r.start ≤ r.end_ ∧ r.end_ ≤ s.length then + ok (some ⟨ s.val.slice r.start r.end_, by scalar_tac⟩, + fun s' => + match s' with + | none => s + | some s' => + if h: (List.replace_slice r.start r.end_ s.val s'.val).length ≤ Usize.max then + ⟨ List.replace_slice r.start r.end_ s.val s'.val, by scalar_tac ⟩ + else s ) + else ok (none, fun _ => s) /- [core::slice::index::Range::get_unchecked]: forward function -/ def core.slice.index.RangeUsize.get_unchecked {T : Type} : Range Usize → ConstRawPtr (Slice T) → Result (ConstRawPtr (Slice T)) := - -- Don't know what the model should be - for now we always fail to make - -- sure code which uses it fails - fun _ _ => fail panic + -- Don't know what the model should be - for now we always fail + fun _ _ => fail .undef /- [core::slice::index::Range::get_unchecked_mut]: forward function -/ def core.slice.index.RangeUsize.get_unchecked_mut {T : Type} : Range Usize → MutRawPtr (Slice T) → Result (MutRawPtr (Slice T)) := - -- Don't know what the model should be - for now we always fail to make - -- sure code which uses it fails - fun _ _ => fail panic + -- Don't know what the model should be - for now we always fail + fun _ _ => fail .undef /- [core::slice::index::Range::index]: forward function -/ -def core.slice.index.RangeUsize.index - {T : Type} : Range Usize → Slice T → Result (Slice T) := - sorry -- TODO +def core.slice.index.RangeUsize.index {T : Type} (r : Range Usize) (s : Slice T) : Result (Slice T) := + if r.start ≤ r.end_ ∧ r.end_ ≤ s.length then + ok (⟨ s.val.slice r.start r.end_, by scalar_tac⟩) + else fail .panic /- [core::slice::index::Range::index_mut]: forward function -/ -def core.slice.index.RangeUsize.index_mut - {T : Type} : Range Usize → Slice T → Result (Slice T × (Slice T → Slice T)) := - sorry -- TODO +def core.slice.index.RangeUsize.index_mut {T : Type} (r : Range Usize) (s : Slice T) : + Result (Slice T × (Slice T → Slice T)) := + if r.start ≤ r.end_ ∧ r.end_ ≤ s.length then + ok (⟨ s.val.slice r.start r.end_, by scalar_tac⟩, + fun s' => + if h: (List.replace_slice r.start r.end_ s.val s'.val).length ≤ Usize.max then + ⟨ List.replace_slice r.start r.end_ s.val s'.val, by scalar_tac ⟩ + else s ) + else fail .panic /- [core::slice::index::[T]::index_mut]: forward function -/ def core.slice.index.Slice.index_mut @@ -553,33 +632,35 @@ def core.ops.index.IndexMutArrayIInst {T I : Type} {N : Usize} } /- [core::slice::index::usize::get]: forward function -/ -def core.slice.index.Usize.get - {T : Type} : Usize → Slice T → Result (Option T) := - sorry -- TODO +@[simp] abbrev core.slice.index.Usize.get + {T : Type} (i : Usize) (s : Slice T) : Result (Option T) := + ok s[i]? /- [core::slice::index::usize::get_mut]: forward function -/ -def core.slice.index.Usize.get_mut - {T : Type} : Usize → Slice T → Result (Option T × (Option T → Slice T)) := - sorry -- TODO +@[simp] abbrev core.slice.index.Usize.get_mut + {T : Type} (i : Usize) (s : Slice T) : Result (Option T × (Option T → Slice T)) := + ok (s[i]?, s.set_opt i) /- [core::slice::index::usize::get_unchecked]: forward function -/ def core.slice.index.Usize.get_unchecked {T : Type} : Usize → ConstRawPtr (Slice T) → Result (ConstRawPtr T) := - sorry -- TODO + -- We don't have a model for now + fun _ _ => fail .undef /- [core::slice::index::usize::get_unchecked_mut]: forward function -/ def core.slice.index.Usize.get_unchecked_mut {T : Type} : Usize → MutRawPtr (Slice T) → Result (MutRawPtr T) := - sorry -- TODO + -- We don't have a model for now + fun _ _ => fail .undef /- [core::slice::index::usize::index]: forward function -/ -def core.slice.index.Usize.index {T : Type} : Usize → Slice T → Result T := - sorry -- TODO +@[simp] abbrev core.slice.index.Usize.index {T : Type} (i : Usize) (s : Slice T) : Result T := + Slice.index_usize s i /- [core::slice::index::usize::index_mut]: forward function -/ -def core.slice.index.Usize.index_mut {T : Type} : - Usize → Slice T → Result (T × (T → (Slice T))) := - sorry -- TODO +@[simp] abbrev core.slice.index.Usize.index_mut {T : Type} + (i : Usize) (s : Slice T) : Result (T × (T → (Slice T))) := + Slice.index_mut_usize s i /- Trait implementation: [core::slice::index::private_slice_index::usize] -/ def core.slice.index.private_slice_index.SealedUsizeInst @@ -608,31 +689,53 @@ def core.slice.Slice.copy_from_slice {T : Type} (_ : core.marker.Copy T) def core.array.TryFromSliceError := () /- [core::slice::index::{core::slice::index::SliceIndex<@Slice> for core::ops::range::RangeFrom}::get] -/ -def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.get {T : Type} : - core.ops.range.RangeFrom Usize → Slice T → Result (Option (Slice T)) := sorry +def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.get {T : Type} (r : core.ops.range.RangeFrom Usize) (s : Slice T) : Result (Option (Slice T)) := + if r.start ≤ s.length then + ok (some (s.drop r.start)) + else ok none /- [core::slice::index::{core::slice::index::SliceIndex<@Slice> for core::ops::range::RangeFrom}::get_mut] -/ def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.get_mut - {T : Type} : - core.ops.range.RangeFrom Usize → Slice T → Result ((Option (Slice T)) × - (Option (Slice T) → Slice T)) := sorry + {T : Type} (r : core.ops.range.RangeFrom Usize) (s : Slice T) : + Result ((Option (Slice T)) × (Option (Slice T) → Slice T)) := + if r.start ≤ s.length then + ok (some (s.drop r.start), + fun s' => match s' with + | none => s + | some s' => + if h: s'.length + s.length - r.start.val ≤ Usize.max then + ⟨ s'.val ++ s.val.drop r.start.val, by scalar_tac ⟩ + else s) + else ok (none, fun _ => s) /- [core::slice::index::{core::slice::index::SliceIndex<@Slice> for core::ops::range::RangeFrom}::get_unchecked] -/ def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.get_unchecked {T : Type} : - core.ops.range.RangeFrom Usize → ConstRawPtr (Slice T) → Result - (ConstRawPtr (Slice T)) := sorry + core.ops.range.RangeFrom Usize → ConstRawPtr (Slice T) → Result (ConstRawPtr (Slice T)) := + -- We don't have a model for now + fun _ _ => fail .undef /- [core::slice::index::{core::slice::index::SliceIndex<@Slice> for core::ops::range::RangeFrom}::get_unchecked_mut] -/ def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.get_unchecked_mut {T : Type} : - core.ops.range.RangeFrom Usize → MutRawPtr (Slice T) → Result (MutRawPtr (Slice T)) := sorry + core.ops.range.RangeFrom Usize → MutRawPtr (Slice T) → Result (MutRawPtr (Slice T)) := + -- We don't have a model for now + fun _ _ => fail .undef /- [core::slice::index::{core::slice::index::SliceIndex<@Slice> for core::ops::range::RangeFrom}::index] -/ -def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.index {T : Type} : - core.ops.range.RangeFrom Usize → Slice T → Result (Slice T) := sorry +def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.index {T : Type} + (r : core.ops.range.RangeFrom Usize) (s : Slice T) : Result (Slice T) := + if r.start.val ≤ s.length then + ok (s.drop r.start) + else fail .undef /- [core::slice::index::{core::slice::index::SliceIndex<@Slice> for core::ops::range::RangeFrom}::index_mut] -/ -def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.index_mut {T : Type} : - core.ops.range.RangeFrom Usize → Slice T → Result ((Slice T) × (Slice T → Slice T)) := sorry +def core.slice.index.SliceIndexcoreopsrangeRangeFromUsizeSlice.index_mut {T : Type} + (r : core.ops.range.RangeFrom Usize) (s : Slice T) : Result ((Slice T) × (Slice T → Slice T)) := + if r.start ≤ s.length then + ok ( s.drop r.start, fun s' => + if h: s'.length + s.length - r.start.val ≤ Usize.max then + ⟨ s'.val ++ s.val.drop r.start.val, by scalar_tac ⟩ + else s ) + else fail .panic /- Trait implementation: [core::slice::index::private_slice_index::{core::slice::index::private_slice_index::Sealed for core::ops::range::RangeFrom}] -/ @[reducible] diff --git a/backends/lean/Aeneas/Std/Core.lean b/backends/lean/Aeneas/Std/Core.lean index ff766dbe..f3d35836 100644 --- a/backends/lean/Aeneas/Std/Core.lean +++ b/backends/lean/Aeneas/Std/Core.lean @@ -7,6 +7,10 @@ namespace Std open Result +/- [alloc::boxed::{core::convert::AsMut for alloc::boxed::Box}::as_mut] -/ +def alloc.boxed.AsMutBoxT.as_mut {T : Type} (x : T) : T × (T → T) := + (x, fun x => x) + namespace core /- Trait declaration: [core::convert::From] -/ @@ -143,6 +147,15 @@ def core.convert.TryIntoFrom {T U : Type} (fromInst : core.convert.TryFrom U T) try_into := core.convert.TryIntoFrom.try_into fromInst } +structure core.convert.AsMut (Self : Type) (T : Type) where + as_mut : Self → Result (T × (T → Self)) + +/- [alloc::boxed::{core::convert::AsMut for alloc::boxed::Box}] -/ +@[reducible] +def core.convert.AsMutBoxT (T : Type) : core.convert.AsMut T T := { + as_mut := fun x => ok (alloc.boxed.AsMutBoxT.as_mut x) +} + /- TODO: -/ axiom Formatter : Type @@ -159,7 +172,6 @@ def core.result.Result.unwrap {T E : Type} structure core.ops.range.RangeFrom (Idx : Type) where start : Idx - end Std end Aeneas diff --git a/backends/lean/Aeneas/Std/CoreConvertNum.lean b/backends/lean/Aeneas/Std/CoreConvertNum.lean index ad1dcbc1..882c7c82 100644 --- a/backends/lean/Aeneas/Std/CoreConvertNum.lean +++ b/backends/lean/Aeneas/Std/CoreConvertNum.lean @@ -7,6 +7,7 @@ import Aeneas.Std.ScalarNotations import Aeneas.Std.ArraySlice import Aeneas.ScalarTac import Aeneas.Progress.Core +import Aeneas.Arith.Lemmas namespace Aeneas @@ -55,56 +56,213 @@ def FromI64Bool.from (b : Bool) : I64 := def FromI128Bool.from (b : Bool) : I128 := if b then 1#i128 else 0#i128 -def FromUsizeU8.from (x : U8) : Usize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromUsizeU8.from (x : U8) : Usize := ⟨ x.val ⟩ +def FromUsizeU16.from (x : U16) : Usize := ⟨ x.val ⟩ +def FromUsizeU32.from (x : U32) : Usize := ⟨ x.val ⟩ +def FromUsizeUsize.from (x : Usize) : Usize := ⟨ x.val ⟩ -def FromUsizeU16.from (x : U16) : Usize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromUsizeU32.from (x : U32) : Usize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromUsizeUsize.from (x : Usize) : Usize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromU8U8.from (x : U8) : U8 := ⟨ x.val ⟩ -def FromU8U8.from (x : U8) : U8 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromU16U8.from (x : U8) : U16 := ⟨ x.val ⟩ +def FromU16U16.from (x : U16) : U16 := ⟨ x.val ⟩ -def FromU16U8.from (x : U8) : U16 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU16U16.from (x : U16) : U16 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromU32U8.from (x : U8) : U32 := ⟨ x.val ⟩ +def FromU32U16.from (x : U16) : U32 := ⟨ x.val ⟩ +def FromU32U32.from (x : U32) : U32 := ⟨ x.val ⟩ -def FromU32U8.from (x : U8) : U32 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU32U16.from (x : U16) : U32 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU32U32.from (x : U32) : U32 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromU64U8.from (x : U8) : U64 := ⟨ x.val ⟩ +def FromU64U16.from (x : U16) : U64 := ⟨ x.val ⟩ +def FromU64U32.from (x : U32) : U64 := ⟨ x.val ⟩ +def FromU64U64.from (x : U64) : U64 := ⟨ x.val ⟩ -def FromU64U8.from (x : U8) : U64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU64U16.from (x : U16) : U64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU64U32.from (x : U32) : U64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU64U64.from (x : U64) : U64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromU128U8.from (x : U8) : U128 := ⟨ x.val ⟩ +def FromU128U16.from (x : U16) : U128 := ⟨ x.val ⟩ +def FromU128U32.from (x : U32) : U128 := ⟨ x.val ⟩ +def FromU128U64.from (x : U64) : U128 := ⟨ x.val ⟩ +def FromU128U128.from (x : U128) : U128 := ⟨ x.val ⟩ -def FromU128U8.from (x : U8) : U128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU128U16.from (x : U16) : U128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU128U32.from (x : U32) : U128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU128U64.from (x : U64) : U128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromU128U128.from (x : U128) : U128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromIsizeI8.from (x : I8) : Isize := ⟨ x.val ⟩ +def FromIsizeI16.from (x : I16) : Isize := ⟨ x.val ⟩ +def FromIsizeI32.from (x : I32) : Isize := ⟨ x.val ⟩ +def FromIsizeIsize.from (x : Isize) : Isize := ⟨ x.val ⟩ -def FromIsizeI8.from (x : I8) : Isize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromIsizeI16.from (x : I16) : Isize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromIsizeI32.from (x : I32) : Isize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromIsizeIsize.from (x : Isize) : Isize := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromI8I8.from (x : I8) : I8 := ⟨ x.val ⟩ -def FromI8I8.from (x : I8) : I8 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromI16I8.from (x : I8) : I16 := ⟨ x.val ⟩ +def FromI16I16.from (x : I16) : I16 := ⟨ x.val ⟩ -def FromI16I8.from (x : I8) : I16 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI16I16.from (x : I16) : I16 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromI32I8.from (x : I8) : I32 := ⟨ x.val ⟩ +def FromI32I16.from (x : I16) : I32 := ⟨ x.val ⟩ +def FromI32I32.from (x : I32) : I32 := ⟨ x.val ⟩ -def FromI32I8.from (x : I8) : I32 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI32I16.from (x : I16) : I32 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI32I32.from (x : I32) : I32 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromI64I8.from (x : I8) : I64 := ⟨ x.val ⟩ +def FromI64I16.from (x : I16) : I64 := ⟨ x.val ⟩ +def FromI64I32.from (x : I32) : I64 := ⟨ x.val ⟩ +def FromI64I64.from (x : I64) : I64 := ⟨ x.val ⟩ -def FromI64I8.from (x : I8) : I64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI64I16.from (x : I16) : I64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI64I32.from (x : I32) : I64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI64I64.from (x : I64) : I64 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +def FromI128I8.from (x : I8) : I128 := ⟨ x.val ⟩ +def FromI128I16.from (x : I16) : I128 := ⟨ x.val ⟩ +def FromI128I32.from (x : I32) : I128 := ⟨ x.val ⟩ +def FromI128I64.from (x : I64) : I128 := ⟨ x.val ⟩ +def FromI128I128.from (x : I128) : I128 := ⟨ x.val ⟩ -def FromI128I8.from (x : I8) : I128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI128I16.from (x : I16) : I128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI128I32.from (x : I32) : I128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI128I64.from (x : I64) : I128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ -def FromI128I128.from (x : I128) : I128 := ⟨ x.val, by scalar_tac, by scalar_tac ⟩ +@[simp] def FromUsizeU8.from_val_eq (x : U8) : (FromUsizeU8.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromUsizeU16.from_val_eq (x : U16) : (FromUsizeU16.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromUsizeU32.from_val_eq (x : U32) : (FromUsizeU32.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromUsizeUsize.from_val_eq (x : Usize) : (FromUsizeUsize.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU8U8.from_val_eq (x : U8) : (FromU8U8.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU16U8.from_val_eq (x : U8) : (FromU16U8.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU16U16.from_val_eq (x : U16) : (FromU16U16.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU32U8.from_val_eq (x : U8) : (FromU32U8.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU32U16.from_val_eq (x : U16) : (FromU32U16.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU32U32.from_val_eq (x : U32) : (FromU32U32.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU64U8.from_val_eq (x : U8) : (FromU64U8.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU64U16.from_val_eq (x : U16) : (FromU64U16.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU64U32.from_val_eq (x : U32) : (FromU64U32.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU64U64.from_val_eq (x : U64) : (FromU64U64.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU128U8.from_val_eq (x : U8) : (FromU128U8.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU128U16.from_val_eq (x : U16) : (FromU128U16.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU128U32.from_val_eq (x : U32) : (FromU128U32.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU128U64.from_val_eq (x : U64) : (FromU128U64.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +@[simp] def FromU128U128.from_val_eq (x : U128) : (FromU128U128.from x).val = x.val := by + simp only [UScalar.val]; simp; apply Nat.mod_eq_of_lt; scalar_tac + +#check Arith.Int.bmod_pow2_eq_of_inBounds + +@[local simp] private theorem zero_lt_size_num_bits : 0 < System.Platform.numBits := by + cases System.Platform.numBits_eq <;> simp [*] + +private theorem bmod_pow2_eq_of_inBounds' (n : ℕ) (x : ℤ) (h : 0 < n) (h0 : -2 ^ (n - 1) ≤ x) (h1 : x < 2 ^ (n - 1)) : + x.bmod (2 ^ n) = x := by + have hn : n - 1 + 1 = n := by omega + have := Arith.Int.bmod_pow2_eq_of_inBounds (n - 1) x (by omega) (by assumption) + simp [hn] at this + apply this + +@[simp] def FromIsizeI8.from_val_eq (x : I8) : (FromIsizeI8.from x).val = x.val := by + simp only [FromIsizeI8.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromIsizeI16.from_val_eq (x : I16) : (FromIsizeI16.from x).val = x.val := by + simp only [FromIsizeI16.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromIsizeI32.from_val_eq (x : I32) : (FromIsizeI32.from x).val = x.val := by + simp only [FromIsizeI32.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromIsizeIsize.from_val_eq (x : Isize) : (FromIsizeIsize.from x).val = x.val := by + simp only [FromIsizeIsize.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + +@[simp] def FromI8I8.from_val_eq (x : I8) : (FromI8I8.from x).val = x.val := by + simp only [FromI8I8.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + +@[simp] def FromI16I8.from_val_eq (x : I8) : (FromI16I8.from x).val = x.val := by + simp only [FromI16I8.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI16I16.from_val_eq (x : I16) : (FromI16I16.from x).val = x.val := by + simp only [FromI16I16.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + +@[simp] def FromI32I8.from_val_eq (x : I8) : (FromI32I8.from x).val = x.val := by + simp only [FromI32I8.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI32I16.from_val_eq (x : I16) : (FromI32I16.from x).val = x.val := by + simp only [FromI32I16.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI32I32.from_val_eq (x : I32) : (FromI32I32.from x).val = x.val := by + simp only [FromI32I32.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + +@[simp] def FromI64I8.from_val_eq (x : I8) : (FromI64I8.from x).val = x.val := by + simp only [FromI64I8.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI64I16.from_val_eq (x : I16) : (FromI64I16.from x).val = x.val := by + simp only [FromI64I16.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI64I32.from_val_eq (x : I32) : (FromI64I32.from x).val = x.val := by + simp only [FromI64I32.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI64I64.from_val_eq (x : I64) : (FromI64I64.from x).val = x.val := by + simp only [FromI64I64.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + +@[simp] def FromI128I8.from_val_eq (x : I8) : (FromI128I8.from x).val = x.val := by + simp only [FromI128I8.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI128I16.from_val_eq (x : I16) : (FromI128I16.from x).val = x.val := by + simp only [FromI128I16.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI128I32.from_val_eq (x : I32) : (FromI128I32.from x).val = x.val := by + simp only [FromI128I32.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI128I64.from_val_eq (x : I64) : (FromI128I64.from x).val = x.val := by + simp only [FromI128I64.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] + apply bmod_pow2_eq_of_inBounds' <;> simp <;> scalar_tac + +@[simp] def FromI128I128.from_val_eq (x : I128) : (FromI128I128.from x).val = x.val := by + simp only [FromI128I128.from, IScalar.val] + simp [-IScalar.bv_toInt_eq, Int.cast, IntCast.intCast, -Nat.reducePow] end num -- core.convert.num @@ -301,32 +459,32 @@ def FromI128I128 : core.convert.From I128 I128 := { end core.convert -- to_le_bytes -def core.num.U8.to_le_bytes (x : U8) : Array U8 1#usize := sorry +def core.num.U8.to_le_bytes (x : U8) : Array U8 1#usize := ⟨ [x], by simp ⟩ def core.num.U16.to_le_bytes (x : U16) : Array U8 2#usize := sorry def core.num.U32.to_le_bytes (x : U32) : Array U8 4#usize := sorry def core.num.U64.to_le_bytes (x : U64) : Array U8 8#usize := sorry def core.num.U128.to_le_bytes (x : U128) : Array U8 128#usize := sorry -- to_be_bytes -def core.num.U8.to_be_bytes (x : U8) : Array U8 1#usize := sorry +def core.num.U8.to_be_bytes (x : U8) : Array U8 1#usize := ⟨ [x], by simp ⟩ def core.num.U16.to_be_bytes (x : U16) : Array U8 2#usize := sorry def core.num.U32.to_be_bytes (x : U32) : Array U8 4#usize := sorry def core.num.U64.to_be_bytes (x : U64) : Array U8 8#usize := sorry def core.num.U128.to_be_bytes (x : U128) : Array U8 128#usize := sorry -- from_le_bytes -def core.num.U8.from_le_bytes (a : Array U8 1#usize) : U8 := sorry +def core.num.U8.from_le_bytes (a : Array U8 1#usize) : U8 := a.val[0] def core.num.U16.from_le_bytes (a : Array U8 2#usize) : U16 := sorry def core.num.U32.from_le_bytes (a : Array U8 4#usize) : U32 := sorry def core.num.U64.from_le_bytes (a : Array U8 8#usize) : U64 := sorry def core.num.U128.from_le_bytes (a : Array U8 128#usize) : U128 := sorry -- from_be_bytes -def core.num.U8.from_be_bytes (a : Array U8 1#usize) : I8 := sorry -def core.num.U16.from_be_bytes (a : Array U8 2#usize) : I16 := sorry -def core.num.U32.from_be_bytes (a : Array U8 4#usize) : I32 := sorry -def core.num.U64.from_be_bytes (a : Array U8 8#usize) : I64 := sorry -def core.num.U128.from_be_bytes (a : Array U8 128#usize) : I128 := sorry +def core.num.U8.from_be_bytes (a : Array U8 1#usize) : U8 := a.val[0] +def core.num.U16.from_be_bytes (a : Array U8 2#usize) : U16 := sorry +def core.num.U32.from_be_bytes (a : Array U8 4#usize) : U32 := sorry +def core.num.U64.from_be_bytes (a : Array U8 8#usize) : U64 := sorry +def core.num.U128.from_be_bytes (a : Array U8 128#usize) : U128 := sorry end Std diff --git a/backends/lean/Aeneas/Std/Primitives.lean b/backends/lean/Aeneas/Std/Primitives.lean index 1cdd161f..e11b3a20 100644 --- a/backends/lean/Aeneas/Std/Primitives.lean +++ b/backends/lean/Aeneas/Std/Primitives.lean @@ -4,9 +4,9 @@ namespace Aeneas namespace Std --------------------- --- ASSERT COMMAND --Std. --------------------- +/-! +# Assert Command +-/ open Lean Elab Command Term Meta @@ -14,22 +14,32 @@ syntax (name := assert) "#assert" term: command @[command_elab assert] unsafe -def assertImpl : CommandElab := fun (_stx: Syntax) => do +def assertImpl : CommandElab := fun (stx: Syntax) => do runTermElabM (fun _ => do - let r ← evalTerm Bool (mkConst ``Bool) _stx[1] + let r ← evalTerm Bool (mkConst ``Bool) stx[1] if not r then - logInfo ("Assertion failed for:\n" ++ _stx[1]) - throwError ("Expression reduced to false:\n" ++ _stx[1]) + logInfo ("Assertion failed for:\n" ++ stx[1]) + throwError ("Expression reduced to false:\n" ++ stx[1]) pure ()) #eval 2 == 2 #assert (2 == 2) -------------- --- PRELUDE -- -------------- +syntax (name := elabSyntax) "#elab" term: command --- Results & monadic combinators +@[command_elab elabSyntax] +unsafe +def elabImpl : CommandElab := fun (stx: Syntax) => do + runTermElabM (fun _ => do + /- Simply elaborate the syntax to check that it is correct -/ + let (_, _) ← Elab.Term.elabTerm stx[1] none |>.run + pure ()) + +#elab 3 + +/-! +# Results and Monadic Combinators +-/ inductive Error where | assertionFailure: Error @@ -38,6 +48,7 @@ inductive Error where | arrayOutOfBounds: Error | maximumSizeExceeded: Error | panic: Error + | undef: Error deriving Repr, BEq open Error @@ -56,7 +67,9 @@ instance Result_Inhabited (α : Type u) : Inhabited (Result α) := instance Result_Nonempty (α : Type u) : Nonempty (Result α) := Nonempty.intro div -/- HELPERS -/ +/-! +# Helpers +-/ def ok? {α: Type u} (r: Result α): Bool := match r with @@ -84,7 +97,9 @@ def Result.ofOption {a : Type u} (x : Option a) (e : Error) : Result a := | some x => ok x | none => fail e -/- DO-DSL SUPPORT -/ +/-! +# Do-DSL Support +-/ def bind {α : Type u} {β : Type v} (x: Result α) (f: α → Result β) : Result β := match x with @@ -104,19 +119,6 @@ instance : Pure Result where @[simp] theorem bind_fail (x : Error) (f : α → Result β) : bind (.fail x) f = .fail x := by simp [bind] @[simp] theorem bind_div (f : α → Result β) : bind .div f = .div := by simp [bind] -/- CUSTOM-DSL SUPPORT -/ - --- Let-binding the Result of a monadic operation is oftentimes not sufficient, --- because we may need a hypothesis for equational reasoning in the scope. We --- rely on subtype, and a custom let-binding operator, in effect recreating our --- own variant of the do-dsl - -def Result.attach {α: Type} (o : Result α): Result { x : α // o = ok x } := - match o with - | ok x => ok ⟨x, rfl⟩ - | fail e => fail e - | div => div - @[simp] theorem bind_tc_ok (x : α) (f : α → Result β) : (do let y ← .ok x; f y) = f x := by simp [Bind.bind, bind] @@ -133,9 +135,50 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ok x } := simp [Bind.bind] cases e <;> simp ----------- --- MISC -- ----------- +/-! +# Lift +-/ + +/-- We use this to lift pure function calls to monadic calls. + We don't mark this as reducible so that let-bindings don't get simplified away. + + In the generated code if regularly happens that we want to lift pure function calls so + that `progress` can reason about them. For instance, `U32.wrapping_add` has type `U32 → U32 → U32`, + but we provide a `progress` theorem with an informative post-condition, and which matches the pattern + `toResult (wrapping_add x y)`. This theorem can only be looked up and appliced if the code is of the + following shape: + ``` + let z ← U32.wrapping_add x y + ... + ``` + -/ +def toResult {α : Type u} (x : α) : Result α := Result.ok x + +instance {α : Type u} : Coe α (Result α) where + coe := toResult + +attribute [coe] toResult + +namespace Test + /- Testing that our coercion from `α` to `Result α` works. -/ + example : Result Int := do + let x0 ← ↑(0 : Int) + let x1 ← ↑(x0 + 1 : Int) + x1 + + /- Testing that our coercion from `α` to `Result α` doesn't break other coercions. -/ + example (n : Nat) (i : Int) (_ : n < i) : True := by simp + + example : Result (BitVec 32) := do + let x : BitVec 32 ← ↑(0#32) + let y ← ↑(1#32) + let z ← ↑(x + y) + ok z +end Test + +/-! +# Misc +-/ instance SubtypeBEq [BEq α] (p : α → Prop) : BEq (Subtype p) where beq v0 v1 := v0.val == v1.val @@ -144,9 +187,17 @@ instance SubtypeLawfulBEq [BEq α] (p : α → Prop) [LawfulBEq α] : LawfulBEq eq_of_beq {a b} h := by cases a; cases b; simp_all [BEq.beq] rfl := by intro a; cases a; simp [BEq.beq] ------------------------------- ----- Misc Primitives Types --- ------------------------------- +/- A helper function that converts failure to none and success to some + TODO: move up to Core module? -/ +def Option.ofResult {a : Type u} (x : Result a) : + Option a := + match x with + | ok x => some x + | _ => none + +/-! +# Misc Primitive Types +-/ -- We don't really use raw pointers for now structure MutRawPtr (T : Type) where diff --git a/backends/lean/Aeneas/Std/Scalar.lean b/backends/lean/Aeneas/Std/Scalar.lean index 6031c91c..cf06ea39 100644 --- a/backends/lean/Aeneas/Std/Scalar.lean +++ b/backends/lean/Aeneas/Std/Scalar.lean @@ -1,190 +1,633 @@ import Aeneas.Std.ScalarCore import Aeneas.ScalarTac +import Aeneas.Arith.Lemmas +import MathLib.Data.BitVec namespace Aeneas namespace Std open Result Error +open Arith -@[simp] theorem Scalar.unsigned_neq_zero_equiv (x : Scalar ty) (h : ¬ ty.isSigned := by decide): x.val ≠ 0 ↔ 0 < x.val := by - cases ty <;> simp_all <;> scalar_tac +/-! +# Misc Theorems +-/ -def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) +@[simp] theorem UScalar.exists_eq_left {p : UScalar ty → Prop} {a' : UScalar ty} : + (∃ (a : UScalar ty), a.val = a'.val ∧ p a) ↔ p a' := by + constructor <;> intro h + . replace ⟨ a, h, hp ⟩ := h + cases a' + simp_all only [val] + have := @BitVec.toNat_injective ty.numBits + have := this h + simp [← this] + apply hp + . exists a' --- Our custom remainder operation, which satisfies the semantics of Rust --- TODO: is there a better way? -def scalar_rem (x y : Int) : Int := - if 0 ≤ x then x % y - else - (|x| % |y|) +@[simp] theorem IScalar.exists_eq_left {p : IScalar ty → Prop} {a' : IScalar ty} : + (∃ (a : IScalar ty), a.val = a'.val ∧ p a) ↔ p a' := by + constructor <;> intro h + . replace ⟨ a, h, hp ⟩ := h + cases a' + simp_all only [val, eq_comm] + rw [BitVec.toInt_inj] at h + simp [h] + apply hp + . exists a' -@[simp] -def scalar_rem_nonneg {x y : Int} (hx : 0 ≤ x) : scalar_rem x y = x % y := by - simp [*, scalar_rem] +@[simp] theorem UScalar.exists_eq_left' {p : UScalar ty → Prop} {a' : UScalar ty} : + (∃ (a : UScalar ty), a'.val = a.val ∧ p a) ↔ p a' := by + constructor <;> intro h + . replace ⟨ a, h, hp ⟩ := h + cases a' + simp_all only [val] + have := @BitVec.toNat_injective ty.numBits + have := this h + simp [this] + apply hp + . exists a' --- Our custom division operation, which satisfies the semantics of Rust --- TODO: is there a better way? -def scalar_div (x y : Int) : Int := - if 0 ≤ x && 0 ≤ y then x / y - else if 0 ≤ x && y < 0 then - (|x| / |y|) - else if x < 0 && 0 ≤ y then - (|x| / |y|) - else |x| / |y| +@[simp] theorem IScalar.exists_eq_left' {p : IScalar ty → Prop} {a' : IScalar ty} : + (∃ (a : IScalar ty), a'.val = a.val ∧ p a) ↔ p a' := by + constructor <;> intro h + . replace ⟨ a, h, hp ⟩ := h + cases a' + simp_all only [val] + rw [BitVec.toInt_inj] at h + simp [h] + apply hp + . exists a' -@[simp] -def scalar_div_nonneg {x y : Int} (hx : 0 ≤ x) (hy : 0 ≤ y) : scalar_div x y = x / y := by - simp [*, scalar_div] - --- Checking that the remainder operation is correct -#assert scalar_rem 1 2 = 1 -#assert scalar_rem (-1) 2 = -1 -#assert scalar_rem 1 (-2) = 1 -#assert scalar_rem (-1) (-2) = -1 -#assert scalar_rem 7 3 = (1:Int) -#assert scalar_rem (-7) 3 = -1 -#assert scalar_rem 7 (-3) = 1 -#assert scalar_rem (-7) (-3) = -1 - --- Checking that the division operation is correct -#assert scalar_div 3 2 = 1 -#assert scalar_div (-3) 2 = -1 -#assert scalar_div 3 (-2) = -1 -#assert scalar_div (-3) (-2) = 1 -#assert scalar_div 7 3 = 2 -#assert scalar_div (-7) 3 = -2 -#assert scalar_div 7 (-3) = -2 -#assert scalar_div (-7) (-3) = 2 - -def Scalar.div {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - if y.val != 0 then Scalar.tryMk ty (scalar_div x.val y.val) else fail divisionByZero - -def Scalar.rem {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - if y.val != 0 then Scalar.tryMk ty (scalar_rem x.val y.val) else fail divisionByZero - -def Scalar.add {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - Scalar.tryMk ty (x.val + y.val) - -def Scalar.sub {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - Scalar.tryMk ty (x.val - y.val) - -def Scalar.mul {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := - Scalar.tryMk ty (x.val * y.val) - --- TODO: shift left -def Scalar.shiftl {ty0 ty1 : ScalarTy} (x : Scalar ty0) (y : Scalar ty1) : Result (Scalar ty0) := - sorry - --- TODO: shift right -def Scalar.shiftr {ty0 ty1 : ScalarTy} (x : Scalar ty0) (y : Scalar ty1) : Result (Scalar ty0) := - sorry - --- TODO: xor -def Scalar.xor {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Scalar ty := - sorry - --- TODO: and -def Scalar.and {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Scalar ty := - sorry - --- TODO: or -def Scalar.or {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Scalar ty := - sorry - -/- ¬ x reverses the bits of x - - It has the following effect: - - if x is unsigned, then it evaluates to Scalar.max - x - - otherwise, it evalutes to -1 - x --/ -def Scalar.not {ty : ScalarTy} (x : Scalar ty) : Scalar ty := - match ty with - -- Unsigned cases - | .U8 => @Scalar.mk ScalarTy.U8 (U8.max - x.val) (by scalar_tac) (by scalar_tac) - | .U16 => @Scalar.mk ScalarTy.U16 (U16.max - x.val) (by scalar_tac) (by scalar_tac) - | .U32 => @Scalar.mk ScalarTy.U32 (U32.max - x.val) (by scalar_tac) (by scalar_tac) - | .U64 => @Scalar.mk ScalarTy.U64 (U64.max - x.val) (by scalar_tac) (by scalar_tac) - | .U128 => @Scalar.mk ScalarTy.U128 (U128.max - x.val) (by scalar_tac) (by scalar_tac) - | .Usize => @Scalar.mk ScalarTy.Usize (Usize.max - x.val) (by scalar_tac) (by scalar_tac) - -- Signed cases - | .I8 => @Scalar.mk ScalarTy.I8 (-1 - x.val) (by scalar_tac) (by scalar_tac) - | .I16 => @Scalar.mk ScalarTy.I16 (-1 - x.val) (by scalar_tac) (by scalar_tac) - | .I32 => @Scalar.mk ScalarTy.I32 (-1 - x.val) (by scalar_tac) (by scalar_tac) - | .I64 => @Scalar.mk ScalarTy.I64 (-1 - x.val) (by scalar_tac) (by scalar_tac) - | .I128 => @Scalar.mk ScalarTy.I128 (-1 - x.val) (by scalar_tac) (by scalar_tac) - | .Isize => @Scalar.mk ScalarTy.Isize (-1 - x.val) - (by have := Isize.bounds_eq; scalar_tac) - (by have := Isize.bounds_eq; scalar_tac) - --- Cast an integer from a [src_ty] to a [tgt_ty] --- TODO: double-check the semantics of casts in Rust -def Scalar.cast {src_ty : ScalarTy} (tgt_ty : ScalarTy) (x : Scalar src_ty) : Result (Scalar tgt_ty) := - Scalar.tryMk tgt_ty x.val - --- This can't fail, but for now we make all casts faillible (easier for the translation) -def Scalar.cast_bool (tgt_ty : ScalarTy) (x : Bool) : Result (Scalar tgt_ty) := - Scalar.tryMk tgt_ty (if x then 1 else 0) - -@[pspec] -theorem Scalar.cast_in_bounds_eq {src_ty tgt_ty : ScalarTy} (x : Scalar src_ty) (h_bounds: Scalar.in_bounds tgt_ty x): ∃ x', Scalar.cast tgt_ty x = .ok x' ∧ x'.val = x.val := by - simp at h_bounds - simp [cast, tryMk, tryMkOpt] - split_ifs with h_nbounds - . use (Scalar.ofIntCore x h_bounds); simp [ofOption, ofIntCore] - . omega - -@[simp] theorem Scalar.exists_eq_left {p : Scalar ty → Prop} {a' : Scalar ty} : - (∃ (a : Scalar ty), a.val = a'.val ∧ p a) ↔ p a' := by +@[simp] theorem UScalar.exists_eq_right {p : UScalar ty → Prop} {a' : UScalar ty} : + (∃ (a : UScalar ty), p a ∧ a.val = a'.val) ↔ p a' := by constructor <;> intro h - . cases h + . replace ⟨ a, hp, h ⟩ := h cases a' - simp_all [eq_comm] + simp_all only [val] + have := @BitVec.toNat_injective ty.numBits + have := this h + simp [← this] + apply hp . exists a' -@[simp] theorem Scalar.exists_eq_left' {p : Scalar ty → Prop} {a' : Scalar ty} : - (∃ (a : Scalar ty), a'.val = a.val ∧ p a) ↔ p a' := by +@[simp] theorem IScalar.exists_eq_right {p : IScalar ty → Prop} {a' : IScalar ty} : + (∃ (a : IScalar ty), p a ∧ a.val = a'.val) ↔ p a' := by constructor <;> intro h - . cases h + . replace ⟨ a, hp, h ⟩ := h cases a' - simp_all [eq_comm] + simp_all only [val, eq_comm] + rw [BitVec.toInt_inj] at h + simp [h] + apply hp . exists a' -@[simp] theorem Scalar.exists_eq_right {p : Scalar ty → Prop} {a' : Scalar ty} : - (∃ (a : Scalar ty), p a ∧ a.val = a'.val) ↔ p a' := by +@[simp] theorem UScalar.exists_eq_right' {p : UScalar ty → Prop} {a' : UScalar ty} : + (∃ (a : UScalar ty), p a ∧ a'.val = a.val) ↔ p a' := by constructor <;> intro h - . cases h + . replace ⟨ a, hp, h ⟩ := h cases a' - simp_all [eq_comm] + simp_all only [val] + have := @BitVec.toNat_injective ty.numBits + have := this h + simp [this] + apply hp . exists a' -@[simp] theorem Scalar.exists_eq_right' {p : Scalar ty → Prop} {a' : Scalar ty} : - (∃ (a : Scalar ty), p a ∧ a'.val = a.val) ↔ p a' := by +@[simp] theorem IScalar.exists_eq_right' {p : IScalar ty → Prop} {a' : IScalar ty} : + (∃ (a : IScalar ty), p a ∧ a'.val = a.val) ↔ p a' := by constructor <;> intro h - . cases h + . replace ⟨ a, hp, h ⟩ := h cases a' - simp_all [eq_comm] + simp_all only [val, eq_comm] + rw [BitVec.toInt_inj] at h + simp [h] + apply hp . exists a' -@[simp] theorem Scalar.exists_eq {a' : Scalar ty} : ∃ (a : Scalar ty), a.val = a'.val := by exists a' -@[simp] theorem Scalar.exists_eq' {a' : Scalar ty} : ∃ (a : Scalar ty), a'.val = a.val := by exists a' +@[simp] theorem UScalar.exists_eq {a' : UScalar ty} : ∃ (a : UScalar ty), a.val = a'.val := by exists a' +@[simp] theorem UScalar.exists_eq' {a' : UScalar ty} : ∃ (a : UScalar ty), a'.val = a.val := by exists a' +@[simp] theorem IScalar.exists_eq {a' : IScalar ty} : ∃ (a : IScalar ty), a.val = a'.val := by exists a' +@[simp] theorem IScalar.exists_eq' {a' : IScalar ty} : ∃ (a : IScalar ty), a'.val = a.val := by exists a' -@[pspec] -theorem Scalar.cast_bool_spec ty (b : Bool) : - ∃ s, Scalar.cast_bool ty b = ok s ∧ s.val = if b then 1 else 0 := by - simp [Scalar.cast_bool, tryMk, tryMkOpt] - split <;> split <;> simp_all <;> scalar_tac +/-! +# Equalities and simplification lemmas +-/ --- TODO: below: not sure this is the best way. --- Should we rather overload operations like +, -, etc.? --- Also, it is possible to automate the generation of those definitions --- with macros (but would it be a good idea? It would be less easy to --- read the file, which is not supposed to change a lot) +theorem UScalar.ofNatCore_bv_lt_equiv {ty} (x y : Nat) (hx) (hy) : + (@UScalar.ofNatCore ty x hx).bv < (@UScalar.ofNatCore ty y hy).bv ↔ x < y := by + simp only [ofNatCore, BitVec.ofNat_lt_ofNat] + have := Nat.mod_eq_of_lt hx + have := Nat.mod_eq_of_lt hy + simp only [*] --- Negation +@[simp, scalar_tac_simp] theorem U8.val_mod_size_eq (x : U8) : x.val % U8.size = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac -/-- -Remark: there is no heterogeneous negation in the Lean prelude: we thus introduce -one here. +@[simp, scalar_tac_simp] theorem U8.val_mod_size_eq' (x : U8) : x.val % 256 = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U16.val_mod_size_eq (x : U16) : x.val % U16.size = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U16.val_mod_size_eq' (x : U16) : x.val % 65536 = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U32.val_mod_size_eq (x : U32) : x.val % U32.size = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U32.val_mod_size_eq' (x : U32) : x.val % 4294967296 = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U64.val_mod_size_eq (x : U64) : x.val % U64.size = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U64.val_mod_size_eq' (x : U64) : x.val % 18446744073709551616 = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U128.val_mod_size_eq (x : U128) : x.val % U128.size = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U128.val_mod_size_eq' (x : U128) : x.val % 340282366920938463463374607431768211456 = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem Usize.val_mod_size_eq (x : Usize) : x.val % Usize.size = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U8.val_mod_max_eq (x : U8) : x.val % (U8.max + 1) = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U16.val_mod_max_eq (x : U16) : x.val % (U16.max + 1) = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U32.val_mod_max_eq (x : U32) : x.val % (U32.max + 1) = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U64.val_mod_max_eq (x : U64) : x.val % (U64.max + 1) = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem U128.val_mod_max_eq (x : U128) : x.val % (U128.max + 1) = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem Usize.val_mod_max_eq (x : Usize) : x.val % (Usize.max + 1) = x.val := by + apply Nat.mod_eq_of_lt; scalar_tac + +@[simp, scalar_tac_simp] theorem I8.val_mod_size_eq (x : I8) : Int.bmod x.val I8.size = x.val := by + simp [size]; apply Int.bmod_pow2_eq_of_inBounds' <;> scalar_tac + +@[simp, scalar_tac_simp] theorem I8.val_mod_size_eq' (x : I8) : Int.bmod x.val 256 = x.val := by + have := val_mod_size_eq x; simp [size, numBits] at this; assumption + +@[simp, scalar_tac_simp] theorem I16.val_mod_size_eq (x : I16) : Int.bmod x.val I16.size = x.val := by + simp [size]; apply Int.bmod_pow2_eq_of_inBounds' <;> scalar_tac + +@[simp, scalar_tac_simp] theorem I16.val_mod_size_eq' (x : I16) : Int.bmod x.val 65536 = x.val := by + have := val_mod_size_eq x; simp [size, numBits] at this; assumption + +@[simp, scalar_tac_simp] theorem I32.val_mod_size_eq (x : I32) : Int.bmod x.val I32.size = x.val := by + simp [size]; apply Int.bmod_pow2_eq_of_inBounds' <;> scalar_tac + +@[simp, scalar_tac_simp] theorem I32.val_mod_size_eq' (x : I32) : Int.bmod x.val 4294967296 = x.val := by + have := val_mod_size_eq x; simp [size, numBits] at this; assumption + +@[simp, scalar_tac_simp] theorem I64.val_mod_size_eq (x : I64) : Int.bmod x.val I64.size = x.val := by + simp [size]; apply Int.bmod_pow2_eq_of_inBounds' <;> scalar_tac + +@[simp, scalar_tac_simp] theorem I64.val_mod_size_eq' (x : I64) : Int.bmod x.val 18446744073709551616 = x.val := by + have := val_mod_size_eq x; simp [size, numBits] at this; assumption + +@[simp, scalar_tac_simp] theorem I128.val_mod_size_eq (x : I128) : Int.bmod x.val I128.size = x.val := by + simp [size]; apply Int.bmod_pow2_eq_of_inBounds' <;> scalar_tac + +@[simp, scalar_tac_simp] theorem I128.val_mod_size_eq' (x : I128) : Int.bmod x.val 340282366920938463463374607431768211456 = x.val := by + have := val_mod_size_eq x; simp [size, numBits] at this; assumption + +@[simp, scalar_tac_simp] theorem Isize.val_mod_size_eq (x : Isize) : Int.bmod x.val Isize.size = x.val := by + simp [size]; apply Int.bmod_pow2_eq_of_inBounds' <;> try scalar_tac + simp [numBits]; dcases System.Platform.numBits_eq <;> simp [*] + +@[simp] theorem U8.val_max_zero_eq (x : U8) : x.val ⊔ 0 = x.val := by scalar_tac +@[simp] theorem U16.val_max_zero_eq (x : U16) : x.val ⊔ 0 = x.val := by scalar_tac +@[simp] theorem U32.val_max_zero_eq (x : U32) : x.val ⊔ 0 = x.val := by scalar_tac +@[simp] theorem U64.val_max_zero_eq (x : U64) : x.val ⊔ 0 = x.val := by scalar_tac +@[simp] theorem U128.val_max_zero_eq (x : U128) : x.val ⊔ 0 = x.val := by scalar_tac +@[simp] theorem Usize.val_max_zero_eq (x : Usize) : x.val ⊔ 0 = x.val := by scalar_tac + +/-! + +# Primitive Operations +## Primitive Operations: Definitions + +-/ + +/-! +The scalar division/modulo on signed machine integers 't'runcates towards 0, meaning it is +implemented by the `Int.tdiv`, `Int.tmod`, etc. definitions. +-/ + +namespace Tests + -- Checking that the division over signed integers agrees with Rust + #assert Int.tdiv 3 2 = 1 + #assert Int.tdiv (-3) 2 = -1 + #assert Int.tdiv 3 (-2) = -1 + #assert Int.tdiv (-3) (-2) = 1 + #assert Int.tdiv 7 3 = 2 + #assert Int.tdiv (-7) 3 = -2 + #assert Int.tdiv 7 (-3) = -2 + #assert Int.tdiv (-7) (-3) = 2 + + -- Checking that the signed division over bit-vectors agrees with Rust + private def bv_sdiv (x y : Int) : Int := + (BitVec.sdiv (BitVec.ofInt 32 x) (BitVec.ofInt 32 y)).toInt + + #assert bv_sdiv 3 2 = 1 + #assert bv_sdiv (-3) 2 = -1 + #assert bv_sdiv 3 (-2) = -1 + #assert bv_sdiv (-3) (-2) = 1 + #assert bv_sdiv 7 3 = 2 + #assert bv_sdiv (-7) 3 = -2 + #assert bv_sdiv 7 (-3) = -2 + #assert bv_sdiv (-7) (-3) = 2 + + -- Checking that the remainder over signed integers agrees with Rust + #assert Int.tmod 1 2 = 1 + #assert Int.tmod (-1) 2 = -1 + #assert Int.tmod 1 (-2) = 1 + #assert Int.tmod (-1) (-2) = -1 + #assert Int.tmod 7 3 = (1:Int) + #assert Int.tmod (-7) 3 = -1 + #assert Int.tmod 7 (-3) = 1 + #assert Int.tmod (-7) (-3) = -1 + + -- Checking that the signed operation over bit-vectors agrees with Rust + private def bv_srem (x y : Int) : Int := + (BitVec.srem (BitVec.ofInt 32 x) (BitVec.ofInt 32 y)).toInt + + #assert bv_srem 1 2 = 1 + #assert bv_srem (-1) 2 = -1 + #assert bv_srem 1 (-2) = 1 + #assert bv_srem (-1) (-2) = -1 + #assert bv_srem 7 3 = (1:Int) + #assert bv_srem (-7) 3 = -1 + #assert bv_srem 7 (-3) = 1 + #assert bv_srem (-7) (-3) = -1 +end Tests + +/-! +Addition +-/ +def UScalar.add {ty : UScalarTy} (x y : UScalar ty) : Result (UScalar ty) := + UScalar.tryMk ty (x.val + y.val) + +def IScalar.add {ty : IScalarTy} (x y : IScalar ty) : Result (IScalar ty) := + IScalar.tryMk ty (x.val + y.val) + +def UScalar.try_add {ty : UScalarTy} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (add x y) + +def IScalar.try_add {ty : IScalarTy} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (add x y) + +instance {ty} : HAdd (UScalar ty) (UScalar ty) (Result (UScalar ty)) where + hAdd x y := UScalar.add x y + +instance {ty} : HAdd (IScalar ty) (IScalar ty) (Result (IScalar ty)) where + hAdd x y := IScalar.add x y + +/-! +Subtraction +-/ +def UScalar.sub {ty : UScalarTy} (x y : UScalar ty) : Result (UScalar ty) := + if x.val < y.val then fail .integerOverflow + else ok ⟨ BitVec.ofNat _ (x.val - y.val) ⟩ + +def IScalar.sub {ty : IScalarTy} (x y : IScalar ty) : Result (IScalar ty) := + IScalar.tryMk ty (x.val - y.val) + +def UScalar.try_sub {ty : UScalarTy} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (sub x y) + +def IScalar.try_sub {ty : IScalarTy} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (sub x y) + +instance {ty} : HSub (UScalar ty) (UScalar ty) (Result (UScalar ty)) where + hSub x y := UScalar.sub x y + +instance {ty} : HSub (IScalar ty) (IScalar ty) (Result (IScalar ty)) where + hSub x y := IScalar.sub x y + +/-! +Multiplication +-/ +def UScalar.mul {ty : UScalarTy} (x y : UScalar ty) : Result (UScalar ty) := + UScalar.tryMk ty (x.val * y.val) + +def IScalar.mul {ty : IScalarTy} (x y : IScalar ty) : Result (IScalar ty) := + IScalar.tryMk ty (x.val * y.val) + +def UScalar.try_mul {ty : UScalarTy} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (mul x y) + +def IScalar.try_mul {ty : IScalarTy} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (mul x y) + +instance {ty} : HMul (UScalar ty) (UScalar ty) (Result (UScalar ty)) where + hMul x y := UScalar.mul x y + +instance {ty} : HMul (IScalar ty) (IScalar ty) (Result (IScalar ty)) where + hMul x y := IScalar.mul x y + +/-! +Division +-/ + +def UScalar.div {ty : UScalarTy} (x y : UScalar ty) : Result (UScalar ty) := + if y.bv != 0 then ok ⟨ BitVec.udiv x.bv y.bv ⟩ else fail divisionByZero + +def IScalar.div {ty : IScalarTy} (x y : IScalar ty): Result (IScalar ty) := + if y.val != 0 then + -- There can be an overflow if `x` is equal to the lower bound and `y` to `-1` + if ¬ (x.val = IScalar.min ty && y.val = -1) then ok ⟨ BitVec.sdiv x.bv y.bv ⟩ + else fail integerOverflow + else fail divisionByZero + +def UScalar.try_div {ty : UScalarTy} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (div x y) + +def IScalar.try_div {ty : IScalarTy} (x y : IScalar ty): Option (IScalar ty) := + Option.ofResult (div x y) + +instance {ty} : HDiv (UScalar ty) (UScalar ty) (Result (UScalar ty)) where + hDiv x y := UScalar.div x y + +instance {ty} : HDiv (IScalar ty) (IScalar ty) (Result (IScalar ty)) where + hDiv x y := IScalar.div x y + +/-! +Remainder +-/ +def UScalar.rem {ty : UScalarTy} (x y : UScalar ty) : Result (UScalar ty) := + if y.val != 0 then ok ⟨ BitVec.umod x.bv y.bv ⟩ else fail divisionByZero + +def IScalar.rem {ty : IScalarTy} (x y : IScalar ty) : Result (IScalar ty) := + if y.val != 0 then ok ⟨ BitVec.srem x.bv y.bv ⟩ + else fail divisionByZero + +def UScalar.try_rem {ty : UScalarTy} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (rem x y) + +def IScalar.try_rem {ty : IScalarTy} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (rem x y) + +instance {ty} : HMod (UScalar ty) (UScalar ty) (Result (UScalar ty)) where + hMod x y := UScalar.rem x y + +instance {ty} : HMod (IScalar ty) (IScalar ty) (Result (IScalar ty)) where + hMod x y := IScalar.rem x y + +/-! +Bit shifts +-/ +def UScalar.shiftLeft {ty : UScalarTy} (x : UScalar ty) (s : Nat) : + Result (UScalar ty) := + if s < ty.numBits then + ok ⟨ x.bv.shiftLeft s ⟩ + else fail .integerOverflow + +def UScalar.shiftRight {ty : UScalarTy} (x : UScalar ty) (s : Nat) : + Result (UScalar ty) := + if s < ty.numBits then + ok ⟨ x.bv.ushiftRight s ⟩ + else fail .integerOverflow + +def UScalar.shiftLeft_UScalar {ty tys} (x : UScalar ty) (s : UScalar tys) : + Result (UScalar ty) := + x.shiftLeft s.val + +def UScalar.shiftRight_UScalar {ty tys} (x : UScalar ty) (s : UScalar tys) : + Result (UScalar ty) := + x.shiftRight s.val + +def UScalar.shiftLeft_IScalar {ty tys} (x : UScalar ty) (s : IScalar tys) : + Result (UScalar ty) := + x.shiftLeft s.toNat + +def UScalar.shiftRight_IScalar {ty tys} (x : UScalar ty) (s : IScalar tys) : + Result (UScalar ty) := + x.shiftRight s.toNat + +def IScalar.shiftLeft {ty : IScalarTy} (x : IScalar ty) (s : Nat) : + Result (IScalar ty) := + if s < ty.numBits then + ok ⟨ x.bv.shiftLeft s ⟩ + else fail .integerOverflow + +def IScalar.shiftRight {ty : IScalarTy} (x : IScalar ty) (s : Nat) : + Result (IScalar ty) := + if s < ty.numBits then + ok ⟨ x.bv.sshiftRight s ⟩ + else fail .integerOverflow + +def IScalar.shiftLeft_UScalar {ty tys} (x : IScalar ty) (s : UScalar tys) : + Result (IScalar ty) := + x.shiftLeft s.val + +def IScalar.shiftRight_UScalar {ty tys} (x : IScalar ty) (s : UScalar tys) : + Result (IScalar ty) := + x.shiftRight s.val + +def IScalar.shiftLeft_IScalar {ty tys} (x : IScalar ty) (s : IScalar tys) : + Result (IScalar ty) := + if s.val ≥ 0 then + x.shiftLeft s.toNat + else fail .integerOverflow + +def IScalar.shiftRight_IScalar {ty tys} (x : IScalar ty) (s : IScalar tys) : + Result (IScalar ty) := + if s.val ≥ 0 then + x.shiftRight s.toNat + else fail .integerOverflow + +instance {ty0 ty1} : HShiftLeft (UScalar ty0) (UScalar ty1) (Result (UScalar ty0)) where + hShiftLeft x y := UScalar.shiftLeft_UScalar x y + +instance {ty0 ty1} : HShiftLeft (UScalar ty0) (IScalar ty1) (Result (UScalar ty0)) where + hShiftLeft x y := UScalar.shiftLeft_IScalar x y + +instance {ty0 ty1} : HShiftLeft (IScalar ty0) (UScalar ty1) (Result (IScalar ty0)) where + hShiftLeft x y := IScalar.shiftLeft_UScalar x y + +instance {ty0 ty1} : HShiftLeft (IScalar ty0) (IScalar ty1) (Result (IScalar ty0)) where + hShiftLeft x y := IScalar.shiftLeft_IScalar x y + +instance {ty0 ty1} : HShiftRight (UScalar ty0) (UScalar ty1) (Result (UScalar ty0)) where + hShiftRight x y := UScalar.shiftRight_UScalar x y + +instance {ty0 ty1} : HShiftRight (UScalar ty0) (IScalar ty1) (Result (UScalar ty0)) where + hShiftRight x y := UScalar.shiftRight_IScalar x y + +instance {ty0 ty1} : HShiftRight (IScalar ty0) (UScalar ty1) (Result (IScalar ty0)) where + hShiftRight x y := IScalar.shiftRight_UScalar x y + +instance {ty0 ty1} : HShiftRight (IScalar ty0) (IScalar ty1) (Result (IScalar ty0)) where + hShiftRight x y := IScalar.shiftRight_IScalar x y + +/-! +Bitwise and +-/ +def UScalar.and {ty} (x y : UScalar ty) : UScalar ty := ⟨ x.bv &&& y.bv ⟩ + +def IScalar.and {ty} (x y : IScalar ty) : IScalar ty := ⟨ x.bv &&& y.bv ⟩ + +instance {ty} : HAnd (UScalar ty) (UScalar ty) (UScalar ty) where + hAnd x y := UScalar.and x y + +instance {ty} : HAnd (IScalar ty) (IScalar ty) (IScalar ty) where + hAnd x y := IScalar.and x y + +/-! +Bitwise or +-/ +def UScalar.or {ty} (x y : UScalar ty) : UScalar ty := ⟨ x.bv ||| y.bv ⟩ + +def IScalar.or {ty} (x y : IScalar ty) : IScalar ty := ⟨ x.bv ||| y.bv ⟩ + +instance {ty} : HOr (UScalar ty) (UScalar ty) (UScalar ty) where + hOr x y := UScalar.or x y + +instance {ty} : HOr (IScalar ty) (IScalar ty) (IScalar ty) where + hOr x y := IScalar.or x y + +/-! +Xor +-/ +def UScalar.xor {ty} (x y : UScalar ty) : UScalar ty := ⟨ x.bv ||| y.bv ⟩ + +def IScalar.xor {ty} (x y : IScalar ty) : IScalar ty := ⟨ x.bv ||| y.bv ⟩ + +instance {ty} : HXor (UScalar ty) (UScalar ty) (UScalar ty) where + hXor x y := UScalar.xor x y + +instance {ty} : HXor (IScalar ty) (IScalar ty) (IScalar ty) where + hXor x y := IScalar.xor x y + +/-! +Not +-/ +def UScalar.not {ty} (x : UScalar ty) : UScalar ty := ⟨ ~~~x.bv ⟩ + +def IScalar.not {ty} (x : IScalar ty) : IScalar ty := ⟨ ~~~x.bv ⟩ + +instance {ty} : Complement (UScalar ty) where + complement x := UScalar.not x + +instance {ty} : Complement (IScalar ty) where + complement x := IScalar.not x + +/-! +Casts + +The reference semantics are here: https://doc.rust-lang.org/reference/expressions/operator-expr.html#semantics +-/ + +/-- When casting between unsigned integers, we truncate or **zero**-extend the integer. -/ +@[progress_pure_def] +def UScalar.cast {src_ty : UScalarTy} (tgt_ty : UScalarTy) (x : UScalar src_ty) : UScalar tgt_ty := + -- This truncates the integer if the numBits is smaller + ⟨ x.bv.zeroExtend tgt_ty.numBits ⟩ + +/- Heterogeneous cast + + When casting from an unsigned integer to a signed integer, we truncate or **zero**-extend. +-/ +@[progress_pure_def] +def UScalar.hcast {src_ty : UScalarTy} (tgt_ty : IScalarTy) (x : UScalar src_ty) : IScalar tgt_ty := + -- This truncates the integer if the numBits is smaller + ⟨ x.bv.zeroExtend tgt_ty.numBits ⟩ +/-- When casting between signed integers, we truncate or **sign**-extend. -/ +@[progress_pure_def] +def IScalar.cast {src_ty : IScalarTy} (tgt_ty : IScalarTy) (x : IScalar src_ty) : IScalar tgt_ty := + ⟨ x.bv.signExtend tgt_ty.numBits ⟩ + +/- Heterogeneous cast + + When casting from a signed integer to a unsigned integer, we truncate or **sign**-extend. +-/ +@[progress_pure_def] +def IScalar.hcast {src_ty : IScalarTy} (tgt_ty : UScalarTy) (x : IScalar src_ty) : UScalar tgt_ty := + ⟨ x.bv.signExtend tgt_ty.numBits ⟩ + +section + /-! Checking that the semantics of casts are correct by using the examples given by the Rust reference. -/ + + private def check_cast_i_to_u (src : Int) (src_ty : IScalarTy) (tgt : Nat) (tgt_ty : UScalarTy) + (hSrc : IScalar.cMin src_ty ≤ src ∧ src ≤ IScalar.cMax src_ty := by decide) + (hTgt : tgt ≤ UScalar.cMax tgt_ty := by decide): Bool := + IScalar.hcast tgt_ty (@IScalar.ofInt src_ty src hSrc) = @UScalar.ofNat tgt_ty tgt hTgt + + private def check_cast_u_to_i (src : Nat) (src_ty : UScalarTy) (tgt : Int) (tgt_ty : IScalarTy) + (hSrc : src ≤ UScalar.cMax src_ty := by decide) + (hTgt : IScalar.cMin tgt_ty ≤ tgt ∧ tgt ≤ IScalar.cMax tgt_ty := by decide) : Bool := + UScalar.hcast tgt_ty (@UScalar.ofNat src_ty src hSrc) = @IScalar.ofInt tgt_ty tgt hTgt + + private def check_cast_u_to_u (src : Nat) (src_ty : UScalarTy) (tgt : Nat) (tgt_ty : UScalarTy) + (hSrc : src ≤ UScalar.cMax src_ty := by decide) + (hTgt : tgt ≤ UScalar.cMax tgt_ty := by decide) : Bool := + UScalar.cast tgt_ty (@UScalar.ofNat src_ty src hSrc) = @UScalar.ofNat tgt_ty tgt hTgt + + private def check_cast_i_to_i (src : Int) (src_ty : IScalarTy) (tgt : Int) (tgt_ty : IScalarTy) + (hSrc : IScalar.cMin src_ty ≤ src ∧ src ≤ IScalar.cMax src_ty := by decide) + (hTgt : IScalar.cMin tgt_ty ≤ tgt ∧ tgt ≤ IScalar.cMax tgt_ty := by decide) : Bool := + IScalar.cast tgt_ty (@IScalar.ofInt src_ty src hSrc) = @IScalar.ofInt tgt_ty tgt hTgt + + local macro:max x:term:max noWs "i8" : term => `(I8.ofInt $x (by decide)) + local macro:max x:term:max noWs "i16" : term => `(I16.ofInt $x (by decide)) + local macro:max x:term:max noWs "i32" : term => `(I32.ofInt $x (by decide)) + local macro:max x:term:max noWs "u8" : term => `(U8.ofNat $x (by decide)) + local macro:max x:term:max noWs "u16" : term => `(U16.ofNat $x (by decide)) + + /- Cast between integers of same size -/ + #assert IScalar.hcast _ 42i8 = 42u8 -- assert_eq!(42i8 as u8, 42u8); + #assert IScalar.hcast _ (-1)i8 = 255u8 -- assert_eq!(-1i8 as u8, 255u8); + #assert UScalar.hcast _ 255u8 = (-1)i8 -- assert_eq!(255u8 as i8, -1i8); + #assert IScalar.hcast _ (-1)i16 = 65535u16 -- assert_eq!(-1i16 as u16, 65535u16); + + /- Cast from larger integer to smaller integer -/ + #assert UScalar.cast _ 42u16 = 42u8 -- assert_eq!(42u16 as u8, 42u8); + #assert UScalar.cast _ 1234u16 = 210u8 -- assert_eq!(1234u16 as u8, 210u8); + #assert UScalar.cast _ 0xabcdu16 = 0xcdu8 -- assert_eq!(0xabcdu16 as u8, 0xcdu8); + + #assert IScalar.cast _ (-42)i16 = (-42)i8 -- assert_eq!(-42i16 as i8, -42i8); + #assert UScalar.hcast _ 1234u16 = (-46)i8 -- assert_eq!(1234u16 as i8, -46i8); + #assert IScalar.cast _ 0xabcdi32 = (-51)i8 -- assert_eq!(0xabcdi32 as i8, -51i8); + + /- Cast from a smaller integer to a larger integer -/ + #assert IScalar.cast _ 42i8 = 42i16 -- assert_eq!(42i8 as i16, 42i16); + #assert IScalar.cast _ (-17)i8 = (-17)i16 -- assert_eq!(-17i8 as i16, -17i16); + #assert UScalar.cast _ 0b1000_1010u8 = 0b0000_0000_1000_1010u16 -- assert_eq!(0b1000_1010u8 as u16, 0b0000_0000_1000_1010u16, "Zero-extend"); + #assert IScalar.cast _ 0b0000_1010i8 = 0b0000_0000_0000_1010i16 -- assert_eq!(0b0000_1010i8 as i16, 0b0000_0000_0000_1010i16, "Sign-extend 0"); + #assert (IScalar.cast .I16 (UScalar.hcast .I8 0b1000_1010u8)) = UScalar.hcast .I16 0b1111_1111_1000_1010u16 -- assert_eq!(0b1000_1010u8 as i8 as i16, 0b1111_1111_1000_1010u16 as i16, "Sign-extend 1"); + +end + +def UScalar.cast_fromBool (ty : UScalarTy) (x : Bool) : UScalar ty := + if x then ⟨ 1#ty.numBits ⟩ else ⟨ 0#ty.numBits ⟩ + +def IScalar.cast_fromBool (ty : IScalarTy) (x : Bool) : IScalar ty := + if x then ⟨ 1#ty.numBits ⟩ else ⟨ 0#ty.numBits ⟩ + +/-! +Negation +-/ +@[progress_pure_def] +def IScalar.neg {ty : IScalarTy} (x : IScalar ty) : Result (IScalar ty) := IScalar.tryMk ty (- x.val) + +/-- The notation typeclass for heterogeneous negation. + +There is no heterogenous negation in the Lean prelude: we thus introduce one here. -/ class HNeg (α : Type u) (β : outParam (Type v)) where /-- `- a` computes the negation of `a`. @@ -199,7 +642,7 @@ class HNeg (α : Type u) (β : outParam (Type v)) where like arrays of constants to take an unreasonable time to get elaborated and type-checked. - TODO: PR to replace Neg with HNeg in Lean? + TODO: PR to introduce HNeg in Lean? -/ prefix:75 "-." => HNeg.hNeg @@ -214,676 +657,2575 @@ prefix:75 "-." => HNeg.hNeg -/ attribute [match_pattern] HNeg.hNeg -instance : HNeg Isize (Result Isize) where hNeg x := Scalar.neg x -instance : HNeg I8 (Result I8) where hNeg x := Scalar.neg x -instance : HNeg I16 (Result I16) where hNeg x := Scalar.neg x -instance : HNeg I32 (Result I32) where hNeg x := Scalar.neg x -instance : HNeg I64 (Result I64) where hNeg x := Scalar.neg x -instance : HNeg I128 (Result I128) where hNeg x := Scalar.neg x +instance {ty} : HNeg (IScalar ty) (Result (IScalar ty)) where hNeg x := IScalar.neg x --- Addition -instance {ty} : HAdd (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hAdd x y := Scalar.add x y +/-! --- Substraction -instance {ty} : HSub (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hSub x y := Scalar.sub x y +## Primitive Operations: Theorems --- Multiplication -instance {ty} : HMul (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hMul x y := Scalar.mul x y +-/ --- Division -instance {ty} : HDiv (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hDiv x y := Scalar.div x y +/-- Important theorem to reason with `Int.bmod` in the proofs about `IScalar` -/ +private theorem bmod_pow_numBits_eq_of_lt (ty : IScalarTy) (x : Int) + (h0 : - 2 ^ (ty.numBits-1) ≤ x) (h1 : x < 2 ^ (ty.numBits -1)) : + Int.bmod x (2^ty.numBits) = x := by + have := ty.numBits_nonzero + have hEq : ty.numBits - 1 + 1 = ty.numBits := by omega + have := Int.bmod_pow2_eq_of_inBounds (ty.numBits-1) x (by omega) (by omega) + simp [hEq] at this + apply this + +/-! +### Add +-/ --- Remainder -instance {ty} : HMod (Scalar ty) (Scalar ty) (Result (Scalar ty)) where - hMod x y := Scalar.rem x y +theorem UScalar.add_equiv {ty} (x y : UScalar ty) : + match x + y with + | ok z => x.val + y.val < 2^ty.numBits ∧ + z.val = x.val + y.val ∧ + z.bv = x.bv + y.bv + | fail _ => ¬ (UScalar.inBounds ty (x.val + y.val)) + | _ => ⊥ := by + have : x + y = add x y := by rfl + rw [this] + simp [add] + have h := tryMk_eq ty (↑x + ↑y) + simp [inBounds] at h + split at h <;> simp_all + zify; simp + zify at h + have := @Int.emod_eq_of_lt (x.val + y.val) (2^ty.numBits) (by omega) (by omega) + simp [*] --- Shift left -instance {ty0 ty1} : HShiftLeft (Scalar ty0) (Scalar ty1) (Result (Scalar ty0)) where - hShiftLeft x y := Scalar.shiftl x y +theorem IScalar.add_equiv {ty} (x y : IScalar ty) : + match x + y with + | ok z => + IScalar.inBounds ty (x.val + y.val) ∧ + z.val = x.val + y.val ∧ + z.bv = x.bv + y.bv + | fail _ => ¬ (IScalar.inBounds ty (x.val + y.val)) + | _ => ⊥ := by + have : x + y = add x y := by rfl + rw [this] + simp [add] + have h := tryMk_eq ty (↑x + ↑y) + simp [inBounds] at h + split at h <;> simp_all + apply BitVec.eq_of_toInt_eq + simp + have := bmod_pow_numBits_eq_of_lt ty (x.val + y.val) (by omega) (by omega) + simp [*] --- Shift right -instance {ty0 ty1} : HShiftRight (Scalar ty0) (Scalar ty1) (Result (Scalar ty0)) where - hShiftRight x y := Scalar.shiftr x y +/-! +Theorems about the addition, with a specification which uses +integers and bit-vectors. +-/ --- Xor -instance {ty} : HXor (Scalar ty) (Scalar ty) (Scalar ty) where - hXor x y := Scalar.xor x y +/-- Generic theorem - shouldn't be used much -/ +theorem UScalar.add_bv_spec {ty} {x y : UScalar ty} + (hmax : ↑x + ↑y ≤ UScalar.max ty) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := by + have h := @add_equiv ty x y + split at h <;> simp_all [max] + have : 0 < 2^ty.numBits := by simp + omega + +/-- Generic theorem - shouldn't be used much -/ +theorem IScalar.add_bv_spec {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ ↑x + ↑y) + (hmax : ↑x + ↑y ≤ IScalar.max ty) : + ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := by + have h := @add_equiv ty x y + split at h <;> simp_all [min, max] + omega --- Or -instance {ty} : HOr (Scalar ty) (Scalar ty) (Scalar ty) where - hOr x y := Scalar.or x y +theorem Usize.add_bv_spec {x y : Usize} (hmax : x.val + y.val ≤ Usize.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + UScalar.add_bv_spec (by scalar_tac) --- And -instance {ty} : HAnd (Scalar ty) (Scalar ty) (Scalar ty) where - hAnd x y := Scalar.and x y +theorem U8.add_bv_spec {x y : U8} (hmax : x.val + y.val ≤ U8.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + UScalar.add_bv_spec (by scalar_tac) --- Not -instance {ty} : HNot (Scalar ty) where - hnot x := Scalar.not x +theorem U16.add_bv_spec {x y : U16} (hmax : x.val + y.val ≤ U16.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + UScalar.add_bv_spec (by scalar_tac) -example (x : Scalar ty) : Scalar ty := ¬ x +theorem U32.add_bv_spec {x y : U32} (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + UScalar.add_bv_spec (by scalar_tac) --- core checked arithmetic operations +theorem U64.add_bv_spec {x y : U64} (hmax : x.val + y.val ≤ U64.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + UScalar.add_bv_spec (by scalar_tac) -/- A helper function that converts failure to none and success to some - TODO: move up to Core module? -/ -def Option.ofResult {a : Type u} (x : Result a) : - Option a := - match x with - | ok x => some x - | _ => none +theorem U128.add_bv_spec {x y : U128} (hmax : x.val + y.val ≤ U128.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + UScalar.add_bv_spec (by scalar_tac) -/- [core::num::{T}::checked_add] -/ -def core.num.checked_add (x y : Scalar ty) : Option (Scalar ty) := - Option.ofResult (x + y) +theorem Isize.add_bv_spec {x y : Isize} + (hmin : Isize.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ Isize.max) : + ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + IScalar.add_bv_spec (by scalar_tac) (by scalar_tac) -def U8.checked_add (x y : U8) : Option U8 := core.num.checked_add x y -def U16.checked_add (x y : U16) : Option U16 := core.num.checked_add x y -def U32.checked_add (x y : U32) : Option U32 := core.num.checked_add x y -def U64.checked_add (x y : U64) : Option U64 := core.num.checked_add x y -def U128.checked_add (x y : U128) : Option U128 := core.num.checked_add x y -def Usize.checked_add (x y : Usize) : Option Usize := core.num.checked_add x y -def I8.checked_add (x y : I8) : Option I8 := core.num.checked_add x y -def I16.checked_add (x y : I16) : Option I16 := core.num.checked_add x y -def I32.checked_add (x y : I32) : Option I32 := core.num.checked_add x y -def I64.checked_add (x y : I64) : Option I64 := core.num.checked_add x y -def I128.checked_add (x y : I128) : Option I128 := core.num.checked_add x y -def Isize.checked_add (x y : Isize) : Option Isize := core.num.checked_add x y +theorem I8.add_bv_spec {x y : I8} + (hmin : I8.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I8.max) : + ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + IScalar.add_bv_spec (by scalar_tac) (by scalar_tac) -/- [core::num::{T}::checked_sub] -/ -def core.num.checked_sub (x y : Scalar ty) : Option (Scalar ty) := - Option.ofResult (x - y) +theorem I16.add_bv_spec {x y : I16} + (hmin : I16.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I16.max) : + ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + IScalar.add_bv_spec (by scalar_tac) (by scalar_tac) -def U8.checked_sub (x y : U8) : Option U8 := core.num.checked_sub x y -def U16.checked_sub (x y : U16) : Option U16 := core.num.checked_sub x y -def U32.checked_sub (x y : U32) : Option U32 := core.num.checked_sub x y -def U64.checked_sub (x y : U64) : Option U64 := core.num.checked_sub x y -def U128.checked_sub (x y : U128) : Option U128 := core.num.checked_sub x y -def Usize.checked_sub (x y : Usize) : Option Usize := core.num.checked_sub x y -def I8.checked_sub (x y : I8) : Option I8 := core.num.checked_sub x y -def I16.checked_sub (x y : I16) : Option I16 := core.num.checked_sub x y -def I32.checked_sub (x y : I32) : Option I32 := core.num.checked_sub x y -def I64.checked_sub (x y : I64) : Option I64 := core.num.checked_sub x y -def I128.checked_sub (x y : I128) : Option I128 := core.num.checked_sub x y -def Isize.checked_sub (x y : Isize) : Option Isize := core.num.checked_sub x y +theorem I32.add_bv_spec {x y : I32} + (hmin : I32.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I32.max) : + ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + IScalar.add_bv_spec (by scalar_tac) (by scalar_tac) -/- [core::num::{T}::checked_mul] -/ -def core.num.checked_mul (x y : Scalar ty) : Option (Scalar ty) := - Option.ofResult (x * y) - -def U8.checked_mul (x y : U8) : Option U8 := core.num.checked_mul x y -def U16.checked_mul (x y : U16) : Option U16 := core.num.checked_mul x y -def U32.checked_mul (x y : U32) : Option U32 := core.num.checked_mul x y -def U64.checked_mul (x y : U64) : Option U64 := core.num.checked_mul x y -def U128.checked_mul (x y : U128) : Option U128 := core.num.checked_mul x y -def Usize.checked_mul (x y : Usize) : Option Usize := core.num.checked_mul x y -def I8.checked_mul (x y : I8) : Option I8 := core.num.checked_mul x y -def I16.checked_mul (x y : I16) : Option I16 := core.num.checked_mul x y -def I32.checked_mul (x y : I32) : Option I32 := core.num.checked_mul x y -def I64.checked_mul (x y : I64) : Option I64 := core.num.checked_mul x y -def I128.checked_mul (x y : I128) : Option I128 := core.num.checked_mul x y -def Isize.checked_mul (x y : Isize) : Option Isize := core.num.checked_mul x y +theorem I64.add_bv_spec {x y : I64} + (hmin : I64.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I64.max) : + ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + IScalar.add_bv_spec (by scalar_tac) (by scalar_tac) -/- [core::num::{T}::checked_div] -/ -def core.num.checked_div (x y : Scalar ty) : Option (Scalar ty) := - Option.ofResult (x / y) - -def U8.checked_div (x y : U8) : Option U8 := core.num.checked_div x y -def U16.checked_div (x y : U16) : Option U16 := core.num.checked_div x y -def U32.checked_div (x y : U32) : Option U32 := core.num.checked_div x y -def U64.checked_div (x y : U64) : Option U64 := core.num.checked_div x y -def U128.checked_div (x y : U128) : Option U128 := core.num.checked_div x y -def Usize.checked_div (x y : Usize) : Option Usize := core.num.checked_div x y -def I8.checked_div (x y : I8) : Option I8 := core.num.checked_div x y -def I16.checked_div (x y : I16) : Option I16 := core.num.checked_div x y -def I32.checked_div (x y : I32) : Option I32 := core.num.checked_div x y -def I64.checked_div (x y : I64) : Option I64 := core.num.checked_div x y -def I128.checked_div (x y : I128) : Option I128 := core.num.checked_div x y -def Isize.checked_div (x y : Isize) : Option Isize := core.num.checked_div x y +theorem I128.add_bv_spec {x y : I128} + (hmin : I128.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I128.max) : + ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y ∧ z.bv = x.bv + y.bv := + IScalar.add_bv_spec (by scalar_tac) (by scalar_tac) -/- [core::num::{T}::checked_rem] -/ -def core.num.checked_rem (x y : Scalar ty) : Option (Scalar ty) := - Option.ofResult (x % y) - -def U8.checked_rem (x y : U8) : Option U8 := core.num.checked_rem x y -def U16.checked_rem (x y : U16) : Option U16 := core.num.checked_rem x y -def U32.checked_rem (x y : U32) : Option U32 := core.num.checked_rem x y -def U64.checked_rem (x y : U64) : Option U64 := core.num.checked_rem x y -def U128.checked_rem (x y : U128) : Option U128 := core.num.checked_rem x y -def Usize.checked_rem (x y : Usize) : Option Usize := core.num.checked_rem x y -def I8.checked_rem (x y : I8) : Option I8 := core.num.checked_rem x y -def I16.checked_rem (x y : I16) : Option I16 := core.num.checked_rem x y -def I32.checked_rem (x y : I32) : Option I32 := core.num.checked_rem x y -def I64.checked_rem (x y : I64) : Option I64 := core.num.checked_rem x y -def I128.checked_rem (x y : I128) : Option I128 := core.num.checked_rem x y -def Isize.checked_rem (x y : Isize) : Option Isize := core.num.checked_rem x y - -theorem Scalar.add_equiv {ty} {x y : Scalar ty} : - match x + y with - | ok z => Scalar.in_bounds ty (↑x + ↑y) ∧ (↑z : Int) = ↑x + ↑y - | fail _ => ¬ (Scalar.in_bounds ty (↑x + ↑y)) - | _ => ⊥ := by - -- Applying the unfoldings only inside the match - conv in _ + _ => unfold HAdd.hAdd instHAddScalarResult; simp [add] - have h := tryMk_eq ty (↑x + ↑y) - simp [in_bounds] at h - split at h <;> simp_all [check_bounds_eq_in_bounds] - --- Generic theorem - shouldn't be used much -@[pspec] -theorem Scalar.add_spec {ty} {x y : Scalar ty} - (hmin : Scalar.min ty ≤ ↑x + y.val) - (hmax : ↑x + ↑y ≤ Scalar.max ty) : - (∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y) := by - have h := @add_equiv ty x y - split at h <;> simp_all +/-! +Theorems about the addition, with a specification which uses +only integers. Those are the most common to use, so we mark them with the +`progress` attribute. +-/ -theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} - (hmax : ↑x + ↑y ≤ Scalar.max ty) : - ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := by - have hmin : Scalar.min ty ≤ ↑x + ↑y := by - have hx := x.hmin - have hy := y.hmin - cases ty <;> simp [min, ScalarTy.isSigned] at * <;> omega - apply add_spec <;> assumption - -/- Fine-grained theorems -/ -@[pspec] theorem Usize.add_spec {x y : Usize} (hmax : x.val + y.val ≤ Usize.max) : +/-- Generic theorem - shouldn't be used much -/ +@[progress] +theorem UScalar.add_spec {ty} {x y : UScalar ty} + (hmax : ↑x + ↑y ≤ UScalar.max ty) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y := by + have h := @add_equiv ty x y + split at h <;> simp_all [max] + have : 0 < 2^ty.numBits := by simp + omega + +/-- Generic theorem - shouldn't be used much -/ +@[progress] +theorem IScalar.add_spec {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ ↑x + ↑y) + (hmax : ↑x + ↑y ≤ IScalar.max ty) : ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := by - apply Scalar.add_unsigned_spec <;> simp [ScalarTy.isSigned, Scalar.max, *] + have h := @add_equiv ty x y + split at h <;> simp_all [min, max] + omega -@[pspec] theorem U8.add_spec {x y : U8} (hmax : x.val + y.val ≤ U8.max) : - ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := by - apply Scalar.add_unsigned_spec <;> simp [ScalarTy.isSigned, Scalar.max, *] +@[progress] theorem Usize.add_spec {x y : Usize} (hmax : x.val + y.val ≤ Usize.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y := + UScalar.add_spec (by scalar_tac) -@[pspec] theorem U16.add_spec {x y : U16} (hmax : x.val + y.val ≤ U16.max) : - ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := by - apply Scalar.add_unsigned_spec <;> simp [ScalarTy.isSigned, Scalar.max, *] +@[progress] theorem U8.add_spec {x y : U8} (hmax : x.val + y.val ≤ U8.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y := + UScalar.add_spec (by scalar_tac) -@[pspec] theorem U32.add_spec {x y : U32} (hmax : x.val + y.val ≤ U32.max) : - ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := by - apply Scalar.add_unsigned_spec <;> simp [ScalarTy.isSigned, Scalar.max, *] +@[progress] theorem U16.add_spec {x y : U16} (hmax : x.val + y.val ≤ U16.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y := + UScalar.add_spec (by scalar_tac) -@[pspec] theorem U64.add_spec {x y : U64} (hmax : x.val + y.val ≤ U64.max) : - ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := by - apply Scalar.add_unsigned_spec <;> simp [ScalarTy.isSigned, Scalar.max, *] +@[progress] theorem U32.add_spec {x y : U32} (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y := + UScalar.add_spec (by scalar_tac) -@[pspec] theorem U128.add_spec {x y : U128} (hmax : x.val + y.val ≤ U128.max) : - ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := by - apply Scalar.add_unsigned_spec <;> simp [ScalarTy.isSigned, Scalar.max, *] +@[progress] theorem U64.add_spec {x y : U64} (hmax : x.val + y.val ≤ U64.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y := + UScalar.add_spec (by scalar_tac) + +@[progress] theorem U128.add_spec {x y : U128} (hmax : x.val + y.val ≤ U128.max) : + ∃ z, x + y = ok z ∧ (↑z : Nat) = ↑x + ↑y := + UScalar.add_spec (by scalar_tac) -@[pspec] theorem Isize.add_spec {x y : Isize} +@[progress] theorem Isize.add_spec {x y : Isize} (hmin : Isize.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ Isize.max) : ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := - Scalar.add_spec hmin hmax + IScalar.add_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I8.add_spec {x y : I8} +@[progress] theorem I8.add_spec {x y : I8} (hmin : I8.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I8.max) : ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := - Scalar.add_spec hmin hmax + IScalar.add_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I16.add_spec {x y : I16} +@[progress] theorem I16.add_spec {x y : I16} (hmin : I16.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I16.max) : ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := - Scalar.add_spec hmin hmax + IScalar.add_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I32.add_spec {x y : I32} +@[progress] theorem I32.add_spec {x y : I32} (hmin : I32.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I32.max) : ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := - Scalar.add_spec hmin hmax + IScalar.add_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I64.add_spec {x y : I64} +@[progress] theorem I64.add_spec {x y : I64} (hmin : I64.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I64.max) : ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := - Scalar.add_spec hmin hmax + IScalar.add_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I128.add_spec {x y : I128} +@[progress] theorem I128.add_spec {x y : I128} (hmin : I128.min ≤ ↑x + ↑y) (hmax : ↑x + ↑y ≤ I128.max) : ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := - Scalar.add_spec hmin hmax - -theorem core.num.checked_add_spec {ty} {x y : Scalar ty} : - match core.num.checked_add x y with - | some z => Scalar.in_bounds ty (↑x + ↑y) ∧ ↑z = (↑x + ↑y : Int) - | none => ¬ (Scalar.in_bounds ty (↑x + ↑y)) := by - have h := Scalar.tryMk_eq ty (↑x + ↑y) - simp only [checked_add, Option.ofResult] - cases heq: x + y <;> simp_all <;> simp [HAdd.hAdd, Scalar.add] at heq - <;> simp [Add.add] at heq - <;> simp_all - -theorem Scalar.sub_equiv {ty} {x y : Scalar ty} : + IScalar.add_spec (by scalar_tac) (by scalar_tac) + +/-! +### Sub +-/ + + +theorem UScalar.sub_equiv {ty} (x y : UScalar ty) : + match x - y with + | ok z => + y.val ≤ x.val ∧ + x.val = z.val + y.val ∧ + z.bv = x.bv - y.bv + | fail _ => x.val < y.val + | _ => ⊥ := by + have : x - y = sub x y := by rfl + simp [this, sub] + dcases h : x.val < y.val <;> simp [h] + simp_all + simp only [UScalar.val] + simp + split_conjs + . have: (x.val - y.val) % 2^ty.numBits = x.val - y.val := by + have : 0 < 2^ty.numBits := by simp + have := x.hBounds + apply Nat.mod_eq_of_lt; omega + simp [this] + omega + . zify; simp + have : (x.val - y.val : Nat) = (x.val : Int) - y.val := by omega + rw [this]; clear this + ring_nf + rw [Int.add_emod] + have : ((2^ty.numBits - y.val) : Nat) % (2^ty.numBits : Int) = + (- (y.val : Int)) % (2^ty.numBits : Int) := by + have : (2^ty.numBits - y.val : Nat) = (2^ty.numBits : Int) - y.val := by + have hBounds := y.hBounds + zify at *; simp at * + have : (2^ty.numBits : Nat) = (2^ty.numBits : Int) := by simp + omega + rw [this] + -- TODO: Int.emod_sub_emod not in this version of mathlib + have := Int.emod_add_emod (2^ty.numBits) (2^ty.numBits) (-y.val) + ring_nf at this + ring_nf + rw [← this] + simp + rw [this] + rw [← Int.add_emod] + ring_nf + +theorem IScalar.sub_equiv {ty} (x y : IScalar ty) : match x - y with - | ok z => Scalar.in_bounds ty (↑x - ↑y) ∧ (↑z : Int) = ↑x - ↑y - | fail _ => ¬ (Scalar.in_bounds ty (↑x - ↑y)) + | ok z => + IScalar.inBounds ty (x.val - y.val) ∧ + z.val = x.val - y.val ∧ + z.bv = x.bv - y.bv + | fail _ => ¬ (IScalar.inBounds ty (x.val - y.val)) | _ => ⊥ := by - -- Applying the unfoldings only inside the match - conv in _ - _ => unfold HSub.hSub instHSubScalarResult; simp [sub] + have : x - y = sub x y := by rfl + simp [this, sub] have h := tryMk_eq ty (↑x - ↑y) - simp [in_bounds] at h - split at h <;> simp_all [check_bounds_eq_in_bounds] - --- Generic theorem - shouldn't be used much -@[pspec] -theorem Scalar.sub_spec {ty} {x y : Scalar ty} - (hmin : Scalar.min ty ≤ ↑x - ↑y) - (hmax : ↑x - ↑y ≤ Scalar.max ty) : - ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by + simp [inBounds] at h + split at h <;> simp_all + apply BitVec.eq_of_toInt_eq + simp + have := bmod_pow_numBits_eq_of_lt ty (x.val - y.val) (by omega) (by omega) + simp [*] + +/-! +Theorems with a specification which uses integers and bit-vectors +-/ + +/- Generic theorem - shouldn't be used much -/ +theorem UScalar.sub_bv_spec {ty} {x y : UScalar ty} + (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val ∧ z.bv = x.bv - y.bv := by have h := @sub_equiv ty x y split at h <;> simp_all + omega -theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned) - {x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x - ↑y) : - ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - have : ↑x - ↑y ≤ Scalar.max ty := by - have hx := x.hmin - have hxm := x.hmax - have hy := y.hmin - cases ty <;> simp [min, max, ScalarTy.isSigned] at * <;> omega - apply sub_spec <;> assumption - -/- Fine-grained theorems -/ -@[pspec] theorem Usize.sub_spec {x y : Usize} (hmin : Usize.min ≤ x.val - y.val) : - ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - apply Scalar.sub_unsigned_spec <;> simp [Scalar.min, ScalarTy.isSigned]; omega +/- Generic theorem - shouldn't be used much -/ +theorem IScalar.sub_bv_spec {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ ↑x - ↑y) + (hmax : ↑x - ↑y ≤ IScalar.max ty) : + ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y ∧ z.bv = x.bv - y.bv := by + have h := @sub_equiv ty x y + split at h <;> simp_all [min, max] + omega + +theorem Usize.sub_bv_spec {x y : Usize} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val ∧ z.bv = x.bv - y.bv := + UScalar.sub_bv_spec h + +theorem U8.sub_bv_spec {x y : U8} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val ∧ z.bv = x.bv - y.bv := + UScalar.sub_bv_spec h + +theorem U16.sub_bv_spec {x y : U16} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val ∧ z.bv = x.bv - y.bv := + UScalar.sub_bv_spec h + +theorem U32.sub_bv_spec {x y : U32} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val ∧ z.bv = x.bv - y.bv := + UScalar.sub_bv_spec h + +theorem U64.sub_bv_spec {x y : U64} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val ∧ z.bv = x.bv - y.bv := + UScalar.sub_bv_spec h + +theorem U128.sub_bv_spec {x y : U128} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val ∧ z.bv = x.bv - y.bv := + UScalar.sub_bv_spec h + +theorem Isize.sub_bv_spec {x y : Isize} + (hmin : Isize.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ Isize.max) : + ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y ∧ z.bv = x.bv - y.bv := + IScalar.sub_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I8.sub_bv_spec {x y : I8} + (hmin : I8.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I8.max) : + ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y ∧ z.bv = x.bv - y.bv := + IScalar.sub_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I16.sub_bv_spec {x y : I16} + (hmin : I16.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I16.max) : + ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y ∧ z.bv = x.bv - y.bv := + IScalar.sub_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I32.sub_bv_spec {x y : I32} + (hmin : I32.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I32.max) : + ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y ∧ z.bv = x.bv - y.bv := + IScalar.sub_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I64.sub_bv_spec {x y : I64} + (hmin : I64.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I64.max) : + ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y ∧ z.bv = x.bv - y.bv := + IScalar.sub_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I128.sub_bv_spec {x y : I128} + (hmin : I128.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I128.max) : + ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y ∧ z.bv = x.bv - y.bv := + IScalar.sub_bv_spec (by scalar_tac) (by scalar_tac) + +/-! +Theorems with a specification which only uses integers +-/ -@[pspec] theorem U8.sub_spec {x y : U8} (hmin : U8.min ≤ x.val - y.val) : - ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - apply Scalar.sub_unsigned_spec <;> simp_all [Scalar.min, ScalarTy.isSigned] +/- Generic theorem - shouldn't be used much -/ +@[progress] +theorem UScalar.sub_spec {ty} {x y : UScalar ty} + (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val := by + have h := @sub_equiv ty x y + split at h <;> simp_all + omega -@[pspec] theorem U16.sub_spec {x y : U16} (hmin : U16.min ≤ x.val - y.val) : +/- Generic theorem - shouldn't be used much -/ +@[progress] +theorem IScalar.sub_spec {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ ↑x - ↑y) + (hmax : ↑x - ↑y ≤ IScalar.max ty) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - apply Scalar.sub_unsigned_spec <;> simp_all [Scalar.min, ScalarTy.isSigned] + have h := @sub_equiv ty x y + split at h <;> simp_all [min, max] + omega -@[pspec] theorem U32.sub_spec {x y : U32} (hmin : U32.min ≤ x.val - y.val) : - ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - apply Scalar.sub_unsigned_spec <;> simp_all [Scalar.min, ScalarTy.isSigned] +@[progress] theorem Usize.sub_spec {x y : Usize} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val := + UScalar.sub_spec h -@[pspec] theorem U64.sub_spec {x y : U64} (hmin : U64.min ≤ x.val - y.val) : - ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - apply Scalar.sub_unsigned_spec <;> simp_all [Scalar.min, ScalarTy.isSigned] +@[progress] theorem U8.sub_spec {x y : U8} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val := + UScalar.sub_spec h -@[pspec] theorem U128.sub_spec {x y : U128} (hmin : U128.min ≤ x.val - y.val) : - ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - apply Scalar.sub_unsigned_spec <;> simp_all [Scalar.min, ScalarTy.isSigned] +@[progress] theorem U16.sub_spec {x y : U16} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val := + UScalar.sub_spec h + +@[progress] theorem U32.sub_spec {x y : U32} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val := + UScalar.sub_spec h -@[pspec] theorem Isize.sub_spec {x y : Isize} (hmin : Isize.min ≤ ↑x - ↑y) - (hmax : ↑x - ↑y ≤ Isize.max) : +@[progress] theorem U64.sub_spec {x y : U64} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val := + UScalar.sub_spec h + +@[progress] theorem U128.sub_spec {x y : U128} (h : y.val ≤ x.val) : + ∃ z, x - y = ok z ∧ x.val = z.val + y.val := + UScalar.sub_spec h + +@[progress] theorem Isize.sub_spec {x y : Isize} + (hmin : Isize.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ Isize.max) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := - Scalar.sub_spec hmin hmax + IScalar.sub_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I8.sub_spec {x y : I8} (hmin : I8.min ≤ ↑x - ↑y) - (hmax : ↑x - ↑y ≤ I8.max) : +@[progress] theorem I8.sub_spec {x y : I8} + (hmin : I8.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I8.max) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := - Scalar.sub_spec hmin hmax + IScalar.sub_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I16.sub_spec {x y : I16} (hmin : I16.min ≤ ↑x - ↑y) - (hmax : ↑x - ↑y ≤ I16.max) : +@[progress] theorem I16.sub_spec {x y : I16} + (hmin : I16.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I16.max) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := - Scalar.sub_spec hmin hmax + IScalar.sub_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I32.sub_spec {x y : I32} (hmin : I32.min ≤ ↑x - ↑y) - (hmax : ↑x - ↑y ≤ I32.max) : +@[progress] theorem I32.sub_spec {x y : I32} + (hmin : I32.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I32.max) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := - Scalar.sub_spec hmin hmax + IScalar.sub_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I64.sub_spec {x y : I64} (hmin : I64.min ≤ ↑x - ↑y) - (hmax : ↑x - ↑y ≤ I64.max) : +@[progress] theorem I64.sub_spec {x y : I64} + (hmin : I64.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I64.max) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := - Scalar.sub_spec hmin hmax + IScalar.sub_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I128.sub_spec {x y : I128} (hmin : I128.min ≤ ↑x - ↑y) - (hmax : ↑x - ↑y ≤ I128.max) : +@[progress] theorem I128.sub_spec {x y : I128} + (hmin : I128.min ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ I128.max) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := - Scalar.sub_spec hmin hmax + IScalar.sub_spec (by scalar_tac) (by scalar_tac) --- Generic theorem - shouldn't be used much -theorem Scalar.mul_spec {ty} {x y : Scalar ty} - (hmin : Scalar.min ty ≤ ↑x * ↑y) - (hmax : ↑x * ↑y ≤ Scalar.max ty) : - ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - conv => congr; ext; lhs; simp [HMul.hMul] - simp [mul, tryMk, tryMkOpt, ofOption] - split_ifs - . simp [pure] - . tauto - -theorem core.num.checked_sub_spec {ty} {x y : Scalar ty} : - match core.num.checked_sub x y with - | some z => Scalar.in_bounds ty (↑x - ↑y) ∧ ↑z = (↑x - ↑y : Int) - | none => ¬ (Scalar.in_bounds ty (↑x - ↑y)) := by - have h := Scalar.tryMk_eq ty (↑x - ↑y) - simp only [checked_sub, Option.ofResult] - have add_neg_eq : x.val + (-y.val) = x.val - y.val := by omega -- TODO: why do we need this?? - cases heq: x - y <;> simp_all <;> simp only [HSub.hSub, Scalar.sub, Sub.sub, Int.sub] at heq - <;> simp_all - -theorem Scalar.mul_equiv {ty} {x y : Scalar ty} : - match x * y with - | ok z => Scalar.in_bounds ty (↑x * ↑y) ∧ (↑z : Int) = ↑x * ↑y - | fail _ => ¬ (Scalar.in_bounds ty (↑x * ↑y)) - | _ => ⊥ := by - -- Applying the unfoldings only inside the match - conv in _ * _ => unfold HMul.hMul instHMulScalarResult; simp [mul] - have h := tryMk_eq ty (↑x * ↑y) - simp [in_bounds] at h - split at h <;> simp_all [check_bounds_eq_in_bounds] - -theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} - (hmax : ↑x * ↑y ≤ Scalar.max ty) : - ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - have : Scalar.min ty ≤ ↑x * ↑y := by - have hx := x.hmin - have hy := y.hmin - cases ty <;> simp [ScalarTy.isSigned] at * <;> apply mul_nonneg hx hy - apply mul_spec <;> assumption - -/- Fine-grained theorems -/ -@[pspec] theorem Usize.mul_spec {x y : Usize} (hmax : x.val * y.val ≤ Usize.max) : - ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - apply Scalar.mul_unsigned_spec <;> simp_all [Scalar.max, ScalarTy.isSigned] +/-! +### Mul +-/ -@[pspec] theorem U8.mul_spec {x y : U8} (hmax : x.val * y.val ≤ U8.max) : - ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - apply Scalar.mul_unsigned_spec <;> simp_all [Scalar.max, ScalarTy.isSigned] +/-! +Theorems with a specification which use integers and bit-vectors +-/ -@[pspec] theorem U16.mul_spec {x y : U16} (hmax : x.val * y.val ≤ U16.max) : - ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - apply Scalar.mul_unsigned_spec <;> simp_all [Scalar.max, ScalarTy.isSigned] +theorem UScalar.mul_equiv {ty} (x y : UScalar ty) : + match mul x y with + | ok z => x.val * y.val ≤ UScalar.max ty ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv + | fail _ => UScalar.max ty < x.val * y.val + | .div => False := by + simp [mul] + have := tryMk_eq ty (x.val * y.val) + split <;> simp_all + simp_all [tryMk, tryMkOpt] + rename_i hEq; simp only [← hEq, ofNatCore, val] + split_conjs + . simp [max]; omega + . zify; simp [max] + . have : 0 < 2^ty.numBits := by simp + simp [max] + omega -@[pspec] theorem U32.mul_spec {x y : U32} (hmax : x.val * y.val ≤ U32.max) : - ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - apply Scalar.mul_unsigned_spec <;> simp_all [Scalar.max, ScalarTy.isSigned] +/-- Generic theorem - shouldn't be used much -/ +theorem UScalar.mul_bv_spec {ty} {x y : UScalar ty} + (hmax : ↑x * ↑y ≤ UScalar.max ty) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := by + have : x * y = mul x y := by rfl + have := mul_equiv x y + split at this <;> simp_all + omega + +theorem IScalar.mul_equiv {ty} (x y : IScalar ty) : + match mul x y with + | ok z => IScalar.min ty ≤ x.val * y.val ∧ x.val * y.val ≤ IScalar.max ty ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | fail _ => ¬(IScalar.min ty ≤ x.val * y.val ∧ x.val * y.val ≤ IScalar.max ty) + | .div => False := by + simp [mul] + have := tryMk_eq ty (x.val * y.val) + split <;> simp_all [min, max] <;> + simp_all [tryMk, tryMkOpt] <;> + rename_i hEq <;> simp only [← hEq, ofIntCore, val] <;> + simp [← BitVec.toInt_inj] <;> + omega + +/-- Generic theorem - shouldn't be used much -/ +theorem IScalar.mul_bv_spec {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ ↑x * ↑y) + (hmax : ↑x * ↑y ≤ IScalar.max ty) : + ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := by + have : x * y = mul x y := by rfl + have := mul_equiv x y + split at this <;> simp_all + +theorem Usize.mul_bv_spec {x y : Usize} (hmax : x.val * y.val ≤ Usize.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + UScalar.mul_bv_spec (by scalar_tac) + +theorem U8.mul_bv_spec {x y : U8} (hmax : x.val * y.val ≤ U8.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + UScalar.mul_bv_spec (by scalar_tac) + +theorem U16.mul_bv_spec {x y : U16} (hmax : x.val * y.val ≤ U16.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + UScalar.mul_bv_spec (by scalar_tac) + +theorem U32.mul_bv_spec {x y : U32} (hmax : x.val * y.val ≤ U32.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + UScalar.mul_bv_spec (by scalar_tac) + +theorem U64.mul_bv_spec {x y : U64} (hmax : x.val * y.val ≤ U64.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + UScalar.mul_bv_spec (by scalar_tac) + +theorem U128.mul_bv_spec {x y : U128} (hmax : x.val * y.val ≤ U128.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + UScalar.mul_bv_spec (by scalar_tac) + +theorem Isize.mul_bv_spec {x y : Isize} + (hmin : Isize.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ Isize.max) : + ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + IScalar.mul_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I8.mul_bv_spec {x y : I8} + (hmin : I8.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I8.max) : + ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + IScalar.mul_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I16.mul_bv_spec {x y : I16} + (hmin : I16.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I16.max) : + ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + IScalar.mul_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I32.mul_bv_spec {x y : I32} + (hmin : I32.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I32.max) : + ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + IScalar.mul_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I64.mul_bv_spec {x y : I64} (hmin : I64.min ≤ ↑x * ↑y) + (hmax : ↑x * ↑y ≤ I64.max) : + ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + IScalar.mul_bv_spec (by scalar_tac) (by scalar_tac) + +theorem I128.mul_bv_spec {x y : I128} (hmin : I128.min ≤ ↑x * ↑y) + (hmax : ↑x * ↑y ≤ I128.max) : + ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y ∧ z.bv = x.bv * y.bv := + IScalar.mul_bv_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem U64.mul_spec {x y : U64} (hmax : x.val * y.val ≤ U64.max) : - ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - apply Scalar.mul_unsigned_spec <;> simp_all [Scalar.max, ScalarTy.isSigned] -@[pspec] theorem U128.mul_spec {x y : U128} (hmax : x.val * y.val ≤ U128.max) : +/-! +Theorems with a specification which only use integers +-/ + +/-- Generic theorem - shouldn't be used much -/ +theorem UScalar.mul_spec {ty} {x y : UScalar ty} + (hmax : ↑x * ↑y ≤ UScalar.max ty) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y := by + have ⟨ z, h⟩ := mul_bv_spec hmax + simp [h] + +/-- Generic theorem - shouldn't be used much -/ +theorem IScalar.mul_spec {ty} {x y : IScalar ty} + (hmin : IScalar.min ty ≤ ↑x * ↑y) + (hmax : ↑x * ↑y ≤ IScalar.max ty) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by - apply Scalar.mul_unsigned_spec <;> simp_all [Scalar.max, ScalarTy.isSigned] + have ⟨ z, h⟩ := @mul_bv_spec ty x y (by scalar_tac) (by scalar_tac) + simp [h] + +@[progress] theorem Usize.mul_spec {x y : Usize} (hmax : x.val * y.val ≤ Usize.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y := + UScalar.mul_spec (by scalar_tac) + +@[progress] theorem U8.mul_spec {x y : U8} (hmax : x.val * y.val ≤ U8.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y := + UScalar.mul_spec (by scalar_tac) + +@[progress] theorem U16.mul_spec {x y : U16} (hmax : x.val * y.val ≤ U16.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y := + UScalar.mul_spec (by scalar_tac) + +@[progress] theorem U32.mul_spec {x y : U32} (hmax : x.val * y.val ≤ U32.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y := + UScalar.mul_spec (by scalar_tac) + +@[progress] theorem U64.mul_spec {x y : U64} (hmax : x.val * y.val ≤ U64.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y := + UScalar.mul_spec (by scalar_tac) -@[pspec] theorem Isize.mul_spec {x y : Isize} (hmin : Isize.min ≤ ↑x * ↑y) - (hmax : ↑x * ↑y ≤ Isize.max) : +@[progress] theorem U128.mul_spec {x y : U128} (hmax : x.val * y.val ≤ U128.max) : + ∃ z, x * y = ok z ∧ (↑z : Nat) = ↑x * ↑y := + UScalar.mul_spec (by scalar_tac) + +@[progress] theorem Isize.mul_spec {x y : Isize} + (hmin : Isize.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ Isize.max) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := - Scalar.mul_spec hmin hmax + IScalar.mul_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I8.mul_spec {x y : I8} (hmin : I8.min ≤ ↑x * ↑y) - (hmax : ↑x * ↑y ≤ I8.max) : +@[progress] theorem I8.mul_spec {x y : I8} + (hmin : I8.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I8.max) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := - Scalar.mul_spec hmin hmax + IScalar.mul_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I16.mul_spec {x y : I16} (hmin : I16.min ≤ ↑x * ↑y) - (hmax : ↑x * ↑y ≤ I16.max) : +@[progress] theorem I16.mul_spec {x y : I16} + (hmin : I16.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I16.max) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := - Scalar.mul_spec hmin hmax + IScalar.mul_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I32.mul_spec {x y : I32} (hmin : I32.min ≤ ↑x * ↑y) - (hmax : ↑x * ↑y ≤ I32.max) : +@[progress] theorem I32.mul_spec {x y : I32} + (hmin : I32.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I32.max) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := - Scalar.mul_spec hmin hmax + IScalar.mul_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I64.mul_spec {x y : I64} (hmin : I64.min ≤ ↑x * ↑y) +@[progress] theorem I64.mul_spec {x y : I64} (hmin : I64.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I64.max) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := - Scalar.mul_spec hmin hmax + IScalar.mul_spec (by scalar_tac) (by scalar_tac) -@[pspec] theorem I128.mul_spec {x y : I128} (hmin : I128.min ≤ ↑x * ↑y) +@[progress] theorem I128.mul_spec {x y : I128} (hmin : I128.min ≤ ↑x * ↑y) (hmax : ↑x * ↑y ≤ I128.max) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := - Scalar.mul_spec hmin hmax - -theorem core.num.checked_mul_spec {ty} {x y : Scalar ty} : - match core.num.checked_mul x y with - | some z => Scalar.in_bounds ty (↑x * ↑y) ∧ ↑z = (↑x * ↑y : Int) - | none => ¬ (Scalar.in_bounds ty (↑x * ↑y)) := by - have h := Scalar.tryMk_eq ty (↑x * ↑y) - simp only [checked_mul, Option.ofResult] - have : Int.mul ↑x ↑y = ↑x * ↑y := by simp -- TODO: why do we need this?? - cases heq: x * y <;> simp_all <;> simp only [HMul.hMul, Scalar.mul, Mul.mul] at heq - <;> simp_all - -theorem Scalar.div_equiv {ty} {x y : Scalar ty} : - match x / y with - | ok z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y) ∧ (↑z : Int) = scalar_div ↑x ↑y - | fail _ => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y)) - | _ => ⊥ := by - -- Applying the unfoldings only inside the match - conv in _ / _ => unfold HDiv.hDiv instHDivScalarResult; simp [div] - have h := tryMk_eq ty (scalar_div ↑x ↑y) - simp [in_bounds] at h - split_ifs <;> simp <;> - split at h <;> simp_all [check_bounds_eq_in_bounds] - --- Generic theorem - shouldn't be used much -@[pspec] -theorem Scalar.div_spec {ty} {x y : Scalar ty} - (hnz : ↑y ≠ (0 : Int)) - (hmin : Scalar.min ty ≤ scalar_div ↑x ↑y) - (hmax : scalar_div ↑x ↑y ≤ Scalar.max ty) : - ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := by - simp [HDiv.hDiv, div, Div.div] - simp [tryMk, tryMkOpt, ofOption, *] - -theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty} - (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x / y = ok z ∧ (↑z : Int) = ↑x / ↑y := by - have h : Scalar.min ty = 0 := by cases ty <;> simp [ScalarTy.isSigned, min] at * - have hx := x.hmin - have hy := y.hmin - simp [h] at hx hy - have hmin : 0 ≤ x.val / y.val := Int.ediv_nonneg hx hy - have hmax : ↑x / ↑y ≤ Scalar.max ty := by - have := Int.ediv_le_self ↑y hx - have := x.hmax + IScalar.mul_spec (by scalar_tac) (by scalar_tac) + +/-! +### Div +-/ + +/-! +Theorems with a specification which use integers and bit-vectors +-/ + +/-- Generic theorem - shouldn't be used much -/ +theorem UScalar.div_bv_spec {ty} (x : UScalar ty) {y : UScalar ty} + (hzero : y.val ≠ 0) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y ∧ z.bv = x.bv / y.bv := by + have hzero' : y.bv ≠ 0#ty.numBits := by + intro h + zify at h + simp_all + conv => congr; ext; lhs; simp [HDiv.hDiv] + simp [hzero', div, tryMk, tryMkOpt, ofOption, hmax, ofNatCore] + simp only [val] + simp + +theorem Int.bmod_pow2_IScalarTy_numBits_minus_one (ty : IScalarTy) : + Int.bmod (2 ^ (ty.numBits - 1)) (2 ^ ty.numBits) = - 2 ^ (ty.numBits - 1) := by + rw [Int.bmod] + /- Just doing a case disjunction on the number of bits because + those proofs are annoying -/ + dcases ty <;> simp + have := System.Platform.numBits_eq + cases this <;> simp [*] + +theorem IScalar.neg_imp_neg_val_toNat_mod_pow_eq_neg_val {ty} (x : IScalar ty) + (hNeg : x.bv.toInt < 0) : + ((-x.val).toNat : Int) % 2^ty.numBits = -(x.val : Int) := by + have hmsb : x.bv.msb = true := by + have := @BitVec.msb_eq_toInt _ x.bv + simp only [hNeg] at this + apply this + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hmsb] at hx + have hBounds := x.hBounds + have pow2Ineq : (2^(ty.numBits - 1) : Int) < 2^ty.numBits := by + have := ty.numBits_nonzero + have : (0 : Int) < 2^(ty.numBits - 1) := by simp + have : ty.numBits = ty.numBits - 1 + 1 := by omega + conv => rhs; rw [this] + rw [Int.pow_succ'] omega - have hs := @div_spec ty x y hnz - simp [*] at hs - apply hs - -/- Fine-grained theorems -/ -@[pspec] theorem Usize.div_spec (x : Usize) {y : Usize} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x / y = ok z ∧ (↑z : Int) = ↑x / ↑y := by - apply Scalar.div_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U8.div_spec (x : U8) {y : U8} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x / y = ok z ∧ (↑z : Int) = ↑x / ↑y := by - apply Scalar.div_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U16.div_spec (x : U16) {y : U16} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x / y = ok z ∧ (↑z : Int) = ↑x / ↑y := by - apply Scalar.div_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U32.div_spec (x : U32) {y : U32} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x / y = ok z ∧ (↑z : Int) = ↑x / ↑y := by - apply Scalar.div_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U64.div_spec (x : U64) {y : U64} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x / y = ok z ∧ (↑z : Int) = ↑x / ↑y := by - apply Scalar.div_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U128.div_spec (x : U128) {y : U128} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x / y = ok z ∧ (↑z : Int) = ↑x / ↑y := by - apply Scalar.div_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem Isize.div_spec (x : Isize) {y : Isize} - (hnz : ↑y ≠ (0 : Int)) - (hmin : Isize.min ≤ scalar_div ↑x ↑y) - (hmax : scalar_div ↑x ↑y ≤ Isize.max): - ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := - Scalar.div_spec hnz hmin hmax - -@[pspec] theorem I8.div_spec (x : I8) {y : I8} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I8.min ≤ scalar_div ↑x ↑y) - (hmax : scalar_div ↑x ↑y ≤ I8.max): - ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := - Scalar.div_spec hnz hmin hmax - -@[pspec] theorem I16.div_spec (x : I16) {y : I16} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I16.min ≤ scalar_div ↑x ↑y) - (hmax : scalar_div ↑x ↑y ≤ I16.max): - ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := - Scalar.div_spec hnz hmin hmax - -@[pspec] theorem I32.div_spec (x : I32) {y : I32} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I32.min ≤ scalar_div ↑x ↑y) - (hmax : scalar_div ↑x ↑y ≤ I32.max): - ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := - Scalar.div_spec hnz hmin hmax - -@[pspec] theorem I64.div_spec (x : I64) {y : I64} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I64.min ≤ scalar_div ↑x ↑y) - (hmax : scalar_div ↑x ↑y ≤ I64.max): - ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := - Scalar.div_spec hnz hmin hmax - -@[pspec] theorem I128.div_spec (x : I128) {y : I128} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I128.min ≤ scalar_div ↑x ↑y) - (hmax : scalar_div ↑x ↑y ≤ I128.max): - ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := - Scalar.div_spec hnz hmin hmax - -theorem core.num.checked_div_spec {ty} {x y : Scalar ty} : - match core.num.checked_div x y with - | some z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y) ∧ ↑z = (scalar_div ↑x ↑y : Int) - | none => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y)) := by - have h := Scalar.tryMk_eq ty (scalar_div ↑x ↑y) - simp only [checked_div, Option.ofResult] - cases heq0: (y.val = 0 : Bool) <;> - cases heq1: x / y <;> simp_all <;> simp only [HDiv.hDiv, Scalar.div, Div.div] at heq1 - <;> simp_all - -theorem Scalar.rem_equiv {ty} {x y : Scalar ty} : - match x % y with - | ok z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y) ∧ (↑z : Int) = scalar_rem ↑x ↑y - | fail _ => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y)) - | _ => ⊥ := by - -- Applying the unfoldings only inside the match - conv in _ % _ => unfold HMod.hMod instHModScalarResult; simp [rem] - have h := tryMk_eq ty (scalar_rem ↑x ↑y) - simp [in_bounds] at h - split_ifs <;> simp <;> - split at h <;> simp_all [check_bounds_eq_in_bounds] - --- Generic theorem - shouldn't be used much -@[pspec] -theorem Scalar.rem_spec {ty} {x y : Scalar ty} - (hnz : ↑y ≠ (0 : Int)) - (hmin : Scalar.min ty ≤ scalar_rem ↑x ↑y) - (hmax : scalar_rem ↑x ↑y ≤ Scalar.max ty) : - ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := by - simp [HMod.hMod, rem] - simp [tryMk, tryMkOpt, ofOption, *] - -theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty} - (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x % y = ok z ∧ (↑z : Int) = ↑x % ↑y := by - have h : Scalar.min ty = 0 := by cases ty <;> simp [ScalarTy.isSigned, min] at * - have hx := x.hmin - have hy := y.hmin - simp [h] at hx hy - have hmin : (0 : Int) ≤ x % y := Int.emod_nonneg ↑x hnz - have hmax : ↑x % ↑y ≤ Scalar.max ty := by - have h : (0 : Int) < y := by int_tac - have h := Int.emod_lt_of_pos ↑x h - have := y.hmax + have hyToNat : 2 ^ ty.numBits - x.bv.toNat = (-x.val).toNat := by + rw [hx] + simp + norm_cast + have hyValToNatMod : ((-x.val).toNat : Nat) % 2^ty.numBits = (-x.val).toNat := by + have : ↑(-x.val).toNat < 2 ^ ty.numBits := by + zify + apply Int.lt_of_neg_lt_neg + have : - (-x.val).toNat = x.val := by omega + rw [this]; clear this + have := x.hmin + omega + have := @Nat.mod_eq_of_lt (-x.val).toNat (2^ty.numBits) (by omega) + apply this + zify at hyValToNatMod + rw [hyValToNatMod] + omega + +theorem IScalar.neg_imp_toNat_neg_eq_neg_toInt {ty} (x : IScalar ty) (hNeg : x.val < 0): + (-x.bv).toNat = -x.bv.toInt := by + have hmsb : x.bv.msb = true := by + have := @BitVec.msb_eq_toInt _ x.bv + simp only [val] at hNeg + simp only [hNeg] at this + apply this + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hmsb] at hx + + have hxNeg : x.val < 0 := by + have := @BitVec.msb_eq_toInt _ x.bv + simp_all + + conv => lhs; simp only [Neg.neg, BitVec.neg] + simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod] + + have hxToNatMod : (x.bv.toNat : Int) % 2^ty.numBits = x.bv.toNat := by + apply Int.emod_eq_of_lt <;> omega + + have hPow : (2 ^ ty.numBits + 1 : Int) / 2 = 2^(ty.numBits - 1) := by + have : ty.numBits = ty.numBits - 1 + 1 := by + have := ty.numBits_nonzero + omega + conv => lhs; rw [this] + rw [Int.pow_succ'] + rw [Int.add_ediv_of_dvd_left] <;> simp + + have : ¬ ((x.bv.toNat : Int) % ↑(2 ^ ty.numBits : Nat) < (↑(2 ^ ty.numBits : Nat) + 1) / 2) := by + have hIneq := @BitVec.msb_eq_toNat _ x.bv + rw [hmsb] at hIneq + simp at hIneq + simp + rw [hPow] + + rw [hxToNatMod] + zify at hIneq omega - have hs := @rem_spec ty x y hnz - simp [*] at hs - simp [*] + rw [Int.bmod_def] + simp only [this] + simp + have : (2 ^ ty.numBits - x.bv.toNat : Nat) % (2 ^ ty.numBits : Int) = + (2^ty.numBits - x.bv.toNat : Nat) := by + apply Int.emod_eq_of_lt + . omega + . have := x.hBounds + simp only [val] at * + have : (2 ^ ty.numBits - x.bv.toNat : Nat) = (2 ^ ty.numBits - x.bv.toNat : Int) := by + have : (2 ^ ty.numBits : Nat) = (2 ^ ty.numBits : Int) := by simp + omega + rw [this] + have : x.bv.toNat > 0 := by + by_contra + have hxz : x.bv.toNat = 0 := by omega + have : x.bv.toInt = 0 := by + simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod, Int.bmod_def, hxz] + simp [hPow] + omega + omega + rw [this]; clear this + rw [hxToNatMod] + + have : (2 ^ ty.numBits : Nat) = (2 ^ ty.numBits : Int) := by simp + omega + +/-- Generic theorem - shouldn't be used much -/ +theorem IScalar.div_bv_spec {ty} {x y : IScalar ty} + (hzero : y.val ≠ 0) (hNoOverflow : ¬ (x.val = IScalar.min ty ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y ∧ z.bv = BitVec.sdiv x.bv y.bv := by + conv => congr; ext; lhs; simp [HDiv.hDiv] + simp [div, tryMk, tryMkOpt, ofOption, ofIntCore, hzero, hNoOverflow] + simp only [val] + simp [BitVec.sdiv_eq, BitVec.udiv_def] + have pow2Ineq : (2^(ty.numBits - 1) : Int) < 2^ty.numBits := by + have := ty.numBits_nonzero + have : (0 : Int) < 2^(ty.numBits - 1) := by simp + have : ty.numBits = ty.numBits - 1 + 1 := by omega + conv => rhs; rw [this] + rw [Int.pow_succ'] + omega + have hxBounds := x.hBounds + have hyBounds := y.hBounds + --have hxyBounds := tdiv_in_bounds x y hnoOverflow + split + + . -- 0 ≤ x.bv.toInt + -- 0 ≤ y.bv.toInt + rw [BitVec.toInt_ofNat] + simp + have hx : x.bv.toNat = x.bv.toInt := by + have := @BitVec.toInt_eq_msb_cond _ x.bv + simp_all + have hy : y.bv.toNat = y.bv.toInt := by + have := @BitVec.toInt_eq_msb_cond _ y.bv + simp_all + simp [hx, hy] + simp at hx hy + have := @Int.tdiv_nonneg x.val y.val (by omega) (by omega) + have : -2 ^ (ty.numBits - 1) ≤ 0 := by simp + have : (x.val).tdiv y.val < 2 ^ (ty.numBits - 1) := by + rw [Int.tdiv_eq_ediv] <;> try omega + have := @Int.ediv_le_self x.val y.val (by omega) + omega + + have := bmod_pow_numBits_eq_of_lt ty (Int.tdiv x.val y.val) (by omega) (by omega) + rw [← Int.tdiv_eq_ediv] <;> omega + + . -- 0 ≤ x.bv.toInt + -- y.bv.toInt < 0 + rename_i hxIneq hyIneq + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hxIneq] at hx + have hy := @BitVec.toInt_eq_msb_cond _ y.bv + simp [hyIneq] at hy + have hyNeg : y.val < 0 := by + have := @BitVec.msb_eq_toInt _ y.bv + simp_all + have : -2 ^ (ty.numBits - 1) ≤ Int.tdiv x.val y.val := by + have : Int.tdiv x.val (-y.val) ≤ 2^(ty.numBits - 1) := by + rw [Int.tdiv_eq_ediv] <;> try omega + have := @Int.ediv_le_self x.val (-y.val) (by omega) + simp at * + have := x.hmax + omega + replace this := Int.neg_le_neg this + simp at this + apply this + have hyToNat : 2 ^ ty.numBits - y.bv.toNat = (-y.val).toNat := by + rw [hy] + simp + norm_cast + rw [BitVec.toInt_neg, BitVec.toInt_ofNat] + simp + rw [hyToNat] + have : ((-y.val).toNat : Int) % 2^ty.numBits = -(y.val : Int) := by + apply IScalar.neg_imp_neg_val_toNat_mod_pow_eq_neg_val + simp; omega + rw [this]; clear this + simp + rw [← hx] + have : (- (x.val / y.val)).bmod (2^ty.numBits) = - (x.val / y.val) := by + have : -(x.val / ↑y) < 2 ^ (ty.numBits - 1) := by + have : x.val / (-y.val) < 2 ^ (ty.numBits - 1) := by + have := @Int.ediv_le_self x.val (-y.val) (by omega) + have := x.hmax + omega + simp at this + apply this + have : 0 ≤ -(x.val / ↑y) := by + have : - (x.val / y.val) = x.val / (-y.val) := by simp + rw [this]; clear this + apply Int.ediv_nonneg <;> omega + have := bmod_pow_numBits_eq_of_lt ty (- (x.val / y.val)) (by omega) (by omega) + rw [this] + rw [this]; clear this + simp + have : (x.val / y.val).bmod (2^ty.numBits) = x.val / y.val := by + have : -2 ^ (ty.numBits - 1) ≤ x.val / ↑y := by + apply Int.le_of_neg_le_neg + have : - (x.val / y.val) = x.val / -y.val := by simp + rw [this]; clear this + conv => rhs; simp + have := @Int.ediv_le_self x.val (-y.val) (by omega) + omega + have : x.val / ↑y < 2 ^ (ty.numBits - 1) := by + have : 0 < 2 ^ (ty.numBits - 1) := by simp + have : x.val / y.val ≤ 0 := by apply Int.ediv_nonpos <;> omega + omega + have := bmod_pow_numBits_eq_of_lt ty (x.val / y.val) (by omega) (by omega) + rw [this] + + rw [this]; clear this + + have : x.val.tdiv y.val = - (x.val.tdiv (-y.val)) := by simp + rw [this] + rw [Int.tdiv_eq_ediv] <;> try omega + simp + + . -- x.bv.toInt < 0 + -- 0 ≤ y.bv.toInt + rename_i hxIneq hyIneq + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hxIneq] at hx + have hy := @BitVec.toInt_eq_msb_cond _ y.bv + simp [hyIneq] at hy + have hxNeg : x.val < 0 := by + have := @BitVec.msb_eq_toInt _ x.bv + simp_all + have hyPos : 0 ≤ y.val := by + have := @BitVec.msb_eq_toInt _ y.bv + simp_all + have : -2 ^ (ty.numBits - 1) ≤ x.val / y.val := by + have := @Int.ediv_le_ediv (-2 ^ (ty.numBits - 1)) x.val y.val (by omega) (by omega) + have := @Int.self_le_ediv x.val y.val (by omega) (by omega) + omega + have hxToNat : 2 ^ ty.numBits - x.bv.toNat = (-x.val).toNat := by + rw [hx] + simp + norm_cast + rw [BitVec.toInt_neg, BitVec.toInt_ofNat] + simp + + rw [hxToNat] + have : ((-x.val).toNat : Int) % 2^ty.numBits = -(x.val : Int) := by + apply IScalar.neg_imp_neg_val_toNat_mod_pow_eq_neg_val + simp; omega + rw [this]; clear this + + /- We have to treat separately the degenerate case where `x` touches the upper bound + and `y = 1` -/ + dcases hxDivY : -x.val / y.val = 2^(ty.numBits - 1) + . rw [← hy] + rw [hxDivY] + have ⟨ hx, hy ⟩ : x.val = - 2^(ty.numBits - 1) ∧ y.val = 1 := by + have := @Int.le_div_eq_bound_imp_eq (-x.val) y.val (2^(ty.numBits - 1)) + (by omega) (by omega) (by omega) (by omega) + omega + simp [hx, hy] + + have : Int.bmod (2 ^ (ty.numBits - 1)) (2 ^ ty.numBits) = + - 2 ^ (ty.numBits - 1) := + Int.bmod_pow2_IScalarTy_numBits_minus_one ty + rw [this] + simp + rw [this] + . have : 0 ≤ (-x.val) / y.val := by + apply Int.ediv_nonneg <;> omega + have : -x.val / y.val < 2^(ty.numBits - 1) := by + have : -x.val ≤ 2^(ty.numBits - 1) := by omega + have := @Int.ediv_le_self (-x.val) y.val (by omega) + omega + rw [← hy] + have : (-x.val / y.val).bmod (2 ^ ty.numBits) = + (-x.val / y.val) := by + apply bmod_pow_numBits_eq_of_lt ty _ (by omega) (by omega) + rw [this]; clear this + have : (-(-x.val / ↑y)).bmod (2 ^ ty.numBits) = + (-(-x.val / ↑y)) := by + apply bmod_pow_numBits_eq_of_lt ty _ (by omega) (by omega) + rw [this]; clear this + rw [← Int.tdiv_eq_ediv] <;> try omega + simp + + . -- x.bv.toInt < 0 + -- y.bv.toInt < 0 + rename_i hxIneq hyIneq + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hxIneq] at hx + have hy := @BitVec.toInt_eq_msb_cond _ y.bv + simp [hyIneq] at hy + have hxNeg : x.val < 0 := by + have := @BitVec.msb_eq_toInt _ x.bv + simp_all + have hyNeg : y.val < 0 := by + have := @BitVec.msb_eq_toInt _ y.bv + simp_all + have hxToNat : 2 ^ ty.numBits - x.bv.toNat = (-x.val).toNat := by + rw [hx] + simp + norm_cast + have hyToNat : 2 ^ ty.numBits - y.bv.toNat = (-y.val).toNat := by + rw [hy] + simp + norm_cast + rw [hxToNat, hyToNat] + + have : (-x.val).toNat % 2^ty.numBits = (-x.val).toNat := by + apply Nat.mod_eq_of_lt + omega + rw [this] + have : (-y.val).toNat % 2^ty.numBits = (-y.val).toNat := by + apply Nat.mod_eq_of_lt + omega + rw [this] + + rw [BitVec.toInt_ofNat] + + /- We have to treat separately the degenerate case where `x` touches the lower bound + and `y = -1`, because then `x / y` actually overflows -/ + have hxyInBouds : (-x.val) / (-y.val) ≠ 2^(ty.numBits - 1) := by + -- We do the proof by contradiction + intro hEq + have hContra : x.val = - 2^(ty.numBits - 1) ∧ y.val = -1 := by + have := @Int.le_div_eq_bound_imp_eq (-x.val) (-y.val) (2^(ty.numBits - 1)) + (by omega) (by omega) (by omega) (by omega) + omega + simp [hContra, min] at hNoOverflow + + have : -(2 ^ (ty.numBits - 1) : Int) ≤ ↑((-x.val).toNat / (-y.val).toNat) := by + have := @Int.ediv_nonneg (-x.val).toNat (-y.val).toNat (by omega) (by omega) + have : -(2 ^ (ty.numBits - 1) : Int) ≤ 0 := by simp + omega + + have : ((-x.val).toNat / (-y.val).toNat) < (2 ^ (ty.numBits - 1) : Int) := by + -- First prove a ≤ bound + have hIneq : ((-x.val).toNat / (-y.val).toNat) ≤ (2 ^ (ty.numBits - 1) : Int) := by + have := @Int.ediv_le_self (-x.val).toNat (-y.val).toNat (by omega) + omega + -- Then use the hypothesis about the fact that we're not equal to the bound + zify at hIneq + have : (-x.val).toNat = -x.val := by omega + rw [this] at hIneq; rw [this] + have : (-y.val).toNat = -y.val := by omega + rw [this] at hIneq; rw [this] + omega + have := bmod_pow_numBits_eq_of_lt ty ((-x.val).toNat / (-y.val).toNat : Nat) (by omega) (by omega) + rw [this] + + zify; simp + + have : (-x.val) ⊔ 0 = -x.val := by omega + simp only [this]; clear this + have : -↑y ⊔ 0 = -y.val := by omega + simp only [this]; clear this + + rw [← Int.tdiv_eq_ediv] <;> try omega + simp + +theorem U8.div_bv_spec (x : U8) {y : U8} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y ∧ z.bv = x.bv / y.bv := + UScalar.div_bv_spec x hnz + +theorem U16.div_bv_spec (x : U16) {y : U16} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y ∧ z.bv = x.bv / y.bv := + UScalar.div_bv_spec x hnz + +theorem U32.div_bv_spec (x : U32) {y : U32} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y ∧ z.bv = x.bv / y.bv := + UScalar.div_bv_spec x hnz + +theorem U64.div_bv_spec (x : U64) {y : U64} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y ∧ z.bv = x.bv / y.bv := + UScalar.div_bv_spec x hnz + +theorem U128.div_bv_spec (x : U128) {y : U128} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y ∧ z.bv = x.bv / y.bv := + UScalar.div_bv_spec x hnz + +theorem Usize.div_bv_spec (x : Usize) {y : Usize} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y ∧ z.bv = x.bv / y.bv := + UScalar.div_bv_spec x hnz + +theorem I8.div_bv_spec {x y : I8} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I8.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y ∧ z.bv = BitVec.sdiv x.bv y.bv := + IScalar.div_bv_spec hnz (by scalar_tac) + +theorem I16.div_bv_spec {x y : I16} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I16.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y ∧ z.bv = BitVec.sdiv x.bv y.bv := + IScalar.div_bv_spec hnz (by scalar_tac) + +theorem I32.div_bv_spec {x y : I32} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I32.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y ∧ z.bv = BitVec.sdiv x.bv y.bv := + IScalar.div_bv_spec hnz (by scalar_tac) + +theorem I64.div_bv_spec {x y : I64} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I64.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y ∧ z.bv = BitVec.sdiv x.bv y.bv := + IScalar.div_bv_spec hnz (by scalar_tac) + +theorem I128.div_bv_spec {x y : I128} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I128.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y ∧ z.bv = BitVec.sdiv x.bv y.bv := + IScalar.div_bv_spec hnz (by scalar_tac) + +theorem Isize.div_bv_spec {x y : Isize} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = Isize.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y ∧ z.bv = BitVec.sdiv x.bv y.bv := + IScalar.div_bv_spec hnz (by scalar_tac) + +/-! +Theorems with a specification which only use integers +-/ + +/-- Generic theorem - shouldn't be used much -/ +theorem UScalar.div_spec {ty} (x : UScalar ty) {y : UScalar ty} + (hzero : y.val ≠ 0) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y := by + have ⟨ z, hz ⟩ := UScalar.div_bv_spec x hzero + simp [hz] + +/-- Generic theorem - shouldn't be used much -/ +theorem IScalar.div_spec {ty} {x y : IScalar ty} + (hzero : y.val ≠ 0) + (hNoOverflow : ¬ (x.val = IScalar.min ty ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y := by + have ⟨ z, hz ⟩ := IScalar.div_bv_spec hzero hNoOverflow + simp [hz] + +@[progress] theorem U8.div_spec (x : U8) {y : U8} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y := + UScalar.div_spec x hnz + +@[progress] theorem U16.div_spec (x : U16) {y : U16} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y := + UScalar.div_spec x hnz + +@[progress] theorem U32.div_spec (x : U32) {y : U32} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y := + UScalar.div_spec x hnz + +@[progress] theorem U64.div_spec (x : U64) {y : U64} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y := + UScalar.div_spec x hnz + +@[progress] theorem U128.div_spec (x : U128) {y : U128} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y := + UScalar.div_spec x hnz + +@[progress] theorem Usize.div_spec (x : Usize) {y : Usize} (hnz : ↑y ≠ (0 : Nat)) : + ∃ z, x / y = ok z ∧ (↑z : Nat) = ↑x / ↑y := + UScalar.div_spec x hnz + +@[progress] theorem I8.div_spec {x y : I8} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I8.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y := + IScalar.div_spec hnz (by scalar_tac) + +@[progress] theorem I16.div_spec {x y : I16} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I16.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y := + IScalar.div_spec hnz (by scalar_tac) + +@[progress] theorem I32.div_spec {x y : I32} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I32.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y := + IScalar.div_spec hnz (by scalar_tac) + +@[progress] theorem I64.div_spec {x y : I64} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I64.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y := + IScalar.div_spec hnz (by scalar_tac) + +@[progress] theorem I128.div_spec {x y : I128} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = I128.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y := + IScalar.div_spec hnz (by scalar_tac) + +@[progress] theorem Isize.div_spec {x y : Isize} (hnz : ↑y ≠ (0 : Int)) + (hNoOverflow : ¬ (x.val = Isize.min ∧ y.val = -1)) : + ∃ z, x / y = ok z ∧ (↑z : Int) = Int.tdiv ↑x ↑y := + IScalar.div_spec hnz (by scalar_tac) + +/-! +### Remainder +-/ + +/-! +Theorems with a specification which uses integers and bit-vectors +-/ + +/-- Generic theorem - shouldn't be used much -/ +theorem UScalar.rem_bv_spec {ty} (x : UScalar ty) {y : UScalar ty} (hzero : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y ∧ z.bv = x.bv % y.bv := by + conv => congr; ext; lhs; simp [HMod.hMod] + simp [hzero, rem, tryMk, tryMkOpt, ofOption, hmax, ofNatCore] + simp only [val] + simp + +/-- Generic theorem - shouldn't be used much -/ +theorem IScalar.rem_bv_spec {ty} (x : IScalar ty) {y : IScalar ty} (hzero : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y ∧ z.bv = BitVec.srem x.bv y.bv := by + conv => congr; ext; lhs; simp [HMod.hMod] + simp [hzero, rem, tryMk, tryMkOpt, ofOption, hmax, ofIntCore] + simp only [val] + simp + + simp [BitVec.srem_eq] + have pow2Ineq : (2^(ty.numBits - 1) : Int) < 2^ty.numBits := by + have := ty.numBits_nonzero + have : (0 : Int) < 2^(ty.numBits - 1) := by simp + have : ty.numBits = ty.numBits - 1 + 1 := by omega + conv => rhs; rw [this] + rw [Int.pow_succ'] + omega + have hxBounds := x.hBounds + have hyBounds := y.hBounds + have := ty.numBits_nonzero + split + + . -- 0 ≤ x + -- 0 ≤ y + rename_i hxMsb hyMsb + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hxMsb] at hx + have hy := @BitVec.toInt_eq_msb_cond _ y.bv + simp [hyMsb] at hy + rw [Int.tmod_eq_emod] <;> try omega + simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod] + have : ((x.bv.toNat % y.bv.toNat : Nat) : Int) < 2 ^ (ty.numBits - 1) := by + have := @Nat.mod_lt x.bv.toNat y.bv.toNat (by omega) + zify at this + omega + have : ((x.bv.toNat % y.bv.toNat : Nat) : Int).bmod (2 ^ ty.numBits) = x.bv.toNat % y.bv.toNat := by + apply bmod_pow_numBits_eq_of_lt ty _ (by omega) (by omega) + rw [this]; clear this + simp only [hx, hy] + + . -- 0 ≤ x + -- y < 0 + rename_i hxMsb hyMsb + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hxMsb] at hx + have hy := @BitVec.toInt_eq_msb_cond _ y.bv + simp [hyMsb] at hy + + have hxNeg : 0 ≤ x.val := by + have := @BitVec.msb_eq_toInt _ x.bv + simp_all + have hyNeg : y.val < 0 := by + have := @BitVec.msb_eq_toInt _ y.bv + simp_all + + have : x.val.tmod y.val = x.val.tmod (-y.val) := by simp + rw [this]; clear this + + rw [Int.tmod_eq_emod] <;> try omega + simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod] + + have : ((x.bv.toNat % (-y.bv).toNat : Nat) : Int) < 2 ^ (ty.numBits - 1) := by + have := @Nat.mod_le x.bv.toNat (-y.bv).toNat + omega + have : ((x.bv.toNat % (-y.bv).toNat : Nat) : Int).bmod (2 ^ ty.numBits) = x.bv.toNat % (-y.bv).toNat := by + apply bmod_pow_numBits_eq_of_lt ty _ (by omega) (by omega) + rw [this]; clear this + simp only [hx] + + have := IScalar.neg_imp_toNat_neg_eq_neg_toInt y (by omega) + simp only [this, val] + + . -- x < 0 + -- 0 ≤ y + rename_i hxMsb hyMsb + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hxMsb] at hx + have hy := @BitVec.toInt_eq_msb_cond _ y.bv + simp [hyMsb] at hy + + have hxNeg : x.val < 0 := by + have := @BitVec.msb_eq_toInt _ x.bv + simp_all + have hyNeg : 0 ≤ y.val := by + have := @BitVec.msb_eq_toInt _ y.bv + simp_all + + have hModEq : ((-x.bv) % y.bv).toInt = ((-x.bv).toNat % y.bv.toNat : Nat) := by + simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod] + + have : ((-x.bv).toNat % y.bv.toNat : Nat) < (2 ^ (ty.numBits - 1) : Int) := by + have := @Nat.mod_lt (-x.bv).toNat y.bv.toNat (by omega) + simp only [val] at * + -- TODO: this is annoying + have : (2 ^ (ty.numBits - 1) : Nat) = (2 ^ (ty.numBits - 1) : Int) := by simp + omega + + have := @bmod_pow_numBits_eq_of_lt ty ((-x.bv).toNat % y.bv.toNat : Nat) + (by omega) (by omega) + rw [this] + + have : 0 ≤ (-x.bv % y.bv).toInt := by omega + + have := BitVec.toInt_neg_of_pos_eq_neg (-x.bv % y.bv) (by omega) (by omega) + rw [this]; clear this + + have : (-x.bv % y.bv).toInt = (-x.bv).toNat % y.bv.toNat := by + rw [hModEq]; simp + rw [this]; clear this + + have : x.val.tmod y.val = - (-x.val).tmod y.val := by simp + rw [this]; clear this + + have hx := IScalar.neg_imp_toNat_neg_eq_neg_toInt x (by omega) + simp only [hx, ← hy] + + rw [Int.tmod_eq_emod] <;> try omega + + simp only [val] + + . -- x < 0 + -- y < 0 + + rename_i hxMsb hyMsb + have hx := @BitVec.toInt_eq_msb_cond _ x.bv + simp [hxMsb] at hx + have hy := @BitVec.toInt_eq_msb_cond _ y.bv + simp [hyMsb] at hy + + have hxNeg : x.val < 0 := by + have := @BitVec.msb_eq_toInt _ x.bv + simp_all + have hyNeg : y.val < 0 := by + have := @BitVec.msb_eq_toInt _ y.bv + simp_all + + have : (x.val).tmod (y.val) = -(-x.val).tmod (-y.val) := by simp + rw [this]; clear this + + rw [Int.tmod_eq_emod] <;> try omega + + have hx := IScalar.neg_imp_toNat_neg_eq_neg_toInt x (by omega) + have hy := IScalar.neg_imp_toNat_neg_eq_neg_toInt y (by omega) + + have : 0 ≤ -x.bv.toInt % -y.bv.toInt := by + scalar_tac +nonLin + + have : -2 ^ (ty.numBits - 1) ≤ -x.bv.toInt % -y.bv.toInt := by omega + + have hxmyToInt : (-x.bv % -y.bv).toInt = (-x.bv.toInt) % (-y.bv.toInt) := by + conv => lhs; simp only [BitVec.toInt_eq_toNat_bmod, BitVec.toNat_umod] + push_cast + simp only [hx, hy] + apply bmod_pow_numBits_eq_of_lt + . omega + . simp only [val] at * + have := @Int.emod_lt_of_pos (-x.bv.toInt) (-y.bv.toInt) (by omega) + omega + + have : 0 ≤ (-x.bv % -y.bv).toInt := by + simp only [hxmyToInt] + omega + + have := BitVec.toInt_neg_of_pos_eq_neg (-x.bv % -y.bv) (by omega) (by omega) + rw [this]; clear this + + simp only [hxmyToInt] + simp + +theorem U8.rem_bv_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y ∧ z.bv = x.bv % y.bv := + UScalar.rem_bv_spec x hnz + +theorem U16.rem_bv_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y ∧ z.bv = x.bv % y.bv := + UScalar.rem_bv_spec x hnz + +theorem U32.rem_bv_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y ∧ z.bv = x.bv % y.bv := + UScalar.rem_bv_spec x hnz + +theorem U64.rem_bv_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y ∧ z.bv = x.bv % y.bv := + UScalar.rem_bv_spec x hnz + +theorem U128.rem_bv_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y ∧ z.bv = x.bv % y.bv := + UScalar.rem_bv_spec x hnz + +theorem Usize.rem_bv_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y ∧ z.bv = x.bv % y.bv := + UScalar.rem_bv_spec x hnz + +theorem I8.rem_bv_spec (x : I8) {y : I8} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y ∧ z.bv = BitVec.srem x.bv y.bv := + IScalar.rem_bv_spec x hnz + +theorem I16.rem_bv_spec (x : I16) {y : I16} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y ∧ z.bv = BitVec.srem x.bv y.bv := + IScalar.rem_bv_spec x hnz + +theorem I32.rem_bv_spec (x : I32) {y : I32} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y ∧ z.bv = BitVec.srem x.bv y.bv := + IScalar.rem_bv_spec x hnz + +theorem I64.rem_bv_spec (x : I64) {y : I64} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y ∧ z.bv = BitVec.srem x.bv y.bv := + IScalar.rem_bv_spec x hnz + +theorem I128.rem_bv_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y ∧ z.bv = BitVec.srem x.bv y.bv := + IScalar.rem_bv_spec x hnz + +theorem Isize.rem_bv_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y ∧ z.bv = BitVec.srem x.bv y.bv := + IScalar.rem_bv_spec x hnz + +/-! +Theorems with a specification which only uses integers +-/ + +/-- Generic theorem - shouldn't be used much -/ +theorem UScalar.rem_spec {ty} (x : UScalar ty) {y : UScalar ty} (hzero : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y := by + have ⟨ z, hz ⟩ := rem_bv_spec x hzero + simp [hz] + +/-- Generic theorem - shouldn't be used much -/ +theorem IScalar.rem_spec {ty} (x : IScalar ty) {y : IScalar ty} (hzero : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y := by + have ⟨ z, hz ⟩ := rem_bv_spec x hzero + simp [hz] + +@[progress] theorem U8.rem_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y := + UScalar.rem_spec x hnz + +@[progress] theorem U16.rem_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y := + UScalar.rem_spec x hnz + +@[progress] theorem U32.rem_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y := + UScalar.rem_spec x hnz + +@[progress] theorem U64.rem_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y := + UScalar.rem_spec x hnz + +@[progress] theorem U128.rem_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y := + UScalar.rem_spec x hnz + +@[progress] theorem Usize.rem_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Nat) = ↑x % ↑y := + UScalar.rem_spec x hnz + +@[progress] theorem I8.rem_spec (x : I8) {y : I8} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y := + IScalar.rem_spec x hnz + +@[progress] theorem I16.rem_spec (x : I16) {y : I16} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y := + IScalar.rem_spec x hnz + +@[progress] theorem I32.rem_spec (x : I32) {y : I32} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y := + IScalar.rem_spec x hnz + +@[progress] theorem I64.rem_spec (x : I64) {y : I64} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y := + IScalar.rem_spec x hnz + +@[progress] theorem I128.rem_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y := + IScalar.rem_spec x hnz + +@[progress] theorem Isize.rem_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) : + ∃ z, x % y = ok z ∧ (↑z : Int) = Int.tmod ↑x ↑y := + IScalar.rem_spec x hnz + +/-! +## Bit shifts +-/ + +theorem UScalar.ShiftRight_val_eq {ty0 ty1} (x : UScalar ty0) (y : UScalar ty1) + (hy : y.val < ty0.numBits) : + ∃ z, x >>> y = ok z ∧ + z.val = x.val >>> y.val + := by + simp only [HShiftRight.hShiftRight, shiftRight_UScalar, shiftRight, hy, reduceIte] + simp only [BitVec.ushiftRight_eq, ok.injEq, _root_.exists_eq_left', val] + simp [HShiftRight.hShiftRight, BitVec.ushiftRight] + +theorem UScalar.ShiftRight_bv_eq {ty0 ty1} (x : UScalar ty0) (y : UScalar ty1) + (hy : y.val < ty0.numBits) : + ∃ z, x >>> y = ok z ∧ z.bv = x.bv >>> y.val + := by + simp only [HShiftRight.hShiftRight, shiftRight_UScalar, shiftRight, hy, reduceIte] + simp only [BitVec.ushiftRight_eq, ok.injEq, _root_.exists_eq_left', val] + +@[progress] theorem U8.ShiftRight_bv_spec (x : U8) (y : UScalar ty1) (hy : y.val < 8) : + ∃ (z : U8), x >>> y = ok z ∧ z.bv = x.bv >>> y.val + := by apply UScalar.ShiftRight_bv_eq; simp [*] + +@[progress] theorem U16.ShiftRight_bv_spec (x : U16) (y : UScalar ty1) (hy : y.val < 16) : + ∃ (z : U16), x >>> y = ok z ∧ z.bv = x.bv >>> y.val + := by apply UScalar.ShiftRight_bv_eq; simp [*] + +@[progress] theorem U32.ShiftRight_bv_spec (x : U32) (y : UScalar ty1) (hy : y.val < 32) : + ∃ (z : U32), x >>> y = ok z ∧ z.bv = x.bv >>> y.val + := by apply UScalar.ShiftRight_bv_eq; simp [*] + +@[progress] theorem U64.ShiftRight_bv_spec (x : U64) (y : UScalar ty1) (hy : y.val < 64) : + ∃ (z : U64), x >>> y = ok z ∧ z.bv = x.bv >>> y.val + := by apply UScalar.ShiftRight_bv_eq; simp [*] + +@[progress] theorem U128.ShiftRight_bv_spec (x : U128) (y : UScalar ty1) (hy : y.val < 128) : + ∃ (z : U128), x >>> y = ok z ∧ z.bv = x.bv >>> y.val + := by apply UScalar.ShiftRight_bv_eq; simp [*] + +@[progress] theorem Usize.ShiftRight_bv_spec (x : Usize) (y : UScalar ty1) (hy : y.val < UScalarTy.Usize.numBits) : + ∃ (z : Usize), x >>> y = ok z ∧ z.bv = x.bv >>> y.val + := by apply UScalar.ShiftRight_bv_eq; simp only [*] + +theorem UScalar.ShiftLeft_val_eq {ty0 ty1} (x : UScalar ty0) (y : UScalar ty1) + (hy : y.val < ty0.numBits) : + ∃ z, x <<< y = ok z ∧ + z.val = (x.val <<< y.val) % 2^ty0.numBits + := by + simp only [HShiftLeft.hShiftLeft, shiftLeft_UScalar, shiftLeft, hy, reduceIte] + simp only [BitVec.shiftLeft_eq, ok.injEq, _root_.exists_eq_left', val] + simp [ShiftLeft.shiftLeft] + +theorem UScalar.ShiftLeft_bv_eq {ty0 ty1} (x : UScalar ty0) (y : UScalar ty1) + (hy : y.val < ty0.numBits) : + ∃ z, x <<< y = ok z ∧ z.bv = x.bv <<< y.val + := by + simp only [HShiftLeft.hShiftLeft, shiftLeft_UScalar, shiftLeft, hy, reduceIte] + simp only [BitVec.shiftLeft_eq, ok.injEq, _root_.exists_eq_left', val] + +@[progress] theorem U8.ShiftLeft_bv_spec (x : U8) (y : UScalar ty1) (hy : y.val < 8) : + ∃ (z : U8), x <<< y = ok z ∧ z.bv = x.bv <<< y.val + := by apply UScalar.ShiftLeft_bv_eq; simp [*] + +@[progress] theorem U16.ShiftLeft_bv_spec (x : U16) (y : UScalar ty1) (hy : y.val < 16) : + ∃ (z : U16), x <<< y = ok z ∧ z.bv = x.bv <<< y.val + := by apply UScalar.ShiftLeft_bv_eq; simp [*] + +@[progress] theorem U32.ShiftLeft_bv_spec (x : U32) (y : UScalar ty1) (hy : y.val < 32) : + ∃ (z : U32), x <<< y = ok z ∧ z.bv = x.bv <<< y.val + := by apply UScalar.ShiftLeft_bv_eq; simp [*] + +@[progress] theorem U64.ShiftLeft_bv_spec (x : U64) (y : UScalar ty1) (hy : y.val < 64) : + ∃ (z : U64), x <<< y = ok z ∧ z.bv = x.bv <<< y.val + := by apply UScalar.ShiftLeft_bv_eq; simp [*] + +@[progress] theorem U128.ShiftLeft_bv_spec (x : U128) (y : UScalar ty1) (hy : y.val < 128) : + ∃ (z : U128), x <<< y = ok z ∧ z.bv = x.bv <<< y.val + := by apply UScalar.ShiftLeft_bv_eq; simp [*] + +@[progress] theorem Usize.ShiftLeft_bv_spec (x : Usize) (y : UScalar ty1) (hy : y.val < UScalarTy.Usize.numBits) : + ∃ (z : Usize), x <<< y = ok z ∧ z.bv = x.bv <<< y.val + := by apply UScalar.ShiftLeft_bv_eq; simp only [*] + +/-! +## Casts +-/ + +@[simp, progress_pure cast_fromBool ty b] +theorem UScalar.cast_fromBool_val_eq ty (b : Bool) : (UScalar.cast_fromBool ty b).val = b.toNat := by + simp [cast_fromBool] + split <;> simp only [val, *] <;> simp + have := ty.numBits_nonzero + omega + +@[simp, progress_pure cast_fromBool ty b] +theorem IScalar.cast_fromBool_val_eq ty (b : Bool) :(IScalar.cast_fromBool ty b).val = b.toInt := by + simp [cast_fromBool] + split <;> simp only [val, *] <;> simp + dcases ty <;> simp + have := System.Platform.numBits_eq + cases this <;> + rename_i h <;> + rw [h] <;> simp + +@[scalar_tac UScalar.cast_fromBool ty b] +theorem UScalar.cast_fromBool_bound_eq ty (b : Bool) : (UScalar.cast_fromBool ty b).val ≤ 1 := by + simp [cast_fromBool] + split <;> simp only [val] <;> simp + have := @Nat.mod_eq_of_lt 1 (2^ty.numBits) (by simp [ty.numBits_nonzero]) + rw [this] -@[pspec] theorem Usize.rem_spec (x : Usize) {y : Usize} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x % y = ok z ∧ (↑z : Int) = ↑x % ↑y := by - apply Scalar.rem_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U8.rem_spec (x : U8) {y : U8} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x % y = ok z ∧ (↑z : Int) = ↑x % ↑y := by - apply Scalar.rem_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U16.rem_spec (x : U16) {y : U16} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x % y = ok z ∧ (↑z : Int) = ↑x % ↑y := by - apply Scalar.rem_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U32.rem_spec (x : U32) {y : U32} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x % y = ok z ∧ (↑z : Int) = ↑x % ↑y := by - apply Scalar.rem_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U64.rem_spec (x : U64) {y : U64} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x % y = ok z ∧ (↑z : Int) = ↑x % ↑y := by - apply Scalar.rem_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem U128.rem_spec (x : U128) {y : U128} (hnz : ↑y ≠ (0 : Int)) : - ∃ z, x % y = ok z ∧ (↑z : Int) = ↑x % ↑y := by - apply Scalar.rem_unsigned_spec <;> simp [ScalarTy.isSigned, *] - -@[pspec] theorem I8.rem_spec (x : I8) {y : I8} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I8.min ≤ scalar_rem ↑x ↑y) - (hmax : scalar_rem ↑x ↑y ≤ I8.max): - ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := - Scalar.rem_spec hnz hmin hmax - -@[pspec] theorem I16.rem_spec (x : I16) {y : I16} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I16.min ≤ scalar_rem ↑x ↑y) - (hmax : scalar_rem ↑x ↑y ≤ I16.max): - ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := - Scalar.rem_spec hnz hmin hmax - -@[pspec] theorem I32.rem_spec (x : I32) {y : I32} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I32.min ≤ scalar_rem ↑x ↑y) - (hmax : scalar_rem ↑x ↑y ≤ I32.max): - ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := - Scalar.rem_spec hnz hmin hmax - -@[pspec] theorem I64.rem_spec (x : I64) {y : I64} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I64.min ≤ scalar_rem ↑x ↑y) - (hmax : scalar_rem ↑x ↑y ≤ I64.max): - ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := - Scalar.rem_spec hnz hmin hmax - -@[pspec] theorem I128.rem_spec (x : I128) {y : I128} - (hnz : ↑y ≠ (0 : Int)) - (hmin : I128.min ≤ scalar_rem ↑x ↑y) - (hmax : scalar_rem ↑x ↑y ≤ I128.max): - ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := - Scalar.rem_spec hnz hmin hmax - -theorem core.num.checked_rem_spec {ty} {x y : Scalar ty} : - match core.num.checked_rem x y with - | some z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y) ∧ ↑z = (scalar_rem ↑x ↑y : Int) - | none => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y)) := by - have h := Scalar.tryMk_eq ty (scalar_rem ↑x ↑y) - simp only [checked_rem, Option.ofResult] - cases heq0: (y.val = 0 : Bool) <;> - cases heq1: x % y <;> simp_all <;> simp only [HMod.hMod, Scalar.rem, Mod.mod] at heq1 - <;> simp_all - --- Leading zeros -def core.num.Usize.leading_zeros (x : Usize) : U32 := sorry -def core.num.U8.leading_zeros (x : U8) : U32 := sorry -def core.num.U16.leading_zeros (x : U16) : U32 := sorry -def core.num.U32.leading_zeros (x : U32) : U32 := sorry -def core.num.U64.leading_zeros (x : U64) : U32 := sorry -def core.num.U128.leading_zeros (x : U128) : U32 := sorry - -def core.num.Isize.leading_zeros (x : Isize) : U32 := sorry -def core.num.I8.leading_zeros (x : I8) : U32 := sorry -def core.num.I16.leading_zeros (x : I16) : U32 := sorry -def core.num.I32.leading_zeros (x : I32) : U32 := sorry -def core.num.I64.leading_zeros (x : I64) : U32 := sorry -def core.num.I128.leading_zeros (x : I128) : U32 := sorry +@[simp] +theorem UScalar.cast_fromBool_bv_eq ty (b : Bool) : (UScalar.cast_fromBool ty b).bv = (BitVec.ofBool b).zeroExtend _ := by + simp [cast_fromBool, BitVec.setWidth_eq] + dcases b <;> simp + apply @BitVec.toNat_injective ty.numBits + simp + +@[simp] +theorem IScalar.cast_fromBool_bv_eq ty (b : Bool) :(IScalar.cast_fromBool ty b).bv = (BitVec.ofBool b).zeroExtend _ := by + simp [cast_fromBool, BitVec.setWidth_eq] + dcases b <;> simp + apply @BitVec.toNat_injective ty.numBits + simp + +@[scalar_tac IScalar.cast_fromBool ty b] +theorem IScalar.cast_fromBool_bound_eq ty (b : Bool) : + 0 ≤ (IScalar.cast_fromBool ty b).val ∧ (IScalar.cast_fromBool ty b).val ≤ 1 := by + simp [cast_fromBool] + split <;> simp only [val] + . have : (1#ty.numBits).toInt = 1 := by + simp [BitVec.toInt] + dcases ty <;> simp + dcases System.Platform.numBits_eq <;> simp [*] + simp [this] + . simp + +theorem UScalar.cast_val_eq {src_ty : UScalarTy} (tgt_ty : UScalarTy) (x : UScalar src_ty) : + (cast tgt_ty x).val = x.val % 2^(tgt_ty.numBits) := by + simp only [cast, val] + simp + +@[simp, scalar_tac UScalar.cast .U16 x] +theorem U8.cast_U16_val_eq (x : U8) : (UScalar.cast .U16 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .U32 x] +theorem U8.cast_U32_val_eq (x : U8) : (UScalar.cast .U32 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .U64 x] +theorem U8.cast_U64_val_eq (x : U8) : (UScalar.cast .U64 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .U128 x] +theorem U8.cast_U128_val_eq (x : U8) : (UScalar.cast .U128 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .Usize x] +theorem U8.cast_Usize_val_eq (x : U8) : (UScalar.cast .Usize x).val = x.val := by + simp [UScalar.cast_val_eq]; dcases System.Platform.numBits_eq <;> simp [*] <;> scalar_tac + +@[simp, scalar_tac UScalar.cast .U32 x] +theorem U16.cast_U32_val_eq (x : U16) : (UScalar.cast .U32 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .U64 x] +theorem U16.cast_U64_val_eq (x : U16) : (UScalar.cast .U64 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .U128 x] +theorem U16.cast_U128_val_eq (x : U16) : (UScalar.cast .U128 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .Usize x] +theorem U16.cast_Usize_val_eq (x : U16) : (UScalar.cast .Usize x).val = x.val := by + simp [UScalar.cast_val_eq]; dcases System.Platform.numBits_eq <;> simp [*] <;> scalar_tac + +@[simp, scalar_tac UScalar.cast .U64 x] +theorem U32.cast_U64_val_eq (x : U32) : (UScalar.cast .U64 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .U128 x] +theorem U32.cast_U128_val_eq (x : U32) : (UScalar.cast .U128 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp, scalar_tac UScalar.cast .Usize x] +theorem U32.cast_Usize_val_eq (x : U32) : (UScalar.cast .Usize x).val = x.val := by + simp [UScalar.cast_val_eq]; dcases System.Platform.numBits_eq <;> simp [*]; scalar_tac + +@[simp, scalar_tac UScalar.cast .U128 x] +theorem U64.cast_U128_val_eq (x : U64) : (UScalar.cast .U128 x).val = x.val := by + simp [UScalar.cast_val_eq]; scalar_tac + +@[simp] +theorem UScalar.cast_val_mod_pow_greater_numBits_eq {src_ty : UScalarTy} (tgt_ty : UScalarTy) (x : UScalar src_ty) (h : src_ty.numBits ≤ tgt_ty.numBits) : + (cast tgt_ty x).val = x.val := by + simp [UScalar.cast_val_eq] + have hBounds := x.hBounds + apply Nat.mod_eq_of_lt + have : 0 < 2^src_ty.numBits := by simp + have := @Nat.pow_le_pow_of_le_right 2 (by simp) src_ty.numBits tgt_ty.numBits (by omega) + omega + +@[simp] +theorem UScalar.cast_val_mod_pow_of_inBounds_eq {src_ty : UScalarTy} (tgt_ty : UScalarTy) (x : UScalar src_ty) (h : x.val < 2^tgt_ty.numBits) : + (cast tgt_ty x).val = x.val := by + simp [UScalar.cast_val_eq] + apply Nat.mod_eq_of_lt + assumption + +@[simp] +theorem UScalar.cast_bv_eq {src_ty : UScalarTy} (tgt_ty : UScalarTy) (x : UScalar src_ty) : + (cast tgt_ty x).bv = x.bv.setWidth tgt_ty.numBits := by + simp [UScalar.cast] + +example (x : U16) : (x.cast .U32).val = x.val := by simp +example : ((U32.ofNat 42).cast .U16).val = 42 := by simp + +theorem IScalar.cast_val_eq {src_ty : IScalarTy} (tgt_ty : IScalarTy) (x : IScalar src_ty) : + (cast tgt_ty x).val = Int.bmod x.val (2^(Min.min tgt_ty.numBits src_ty.numBits)) := by + simp only [cast, val] + simp only [BitVec.toInt_signExtend, val] + rw [BitVec.toInt_eq_toNat_bmod] + rw [Int.bmod_bmod_of_dvd] + apply Nat.pow_dvd_pow + simp + +@[simp] +theorem IScalar.val_mod_pow_greater_numBits {src_ty : IScalarTy} (tgt_ty : IScalarTy) (x : IScalar src_ty) (h : src_ty.numBits ≤ tgt_ty.numBits) : + (cast tgt_ty x).val = x.val := by + simp [IScalar.cast_val_eq] + have hBounds := x.hBounds + simp [h] + have := src_ty.numBits_nonzero + have : src_ty.numBits = src_ty.numBits - 1 + 1 := by omega + rw [this] + apply Int.bmod_pow2_eq_of_inBounds <;> omega + +@[simp] +theorem IScalar.val_mod_pow_inBounds {src_ty : IScalarTy} (tgt_ty : IScalarTy) (x : IScalar src_ty) + (hMin : -2^(tgt_ty.numBits - 1) ≤ x.val) (hMax : x.val < 2^(tgt_ty.numBits - 1)) : + (cast tgt_ty x).val = x.val := by + simp [IScalar.cast_val_eq] + have hBounds := x.hBounds + have := src_ty.numBits_nonzero + have := tgt_ty.numBits_nonzero + have : tgt_ty.numBits ⊓ src_ty.numBits = tgt_ty.numBits ⊓ src_ty.numBits - 1 + 1 := by omega + rw [this] + have : -2 ^ (tgt_ty.numBits ⊓ src_ty.numBits - 1) ≤ x.val ∧ + x.val < 2 ^ (tgt_ty.numBits ⊓ src_ty.numBits - 1) := by + have : tgt_ty.numBits ⊓ src_ty.numBits = tgt_ty.numBits ∨ tgt_ty.numBits ⊓ src_ty.numBits = src_ty.numBits := by + rw [Nat.min_def] + split <;> simp + cases this <;> rename_i hEq <;> simp [hEq] <;> omega + apply Int.bmod_pow2_eq_of_inBounds <;> omega + + +/-! +# Checked Operations +## Checked Operations: Definitions +-/ + +/-! +### Checked Addition +-/ + +/- [core::num::{T}::checked_add] -/ +def core.num.checked_add_UScalar {ty} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (x + y) + +def U8.checked_add (x y : U8) : Option U8 := core.num.checked_add_UScalar x y +def U16.checked_add (x y : U16) : Option U16 := core.num.checked_add_UScalar x y +def U32.checked_add (x y : U32) : Option U32 := core.num.checked_add_UScalar x y +def U64.checked_add (x y : U64) : Option U64 := core.num.checked_add_UScalar x y +def U128.checked_add (x y : U128) : Option U128 := core.num.checked_add_UScalar x y +def Usize.checked_add (x y : Usize) : Option Usize := core.num.checked_add_UScalar x y + +/- [core::num::{T}::checked_add] -/ +def core.num.checked_add_IScalar {ty} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (x + y) + +def I8.checked_add (x y : I8) : Option I8 := core.num.checked_add_IScalar x y +def I16.checked_add (x y : I16) : Option I16 := core.num.checked_add_IScalar x y +def I32.checked_add (x y : I32) : Option I32 := core.num.checked_add_IScalar x y +def I64.checked_add (x y : I64) : Option I64 := core.num.checked_add_IScalar x y +def I128.checked_add (x y : I128) : Option I128 := core.num.checked_add_IScalar x y +def Isize.checked_add (x y : Isize) : Option Isize := core.num.checked_add_IScalar x y + +/-! +### Checked Subtraction +-/ + +/- [core::num::{T}::checked_sub] -/ +def core.num.checked_sub_UScalar {ty} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (x - y) + +def U8.checked_sub (x y : U8) : Option U8 := core.num.checked_sub_UScalar x y +def U16.checked_sub (x y : U16) : Option U16 := core.num.checked_sub_UScalar x y +def U32.checked_sub (x y : U32) : Option U32 := core.num.checked_sub_UScalar x y +def U64.checked_sub (x y : U64) : Option U64 := core.num.checked_sub_UScalar x y +def U128.checked_sub (x y : U128) : Option U128 := core.num.checked_sub_UScalar x y +def Usize.checked_sub (x y : Usize) : Option Usize := core.num.checked_sub_UScalar x y + +/- [core::num::{T}::checked_sub] -/ +def core.num.checked_sub_IScalar {ty} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (x - y) + +def I8.checked_sub (x y : I8) : Option I8 := core.num.checked_sub_IScalar x y +def I16.checked_sub (x y : I16) : Option I16 := core.num.checked_sub_IScalar x y +def I32.checked_sub (x y : I32) : Option I32 := core.num.checked_sub_IScalar x y +def I64.checked_sub (x y : I64) : Option I64 := core.num.checked_sub_IScalar x y +def I128.checked_sub (x y : I128) : Option I128 := core.num.checked_sub_IScalar x y +def Isize.checked_sub (x y : Isize) : Option Isize := core.num.checked_sub_IScalar x y + +/-! +### Checked Multiplication +-/ + +/- [core::num::{T}::checked_mul] -/ +def core.num.checked_mul_UScalar {ty} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (UScalar.mul x y) + +def U8.checked_mul (x y : U8) : Option U8 := core.num.checked_mul_UScalar x y +def U16.checked_mul (x y : U16) : Option U16 := core.num.checked_mul_UScalar x y +def U32.checked_mul (x y : U32) : Option U32 := core.num.checked_mul_UScalar x y +def U64.checked_mul (x y : U64) : Option U64 := core.num.checked_mul_UScalar x y +def U128.checked_mul (x y : U128) : Option U128 := core.num.checked_mul_UScalar x y +def Usize.checked_mul (x y : Usize) : Option Usize := core.num.checked_mul_UScalar x y + +/- [core::num::{T}::checked_mul] -/ +def core.num.checked_mul_IScalar {ty} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (IScalar.mul x y) + +def I8.checked_mul (x y : I8) : Option I8 := core.num.checked_mul_IScalar x y +def I16.checked_mul (x y : I16) : Option I16 := core.num.checked_mul_IScalar x y +def I32.checked_mul (x y : I32) : Option I32 := core.num.checked_mul_IScalar x y +def I64.checked_mul (x y : I64) : Option I64 := core.num.checked_mul_IScalar x y +def I128.checked_mul (x y : I128) : Option I128 := core.num.checked_mul_IScalar x y +def Isize.checked_mul (x y : Isize) : Option Isize := core.num.checked_mul_IScalar x y + +/-! +### Checked Division +-/ + +/- [core::num::{T}::checked_div] -/ +def core.num.checked_div_UScalar {ty} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (UScalar.div x y) + +def U8.checked_div (x y : U8) : Option U8 := core.num.checked_div_UScalar x y +def U16.checked_div (x y : U16) : Option U16 := core.num.checked_div_UScalar x y +def U32.checked_div (x y : U32) : Option U32 := core.num.checked_div_UScalar x y +def U64.checked_div (x y : U64) : Option U64 := core.num.checked_div_UScalar x y +def U128.checked_div (x y : U128) : Option U128 := core.num.checked_div_UScalar x y +def Usize.checked_div (x y : Usize) : Option Usize := core.num.checked_div_UScalar x y + +/- [core::num::{T}::checked_div] -/ +def core.num.checked_div_IScalar {ty} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (IScalar.div x y) + +def I8.checked_div (x y : I8) : Option I8 := core.num.checked_div_IScalar x y +def I16.checked_div (x y : I16) : Option I16 := core.num.checked_div_IScalar x y +def I32.checked_div (x y : I32) : Option I32 := core.num.checked_div_IScalar x y +def I64.checked_div (x y : I64) : Option I64 := core.num.checked_div_IScalar x y +def I128.checked_div (x y : I128) : Option I128 := core.num.checked_div_IScalar x y +def Isize.checked_div (x y : Isize) : Option Isize := core.num.checked_div_IScalar x y + +/-! +### Checked Remainder +-/ + +/- [core::num::{T}::checked_rem] -/ +def core.num.checked_rem_UScalar {ty} (x y : UScalar ty) : Option (UScalar ty) := + Option.ofResult (UScalar.rem x y) + +def U8.checked_rem (x y : U8) : Option U8 := core.num.checked_rem_UScalar x y +def U16.checked_rem (x y : U16) : Option U16 := core.num.checked_rem_UScalar x y +def U32.checked_rem (x y : U32) : Option U32 := core.num.checked_rem_UScalar x y +def U64.checked_rem (x y : U64) : Option U64 := core.num.checked_rem_UScalar x y +def U128.checked_rem (x y : U128) : Option U128 := core.num.checked_rem_UScalar x y +def Usize.checked_rem (x y : Usize) : Option Usize := core.num.checked_rem_UScalar x y + +/- [core::num::{T}::checked_rem] -/ +def core.num.checked_rem_IScalar {ty} (x y : IScalar ty) : Option (IScalar ty) := + Option.ofResult (IScalar.rem x y) + +def I8.checked_rem (x y : I8) : Option I8 := core.num.checked_rem_IScalar x y +def I16.checked_rem (x y : I16) : Option I16 := core.num.checked_rem_IScalar x y +def I32.checked_rem (x y : I32) : Option I32 := core.num.checked_rem_IScalar x y +def I64.checked_rem (x y : I64) : Option I64 := core.num.checked_rem_IScalar x y +def I128.checked_rem (x y : I128) : Option I128 := core.num.checked_rem_IScalar x y +def Isize.checked_rem (x y : Isize) : Option Isize := core.num.checked_rem_IScalar x y + +/-! +## Checked Operations: Theorems +-/ + +/-! +### Checked Add +-/ + +/-! +Unsigned checked add +-/ +theorem core.num.checked_add_UScalar_bv_spec {ty} (x y : UScalar ty) : + match core.num.checked_add_UScalar x y with + | some z => x.val + y.val ≤ UScalar.max ty ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => UScalar.max ty < x.val + y.val := by + have h := UScalar.add_equiv x y + have hAdd : x + y = UScalar.add x y := by rfl + rw [hAdd] at h + dcases hEq : UScalar.add x y <;> simp_all [Option.ofResult, checked_add_UScalar, UScalar.max] <;> + (have : 0 < 2^ty.numBits := by simp) <;> + omega + +@[progress_pure checked_add x y] +theorem U8.checked_add_bv_spec (x y : U8) : + match U8.checked_add x y with + | some z => x.val + y.val ≤ U8.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => U8.max < x.val + y.val := by + have := core.num.checked_add_UScalar_bv_spec x y + simp_all [U8.checked_add, UScalar.max, U8.bv] + cases h: core.num.checked_add_UScalar x y <;> simp_all [max, numBits] + +@[progress_pure checked_add x y] +theorem U16.checked_add_bv_spec (x y : U16) : + match U16.checked_add x y with + | some z => x.val + y.val ≤ U16.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => U16.max < x.val + y.val := by + have := core.num.checked_add_UScalar_bv_spec x y + simp_all [U16.checked_add, UScalar.max, U16.bv] + cases h: core.num.checked_add_UScalar x y <;> simp_all [max, numBits] + +@[progress_pure checked_add x y] +theorem U32.checked_add_bv_spec (x y : U32) : + match U32.checked_add x y with + | some z => x.val + y.val ≤ U32.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => U32.max < x.val + y.val := by + have := core.num.checked_add_UScalar_bv_spec x y + simp_all [U32.checked_add, UScalar.max, U32.bv] + cases h: core.num.checked_add_UScalar x y <;> simp_all [max, numBits] + +@[progress_pure checked_add x y] +theorem U64.checked_add_bv_spec (x y : U64) : + match U64.checked_add x y with + | some z => x.val + y.val ≤ U64.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => U64.max < x.val + y.val := by + have := core.num.checked_add_UScalar_bv_spec x y + simp_all [U64.checked_add, UScalar.max, U64.bv] + cases h: core.num.checked_add_UScalar x y <;> simp_all [max, numBits] + +@[progress_pure checked_add x y] +theorem U128.checked_add_bv_spec (x y : U128) : + match U128.checked_add x y with + | some z => x.val + y.val ≤ U128.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => U128.max < x.val + y.val := by + have := core.num.checked_add_UScalar_bv_spec x y + simp_all [U128.checked_add, UScalar.max, U128.bv] + cases h: core.num.checked_add_UScalar x y <;> simp_all [max, numBits] + +@[progress_pure checked_add x y] +theorem Usize.checked_add_bv_spec (x y : Usize) : + match Usize.checked_add x y with + | some z => x.val + y.val ≤ Usize.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => Usize.max < x.val + y.val := by + have := core.num.checked_add_UScalar_bv_spec x y + simp_all [Usize.checked_add, UScalar.max, Usize.bv] + cases h: core.num.checked_add_UScalar x y <;> simp_all [max, numBits] + +/-! +Signed checked add +-/ +theorem core.num.checked_add_IScalar_bv_spec {ty} (x y : IScalar ty) : + match core.num.checked_add_IScalar x y with + | some z => IScalar.min ty ≤ x.val + y.val ∧ x.val + y.val ≤ IScalar.max ty ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => ¬ (IScalar.min ty ≤ x.val + y.val ∧ x.val + y.val ≤ IScalar.max ty) := by + have h := IScalar.add_equiv x y + have hAdd : x + y = IScalar.add x y := by rfl + rw [hAdd] at h + dcases hEq : IScalar.add x y <;> simp_all [Option.ofResult, checked_add_IScalar, IScalar.min, IScalar.max] <;> + omega + +@[progress_pure checked_add x y] +theorem I8.checked_add_bv_spec (x y : I8) : + match core.num.checked_add_IScalar x y with + | some z => I8.min ≤ x.val + y.val ∧ x.val + y.val ≤ I8.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => ¬ (I8.min ≤ x.val + y.val ∧ x.val + y.val ≤ I8.max) := by + have := core.num.checked_add_IScalar_bv_spec x y + simp_all only [I8.checked_add, IScalar.min, IScalar.max, I8.bv, min, max, numBits] + cases h: core.num.checked_add_IScalar x y <;> simp_all only [numBits] <;> simp + +@[progress_pure checked_add x y] +theorem I16.checked_add_bv_spec (x y : I16) : + match core.num.checked_add_IScalar x y with + | some z => I16.min ≤ x.val + y.val ∧ x.val + y.val ≤ I16.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => ¬ (I16.min ≤ x.val + y.val ∧ x.val + y.val ≤ I16.max) := by + have := core.num.checked_add_IScalar_bv_spec x y + simp_all only [I16.checked_add, IScalar.min, IScalar.max, I16.bv, min, max, numBits] + cases h: core.num.checked_add_IScalar x y <;> simp_all only [numBits] <;> simp + +@[progress_pure checked_add x y] +theorem I32.checked_add_bv_spec (x y : I32) : + match core.num.checked_add_IScalar x y with + | some z => I32.min ≤ x.val + y.val ∧ x.val + y.val ≤ I32.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => ¬ (I32.min ≤ x.val + y.val ∧ x.val + y.val ≤ I32.max) := by + have := core.num.checked_add_IScalar_bv_spec x y + simp_all only [I32.checked_add, IScalar.min, IScalar.max, I32.bv, min, max, numBits] + cases h: core.num.checked_add_IScalar x y <;> simp_all only [numBits] <;> simp + +@[progress_pure checked_add x y] +theorem I64.checked_add_bv_spec (x y : I64) : + match core.num.checked_add_IScalar x y with + | some z => I64.min ≤ x.val + y.val ∧ x.val + y.val ≤ I64.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => ¬ (I64.min ≤ x.val + y.val ∧ x.val + y.val ≤ I64.max) := by + have := core.num.checked_add_IScalar_bv_spec x y + simp_all only [I64.checked_add, IScalar.min, IScalar.max, I64.bv, min, max, numBits] + cases h: core.num.checked_add_IScalar x y <;> simp_all only [numBits] <;> simp + +@[progress_pure checked_add x y] +theorem I128.checked_add_bv_spec (x y : I128) : + match core.num.checked_add_IScalar x y with + | some z => I128.min ≤ x.val + y.val ∧ x.val + y.val ≤ I128.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => ¬ (I128.min ≤ x.val + y.val ∧ x.val + y.val ≤ I128.max) := by + have := core.num.checked_add_IScalar_bv_spec x y + simp_all only [I128.checked_add, IScalar.min, IScalar.max, I128.bv, min, max, numBits] + cases h: core.num.checked_add_IScalar x y <;> simp_all only [numBits] <;> simp + +@[progress_pure checked_add x y] +theorem Isize.checked_add_bv_spec (x y : Isize) : + match core.num.checked_add_IScalar x y with + | some z => Isize.min ≤ x.val + y.val ∧ x.val + y.val ≤ Isize.max ∧ z.val = x.val + y.val ∧ z.bv = x.bv + y.bv + | none => ¬ (Isize.min ≤ x.val + y.val ∧ x.val + y.val ≤ Isize.max) := by + have := core.num.checked_add_IScalar_bv_spec x y + simp_all only [Isize.checked_add, IScalar.min, IScalar.max, Isize.bv, min, max, numBits] + cases h: core.num.checked_add_IScalar x y <;> simp_all only [numBits] <;> simp + +/-! +### Checked Sub +-/ + +/-! +Unsigned checked sub +-/ +theorem core.num.checked_sub_UScalar_bv_spec {ty} (x y : UScalar ty) : + match core.num.checked_sub_UScalar x y with + | some z => y.val ≤ x.val ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => x.val < y.val := by + have h := UScalar.sub_equiv x y + have hsub : x - y = UScalar.sub x y := by rfl + rw [hsub] at h + dcases hEq : UScalar.sub x y <;> simp_all [Option.ofResult, checked_sub_UScalar] + +@[progress_pure checked_sub x y] +theorem U8.checked_sub_bv_spec (x y : U8) : + match U8.checked_sub x y with + | some z => y.val ≤ x.val ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => x.val < y.val := by + have := core.num.checked_sub_UScalar_bv_spec x y + simp_all [U8.checked_sub, UScalar.max, U8.bv] + cases h: core.num.checked_sub_UScalar x y <;> simp_all + +@[progress_pure checked_sub x y] +theorem U16.checked_sub_bv_spec (x y : U16) : + match U16.checked_sub x y with + | some z => y.val ≤ x.val ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => x.val < y.val := by + have := core.num.checked_sub_UScalar_bv_spec x y + simp_all [U16.checked_sub, UScalar.max, U16.bv] + cases h: core.num.checked_sub_UScalar x y <;> simp_all + +@[progress_pure checked_sub x y] +theorem U32.checked_sub_bv_spec (x y : U32) : + match U32.checked_sub x y with + | some z => y.val ≤ x.val ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => x.val < y.val := by + have := core.num.checked_sub_UScalar_bv_spec x y + simp_all [U32.checked_sub, UScalar.max, U32.bv] + cases h: core.num.checked_sub_UScalar x y <;> simp_all + +@[progress_pure checked_sub x y] +theorem U64.checked_sub_bv_spec (x y : U64) : + match U64.checked_sub x y with + | some z => y.val ≤ x.val ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => x.val < y.val := by + have := core.num.checked_sub_UScalar_bv_spec x y + simp_all [U64.checked_sub, UScalar.max, U64.bv] + cases h: core.num.checked_sub_UScalar x y <;> simp_all + +@[progress_pure checked_sub x y] +theorem U128.checked_sub_bv_spec (x y : U128) : + match U128.checked_sub x y with + | some z => y.val ≤ x.val ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => x.val < y.val := by + have := core.num.checked_sub_UScalar_bv_spec x y + simp_all [U128.checked_sub, UScalar.max, U128.bv] + cases h: core.num.checked_sub_UScalar x y <;> simp_all + +theorem Usize.checked_sub_bv_spec (x y : Usize) : + match Usize.checked_sub x y with + | some z => y.val ≤ x.val ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => x.val < y.val := by + have := core.num.checked_sub_UScalar_bv_spec x y + simp_all [Usize.checked_sub, UScalar.max, Usize.bv] + cases h: core.num.checked_sub_UScalar x y <;> simp_all + +/-! +Signed checked sub +-/ +theorem core.num.checked_sub_IScalar_bv_spec {ty} (x y : IScalar ty) : + match core.num.checked_sub_IScalar x y with + | some z => IScalar.min ty ≤ x.val - y.val ∧ x.val - y.val ≤ IScalar.max ty ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => ¬ (IScalar.min ty ≤ x.val - y.val ∧ x.val - y.val ≤ IScalar.max ty) := by + have h := IScalar.sub_equiv x y + have hsub : x - y = IScalar.sub x y := by rfl + rw [hsub] at h + dcases hEq : IScalar.sub x y <;> simp_all [Option.ofResult, checked_sub_IScalar, IScalar.min, IScalar.max] <;> + (have : 0 < 2^ty.numBits := by simp) <;> + omega + +@[progress_pure checked_sub x y] +theorem I8.checked_sub_bv_spec (x y : I8) : + match core.num.checked_sub_IScalar x y with + | some z => I8.min ≤ x.val - y.val ∧ x.val - y.val ≤ I8.max ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => ¬ (I8.min ≤ x.val - y.val ∧ x.val - y.val ≤ I8.max) := by + have := core.num.checked_sub_IScalar_bv_spec x y + simp_all only [I8.checked_sub, IScalar.min, IScalar.max, I8.bv, min, max, numBits] + cases h: core.num.checked_sub_IScalar x y <;> simp_all only <;> simp + +@[progress_pure checked_sub x y] +theorem I16.checked_sub_bv_spec (x y : I16) : + match core.num.checked_sub_IScalar x y with + | some z => I16.min ≤ x.val - y.val ∧ x.val - y.val ≤ I16.max ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => ¬ (I16.min ≤ x.val - y.val ∧ x.val - y.val ≤ I16.max) := by + have := core.num.checked_sub_IScalar_bv_spec x y + simp_all only [I16.checked_sub, IScalar.min, IScalar.max, I16.bv, min, max, numBits] + cases h: core.num.checked_sub_IScalar x y <;> simp_all only <;> simp + +@[progress_pure checked_sub x y] +theorem I32.checked_sub_bv_spec (x y : I32) : + match core.num.checked_sub_IScalar x y with + | some z => I32.min ≤ x.val - y.val ∧ x.val - y.val ≤ I32.max ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => ¬ (I32.min ≤ x.val - y.val ∧ x.val - y.val ≤ I32.max) := by + have := core.num.checked_sub_IScalar_bv_spec x y + simp_all only [I32.checked_sub, IScalar.min, IScalar.max, I32.bv, min, max, numBits] + cases h: core.num.checked_sub_IScalar x y <;> simp_all only <;> simp + +@[progress_pure checked_sub x y] +theorem I64.checked_sub_bv_spec (x y : I64) : + match core.num.checked_sub_IScalar x y with + | some z => I64.min ≤ x.val - y.val ∧ x.val - y.val ≤ I64.max ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => ¬ (I64.min ≤ x.val - y.val ∧ x.val - y.val ≤ I64.max) := by + have := core.num.checked_sub_IScalar_bv_spec x y + simp_all only [I64.checked_sub, IScalar.min, IScalar.max, I64.bv, min, max, numBits] + cases h: core.num.checked_sub_IScalar x y <;> simp_all only <;> simp + +@[progress_pure checked_sub x y] +theorem I128.checked_sub_bv_spec (x y : I128) : + match core.num.checked_sub_IScalar x y with + | some z => I128.min ≤ x.val - y.val ∧ x.val - y.val ≤ I128.max ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => ¬ (I128.min ≤ x.val - y.val ∧ x.val - y.val ≤ I128.max) := by + have := core.num.checked_sub_IScalar_bv_spec x y + simp_all only [I128.checked_sub, IScalar.min, IScalar.max, I128.bv, min, max, numBits] + cases h: core.num.checked_sub_IScalar x y <;> simp_all only <;> simp + +@[progress_pure checked_sub x y] +theorem Isize.checked_sub_bv_spec (x y : Isize) : + match core.num.checked_sub_IScalar x y with + | some z => Isize.min ≤ x.val - y.val ∧ x.val - y.val ≤ Isize.max ∧ z.val = x.val - y.val ∧ z.bv = x.bv - y.bv + | none => ¬ (Isize.min ≤ x.val - y.val ∧ x.val - y.val ≤ Isize.max) := by + have := core.num.checked_sub_IScalar_bv_spec x y + simp_all only [Isize.checked_sub, IScalar.min, IScalar.max, Isize.bv, min, max, numBits] + cases h: core.num.checked_sub_IScalar x y <;> simp_all only <;> simp + +/-! +### Checked Mul +-/ + +/-! +Unsigned checked mul +-/ +theorem core.num.checked_mul_UScalar_bv_spec {ty} (x y : UScalar ty) : + match core.num.checked_mul_UScalar x y with + | some z => x.val * y.val ≤ UScalar.max ty ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => UScalar.max ty < x.val * y.val := by + have h := UScalar.mul_equiv x y + simp [checked_mul_UScalar] + dcases hEq : UScalar.mul x y <;> simp_all [Option.ofResult] + +@[progress_pure checked_mul x y] +theorem U8.checked_mul_bv_spec (x y : U8) : + match U8.checked_mul x y with + | some z => x.val * y.val ≤ U8.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => U8.max < x.val * y.val := by + have := core.num.checked_mul_UScalar_bv_spec x y + simp_all only [U8.checked_mul, UScalar.max, U8.bv, min, max, numBits] + cases h: core.num.checked_mul_UScalar x y <;> simp_all only [and_self] + +@[progress_pure checked_mul x y] +theorem U16.checked_mul_bv_spec (x y : U16) : + match U16.checked_mul x y with + | some z => x.val * y.val ≤ U16.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => U16.max < x.val * y.val := by + have := core.num.checked_mul_UScalar_bv_spec x y + simp_all only [U16.checked_mul, UScalar.max, U16.bv, min, max, numBits] + cases h: core.num.checked_mul_UScalar x y <;> simp_all only [and_self] + +@[progress_pure checked_mul x y] +theorem U32.checked_mul_bv_spec (x y : U32) : + match U32.checked_mul x y with + | some z => x.val * y.val ≤ U32.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => U32.max < x.val * y.val := by + have := core.num.checked_mul_UScalar_bv_spec x y + simp_all only [U32.checked_mul, UScalar.max, U32.bv, min, max, numBits] + cases h: core.num.checked_mul_UScalar x y <;> simp_all only [and_self] + +@[progress_pure checked_mul x y] +theorem U64.checked_mul_bv_spec (x y : U64) : + match U64.checked_mul x y with + | some z => x.val * y.val ≤ U64.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => U64.max < x.val * y.val := by + have := core.num.checked_mul_UScalar_bv_spec x y + simp_all only [U64.checked_mul, UScalar.max, U64.bv, min, max, numBits] + cases h: core.num.checked_mul_UScalar x y <;> simp_all only [and_self] + +@[progress_pure checked_mul x y] +theorem U128.checked_mul_bv_spec (x y : U128) : + match U128.checked_mul x y with + | some z => x.val * y.val ≤ U128.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => U128.max < x.val * y.val := by + have := core.num.checked_mul_UScalar_bv_spec x y + simp_all only [U128.checked_mul, UScalar.max, U128.bv, min, max, numBits] + cases h: core.num.checked_mul_UScalar x y <;> simp_all only [and_self] + +@[progress_pure checked_mul x y] +theorem Usize.checked_mul_bv_spec (x y : Usize) : + match Usize.checked_mul x y with + | some z => x.val * y.val ≤ Usize.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => Usize.max < x.val * y.val := by + have := core.num.checked_mul_UScalar_bv_spec x y + simp_all only [Usize.checked_mul, UScalar.max, Usize.bv, min, max, numBits] + cases h: core.num.checked_mul_UScalar x y <;> simp_all only [and_self] + +/-! +Signed checked mul +-/ +theorem core.num.checked_mul_IScalar_bv_spec {ty} (x y : IScalar ty) : + match core.num.checked_mul_IScalar x y with + | some z => IScalar.min ty ≤ x.val * y.val ∧ x.val * y.val ≤ IScalar.max ty ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => ¬ (IScalar.min ty ≤ x.val * y.val ∧ x.val * y.val ≤ IScalar.max ty) := by + have h := IScalar.mul_equiv x y + simp [checked_mul_IScalar] + dcases hEq : IScalar.mul x y <;> simp_all [Option.ofResult] + +@[progress_pure checked_mul x y] +theorem I8.checked_mul_bv_spec (x y : I8) : + match core.num.checked_mul_IScalar x y with + | some z => I8.min ≤ x.val * y.val ∧ x.val * y.val ≤ I8.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => ¬ (I8.min ≤ x.val * y.val ∧ x.val * y.val ≤ I8.max) := by + have := core.num.checked_mul_IScalar_bv_spec x y + simp_all only [I8.checked_mul, IScalar.min, IScalar.max, I8.bv, min, max, numBits] + cases h: core.num.checked_mul_IScalar x y <;> simp_all only [not_false_eq_true, and_self] + +@[progress_pure checked_mul x y] +theorem I16.checked_mul_bv_spec (x y : I16) : + match core.num.checked_mul_IScalar x y with + | some z => I16.min ≤ x.val * y.val ∧ x.val * y.val ≤ I16.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => ¬ (I16.min ≤ x.val * y.val ∧ x.val * y.val ≤ I16.max) := by + have := core.num.checked_mul_IScalar_bv_spec x y + simp_all only [I16.checked_mul, IScalar.min, IScalar.max, I16.bv, min, max, numBits] + cases h: core.num.checked_mul_IScalar x y <;> simp_all only [not_false_eq_true, and_self] + +@[progress_pure checked_mul x y] +theorem I32.checked_mul_bv_spec (x y : I32) : + match core.num.checked_mul_IScalar x y with + | some z => I32.min ≤ x.val * y.val ∧ x.val * y.val ≤ I32.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => ¬ (I32.min ≤ x.val * y.val ∧ x.val * y.val ≤ I32.max) := by + have := core.num.checked_mul_IScalar_bv_spec x y + simp_all only [I32.checked_mul, IScalar.min, IScalar.max, I32.bv, min, max, numBits] + cases h: core.num.checked_mul_IScalar x y <;> simp_all only [not_false_eq_true, and_self] + +@[progress_pure checked_mul x y] +theorem I64.checked_mul_bv_spec (x y : I64) : + match core.num.checked_mul_IScalar x y with + | some z => I64.min ≤ x.val * y.val ∧ x.val * y.val ≤ I64.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => ¬ (I64.min ≤ x.val * y.val ∧ x.val * y.val ≤ I64.max) := by + have := core.num.checked_mul_IScalar_bv_spec x y + simp_all only [I64.checked_mul, IScalar.min, IScalar.max, I64.bv, min, max, numBits] + cases h: core.num.checked_mul_IScalar x y <;> simp_all only [not_false_eq_true, and_self] + +@[progress_pure checked_mul x y] +theorem I128.checked_mul_bv_spec (x y : I128) : + match core.num.checked_mul_IScalar x y with + | some z => I128.min ≤ x.val * y.val ∧ x.val * y.val ≤ I128.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => ¬ (I128.min ≤ x.val * y.val ∧ x.val * y.val ≤ I128.max) := by + have := core.num.checked_mul_IScalar_bv_spec x y + simp_all only [I128.checked_mul, IScalar.min, IScalar.max, I128.bv, min, max, numBits] + cases h: core.num.checked_mul_IScalar x y <;> simp_all only [not_false_eq_true, and_self] + +@[progress_pure checked_mul x y] +theorem Isize.checked_mul_bv_spec (x y : Isize) : + match core.num.checked_mul_IScalar x y with + | some z => Isize.min ≤ x.val * y.val ∧ x.val * y.val ≤ Isize.max ∧ z.val = x.val * y.val ∧ z.bv = x.bv * y.bv + | none => ¬ (Isize.min ≤ x.val * y.val ∧ x.val * y.val ≤ Isize.max) := by + have := core.num.checked_mul_IScalar_bv_spec x y + simp_all only [Isize.checked_mul, IScalar.min, IScalar.max, Isize.bv, min, max, numBits] + cases h: core.num.checked_mul_IScalar x y <;> simp_all only [not_false_eq_true, and_self] + +/-! +### Checked Division +-/ + +/-! +Unsigned checked div +-/ +theorem core.num.checked_div_UScalar_bv_spec {ty} (x y : UScalar ty) : + match core.num.checked_div_UScalar x y with + | some z => y.val ≠ 0 ∧ z.val = x.val / y.val ∧ z.bv = x.bv / y.bv + | none => y.val = 0 := by + simp [checked_div_UScalar, Option.ofResult, UScalar.div] + split_ifs + . zify at * + simp_all + . rename_i hnz + simp + have hnz' : y.val ≠ 0 := by zify at *; simp_all + have ⟨ z, hz ⟩ := UScalar.div_bv_spec x hnz' + have : x / y = x.div y := by rfl + simp [this, UScalar.div, hnz] at hz + simp [hz, hnz'] + +@[progress_pure checked_div x y] +theorem U8.checked_div_bv_spec (x y : U8) : + match U8.checked_div x y with + | some z => y.val ≠ 0 ∧ z.val = x.val / y.val ∧ z.bv = x.bv / y.bv + | none => y.val = 0 := by + have := core.num.checked_div_UScalar_bv_spec x y + simp_all [U8.checked_div, UScalar.max, U8.bv] + cases h: core.num.checked_div_UScalar x y <;> simp_all + +@[progress_pure checked_div x y] +theorem U16.checked_div_bv_spec (x y : U16) : + match U16.checked_div x y with + | some z => y.val ≠ 0 ∧ z.val = x.val / y.val ∧ z.bv = x.bv / y.bv + | none => y.val = 0 := by + have := core.num.checked_div_UScalar_bv_spec x y + simp_all [U16.checked_div, UScalar.max, U16.bv] + cases h: core.num.checked_div_UScalar x y <;> simp_all + +@[progress_pure checked_div x y] +theorem U32.checked_div_bv_spec (x y : U32) : + match U32.checked_div x y with + | some z => y.val ≠ 0 ∧ z.val = x.val / y.val ∧ z.bv = x.bv / y.bv + | none => y.val = 0 := by + have := core.num.checked_div_UScalar_bv_spec x y + simp_all [U32.checked_div, UScalar.max, U32.bv] + cases h: core.num.checked_div_UScalar x y <;> simp_all + +@[progress_pure checked_div x y] +theorem U64.checked_div_bv_spec (x y : U64) : + match U64.checked_div x y with + | some z => y.val ≠ 0 ∧ z.val = x.val / y.val ∧ z.bv = x.bv / y.bv + | none => y.val = 0 := by + have := core.num.checked_div_UScalar_bv_spec x y + simp_all [U64.checked_div, UScalar.max, U64.bv] + cases h: core.num.checked_div_UScalar x y <;> simp_all + +@[progress_pure checked_div x y] +theorem U128.checked_div_bv_spec (x y : U128) : + match U128.checked_div x y with + | some z => y.val ≠ 0 ∧ z.val = x.val / y.val ∧ z.bv = x.bv / y.bv + | none => y.val = 0 := by + have := core.num.checked_div_UScalar_bv_spec x y + simp_all [U128.checked_div, UScalar.max, U128.bv] + cases h: core.num.checked_div_UScalar x y <;> simp_all + +@[progress_pure checked_div x y] +theorem Usize.checked_div_bv_spec (x y : Usize) : + match Usize.checked_div x y with + | some z => y.val ≠ 0 ∧ z.val = x.val / y.val ∧ z.bv = x.bv / y.bv + | none => y.val = 0 := by + have := core.num.checked_div_UScalar_bv_spec x y + simp_all [Usize.checked_div, UScalar.max, Usize.bv] + cases h: core.num.checked_div_UScalar x y <;> simp_all + +/-! +Signed checked div +-/ +theorem core.num.checked_div_IScalar_bv_spec {ty} (x y : IScalar ty) : + match core.num.checked_div_IScalar x y with + | some z => y.val ≠ 0 ∧ ¬ (x.val = IScalar.min ty ∧ y.val = -1) ∧ z.val = Int.tdiv x.val y.val ∧ z.bv = BitVec.sdiv x.bv y.bv + | none => y.val = 0 ∨ (x.val = IScalar.min ty ∧ y.val = -1) := by + simp [checked_div_IScalar, Option.ofResult, IScalar.div] + split_ifs + . zify at * + simp_all + . rename_i hnz hNoOverflow + simp + have hnz' : y.val ≠ 0 := by zify at *; simp_all + have ⟨ z, hz ⟩ := @IScalar.div_bv_spec _ x y hnz' (by simp; tauto) + have : x / y = x.div y := by rfl + simp [this, IScalar.div, hnz, hNoOverflow] at hz + split_ifs at hz + simp at hz + simp [hz, hnz'] + tauto + . simp_all + +@[progress_pure checked_div x y] +theorem I8.checked_div_bv_spec (x y : I8) : + match core.num.checked_div_IScalar x y with + | some z => y.val ≠ 0 ∧ ¬ (x.val = I8.min ∧ y.val = -1) ∧ z.val = Int.tdiv x.val y.val ∧ z.bv = BitVec.sdiv x.bv y.bv + | none => y.val = 0 ∨ (x.val = I8.min ∧ y.val = -1) := by + have := core.num.checked_div_IScalar_bv_spec x y + simp_all only [I8.checked_div, I8.bv, IScalar.min, min, max, numBits] + cases h: core.num.checked_div_IScalar x y <;> simp_all only [ne_eq, not_false_eq_true, and_self, this] + +@[progress_pure checked_div x y] +theorem I16.checked_div_bv_spec (x y : I16) : + match core.num.checked_div_IScalar x y with + | some z => y.val ≠ 0 ∧ ¬ (x.val = I16.min ∧ y.val = -1) ∧ z.val = Int.tdiv x.val y.val ∧ z.bv = BitVec.sdiv x.bv y.bv + | none => y.val = 0 ∨ (x.val = I16.min ∧ y.val = -1) := by + have := core.num.checked_div_IScalar_bv_spec x y + simp_all only [I16.checked_div, I16.bv, IScalar.min, min, max, numBits] + cases h: core.num.checked_div_IScalar x y <;> simp_all only [ne_eq, not_false_eq_true, and_self, this] + +@[progress_pure checked_div x y] +theorem I32.checked_div_bv_spec (x y : I32) : + match core.num.checked_div_IScalar x y with + | some z => y.val ≠ 0 ∧ ¬ (x.val = I32.min ∧ y.val = -1) ∧ z.val = Int.tdiv x.val y.val ∧ z.bv = BitVec.sdiv x.bv y.bv + | none => y.val = 0 ∨ (x.val = I32.min ∧ y.val = -1) := by + have := core.num.checked_div_IScalar_bv_spec x y + simp_all only [I32.checked_div, I32.bv, IScalar.min, min, max, numBits] + cases h: core.num.checked_div_IScalar x y <;> simp_all only [ne_eq, not_false_eq_true, and_self, this] + +@[progress_pure checked_div x y] +theorem I64.checked_div_bv_spec (x y : I64) : + match core.num.checked_div_IScalar x y with + | some z => y.val ≠ 0 ∧ ¬ (x.val = I64.min ∧ y.val = -1) ∧ z.val = Int.tdiv x.val y.val ∧ z.bv = BitVec.sdiv x.bv y.bv + | none => y.val = 0 ∨ (x.val = I64.min ∧ y.val = -1) := by + have := core.num.checked_div_IScalar_bv_spec x y + simp_all only [I64.checked_div, I64.bv, IScalar.min, min, max, numBits] + cases h: core.num.checked_div_IScalar x y <;> simp_all only [ne_eq, not_false_eq_true, and_self, this] + +@[progress_pure checked_div x y] +theorem I128.checked_div_bv_spec (x y : I128) : + match core.num.checked_div_IScalar x y with + | some z => y.val ≠ 0 ∧ ¬ (x.val = I128.min ∧ y.val = -1) ∧ z.val = Int.tdiv x.val y.val ∧ z.bv = BitVec.sdiv x.bv y.bv + | none => y.val = 0 ∨ (x.val = I128.min ∧ y.val = -1) := by + have := core.num.checked_div_IScalar_bv_spec x y + simp_all only [I128.checked_div, I128.bv, IScalar.min, min, max, numBits] + cases h: core.num.checked_div_IScalar x y <;> simp_all only [ne_eq, not_false_eq_true, and_self, this] + +@[progress_pure checked_div x y] +theorem Isize.checked_div_bv_spec (x y : Isize) : + match core.num.checked_div_IScalar x y with + | some z => y.val ≠ 0 ∧ ¬ (x.val = Isize.min ∧ y.val = -1) ∧ z.val = Int.tdiv x.val y.val ∧ z.bv = BitVec.sdiv x.bv y.bv + | none => y.val = 0 ∨ (x.val = Isize.min ∧ y.val = -1) := by + have := core.num.checked_div_IScalar_bv_spec x y + simp_all only [Isize.checked_div, Isize.bv, IScalar.min, min, max, numBits] + cases h: core.num.checked_div_IScalar x y <;> simp_all only [ne_eq, not_false_eq_true, and_self, this] + +/-! +### Checked Remained +-/ + +/-! +Unsigned checked remainder +-/ +theorem core.num.checked_rem_UScalar_bv_spec {ty} (x y : UScalar ty) : + match core.num.checked_rem_UScalar x y with + | some z => y.val ≠ 0 ∧ z.val = x.val % y.val ∧ z.bv = x.bv % y.bv + | none => y.val = 0 := by + simp [checked_rem_UScalar, Option.ofResult, UScalar.rem] + split_ifs + . zify at * + simp_all + . rename_i hnz + simp + have hnz' : y.val ≠ 0 := by zify at *; simp_all + have ⟨ z, hz ⟩ := UScalar.rem_bv_spec x hnz' + have : x % y = x.rem y := by rfl + simp [this, UScalar.rem, hnz] at hz + simp [hz, hnz'] + +@[progress_pure checked_rem x y] +theorem U8.checked_rem_bv_spec (x y : U8) : + match U8.checked_rem x y with + | some z => y.val ≠ 0 ∧ z.val = x.val % y.val ∧ z.bv = x.bv % y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_UScalar_bv_spec x y + simp_all [U8.checked_rem, UScalar.max, U8.bv] + cases h: core.num.checked_rem_UScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem U16.checked_rem_bv_spec (x y : U16) : + match U16.checked_rem x y with + | some z => y.val ≠ 0 ∧ z.val = x.val % y.val ∧ z.bv = x.bv % y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_UScalar_bv_spec x y + simp_all [U16.checked_rem, UScalar.max, U16.bv] + cases h: core.num.checked_rem_UScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem U32.checked_rem_bv_spec (x y : U32) : + match U32.checked_rem x y with + | some z => y.val ≠ 0 ∧ z.val = x.val % y.val ∧ z.bv = x.bv % y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_UScalar_bv_spec x y + simp_all [U32.checked_rem, UScalar.max, U32.bv] + cases h: core.num.checked_rem_UScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem U64.checked_rem_bv_spec (x y : U64) : + match U64.checked_rem x y with + | some z => y.val ≠ 0 ∧ z.val = x.val % y.val ∧ z.bv = x.bv % y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_UScalar_bv_spec x y + simp_all [U64.checked_rem, UScalar.max, U64.bv] + cases h: core.num.checked_rem_UScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem U128.checked_rem_bv_spec (x y : U128) : + match U128.checked_rem x y with + | some z => y.val ≠ 0 ∧ z.val = x.val % y.val ∧ z.bv = x.bv % y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_UScalar_bv_spec x y + simp_all [U128.checked_rem, UScalar.max, U128.bv] + cases h: core.num.checked_rem_UScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem Usize.checked_rem_bv_spec (x y : Usize) : + match Usize.checked_rem x y with + | some z => y.val ≠ 0 ∧ z.val = x.val % y.val ∧ z.bv = x.bv % y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_UScalar_bv_spec x y + simp_all [Usize.checked_rem, UScalar.max, Usize.bv] + cases h: core.num.checked_rem_UScalar x y <;> simp_all + +/-! +Signed checked rem +-/ +theorem core.num.checked_rem_IScalar_bv_spec {ty} (x y : IScalar ty) : + match core.num.checked_rem_IScalar x y with + | some z => y.val ≠ 0 ∧ z.val = Int.tmod x.val y.val ∧ z.bv = BitVec.srem x.bv y.bv + | none => y.val = 0 := by + simp [checked_rem_IScalar, Option.ofResult, IScalar.rem] + split_ifs + . zify at * + simp_all + . rename_i hnz + simp + have hnz' : y.val ≠ 0 := by zify at *; simp_all + have ⟨ z, hz ⟩ := @IScalar.rem_bv_spec _ x y hnz' + have : x % y = x.rem y := by rfl + simp [this, IScalar.rem, hnz] at hz + simp [*] + +@[progress_pure checked_rem x y] +theorem I8.checked_rem_bv_spec (x y : I8) : + match core.num.checked_rem_IScalar x y with + | some z => y.val ≠ 0 ∧ z.val = Int.tmod x.val y.val ∧ z.bv = BitVec.srem x.bv y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_IScalar_bv_spec x y + simp_all only [I8.checked_rem, I8.bv, IScalar.min] + cases h: core.num.checked_rem_IScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem I16.checked_rem_bv_spec (x y : I16) : + match core.num.checked_rem_IScalar x y with + | some z => y.val ≠ 0 ∧ z.val = Int.tmod x.val y.val ∧ z.bv = BitVec.srem x.bv y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_IScalar_bv_spec x y + simp_all only [I16.checked_rem, I16.bv, IScalar.min] + cases h: core.num.checked_rem_IScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem I32.checked_rem_bv_spec (x y : I32) : + match core.num.checked_rem_IScalar x y with + | some z => y.val ≠ 0 ∧ z.val = Int.tmod x.val y.val ∧ z.bv = BitVec.srem x.bv y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_IScalar_bv_spec x y + simp_all only [I32.checked_rem, I32.bv, IScalar.min] + cases h: core.num.checked_rem_IScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem I64.checked_rem_bv_spec (x y : I64) : + match core.num.checked_rem_IScalar x y with + | some z => y.val ≠ 0 ∧ z.val = Int.tmod x.val y.val ∧ z.bv = BitVec.srem x.bv y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_IScalar_bv_spec x y + simp_all only [I64.checked_rem, I64.bv, IScalar.min] + cases h: core.num.checked_rem_IScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem I128.checked_rem_bv_spec (x y : I128) : + match core.num.checked_rem_IScalar x y with + | some z => y.val ≠ 0 ∧ z.val = Int.tmod x.val y.val ∧ z.bv = BitVec.srem x.bv y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_IScalar_bv_spec x y + simp_all only [I128.checked_rem, I128.bv, IScalar.min] + cases h: core.num.checked_rem_IScalar x y <;> simp_all + +@[progress_pure checked_rem x y] +theorem Isize.checked_rem_bv_spec (x y : Isize) : + match core.num.checked_rem_IScalar x y with + | some z => y.val ≠ 0 ∧ z.val = Int.tmod x.val y.val ∧ z.bv = BitVec.srem x.bv y.bv + | none => y.val = 0 := by + have := core.num.checked_rem_IScalar_bv_spec x y + simp_all only [Isize.checked_rem, Isize.bv, IScalar.min] + cases h: core.num.checked_rem_IScalar x y <;> simp_all + +/-! +# Leading zeros +-/ + +/- TODO: move to Mathlib? + Also not sure this is the best way of defining this quantity -/ +def BitVec.leadingZerosAux {w : Nat} (x : BitVec w) (i : Nat) : Nat := + if i < w then + if ¬ x.getMsbD i then leadingZerosAux x (i + 1) + else i + else 0 + +def BitVec.leadingZeros {w : Nat} (x : BitVec w) : Nat := + leadingZerosAux x 0 + +#assert BitVec.leadingZeros 1#16 = 15 +#assert BitVec.leadingZeros 1#32 = 31 +#assert BitVec.leadingZeros 255#32 = 24 + +@[progress_pure_def] def core.num.Usize.leading_zeros (x : Usize) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.U8.leading_zeros (x : U8) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.U16.leading_zeros (x : U16) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.U32.leading_zeros (x : U32) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.U64.leading_zeros (x : U64) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.U128.leading_zeros (x : U128) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ + +@[progress_pure_def] def core.num.Isize.leading_zeros (x : Isize) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.I8.leading_zeros (x : I8) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.I16.leading_zeros (x : I16) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.I32.leading_zeros (x : I32) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.I64.leading_zeros (x : I64) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ +@[progress_pure_def] def core.num.I128.leading_zeros (x : I128) : U32 := ⟨ BitVec.leadingZeros x.bv ⟩ -- Clone @[reducible, simp] def core.clone.impls.CloneUsize.clone (x : Usize) : Usize := x @@ -900,6 +3242,10 @@ def core.num.I128.leading_zeros (x : I128) : U32 := sorry @[reducible, simp] def core.clone.impls.CloneI64.clone (x : I64) : I64 := x @[reducible, simp] def core.clone.impls.CloneI128.clone (x : I128) : I128 := x +/-! +# Clone and Copy +-/ + @[reducible] def core.clone.CloneUsize : core.clone.Clone Usize := { clone := fun x => ok (core.clone.impls.CloneUsize.clone x) @@ -1020,494 +3366,695 @@ def core.marker.CopyIsize : core.marker.Copy Isize := { cloneInst := core.clone.CloneIsize } --- This is easier defined this way than with the modulo operation (because of the --- unsigned cases). -def int_overflowing_add (ty : ScalarTy) (x y : Int) : Int × Bool := - let z := x + y - let b := false - let range := Scalar.max ty - Scalar.min ty + 1 - let r := (z, b) - let r := if r.1 > Scalar.max ty then (r.1 - range, true) else r - let r := if r.1 < Scalar.min ty then (r.1 + range, true) else r - r - -def int_overflowing_add_in_bounds {ty} (x y : Scalar ty) : - let r := int_overflowing_add ty x.val y.val - Scalar.min ty ≤ r.1 ∧ r.1 ≤ Scalar.max ty := by - simp [int_overflowing_add] - split <;> split <;> cases ty <;> simp at * <;> - scalar_tac - -def int_overflowing_add_unsigned_overflow {ty} (h: ¬ ty.isSigned) (x y : Scalar ty) : - let r := int_overflowing_add ty x.val y.val - x.val + y.val = if r.2 then r.1 + Scalar.max ty + 1 else r.1 := by - simp [int_overflowing_add] - split <;> split <;> cases ty <;> simp [ScalarTy.isSigned] at * <;> - scalar_tac - -def Scalar.overflowing_add {ty} (x y : Scalar ty) : Result (Scalar ty × Bool) := - let r := int_overflowing_add ty x.val y.val - have h := int_overflowing_add_in_bounds x y - let z : Scalar ty := ⟨ r.1, h.left, h.right ⟩ - ok (z, r.2) +/-! +# Overflowing Operations +-/ + +-- TODO: we should redefine this, in particular so that it doesn't live in the `Result` monad + +def UScalar.overflowing_add {ty} (x y : UScalar ty) : UScalar ty × Bool := + (⟨ BitVec.ofNat _ (x.val + y.val) ⟩, 2^ty.numBits ≤ x.val + y.val) + +def IScalar.overflowing_add (ty : IScalarTy) (x y : IScalar ty) : IScalar ty × Bool := + (⟨ BitVec.ofInt _ (x.val + y.val) ⟩, + ¬ (-2^(ty.numBits -1) ≤ x.val + y.val ∧ x.val + y.val < 2^ty.numBits)) /- [core::num::{u8}::overflowing_add] -/ -def core.num.U8.overflowing_add := @Scalar.overflowing_add ScalarTy.U8 +def core.num.U8.overflowing_add := @UScalar.overflowing_add .U8 /- [core::num::{u16}::overflowing_add] -/ -def core.num.U16.overflowing_add := @Scalar.overflowing_add ScalarTy.U16 +def core.num.U16.overflowing_add := @UScalar.overflowing_add .U16 /- [core::num::{u32}::overflowing_add] -/ -def core.num.U32.overflowing_add := @Scalar.overflowing_add ScalarTy.U32 +def core.num.U32.overflowing_add := @UScalar.overflowing_add .U32 /- [core::num::{u64}::overflowing_add] -/ -def core.num.U64.overflowing_add := @Scalar.overflowing_add ScalarTy.U64 +def core.num.U64.overflowing_add := @UScalar.overflowing_add .U64 /- [core::num::{u128}::overflowing_add] -/ -def core.num.U128.overflowing_add := @Scalar.overflowing_add ScalarTy.U128 +def core.num.U128.overflowing_add := @UScalar.overflowing_add .U128 /- [core::num::{usize}::overflowing_add] -/ -def core.num.Usize.overflowing_add := @Scalar.overflowing_add ScalarTy.Usize +def core.num.Usize.overflowing_add := @UScalar.overflowing_add .Usize /- [core::num::{i8}::overflowing_add] -/ -def core.num.I8.overflowing_add := @Scalar.overflowing_add ScalarTy.I8 +def core.num.I8.overflowing_add := @IScalar.overflowing_add .I8 /- [core::num::{i16}::overflowing_add] -/ -def core.num.I16.overflowing_add := @Scalar.overflowing_add ScalarTy.I16 +def core.num.I16.overflowing_add := @IScalar.overflowing_add .I16 /- [core::num::{i32}::overflowing_add] -/ -def core.num.I32.overflowing_add := @Scalar.overflowing_add ScalarTy.I32 +def core.num.I32.overflowing_add := @IScalar.overflowing_add .I32 /- [core::num::{i64}::overflowing_add] -/ -def core.num.I64.overflowing_add := @Scalar.overflowing_add ScalarTy.I64 +def core.num.I64.overflowing_add := @IScalar.overflowing_add .I64 /- [core::num::{i128}::overflowing_add] -/ -def core.num.I128.overflowing_add := @Scalar.overflowing_add ScalarTy.I128 +def core.num.I128.overflowing_add := @IScalar.overflowing_add .I128 /- [core::num::{isize}::overflowing_add] -/ -def core.num.Isize.overflowing_add := @Scalar.overflowing_add ScalarTy.Isize - -@[pspec] -theorem core.num.U8.overflowing_add_spec (x y : U8) : - ∃ z b, overflowing_add x y = ok (z, b) ∧ - if x.val + y.val > U8.max then z.val = x.val + y.val - U8.max - 1 ∧ b = true - else z.val = x.val + y.val ∧ b = false - := by - simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add] - split <;> split <;> simp_all <;> scalar_tac - -@[pspec] -theorem core.num.U16.overflowing_add_spec (x y : U16) : - ∃ z b, overflowing_add x y = ok (z, b) ∧ - if x.val + y.val > U16.max then z.val = x.val + y.val - U16.max - 1 ∧ b = true - else z.val = x.val + y.val ∧ b = false - := by - simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add] - split <;> split <;> simp_all <;> scalar_tac - -@[pspec] -theorem core.num.U32.overflowing_add_spec (x y : U32) : - ∃ z b, overflowing_add x y = ok (z, b) ∧ - if x.val + y.val > U32.max then z.val = x.val + y.val - U32.max - 1 ∧ b = true - else z.val = x.val + y.val ∧ b = false - := by - simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add] - split <;> split <;> simp_all <;> scalar_tac - -@[pspec] -theorem core.num.U64.overflowing_add_spec (x y : U64) : - ∃ z b, overflowing_add x y = ok (z, b) ∧ - if x.val + y.val > U64.max then z.val = x.val + y.val - U64.max - 1 ∧ b = true - else z.val = x.val + y.val ∧ b = false - := by - simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add] - split <;> split <;> simp_all <;> scalar_tac - -@[pspec] -theorem core.num.U128.overflowing_add_spec (x y : U128) : - ∃ z b, overflowing_add x y = ok (z, b) ∧ - if x.val + y.val > U128.max then z.val = x.val + y.val - U128.max - 1 ∧ b = true - else z.val = x.val + y.val ∧ b = false - := by - simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add] - split <;> split <;> simp_all <;> scalar_tac - -@[pspec] -theorem core.num.Usize.overflowing_add_spec (x y : Usize) : - ∃ z b, overflowing_add x y = ok (z, b) ∧ - if x.val + y.val > Usize.max then z.val = x.val + y.val - Usize.max - 1 ∧ b = true - else z.val = x.val + y.val ∧ b = false +def core.num.Isize.overflowing_add := @IScalar.overflowing_add .Isize + +attribute [-simp] Bool.exists_bool + +theorem UScalar.overflowing_add_eq {ty} (x y : UScalar ty) : + let z := overflowing_add x y + if x.val + y.val > UScalar.max ty then + z.fst.val + UScalar.size ty = x.val + y.val ∧ + z.snd = true + else + z.fst.val = x.val + y.val ∧ + z.snd = false := by - simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add] - split <;> split <;> simp_all <;> scalar_tac - --- Saturating add -def int_saturating_add (ty : ScalarTy) (x y : Int) : Int := - let r := x + y - let r := if r > Scalar.max ty then Scalar.max ty else r - let r := if r < Scalar.min ty then Scalar.min ty else r - r - -def int_saturating_add_in_bounds {ty} (x y : Scalar ty) : - let r := int_saturating_add ty x.val y.val - Scalar.min ty ≤ r ∧ r ≤ Scalar.max ty := by - simp [int_saturating_add] - split <;> constructor <;> cases ty <;> scalar_tac - -def Scalar.saturating_add {ty} (x y : Scalar ty) : Scalar ty := - let r := int_saturating_add ty x.val y.val - have h := int_saturating_add_in_bounds x y - ⟨ r, h.1, h.2 ⟩ + simp [overflowing_add] + simp only [val, BitVec.toNat_ofNat, max] + split <;> rename_i hLt + . split_conjs + . have : (x.bv.toNat + y.bv.toNat) % 2^ty.numBits = + (x.bv.toNat + y.bv.toNat - 2^ty.numBits) % 2^ty.numBits := by + rw [Nat.mod_eq_sub_mod] + omega + rw [this]; clear this + + have := @Nat.mod_eq_of_lt (x.bv.toNat + y.bv.toNat - 2^ty.numBits) (2^ty.numBits) (by omega) + rw [this]; clear this + simp [size] + scalar_tac + . omega + . split_conjs + . apply Nat.mod_eq_of_lt + omega + . omega + +@[progress_pure overflowing_add x y] +theorem core.num.U8.overflowing_add_eq (x y : U8) : + let z := overflowing_add x y + if x.val + y.val > UScalar.max .U8 then z.fst.val + UScalar.size .U8 = x.val + y.val ∧ z.snd = true + else z.fst.val = x.val + y.val ∧ z.snd = false + := UScalar.overflowing_add_eq x y + +@[progress_pure overflowing_add x y] +theorem core.num.U16.overflowing_add_eq (x y : U16) : + let z := overflowing_add x y + if x.val + y.val > UScalar.max .U16 then z.fst.val + UScalar.size .U16 = x.val + y.val ∧ z.snd = true + else z.fst.val = x.val + y.val ∧ z.snd = false + := UScalar.overflowing_add_eq x y + +@[progress_pure overflowing_add x y] +theorem core.num.U32.overflowing_add_eq (x y : U32) : + let z := overflowing_add x y + if x.val + y.val > UScalar.max .U32 then z.fst.val + UScalar.size .U32 = x.val + y.val ∧ z.snd = true + else z.fst.val = x.val + y.val ∧ z.snd = false + := UScalar.overflowing_add_eq x y + +@[progress_pure overflowing_add x y] +theorem core.num.U64.overflowing_add_eq (x y : U64) : + let z := overflowing_add x y + if x.val + y.val > UScalar.max .U64 then z.fst.val + UScalar.size .U64 = x.val + y.val ∧ z.snd = true + else z.fst.val = x.val + y.val ∧ z.snd = false + := UScalar.overflowing_add_eq x y + +@[progress_pure overflowing_add x y] +theorem core.num.U128.overflowing_add_eq (x y : U128) : + let z := overflowing_add x y + if x.val + y.val > UScalar.max .U128 then z.fst.val + UScalar.size .U128 = x.val + y.val ∧ z.snd = true + else z.fst.val = x.val + y.val ∧ z.snd = false + := UScalar.overflowing_add_eq x y + +@[progress_pure overflowing_add x y] +theorem core.num.Usize.overflowing_add_eq (x y : Usize) : + let z := overflowing_add x y + if x.val + y.val > UScalar.max .Usize then z.fst.val + UScalar.size .Usize = x.val + y.val ∧ z.snd = true + else z.fst.val = x.val + y.val ∧ z.snd = false + := UScalar.overflowing_add_eq x y + +/-! +# Saturating Operations +-/ + +/-! +Saturating add: unsigned +-/ +def UScalar.saturating_add {ty : UScalarTy} (x y : UScalar ty) : UScalar ty := + ⟨ BitVec.ofNat _ (Min.min (UScalar.max ty) (x.val + y.val)) ⟩ /- [core::num::{u8}::saturating_add] -/ -def core.num.U8.saturating_add := @Scalar.saturating_add ScalarTy.U8 +def core.num.U8.saturating_add := @UScalar.saturating_add UScalarTy.U8 /- [core::num::{u16}::saturating_add] -/ -def core.num.U16.saturating_add := @Scalar.saturating_add ScalarTy.U16 +def core.num.U16.saturating_add := @UScalar.saturating_add UScalarTy.U16 /- [core::num::{u32}::saturating_add] -/ -def core.num.U32.saturating_add := @Scalar.saturating_add ScalarTy.U32 +def core.num.U32.saturating_add := @UScalar.saturating_add UScalarTy.U32 /- [core::num::{u64}::saturating_add] -/ -def core.num.U64.saturating_add := @Scalar.saturating_add ScalarTy.U64 +def core.num.U64.saturating_add := @UScalar.saturating_add UScalarTy.U64 /- [core::num::{u128}::saturating_add] -/ -def core.num.U128.saturating_add := @Scalar.saturating_add ScalarTy.U128 +def core.num.U128.saturating_add := @UScalar.saturating_add UScalarTy.U128 /- [core::num::{usize}::saturating_add] -/ -def core.num.Usize.saturating_add := @Scalar.saturating_add ScalarTy.Usize +def core.num.Usize.saturating_add := @UScalar.saturating_add UScalarTy.Usize + +/-! +Saturating add: signed +-/ +def IScalar.saturating_add {ty : IScalarTy} (x y : IScalar ty) : IScalar ty := + ⟨ BitVec.ofInt _ (Max.max (IScalar.min ty) (Min.min (IScalar.max ty) (x.val + y.val))) ⟩ /- [core::num::{i8}::saturating_add] -/ -def core.num.I8.saturating_add := @Scalar.saturating_add ScalarTy.I8 +def core.num.I8.saturating_add := @IScalar.saturating_add IScalarTy.I8 /- [core::num::{i16}::saturating_add] -/ -def core.num.I16.saturating_add := @Scalar.saturating_add ScalarTy.I16 +def core.num.I16.saturating_add := @IScalar.saturating_add IScalarTy.I16 /- [core::num::{i32}::saturating_add] -/ -def core.num.I32.saturating_add := @Scalar.saturating_add ScalarTy.I32 +def core.num.I32.saturating_add := @IScalar.saturating_add IScalarTy.I32 /- [core::num::{i64}::saturating_add] -/ -def core.num.I64.saturating_add := @Scalar.saturating_add ScalarTy.I64 +def core.num.I64.saturating_add := @IScalar.saturating_add IScalarTy.I64 /- [core::num::{i128}::saturating_add] -/ -def core.num.I128.saturating_add := @Scalar.saturating_add ScalarTy.I128 +def core.num.I128.saturating_add := @IScalar.saturating_add IScalarTy.I128 /- [core::num::{isize}::saturating_add] -/ -def core.num.Isize.saturating_add := @Scalar.saturating_add ScalarTy.Isize - -theorem core.num.U8.saturating_add_spec (x y : U8) : - let z := saturating_add x y - if x.val + y.val > U8.max then z.val = U8.max - else z.val = x.val + y.val - := by - simp [saturating_add, Scalar.saturating_add, int_saturating_add] - split <;> split <;> split <;> scalar_tac +def core.num.Isize.saturating_add := @IScalar.saturating_add IScalarTy.Isize -theorem core.num.U16.saturating_add_spec (x y : U16) : - let z := saturating_add x y - if x.val + y.val > U16.max then z.val = U16.max - else z.val = x.val + y.val - := by - simp [saturating_add, Scalar.saturating_add, int_saturating_add] - split <;> split <;> split <;> scalar_tac - -theorem core.num.U32.saturating_add_spec (x y : U32) : - let z := saturating_add x y - if x.val + y.val > U32.max then z.val = U32.max - else z.val = x.val + y.val - := by - simp [saturating_add, Scalar.saturating_add, int_saturating_add] - split <;> split <;> split <;> scalar_tac - -theorem core.num.U64.saturating_add_spec (x y : U64) : - let z := saturating_add x y - if x.val + y.val > U64.max then z.val = U64.max - else z.val = x.val + y.val - := by - simp [saturating_add, Scalar.saturating_add, int_saturating_add] - split <;> split <;> split <;> scalar_tac - -theorem core.num.U128.saturating_add_spec (x y : U128) : - let z := saturating_add x y - if x.val + y.val > U128.max then z.val = U128.max - else z.val = x.val + y.val - := by - simp [saturating_add, Scalar.saturating_add, int_saturating_add] - split <;> split <;> split <;> scalar_tac - -theorem core.num.Usize.saturating_add_spec (x y : Usize) : - let z := saturating_add x y - if x.val + y.val > Usize.max then z.val = Usize.max - else z.val = x.val + y.val - := by - simp [saturating_add, Scalar.saturating_add, int_saturating_add] - split <;> split <;> split <;> scalar_tac - --- Saturating sub -def int_saturating_sub (ty : ScalarTy) (x y : Int) : Int := - let r := x - y - let r := if r > Scalar.max ty then Scalar.max ty else r - let r := if r < Scalar.min ty then Scalar.min ty else r - r - -def int_saturating_sub_in_bounds {ty} (x y : Scalar ty) : - let r := int_saturating_sub ty x.val y.val - Scalar.min ty ≤ r ∧ r ≤ Scalar.max ty := by - simp [int_saturating_sub] - split <;> constructor <;> cases ty <;> scalar_tac - -def Scalar.saturating_sub {ty} (x y : Scalar ty) : Scalar ty := - let r := int_saturating_sub ty x.val y.val - have h := int_saturating_sub_in_bounds x y - ⟨ r, h.1, h.2 ⟩ +/-! +Saturating sub: unsigned +-/ +def UScalar.saturating_sub {ty : UScalarTy} (x y : UScalar ty) : UScalar ty := + ⟨ BitVec.ofNat _ (Max.max 0 (x.val - y.val)) ⟩ /- [core::num::{u8}::saturating_sub] -/ -def core.num.U8.saturating_sub := @Scalar.saturating_sub ScalarTy.U8 +def core.num.U8.saturating_sub := @UScalar.saturating_sub UScalarTy.U8 /- [core::num::{u16}::saturating_sub] -/ -def core.num.U16.saturating_sub := @Scalar.saturating_sub ScalarTy.U16 +def core.num.U16.saturating_sub := @UScalar.saturating_sub UScalarTy.U16 /- [core::num::{u32}::saturating_sub] -/ -def core.num.U32.saturating_sub := @Scalar.saturating_sub ScalarTy.U32 +def core.num.U32.saturating_sub := @UScalar.saturating_sub UScalarTy.U32 /- [core::num::{u64}::saturating_sub] -/ -def core.num.U64.saturating_sub := @Scalar.saturating_sub ScalarTy.U64 +def core.num.U64.saturating_sub := @UScalar.saturating_sub UScalarTy.U64 /- [core::num::{u128}::saturating_sub] -/ -def core.num.U128.saturating_sub := @Scalar.saturating_sub ScalarTy.U128 +def core.num.U128.saturating_sub := @UScalar.saturating_sub UScalarTy.U128 /- [core::num::{usize}::saturating_sub] -/ -def core.num.Usize.saturating_sub := @Scalar.saturating_sub ScalarTy.Usize +def core.num.Usize.saturating_sub := @UScalar.saturating_sub UScalarTy.Usize + +/-! +Saturating sub: signed +-/ +def IScalar.saturating_sub {ty : IScalarTy} (x y : IScalar ty) : IScalar ty := + ⟨ BitVec.ofInt _ (Max.max (IScalar.min ty) (Min.min (IScalar.max ty) (x.val - y.val))) ⟩ /- [core::num::{i8}::saturating_sub] -/ -def core.num.I8.saturating_sub := @Scalar.saturating_sub ScalarTy.I8 +def core.num.I8.saturating_sub := @IScalar.saturating_sub IScalarTy.I8 /- [core::num::{i16}::saturating_sub] -/ -def core.num.I16.saturating_sub := @Scalar.saturating_sub ScalarTy.I16 +def core.num.I16.saturating_sub := @IScalar.saturating_sub IScalarTy.I16 /- [core::num::{i32}::saturating_sub] -/ -def core.num.I32.saturating_sub := @Scalar.saturating_sub ScalarTy.I32 +def core.num.I32.saturating_sub := @IScalar.saturating_sub IScalarTy.I32 /- [core::num::{i64}::saturating_sub] -/ -def core.num.I64.saturating_sub := @Scalar.saturating_sub ScalarTy.I64 +def core.num.I64.saturating_sub := @IScalar.saturating_sub IScalarTy.I64 /- [core::num::{i128}::saturating_sub] -/ -def core.num.I128.saturating_sub := @Scalar.saturating_sub ScalarTy.I128 +def core.num.I128.saturating_sub := @IScalar.saturating_sub IScalarTy.I128 /- [core::num::{isize}::saturating_sub] -/ -def core.num.Isize.saturating_sub := @Scalar.saturating_sub ScalarTy.Isize - -theorem core.num.U8.saturating_sub_spec (x y : U8) : - let z := saturating_sub x y - if x.val - y.val < 0 then z.val = 0 - else z.val = x.val - y.val - := by - simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] - split <;> split <;> split <;> scalar_tac +def core.num.Isize.saturating_sub := @IScalar.saturating_sub IScalarTy.Isize -theorem core.num.U16.saturating_sub_spec (x y : U16) : - let z := saturating_sub x y - if x.val - y.val < 0 then z.val = 0 - else z.val = x.val - y.val - := by - simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] - split <;> split <;> split <;> scalar_tac - -theorem core.num.U32.saturating_sub_spec (x y : U32) : - let z := saturating_sub x y - if x.val - y.val < 0 then z.val = 0 - else z.val = x.val - y.val - := by - simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] - split <;> split <;> split <;> scalar_tac - -theorem core.num.U64.saturating_sub_spec (x y : U64) : - let z := saturating_sub x y - if x.val - y.val < 0 then z.val = 0 - else z.val = x.val - y.val - := by - simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] - split <;> split <;> split <;> scalar_tac - -theorem core.num.U128.saturating_sub_spec (x y : U128) : - let z := saturating_sub x y - if x.val - y.val < 0 then z.val = 0 - else z.val = x.val - y.val - := by - simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] - split <;> split <;> split <;> scalar_tac +/-! +# Wrapping Operations +-/ -theorem core.num.Usize.saturating_sub_spec (x y : Usize) : - let z := saturating_sub x y - if x.val - y.val < 0 then z.val = 0 - else z.val = x.val - y.val - := by - simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] - split <;> split <;> split <;> scalar_tac +/-! +## Wrapping Add +-/ --- Wrapping add -def Scalar.wrapping_add {ty} (x y : Scalar ty) : Scalar ty := sorry +def UScalar.wrapping_add {ty} (x y : UScalar ty) : UScalar ty := ⟨ x.bv + y.bv ⟩ /- [core::num::{u8}::wrapping_add] -/ -def core.num.U8.wrapping_add : U8 → U8 → U8 := @Scalar.wrapping_add ScalarTy.U8 +@[progress_pure_def] +def core.num.U8.wrapping_add : U8 → U8 → U8 := @UScalar.wrapping_add UScalarTy.U8 /- [core::num::{u16}::wrapping_add] -/ -def core.num.U16.wrapping_add : U16 → U16 → U16 := @Scalar.wrapping_add ScalarTy.U16 +@[progress_pure_def] +def core.num.U16.wrapping_add : U16 → U16 → U16 := @UScalar.wrapping_add UScalarTy.U16 /- [core::num::{u32}::wrapping_add] -/ -def core.num.U32.wrapping_add : U32 → U32 → U32 := @Scalar.wrapping_add ScalarTy.U32 +@[progress_pure_def] +def core.num.U32.wrapping_add : U32 → U32 → U32 := @UScalar.wrapping_add UScalarTy.U32 /- [core::num::{u64}::wrapping_add] -/ -def core.num.U64.wrapping_add : U64 → U64 → U64 := @Scalar.wrapping_add ScalarTy.U64 +@[progress_pure_def] +def core.num.U64.wrapping_add : U64 → U64 → U64 := @UScalar.wrapping_add UScalarTy.U64 /- [core::num::{u128}::wrapping_add] -/ -def core.num.U128.wrapping_add : U128 → U128 → U128 := @Scalar.wrapping_add ScalarTy.U128 +@[progress_pure_def] +def core.num.U128.wrapping_add : U128 → U128 → U128 := @UScalar.wrapping_add UScalarTy.U128 /- [core::num::{usize}::wrapping_add] -/ -def core.num.Usize.wrapping_add : Usize → Usize → Usize := @Scalar.wrapping_add ScalarTy.Usize +@[progress_pure_def] +def core.num.Usize.wrapping_add : Usize → Usize → Usize := @UScalar.wrapping_add UScalarTy.Usize + +def IScalar.wrapping_add {ty} (x y : IScalar ty) : IScalar ty := ⟨ x.bv + y.bv ⟩ /- [core::num::{i8}::wrapping_add] -/ -def core.num.I8.wrapping_add : I8 → I8 → I8 := @Scalar.wrapping_add ScalarTy.I8 +@[progress_pure_def] +def core.num.I8.wrapping_add : I8 → I8 → I8 := @IScalar.wrapping_add IScalarTy.I8 /- [core::num::{i16}::wrapping_add] -/ -def core.num.I16.wrapping_add : I16 → I16 → I16 := @Scalar.wrapping_add ScalarTy.I16 +@[progress_pure_def] +def core.num.I16.wrapping_add : I16 → I16 → I16 := @IScalar.wrapping_add IScalarTy.I16 /- [core::num::{i32}::wrapping_add] -/ -def core.num.I32.wrapping_add : I32 → I32 → I32 := @Scalar.wrapping_add ScalarTy.I32 +@[progress_pure_def] +def core.num.I32.wrapping_add : I32 → I32 → I32 := @IScalar.wrapping_add IScalarTy.I32 /- [core::num::{i64}::wrapping_add] -/ -def core.num.I64.wrapping_add : I64 → I64 → I64 := @Scalar.wrapping_add ScalarTy.I64 +@[progress_pure_def] +def core.num.I64.wrapping_add : I64 → I64 → I64 := @IScalar.wrapping_add IScalarTy.I64 /- [core::num::{i128}::wrapping_add] -/ -def core.num.I128.wrapping_add : I128 → I128 → I128 := @Scalar.wrapping_add ScalarTy.I128 +@[progress_pure_def] +def core.num.I128.wrapping_add : I128 → I128 → I128 := @IScalar.wrapping_add IScalarTy.I128 /- [core::num::{isize}::wrapping_add] -/ -def core.num.Isize.wrapping_add : Isize → Isize → Isize := @Scalar.wrapping_add ScalarTy.Isize +@[progress_pure_def] +def core.num.Isize.wrapping_add : Isize → Isize → Isize := @IScalar.wrapping_add IScalarTy.Isize + +@[simp] theorem UScalar.wrapping_add_bv_eq {ty} (x y : UScalar ty) : + (wrapping_add x y).bv = x.bv + y.bv := by + simp [wrapping_add] + +@[simp] theorem U8.wrapping_add_bv_eq (x y : U8) : + (core.num.U8.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.U8.wrapping_add, bv] + +@[simp] theorem U16.wrapping_add_bv_eq (x y : U16) : + (core.num.U16.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.U16.wrapping_add, bv] --- TODO: reasoning lemmas for wrapping add +@[simp] theorem U32.wrapping_add_bv_eq (x y : U32) : + (core.num.U32.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.U32.wrapping_add, bv] --- Wrapping sub -def Scalar.wrapping_sub {ty} (x y : Scalar ty) : Scalar ty := sorry +@[simp] theorem U64.wrapping_add_bv_eq (x y : U64) : + (core.num.U64.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.U64.wrapping_add, bv] + +@[simp] theorem U128.wrapping_add_bv_eq (x y : U128) : + (core.num.U128.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.U128.wrapping_add, bv] + +@[simp] theorem Usize.wrapping_add_bv_eq (x y : Usize) : + (core.num.Usize.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.Usize.wrapping_add, bv] + +@[simp] theorem IScalar.wrapping_add_bv_eq {ty} (x y : IScalar ty) : + (wrapping_add x y).bv = x.bv + y.bv := by + simp [wrapping_add] + +@[simp] theorem I8.wrapping_add_bv_eq (x y : I8) : + (core.num.I8.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.I8.wrapping_add, bv] + +@[simp] theorem I16.wrapping_add_bv_eq (x y : I16) : + (core.num.I16.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.I16.wrapping_add, bv] + +@[simp] theorem I32.wrapping_add_bv_eq (x y : I32) : + (core.num.I32.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.I32.wrapping_add, bv] + +@[simp] theorem I64.wrapping_add_bv_eq (x y : I64) : + (core.num.I64.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.I64.wrapping_add, bv] + +@[simp] theorem I128.wrapping_add_bv_eq (x y : I128) : + (core.num.I128.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.I128.wrapping_add, bv] + +@[simp] theorem Isize.wrapping_add_bv_eq (x y : Isize) : + (core.num.Isize.wrapping_add x y).bv = x.bv + y.bv := by + simp [core.num.Isize.wrapping_add, bv] + +@[simp] theorem UScalar.wrapping_add_val_eq {ty} (x y : UScalar ty) : + (wrapping_add x y).val = (x.val + y.val) % (UScalar.max ty + 1) := by + simp only [wrapping_add, val, max] + have : 0 < 2^ty.numBits := by simp + have : 2 ^ ty.numBits - 1 + 1 = 2^ty.numBits := by omega + simp [this] + +@[simp] theorem U8.wrapping_add_val_eq (x y : U8) : + (core.num.U8.wrapping_add x y).val = (x.val + y.val) % (UScalar.max .U8 + 1) := + UScalar.wrapping_add_val_eq x y + +@[simp] theorem U16.wrapping_add_val_eq (x y : U16) : + (core.num.U16.wrapping_add x y).val = (x.val + y.val) % (UScalar.max .U16 + 1) := + UScalar.wrapping_add_val_eq x y + +@[simp] theorem U32.wrapping_add_val_eq (x y : U32) : + (core.num.U32.wrapping_add x y).val = (x.val + y.val) % (UScalar.max .U32 + 1) := + UScalar.wrapping_add_val_eq x y + +@[simp] theorem U64.wrapping_add_val_eq (x y : U64) : + (core.num.U64.wrapping_add x y).val = (x.val + y.val) % (UScalar.max .U64 + 1) := + UScalar.wrapping_add_val_eq x y + +@[simp] theorem U128.wrapping_add_val_eq (x y : U128) : + (core.num.U128.wrapping_add x y).val = (x.val + y.val) % (UScalar.max .U128 + 1) := + UScalar.wrapping_add_val_eq x y + +@[simp] theorem Usize.wrapping_add_val_eq (x y : Usize) : + (core.num.Usize.wrapping_add x y).val = (x.val + y.val) % (UScalar.max .Usize + 1) := + UScalar.wrapping_add_val_eq x y + +@[simp] theorem IScalar.wrapping_add_val_eq {ty} (x y : IScalar ty) : + (wrapping_add x y).val = Int.bmod (x.val + y.val) (2^ty.numBits) := by + simp only [wrapping_add, val] + simp + +@[simp] theorem I8.wrapping_add_val_eq (x y : I8) : + (core.num.I8.wrapping_add x y).val = Int.bmod (x.val + y.val) (2^8) := + IScalar.wrapping_add_val_eq x y + +@[simp] theorem I16.wrapping_add_val_eq (x y : I16) : + (core.num.I16.wrapping_add x y).val = Int.bmod (x.val + y.val) (2^16) := + IScalar.wrapping_add_val_eq x y + +@[simp] theorem I32.wrapping_add_val_eq (x y : I32) : + (core.num.I32.wrapping_add x y).val = Int.bmod (x.val + y.val) (2^32) := + IScalar.wrapping_add_val_eq x y + +@[simp] theorem I64.wrapping_add_val_eq (x y : I64) : + (core.num.I64.wrapping_add x y).val = Int.bmod (x.val + y.val) (2^64) := + IScalar.wrapping_add_val_eq x y + +@[simp] theorem I128.wrapping_add_val_eq (x y : I128) : + (core.num.I128.wrapping_add x y).val = Int.bmod (x.val + y.val) (2^128) := + IScalar.wrapping_add_val_eq x y + +@[simp] theorem Isize.wrapping_add_val_eq (x y : Isize) : + (core.num.Isize.wrapping_add x y).val = Int.bmod (x.val + y.val) (2^System.Platform.numBits) := + IScalar.wrapping_add_val_eq x y + +/-! +### Wrapping Sub +-/ + +def UScalar.wrapping_sub {ty} (x y : UScalar ty) : UScalar ty := ⟨ x.bv - y.bv ⟩ /- [core::num::{u8}::wrapping_sub] -/ -def core.num.U8.wrapping_sub : U8 → U8 → U8 := @Scalar.wrapping_sub ScalarTy.U8 +@[progress_pure_def] +def core.num.U8.wrapping_sub : U8 → U8 → U8 := @UScalar.wrapping_sub UScalarTy.U8 /- [core::num::{u16}::wrapping_sub] -/ -def core.num.U16.wrapping_sub : U16 → U16 → U16 := @Scalar.wrapping_sub ScalarTy.U16 +@[progress_pure_def] +def core.num.U16.wrapping_sub : U16 → U16 → U16 := @UScalar.wrapping_sub UScalarTy.U16 /- [core::num::{u32}::wrapping_sub] -/ -def core.num.U32.wrapping_sub : U32 → U32 → U32 := @Scalar.wrapping_sub ScalarTy.U32 +@[progress_pure_def] +def core.num.U32.wrapping_sub : U32 → U32 → U32 := @UScalar.wrapping_sub UScalarTy.U32 /- [core::num::{u64}::wrapping_sub] -/ -def core.num.U64.wrapping_sub : U64 → U64 → U64 := @Scalar.wrapping_sub ScalarTy.U64 +@[progress_pure_def] +def core.num.U64.wrapping_sub : U64 → U64 → U64 := @UScalar.wrapping_sub UScalarTy.U64 /- [core::num::{u128}::wrapping_sub] -/ -def core.num.U128.wrapping_sub : U128 → U128 → U128 := @Scalar.wrapping_sub ScalarTy.U128 +@[progress_pure_def] +def core.num.U128.wrapping_sub : U128 → U128 → U128 := @UScalar.wrapping_sub UScalarTy.U128 /- [core::num::{usize}::wrapping_sub] -/ -def core.num.Usize.wrapping_sub : Usize → Usize → Usize := @Scalar.wrapping_sub ScalarTy.Usize +@[progress_pure_def] +def core.num.Usize.wrapping_sub : Usize → Usize → Usize := @UScalar.wrapping_sub UScalarTy.Usize + +def IScalar.wrapping_sub {ty} (x y : IScalar ty) : IScalar ty := ⟨ x.bv - y.bv ⟩ /- [core::num::{i8}::wrapping_sub] -/ -def core.num.I8.wrapping_sub : I8 → I8 → I8 := @Scalar.wrapping_sub ScalarTy.I8 +@[progress_pure_def] +def core.num.I8.wrapping_sub : I8 → I8 → I8 := @IScalar.wrapping_sub IScalarTy.I8 /- [core::num::{i16}::wrapping_sub] -/ -def core.num.I16.wrapping_sub : I16 → I16 → I16 := @Scalar.wrapping_sub ScalarTy.I16 +@[progress_pure_def] +def core.num.I16.wrapping_sub : I16 → I16 → I16 := @IScalar.wrapping_sub IScalarTy.I16 /- [core::num::{i32}::wrapping_sub] -/ -def core.num.I32.wrapping_sub : I32 → I32 → I32 := @Scalar.wrapping_sub ScalarTy.I32 +@[progress_pure_def] +def core.num.I32.wrapping_sub : I32 → I32 → I32 := @IScalar.wrapping_sub IScalarTy.I32 /- [core::num::{i64}::wrapping_sub] -/ -def core.num.I64.wrapping_sub : I64 → I64 → I64 := @Scalar.wrapping_sub ScalarTy.I64 +@[progress_pure_def] +def core.num.I64.wrapping_sub : I64 → I64 → I64 := @IScalar.wrapping_sub IScalarTy.I64 /- [core::num::{i128}::wrapping_sub] -/ -def core.num.I128.wrapping_sub : I128 → I128 → I128 := @Scalar.wrapping_sub ScalarTy.I128 +@[progress_pure_def] +def core.num.I128.wrapping_sub : I128 → I128 → I128 := @IScalar.wrapping_sub IScalarTy.I128 /- [core::num::{isize}::wrapping_sub] -/ -def core.num.Isize.wrapping_sub : Isize → Isize → Isize := @Scalar.wrapping_sub ScalarTy.Isize +@[progress_pure_def] +def core.num.Isize.wrapping_sub : Isize → Isize → Isize := @IScalar.wrapping_sub IScalarTy.Isize + +@[simp] theorem UScalar.wrapping_sub_bv_eq {ty} (x y : UScalar ty) : + (wrapping_sub x y).bv = x.bv - y.bv := by + simp [wrapping_sub] + +@[simp] theorem U8.wrapping_sub_bv_eq (x y : U8) : + (core.num.U8.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.U8.wrapping_sub, bv] + +@[simp] theorem U16.wrapping_sub_bv_eq (x y : U16) : + (core.num.U16.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.U16.wrapping_sub, bv] + +@[simp] theorem U32.wrapping_sub_bv_eq (x y : U32) : + (core.num.U32.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.U32.wrapping_sub, bv] + +@[simp] theorem U64.wrapping_sub_bv_eq (x y : U64) : + (core.num.U64.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.U64.wrapping_sub, bv] + +@[simp] theorem U128.wrapping_sub_bv_eq (x y : U128) : + (core.num.U128.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.U128.wrapping_sub, bv] + +@[simp] theorem Usize.wrapping_sub_bv_eq (x y : Usize) : + (core.num.Usize.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.Usize.wrapping_sub, bv] + +@[simp] theorem IScalar.wrapping_sub_bv_eq {ty} (x y : IScalar ty) : + (wrapping_sub x y).bv = x.bv - y.bv := by + simp [wrapping_sub] + +@[simp] theorem I8.wrapping_sub_bv_eq (x y : I8) : + (core.num.I8.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.I8.wrapping_sub, bv] + +@[simp] theorem I16.wrapping_sub_bv_eq (x y : I16) : + (core.num.I16.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.I16.wrapping_sub, bv] + +@[simp] theorem I32.wrapping_sub_bv_eq (x y : I32) : + (core.num.I32.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.I32.wrapping_sub, bv] --- TODO: reasoning lemmas for wrapping sub +@[simp] theorem I64.wrapping_sub_bv_eq (x y : I64) : + (core.num.I64.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.I64.wrapping_sub, bv] --- Rotate left -def Scalar.rotate_left {ty} (x : Scalar ty) (shift : U32) : Scalar ty := sorry +@[simp] theorem I128.wrapping_sub_bv_eq (x y : I128) : + (core.num.I128.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.I128.wrapping_sub, bv] + +@[simp] theorem Isize.wrapping_sub_bv_eq (x y : Isize) : + (core.num.Isize.wrapping_sub x y).bv = x.bv - y.bv := by + simp [core.num.Isize.wrapping_sub, bv] + +@[simp] theorem UScalar.wrapping_sub_val_eq {ty} (x y : UScalar ty) : + (wrapping_sub x y).val = (x.val + (UScalar.size ty - y.val)) % UScalar.size ty := by + simp only [wrapping_sub, val, size] + have : 0 < 2^ty.numBits := by simp + have : 2 ^ ty.numBits - 1 + 1 = 2^ty.numBits := by omega + simp [this] + ring_nf + +@[simp] theorem U8.wrapping_sub_val_eq (x y : U8) : + (core.num.U8.wrapping_sub x y).val = (x.val + (UScalar.size .U8 - y.val)) % UScalar.size .U8 := + UScalar.wrapping_sub_val_eq x y + +@[simp] theorem U16.wrapping_sub_val_eq (x y : U16) : + (core.num.U16.wrapping_sub x y).val = (x.val + (UScalar.size .U16 - y.val)) % UScalar.size .U16 := + UScalar.wrapping_sub_val_eq x y + +@[simp] theorem U32.wrapping_sub_val_eq (x y : U32) : + (core.num.U32.wrapping_sub x y).val = (x.val + (UScalar.size .U32 - y.val)) % UScalar.size .U32 := + UScalar.wrapping_sub_val_eq x y + +@[simp] theorem U64.wrapping_sub_val_eq (x y : U64) : + (core.num.U64.wrapping_sub x y).val = (x.val + (UScalar.size .U64 - y.val)) % UScalar.size .U64 := + UScalar.wrapping_sub_val_eq x y + +@[simp] theorem U128.wrapping_sub_val_eq (x y : U128) : + (core.num.U128.wrapping_sub x y).val = (x.val + (UScalar.size .U128 - y.val)) % UScalar.size .U128 := + UScalar.wrapping_sub_val_eq x y + +@[simp] theorem Usize.wrapping_sub_val_eq (x y : Usize) : + (core.num.Usize.wrapping_sub x y).val = (x.val + (UScalar.size .Usize - y.val)) % UScalar.size .Usize := + UScalar.wrapping_sub_val_eq x y + +@[simp] theorem IScalar.wrapping_sub_val_eq {ty} (x y : IScalar ty) : + (wrapping_sub x y).val = Int.bmod (x.val - y.val) (2^ty.numBits) := by + simp only [wrapping_sub, val] + simp + +@[simp] theorem I8.wrapping_sub_val_eq (x y : I8) : + (core.num.I8.wrapping_sub x y).val = Int.bmod (x.val - y.val) (2^8) := + IScalar.wrapping_sub_val_eq x y + +@[simp] theorem I16.wrapping_sub_val_eq (x y : I16) : + (core.num.I16.wrapping_sub x y).val = Int.bmod (x.val - y.val) (2^16) := + IScalar.wrapping_sub_val_eq x y + +@[simp] theorem I32.wrapping_sub_val_eq (x y : I32) : + (core.num.I32.wrapping_sub x y).val = Int.bmod (x.val - y.val) (2^32) := + IScalar.wrapping_sub_val_eq x y + +@[simp] theorem I64.wrapping_sub_val_eq (x y : I64) : + (core.num.I64.wrapping_sub x y).val = Int.bmod (x.val - y.val) (2^64) := + IScalar.wrapping_sub_val_eq x y + +@[simp] theorem I128.wrapping_sub_val_eq (x y : I128) : + (core.num.I128.wrapping_sub x y).val = Int.bmod (x.val - y.val) (2^128) := + IScalar.wrapping_sub_val_eq x y + +@[simp] theorem Isize.wrapping_sub_val_eq (x y : Isize) : + (core.num.Isize.wrapping_sub x y).val = Int.bmod (x.val - y.val) (2^System.Platform.numBits) := + IScalar.wrapping_sub_val_eq x y + + +/-! +# Rotation +-/ + +/-! +## Rotate Left +-/ +def UScalar.rotate_left {ty} (x : UScalar ty) (shift : U32) : UScalar ty := + ⟨ x.bv.rotateLeft shift.val ⟩ /- [core::num::{u8}::rotate_left] -/ -def core.num.U8.rotate_left := @Scalar.rotate_left ScalarTy.U8 +@[progress_pure_def] +def core.num.U8.rotate_left : U8 → U32 → U8 := @UScalar.rotate_left .U8 /- [core::num::{u16}::rotate_left] -/ -def core.num.U16.rotate_left := @Scalar.rotate_left ScalarTy.U16 +@[progress_pure_def] +def core.num.U16.rotate_left : U16 → U32 → U16 := @UScalar.rotate_left .U16 /- [core::num::{u32}::rotate_left] -/ -def core.num.U32.rotate_left := @Scalar.rotate_left ScalarTy.U32 +@[progress_pure_def] +def core.num.U32.rotate_left : U32 → U32 → U32 := @UScalar.rotate_left .U32 /- [core::num::{u64}::rotate_left] -/ -def core.num.U64.rotate_left := @Scalar.rotate_left ScalarTy.U64 +@[progress_pure_def] +def core.num.U64.rotate_left : U64 → U32 → U64 := @UScalar.rotate_left .U64 /- [core::num::{u128}::rotate_left] -/ -def core.num.U128.rotate_left := @Scalar.rotate_left ScalarTy.U128 +@[progress_pure_def] +def core.num.U128.rotate_left : U128 → U32 → U128 := @UScalar.rotate_left .U128 /- [core::num::{usize}::rotate_left] -/ -def core.num.Usize.rotate_left := @Scalar.rotate_left ScalarTy.Usize +@[progress_pure_def] +def core.num.Usize.rotate_left : Usize → U32 → Usize := @UScalar.rotate_left .Usize -/- [core::num::{i8}::rotate_left] -/ -def core.num.I8.rotate_left := @Scalar.rotate_left ScalarTy.I8 +def IScalar.rotate_left {ty} (x : IScalar ty) (shift : U32) : IScalar ty := + ⟨ x.bv.rotateLeft shift.val ⟩ -/- [core::num::{i16}::rotate_left] -/ -def core.num.I16.rotate_left := @Scalar.rotate_left ScalarTy.I16 +/- [core::num::{u8}::rotate_left] -/ +@[progress_pure_def] +def core.num.I8.rotate_left : I8 → U32 → I8 := @IScalar.rotate_left .I8 -/- [core::num::{i32}::rotate_left] -/ -def core.num.I32.rotate_left := @Scalar.rotate_left ScalarTy.I32 +/- [core::num::{u16}::rotate_left] -/ +@[progress_pure_def] +def core.num.I16.rotate_left : I16 → U32 → I16 := @IScalar.rotate_left .I16 -/- [core::num::{i64}::rotate_left] -/ -def core.num.I64.rotate_left := @Scalar.rotate_left ScalarTy.I64 +/- [core::num::{u32}::rotate_left] -/ +@[progress_pure_def] +def core.num.I32.rotate_left : I32 → U32 → I32 := @IScalar.rotate_left .I32 -/- [core::num::{i128}::rotate_left] -/ -def core.num.I128.rotate_left := @Scalar.rotate_left ScalarTy.I128 +/- [core::num::{u64}::rotate_left] -/ +@[progress_pure_def] +def core.num.I64.rotate_left : I64 → U32 → I64 := @IScalar.rotate_left .I64 -/- [core::num::{isize}::rotate_left] -/ -def core.num.Isize.rotate_left := @Scalar.rotate_left ScalarTy.Isize +/- [core::num::{u128}::rotate_left] -/ +@[progress_pure_def] +def core.num.I128.rotate_left : I128 → U32 → I128 := @IScalar.rotate_left .I128 --- TODO: reasoning lemmas for rotate left +/- [core::num::{usize}::rotate_left] -/ +@[progress_pure_def] +def core.num.Isize.rotate_left : Isize → U32 → Isize := @IScalar.rotate_left .Isize --- Rotate right -def Scalar.rotate_right {ty} (x : Scalar ty) (shift : U32) : Scalar ty := sorry +/-! +## Rotate Left +-/ +def UScalar.rotate_right {ty} (x : UScalar ty) (shift : U32) : UScalar ty := + ⟨ x.bv.rotateLeft shift.val ⟩ /- [core::num::{u8}::rotate_right] -/ -def core.num.U8.rotate_right := @Scalar.rotate_right ScalarTy.U8 +@[progress_pure_def] +def core.num.U8.rotate_right : U8 → U32 → U8 := @UScalar.rotate_right .U8 /- [core::num::{u16}::rotate_right] -/ -def core.num.U16.rotate_right := @Scalar.rotate_right ScalarTy.U16 +@[progress_pure_def] +def core.num.U16.rotate_right : U16 → U32 → U16 := @UScalar.rotate_right .U16 /- [core::num::{u32}::rotate_right] -/ -def core.num.U32.rotate_right := @Scalar.rotate_right ScalarTy.U32 +@[progress_pure_def] +def core.num.U32.rotate_right : U32 → U32 → U32 := @UScalar.rotate_right .U32 /- [core::num::{u64}::rotate_right] -/ -def core.num.U64.rotate_right := @Scalar.rotate_right ScalarTy.U64 +@[progress_pure_def] +def core.num.U64.rotate_right : U64 → U32 → U64 := @UScalar.rotate_right .U64 /- [core::num::{u128}::rotate_right] -/ -def core.num.U128.rotate_right := @Scalar.rotate_right ScalarTy.U128 +@[progress_pure_def] +def core.num.U128.rotate_right : U128 → U32 → U128 := @UScalar.rotate_right .U128 /- [core::num::{usize}::rotate_right] -/ -def core.num.Usize.rotate_right := @Scalar.rotate_right ScalarTy.Usize +@[progress_pure_def] +def core.num.Usize.rotate_right : Usize → U32 → Usize := @UScalar.rotate_right .Usize -/- [core::num::{i8}::rotate_right] -/ -def core.num.I8.rotate_right := @Scalar.rotate_right ScalarTy.I8 +def IScalar.rotate_right {ty} (x : IScalar ty) (shift : U32) : IScalar ty := + ⟨ x.bv.rotateLeft shift.val ⟩ -/- [core::num::{i16}::rotate_right] -/ -def core.num.I16.rotate_right := @Scalar.rotate_right ScalarTy.I16 +/- [core::num::{u8}::rotate_right] -/ +@[progress_pure_def] +def core.num.I8.rotate_right : I8 → U32 → I8 := @IScalar.rotate_right .I8 -/- [core::num::{i32}::rotate_right] -/ -def core.num.I32.rotate_right := @Scalar.rotate_right ScalarTy.I32 +/- [core::num::{u16}::rotate_right] -/ +@[progress_pure_def] +def core.num.I16.rotate_right : I16 → U32 → I16 := @IScalar.rotate_right .I16 -/- [core::num::{i64}::rotate_right] -/ -def core.num.I64.rotate_right := @Scalar.rotate_right ScalarTy.I64 +/- [core::num::{u32}::rotate_right] -/ +@[progress_pure_def] +def core.num.I32.rotate_right : I32 → U32 → I32 := @IScalar.rotate_right .I32 -/- [core::num::{i128}::rotate_right] -/ -def core.num.I128.rotate_right := @Scalar.rotate_right ScalarTy.I128 +/- [core::num::{u64}::rotate_right] -/ +@[progress_pure_def] +def core.num.I64.rotate_right : I64 → U32 → I64 := @IScalar.rotate_right .I64 -/- [core::num::{isize}::rotate_right] -/ -def core.num.Isize.rotate_right := @Scalar.rotate_right ScalarTy.Isize +/- [core::num::{u128}::rotate_right] -/ +@[progress_pure_def] +def core.num.I128.rotate_right : I128 → U32 → I128 := @IScalar.rotate_right .I128 --- TODO: reasoning lemmas for rotate right +/- [core::num::{usize}::rotate_right] -/ +@[progress_pure_def] +def core.num.Isize.rotate_right : Isize → U32 → Isize := @IScalar.rotate_right .Isize end Std diff --git a/backends/lean/Aeneas/Std/ScalarCore.lean b/backends/lean/Aeneas/Std/ScalarCore.lean index c0556a37..ea067e66 100644 --- a/backends/lean/Aeneas/Std/ScalarCore.lean +++ b/backends/lean/Aeneas/Std/ScalarCore.lean @@ -4,7 +4,7 @@ import Aeneas.Std.Core import Aeneas.Std.Core import Aeneas.Diverge.Core import Aeneas.Progress.Core -import Aeneas.ScalarTac.IntTac +import Aeneas.ScalarTac.ScalarTac namespace Aeneas @@ -13,330 +13,502 @@ namespace Std -- Deactivate the warnings which appear when we use `#assert` set_option linter.hashCommand false ----------------------- --- MACHINE INTEGERS -- ----------------------- +/-! +# Machine Integers + +Because they tend to behave quite differently, we have two classes of machine integers: one for +signed integers, and another for unsigned integers. Inside of each class, we factor out definitions. +-/ open Result Error open System.Platform.getNumBits --- TODO: is there a way of only importing System.Platform.getNumBits? --- -@[simp] def size_num_bits : Nat := (System.Platform.getNumBits ()).val - --- Remark: Lean seems to use < for the comparisons with the upper bounds by convention. - --- The "structured" bounds -def Isize.smin : Int := - (HPow.hPow 2 (size_num_bits - 1)) -def Isize.smax : Int := (HPow.hPow 2 (size_num_bits - 1)) - 1 -def I8.smin : Int := - (HPow.hPow 2 7) -def I8.smax : Int := HPow.hPow 2 7 - 1 -def I16.smin : Int := - (HPow.hPow 2 15) -def I16.smax : Int := HPow.hPow 2 15 - 1 -def I32.smin : Int := -(HPow.hPow 2 31) -def I32.smax : Int := HPow.hPow 2 31 - 1 -def I64.smin : Int := -(HPow.hPow 2 63) -def I64.smax : Int := HPow.hPow 2 63 - 1 -def I128.smin : Int := -(HPow.hPow 2 127) -def I128.smax : Int := HPow.hPow 2 127 - 1 -def Usize.smin : Nat := 0 -def Usize.smax : Nat := HPow.hPow 2 size_num_bits - 1 -def U8.smin : Nat := 0 -def U8.smax : Nat := HPow.hPow 2 8 - 1 -def U16.smin : Nat := 0 -def U16.smax : Nat := HPow.hPow 2 16 - 1 -def U32.smin : Nat := 0 -def U32.smax : Nat := HPow.hPow 2 32 - 1 -def U64.smin : Nat := 0 -def U64.smax : Nat := HPow.hPow 2 64 - 1 -def U128.smin : Nat := 0 -def U128.smax : Nat := HPow.hPow 2 128 - 1 - --- The "normalized" bounds, that we use in practice -def I8.min : Int := -128 -def I8.max : Int := 127 -def I16.min : Int := -32768 -def I16.max : Int := 32767 -def I32.min : Int := -2147483648 -def I32.max : Int := 2147483647 -def I64.min : Int := -9223372036854775808 -def I64.max : Int := 9223372036854775807 -def I128.min : Int := -170141183460469231731687303715884105728 -def I128.max : Int := 170141183460469231731687303715884105727 -@[simp] -def U8.min : Nat := 0 -def U8.max : Nat := 255 -@[simp] -def U16.min : Nat := 0 -def U16.max : Nat := 65535 -@[simp] -def U32.min : Nat := 0 -def U32.max : Nat := 4294967295 -@[simp] -def U64.min : Nat := 0 -def U64.max : Nat := 18446744073709551615 -@[simp] -def U128.min : Nat := 0 -def U128.max : Nat := 340282366920938463463374607431768211455 -@[simp] -def Usize.min : Nat := 0 - -def Isize.refined_min : { n:Int // n = I32.min ∨ n = I64.min } := - ⟨ Isize.smin, by - simp [Isize.smin] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] <;> decide ⟩ - -def Isize.refined_max : { n:Int // n = I32.max ∨ n = I64.max } := - ⟨ Isize.smax, by - simp [Isize.smax] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] <;> decide ⟩ - -def Usize.refined_max : { n:Nat // n = U32.max ∨ n = U64.max } := - ⟨ Usize.smax, by - simp [Usize.smax] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] <;> decide ⟩ - -def Isize.min := Isize.refined_min.val -def Isize.max := Isize.refined_max.val -def Usize.max := Usize.refined_max.val - -theorem Usize.bounds_eq : - Usize.max = U32.max ∨ Usize.max = U64.max := by - simp [Usize.min, Usize.max, refined_max, smin, smax] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] <;> decide - -theorem Isize.bounds_eq : - (Isize.min = I32.min ∧ Isize.max = I32.max) - ∨ (Isize.min = I64.min ∧ Isize.max = I64.max) := by - simp [Isize.min, Isize.max, refined_min, refined_max, smin, smax] - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> simp [*] <;> decide +/-- Kinds of unsigned integers -/ +inductive UScalarTy where +| Usize +| U8 +| U16 +| U32 +| U64 +| U128 -inductive ScalarTy where +/-- Kinds of signed integers -/ +inductive IScalarTy where | Isize | I8 | I16 | I32 | I64 | I128 -| Usize -| U8 -| U16 -| U32 -| U64 -| U128 -@[reducible] -def ScalarTy.isSigned (ty : ScalarTy) : Bool := +def UScalarTy.numBits (ty : UScalarTy) : Nat := + match ty with + | Usize => System.Platform.numBits + | U8 => 8 + | U16 => 16 + | U32 => 32 + | U64 => 64 + | U128 => 128 + +def IScalarTy.numBits (ty : IScalarTy) : Nat := match ty with - | Isize - | I8 - | I16 - | I32 - | I64 - | I128 => true - | Usize - | U8 - | U16 - | U32 - | U64 - | U128 => false + | Isize => System.Platform.numBits + | I8 => 8 + | I16 => 16 + | I32 => 32 + | I64 => 64 + | I128 => 128 + +/-- Signed integer -/ +structure UScalar (ty : UScalarTy) where + /- The internal representation is a bit-vector -/ + bv : BitVec ty.numBits +deriving Repr, BEq, DecidableEq --- FIXME(chore): bulk prove them via macro? -instance : Fact (¬ ScalarTy.isSigned .Usize) where - out := by decide +def UScalar.val {ty} (x : UScalar ty) : ℕ := x.bv.toNat -instance : Fact (¬ ScalarTy.isSigned .U8) where - out := by decide +/-- Unsigned integer -/ +structure IScalar (ty : IScalarTy) where + /- The internal representation is a bit-vector -/ + bv : BitVec ty.numBits +deriving Repr, BEq, DecidableEq -instance : Fact (¬ ScalarTy.isSigned .U16) where - out := by decide +def IScalar.val {ty} (x : IScalar ty) : ℤ := x.bv.toInt -instance : Fact (¬ ScalarTy.isSigned .U32) where - out := by decide +/-! +# Bounds, Size -instance : Fact (¬ ScalarTy.isSigned .U64) where - out := by decide +**Remark:** we mark most constants as irreducible because otherwise it leads to issues +when using tactics like `assumption`: it often happens that unification attempts to reduce +complex expressions (for instance by trying to reduce an expression like `2^128`, which +is extremely expensive). +-/ -instance : Fact (¬ ScalarTy.isSigned .U128) where - out := by decide +irreducible_def UScalar.max (ty : UScalarTy) : Nat := 2^ty.numBits-1 +irreducible_def IScalar.min (ty : IScalarTy) : Int := -2^(ty.numBits - 1) +irreducible_def IScalar.max (ty : IScalarTy) : Int := 2^(ty.numBits - 1)-1 + +irreducible_def UScalar.size (ty : UScalarTy) : Nat := 2^ty.numBits +irreducible_def IScalar.size (ty : IScalarTy) : Int := 2^ty.numBits + +/-! ## Num Bits -/ +irreducible_def U8.numBits : Nat := UScalarTy.U8.numBits +irreducible_def U16.numBits : Nat := UScalarTy.U16.numBits +irreducible_def U32.numBits : Nat := UScalarTy.U32.numBits +irreducible_def U64.numBits : Nat := UScalarTy.U64.numBits +irreducible_def U128.numBits : Nat := UScalarTy.U128.numBits +irreducible_def Usize.numBits : Nat := UScalarTy.Usize.numBits + +irreducible_def I8.numBits : Nat := IScalarTy.I8.numBits +irreducible_def I16.numBits : Nat := IScalarTy.I16.numBits +irreducible_def I32.numBits : Nat := IScalarTy.I32.numBits +irreducible_def I64.numBits : Nat := IScalarTy.I64.numBits +irreducible_def I128.numBits : Nat := IScalarTy.I128.numBits +irreducible_def Isize.numBits : Nat := IScalarTy.Isize.numBits + +/-! ## Bounds -/ +irreducible_def U8.max : Nat := 2^U8.numBits - 1 +irreducible_def U16.max : Nat := 2^U16.numBits - 1 +irreducible_def U32.max : Nat := 2^U32.numBits - 1 +irreducible_def U64.max : Nat := 2^U64.numBits - 1 +irreducible_def U128.max : Nat := 2^U128.numBits - 1 +irreducible_def Usize.max : Nat := 2^Usize.numBits - 1 + +irreducible_def I8.min : Int := -2^(I8.numBits - 1) +irreducible_def I8.max : Int := 2^(I8.numBits - 1) - 1 +irreducible_def I16.min : Int := -2^(I16.numBits - 1) +irreducible_def I16.max : Int := 2^(I16.numBits - 1) - 1 +irreducible_def I32.min : Int := -2^(I32.numBits - 1) +irreducible_def I32.max : Int := 2^(I32.numBits - 1) - 1 +irreducible_def I64.min : Int := -2^(I64.numBits - 1) +irreducible_def I64.max : Int := 2^(I64.numBits - 1) - 1 +irreducible_def I128.min : Int := -2^(I128.numBits - 1) +irreducible_def I128.max : Int := 2^(I128.numBits - 1) - 1 +irreducible_def Isize.min : Int := -2^(Isize.numBits - 1) +irreducible_def Isize.max : Int := 2^(Isize.numBits - 1) - 1 + +/-! ## Size -/ +irreducible_def U8.size : Nat := 2^U8.numBits +irreducible_def U16.size : Nat := 2^U16.numBits +irreducible_def U32.size : Nat := 2^U32.numBits +irreducible_def U64.size : Nat := 2^U64.numBits +irreducible_def U128.size : Nat := 2^U128.numBits +irreducible_def Usize.size : Nat := 2^Usize.numBits + +irreducible_def I8.size : Nat := 2^I8.numBits +irreducible_def I16.size : Nat := 2^I16.numBits +irreducible_def I32.size : Nat := 2^I32.numBits +irreducible_def I64.size : Nat := 2^I64.numBits +irreducible_def I128.size : Nat := 2^I128.numBits +irreducible_def Isize.size : Nat := 2^Isize.numBits + +/-! ## "Reduced" Constants -/ +/-! ### Size -/ +def I8.rSize : Int := 256 +def I16.rSize : Int := 65536 +def I32.rSize : Int := 4294967296 +def I64.rSize : Int := 18446744073709551616 +def I128.rSize : Int := 340282366920938463463374607431768211456 + +def U8.rSize : Nat := 256 +def U16.rSize : Nat := 65536 +def U32.rSize : Nat := 4294967296 +def U64.rSize : Nat := 18446744073709551616 +def U128.rSize : Nat := 340282366920938463463374607431768211456 + +/-! ### Bounds -/ +def U8.rMax : Nat := 255 +def U16.rMax : Nat := 65535 +def U32.rMax : Nat := 4294967295 +def U64.rMax : Nat := 18446744073709551615 +def U128.rMax : Nat := 340282366920938463463374607431768211455 +def Usize.rMax : Nat := 2^System.Platform.numBits-1 + +def I8.rMin : Int := -128 +def I8.rMax : Int := 127 +def I16.rMin : Int := -32768 +def I16.rMax : Int := 32767 +def I32.rMin : Int := -2147483648 +def I32.rMax : Int := 2147483647 +def I64.rMin : Int := -9223372036854775808 +def I64.rMax : Int := 9223372036854775807 +def I128.rMin : Int := -170141183460469231731687303715884105728 +def I128.rMax : Int := 170141183460469231731687303715884105727 +def Isize.rMin : Int := -2^(System.Platform.numBits - 1) +def Isize.rMax : Int := 2^(System.Platform.numBits - 1)-1 + +def UScalar.rMax (ty : UScalarTy) : Nat := + match ty with + | .Usize => Usize.rMax + | .U8 => U8.rMax + | .U16 => U16.rMax + | .U32 => U32.rMax + | .U64 => U64.rMax + | .U128 => U128.rMax + +def IScalar.rMin (ty : IScalarTy) : Int := + match ty with + | .Isize => Isize.rMin + | .I8 => I8.rMin + | .I16 => I16.rMin + | .I32 => I32.rMin + | .I64 => I64.rMin + | .I128 => I128.rMin + +def IScalar.rMax (ty : IScalarTy) : Int := + match ty with + | .Isize => Isize.rMax + | .I8 => I8.rMax + | .I16 => I16.rMax + | .I32 => I32.rMax + | .I64 => I64.rMax + | .I128 => I128.rMax + +/-! # Theorems -/ +theorem UScalarTy.numBits_nonzero (ty : UScalarTy) : ty.numBits ≠ 0 := by + dcases ty <;> simp [numBits] + dcases System.Platform.numBits_eq <;> simp_all + +theorem IScalarTy.numBits_nonzero (ty : IScalarTy) : ty.numBits ≠ 0 := by + dcases ty <;> simp [numBits] + dcases System.Platform.numBits_eq <;> simp_all + +@[simp, scalar_tac_simp] theorem UScalarTy.U8_numBits_eq : UScalarTy.U8.numBits = 8 := by rfl +@[simp, scalar_tac_simp] theorem UScalarTy.U16_numBits_eq : UScalarTy.U16.numBits = 16 := by rfl +@[simp, scalar_tac_simp] theorem UScalarTy.U32_numBits_eq : UScalarTy.U32.numBits = 32 := by rfl +@[simp, scalar_tac_simp] theorem UScalarTy.U64_numBits_eq : UScalarTy.U64.numBits = 64 := by rfl +@[simp, scalar_tac_simp] theorem UScalarTy.U128_numBits_eq : UScalarTy.U128.numBits = 128 := by rfl +@[simp, scalar_tac_simp] theorem UScalarTy.Usize_numBits_eq : UScalarTy.Usize.numBits = System.Platform.numBits := by rfl + +@[simp, scalar_tac_simp] theorem IScalarTy.I8_numBits_eq : IScalarTy.I8.numBits = 8 := by rfl +@[simp, scalar_tac_simp] theorem IScalarTy.I16_numBits_eq : IScalarTy.I16.numBits = 16 := by rfl +@[simp, scalar_tac_simp] theorem IScalarTy.I32_numBits_eq : IScalarTy.I32.numBits = 32 := by rfl +@[simp, scalar_tac_simp] theorem IScalarTy.I64_numBits_eq : IScalarTy.I64.numBits = 64 := by rfl +@[simp, scalar_tac_simp] theorem IScalarTy.I128_numBits_eq : IScalarTy.I128.numBits = 128 := by rfl +@[simp, scalar_tac_simp] theorem IScalarTy.Isize_numBits_eq : IScalarTy.Isize.numBits = System.Platform.numBits := by rfl + +@[simp] theorem UScalar.max_UScalarTy_U8_eq : UScalar.max .U8 = U8.max := by simp [UScalar.max, U8.max, U8.numBits] +@[simp] theorem UScalar.max_UScalarTy_U16_eq : UScalar.max .U16 = U16.max := by simp [UScalar.max, U16.max, U16.numBits] +@[simp] theorem UScalar.max_UScalarTy_U32_eq : UScalar.max .U32 = U32.max := by simp [UScalar.max, U32.max, U32.numBits] +@[simp] theorem UScalar.max_UScalarTy_U64_eq : UScalar.max .U64 = U64.max := by simp [UScalar.max, U64.max, U64.numBits] +@[simp] theorem UScalar.max_UScalarTy_U128_eq : UScalar.max .U128 = U128.max := by simp [UScalar.max, U128.max, U128.numBits] + +@[simp] theorem IScalar.min_IScalarTy_I8_eq : IScalar.min .I8 = I8.min := by simp [IScalar.min, I8.min, I8.numBits] +@[simp] theorem IScalar.max_IScalarTy_I8_eq : IScalar.max .I8 = I8.max := by simp [IScalar.max, I8.max, I8.numBits] +@[simp] theorem IScalar.min_IScalarTy_I16_eq : IScalar.min .I16 = I16.min := by simp [IScalar.min, I16.min, I16.numBits] +@[simp] theorem IScalar.max_IScalarTy_I16_eq : IScalar.max .I16 = I16.max := by simp [IScalar.max, I16.max, I16.numBits] +@[simp] theorem IScalar.min_IScalarTy_I32_eq : IScalar.min .I32 = I32.min := by simp [IScalar.min, I32.min, I32.numBits] +@[simp] theorem IScalar.max_IScalarTy_I32_eq : IScalar.max .I32 = I32.max := by simp [IScalar.max, I32.max, I32.numBits] +@[simp] theorem IScalar.min_IScalarTy_I64_eq : IScalar.min .I64 = I64.min := by simp [IScalar.min, I64.min, I64.numBits] +@[simp] theorem IScalar.max_IScalarTy_I64_eq : IScalar.max .I64 = I64.max := by simp [IScalar.max, I64.max, I64.numBits] +@[simp] theorem IScalar.min_IScalarTy_I128_eq : IScalar.min .I128 = I128.min := by simp [IScalar.min, I128.min, I128.numBits] +@[simp] theorem IScalar.max_IScalarTy_I128_eq : IScalar.max .I128 = I128.max := by simp [IScalar.max, I128.max, I128.numBits] + +local syntax "simp_uscalar_bounds" : tactic +local macro_rules +| `(tactic|simp_uscalar_bounds) => + `(tactic| + simp [ + UScalar.rMax, UScalar.max, + Usize.rMax, Usize.rMax, Usize.max, + U8.rMax, U8.max, U16.rMax, U16.max, U32.rMax, U32.max, + U64.rMax, U64.max, U128.rMax, U128.max, + U8.numBits, U16.numBits, U32.numBits, U64.numBits, U128.numBits]) + +local syntax "simp_iscalar_bounds" : tactic +local macro_rules +| `(tactic|simp_iscalar_bounds) => + `(tactic| + simp [ + IScalar.rMax, IScalar.max, + IScalar.rMin, IScalar.min, + Isize.rMax, Isize.rMax, Isize.max, + I8.rMax, I8.max, I16.rMax, I16.max, I32.rMax, I32.max, + I64.rMax, I64.max, I128.rMax, I128.max, + Isize.rMin, Isize.rMin, Isize.min, + I8.rMin, I8.min, I16.rMin, I16.min, I32.rMin, I32.min, + I64.rMin, I64.min, I128.rMin, I128.min, + I8.numBits, I16.numBits, I32.numBits, I64.numBits, I128.numBits]) +theorem Usize.bounds_eq : + Usize.max = U32.max ∨ Usize.max = U64.max := by + simp [Usize.max, UScalar.max, Usize.numBits] + cases System.Platform.numBits_eq <;> + simp [*] <;> + simp_uscalar_bounds -def Scalar.smin (ty : ScalarTy) : Int := - match ty with - | .Isize => Isize.smin - | .I8 => I8.smin - | .I16 => I16.smin - | .I32 => I32.smin - | .I64 => I64.smin - | .I128 => I128.smin - | .Usize => Usize.smin - | .U8 => U8.smin - | .U16 => U16.smin - | .U32 => U32.smin - | .U64 => U64.smin - | .U128 => U128.smin - -def Scalar.smax (ty : ScalarTy) : Int := +theorem Isize.bounds_eq : + (Isize.min = I32.min ∧ Isize.max = I32.max) + ∨ (Isize.min = I64.min ∧ Isize.max = I64.max) := by + simp [Isize.min, Isize.max, IScalar.min, IScalar.max, Isize.numBits] + cases System.Platform.numBits_eq <;> + simp [*] <;> simp [*, I32.min, I32.numBits, I32.max, I64.min, I64.numBits, I64.max] + +theorem UScalar.rMax_eq_max (ty : UScalarTy) : UScalar.rMax ty = UScalar.max ty := by + dcases ty <;> + simp_uscalar_bounds + +theorem IScalar.rbound_eq_bound (ty : IScalarTy) : + IScalar.rMin ty = IScalar.min ty ∧ IScalar.rMax ty = IScalar.max ty := by + dcases ty <;> split_conjs <;> + simp_iscalar_bounds + +theorem IScalar.rMin_eq_min (ty : IScalarTy) : IScalar.rMin ty = IScalar.min ty := by + apply (IScalar.rbound_eq_bound ty).left + +theorem IScalar.rMax_eq_max (ty : IScalarTy) : IScalar.rMax ty = IScalar.max ty := by + apply (IScalar.rbound_eq_bound ty).right + +/-! +# "Conservative" Bounds + +We use those because we can't compare to the isize bounds (which can't +reduce at compile-time). Whenever we perform an arithmetic operation like +addition we need to check that the result is in bounds: we first compare +to the conservative bounds, which reduces, then compare to the real bounds. + +This is useful for the various #asserts that we want to reduce at +type-checking time, or when defining constants. +-/ + +def UScalarTy.cNumBits (ty : UScalarTy) : Nat := match ty with - | .Isize => Isize.smax - | .I8 => I8.smax - | .I16 => I16.smax - | .I32 => I32.smax - | .I64 => I64.smax - | .I128 => I128.smax - | .Usize => Usize.smax - | .U8 => U8.smax - | .U16 => U16.smax - | .U32 => U32.smax - | .U64 => U64.smax - | .U128 => U128.smax - -def Scalar.min (ty : ScalarTy) : Int := + | .Usize => U32.numBits + | _ => ty.numBits + +def IScalarTy.cNumBits (ty : IScalarTy) : Nat := match ty with - | .Isize => Isize.min - | .I8 => I8.min - | .I16 => I16.min - | .I32 => I32.min - | .I64 => I64.min - | .I128 => I128.min - | .Usize => Usize.min - | .U8 => U8.min - | .U16 => U16.min - | .U32 => U32.min - | .U64 => U64.min - | .U128 => U128.min - -def Scalar.max (ty : ScalarTy) : Int := + | .Isize => I32.numBits + | _ => ty.numBits + +theorem UScalarTy.cNumBits_le (ty : UScalarTy) : ty.cNumBits ≤ ty.numBits := by + dcases ty <;> simp [cNumBits, numBits, U32.numBits] + dcases System.Platform.numBits_eq <;> simp [*] + +theorem IScalarTy.cNumBits_le (ty : IScalarTy) : ty.cNumBits ≤ ty.numBits := by + dcases ty <;> simp [cNumBits, numBits, I32.numBits] + dcases System.Platform.numBits_eq <;> simp [*] + +theorem UScalarTy.cNumBits_nonzero (ty : UScalarTy) : ty.cNumBits ≠ 0 := by + dcases ty <;> simp [cNumBits, U32.numBits] + +theorem IScalarTy.cNumBits_nonzero (ty : IScalarTy) : ty.cNumBits ≠ 0 := by + dcases ty <;> simp [cNumBits, I32.numBits] + +def UScalar.cMax (ty : UScalarTy) : Nat := match ty with - | .Isize => Isize.max - | .I8 => I8.max - | .I16 => I16.max - | .I32 => I32.max - | .I64 => I64.max - | .I128 => I128.max - | .Usize => Usize.max - | .U8 => U8.max - | .U16 => U16.max - | .U32 => U32.max - | .U64 => U64.max - | .U128 => U128.max - -def Scalar.smin_eq (ty : ScalarTy) : Scalar.min ty = Scalar.smin ty := by - cases ty <;> rfl - -def Scalar.smax_eq (ty : ScalarTy) : Scalar.max ty = Scalar.smax ty := by - cases ty <;> rfl - --- "Conservative" bounds --- We use those because we can't compare to the isize bounds (which can't --- reduce at compile-time). Whenever we perform an arithmetic operation like --- addition we need to check that the result is in bounds: we first compare --- to the conservative bounds, which reduce, then compare to the real bounds. --- This is useful for the various #asserts that we want to reduce at --- type-checking time. -def Scalar.cMin (ty : ScalarTy) : Int := + | .Usize => UScalar.rMax .U32 + | _ => UScalar.rMax ty + +def IScalar.cMin (ty : IScalarTy) : Int := match ty with - | .Isize => Scalar.min .I32 - | _ => Scalar.min ty + | .Isize => IScalar.rMin .I32 + | _ => IScalar.rMin ty -def Scalar.cMax (ty : ScalarTy) : Int := +def IScalar.cMax (ty : IScalarTy) : Int := match ty with - | .Isize => Scalar.max .I32 - | .Usize => Scalar.max .U32 - | _ => Scalar.max ty - -theorem Scalar.min_lt_max (ty : ScalarTy) : Scalar.min ty < Scalar.max ty := by - cases ty <;> simp [Scalar.min, Scalar.max] <;> try decide - . simp [Isize.min, Isize.max] - have h1 := Isize.refined_min.property - have h2 := Isize.refined_max.property - cases h1 <;> cases h2 <;> simp [*] <;> decide - . simp [Usize.max] - have h := Usize.refined_max.property - cases h <;> simp [*] <;> decide - -theorem Scalar.min_le_max (ty : ScalarTy) : Scalar.min ty ≤ Scalar.max ty := by - have := Scalar.min_lt_max ty - int_tac - -theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by - cases ty <;> (simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at *; try decide) - have h := Isize.refined_min.property - cases h <;> simp [*, Isize.min] - decide - -theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by - cases ty <;> (simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at *; try decide) - . have h := Isize.refined_max.property - cases h <;> simp [*, Isize.max]; decide - . have h := Usize.refined_max.property - cases h <;> simp [*, Usize.max]; decide - -theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by - have := Scalar.cMin_bound ty + | .Isize => IScalar.rMax .I32 + | _ => IScalar.rMax ty + +def UScalar.hBounds {ty} (x : UScalar ty) : x.val < 2^ty.numBits := by + dcases h: x.bv + simp [h, val] + +def UScalar.hSize {ty} (x : UScalar ty) : x.val < UScalar.size ty := by + dcases h: x.bv + simp [h, val, size] + +def UScalar.rMax_eq_pow_numBits (ty : UScalarTy) : UScalar.rMax ty = 2^ty.numBits - 1 := by + dcases ty <;> simp [rMax] <;> simp_uscalar_bounds + +def UScalar.cMax_eq_pow_cNumBits (ty : UScalarTy) : UScalar.cMax ty = 2^ty.cNumBits - 1 := by + dcases ty <;> simp [cMax, UScalarTy.cNumBits] <;> simp_uscalar_bounds + +def UScalar.cMax_le_rMax (ty : UScalarTy) : UScalar.cMax ty ≤ UScalar.rMax ty := by + have := rMax_eq_pow_numBits ty + have := cMax_eq_pow_cNumBits ty + have := ty.cNumBits_le + have := @Nat.pow_le_pow_of_le_right 2 (by simp) ty.cNumBits ty.numBits ty.cNumBits_le omega -theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by - have := Scalar.cMax_bound ty +def UScalar.hrBounds {ty} (x : UScalar ty) : x.val ≤ UScalar.rMax ty := by + have := UScalar.hBounds x + have := UScalar.rMax_eq_pow_numBits ty omega -/-- The scalar type. - - We could use a subtype, but it using a custom structure type allows us - to have more control over the coercions and the simplifications (we tried - using a subtype and it caused issues especially as we had to make the Scalar - type non-reducible, so that we could have more control, but leading to - some natural equalities not being obvious to the simplifier anymore). - -/ -structure Scalar (ty : ScalarTy) where - val : Int - hmin : Scalar.min ty ≤ val - hmax : val ≤ Scalar.max ty -deriving Repr, BEq, DecidableEq +def UScalar.hmax {ty} (x : UScalar ty) : x.val < 2^ty.numBits := x.hBounds + +def IScalar.hBounds {ty} (x : IScalar ty) : + -2^(ty.numBits - 1) ≤ x.val ∧ x.val < 2^(ty.numBits - 1) := by + match x with + | ⟨ ⟨ fin ⟩ ⟩ => + simp [val, min, max, BitVec.toInt] + dcases ty <;> simp at * <;> try omega + have hFinLt := fin.isLt + cases h: System.Platform.numBits_eq <;> + simp_all only [IScalarTy.Isize_numBits_eq, true_or, Nat.add_one_sub_one] <;> + omega + +def IScalar.rMin_eq_pow_numBits (ty : IScalarTy) : IScalar.rMin ty = -2^(ty.numBits - 1) := by + dcases ty <;> simp [cMax] <;> simp_iscalar_bounds + +def IScalar.rMax_eq_pow_numBits (ty : IScalarTy) : IScalar.rMax ty = 2^(ty.numBits - 1) - 1 := by + dcases ty <;> simp [rMax] <;> simp_iscalar_bounds + +def IScalar.cMin_eq_pow_cNumBits (ty : IScalarTy) : IScalar.cMin ty = -2^(ty.cNumBits - 1) := by + dcases ty <;> simp [cMin, IScalarTy.cNumBits] <;> simp_iscalar_bounds + +def IScalar.cMax_eq_pow_cNumBits (ty : IScalarTy) : IScalar.cMax ty = 2^(ty.cNumBits - 1) - 1 := by + dcases ty <;> simp [cMax, IScalarTy.cNumBits] <;> simp_iscalar_bounds + +def IScalar.rMin_le_cMin (ty : IScalarTy) : IScalar.rMin ty ≤ IScalar.cMin ty := by + have := rMin_eq_pow_numBits ty + have := cMin_eq_pow_cNumBits ty + have := ty.cNumBits_le + have := ty.cNumBits_nonzero + have := @Int.pow_le_pow_of_le_right 2 (by simp) (ty.cNumBits - 1) (ty.numBits - 1) (by omega) + zify at this + omega + +def IScalar.cMax_le_rMax (ty : IScalarTy) : IScalar.cMax ty ≤ IScalar.rMax ty := by + have := rMax_eq_pow_numBits ty + have := cMax_eq_pow_cNumBits ty + have := ty.cNumBits_le + have := ty.cNumBits_nonzero + have := @Int.pow_le_pow_of_le_right 2 (by simp) (ty.cNumBits - 1) (ty.numBits - 1) (by omega) + zify at this + omega + +def IScalar.hrBounds {ty} (x : IScalar ty) : + IScalar.rMin ty ≤ x.val ∧ x.val ≤ IScalar.rMax ty := by + have := IScalar.hBounds x + have := IScalar.rMin_eq_pow_numBits ty + have := IScalar.rMax_eq_pow_numBits ty + omega + +def IScalar.hmin {ty} (x : IScalar ty) : -2^(ty.numBits - 1) ≤ x.val := x.hBounds.left +def IScalar.hmax {ty} (x : IScalar ty) : x.val < 2^(ty.numBits - 1) := x.hBounds.right + +instance {ty} : BEq (UScalar ty) where + beq a b := a.bv = b.bv -instance {ty} : BEq (Scalar ty) where - beq a b := a.val = b.val +instance {ty} : BEq (IScalar ty) where + beq a b := a.bv = b.bv -instance {ty} : LawfulBEq (Scalar ty) where +instance {ty} : LawfulBEq (UScalar ty) where + eq_of_beq {a b} := by cases a; cases b; simp [BEq.beq] + rfl {a} := by cases a; simp [BEq.beq] + +instance {ty} : LawfulBEq (IScalar ty) where eq_of_beq {a b} := by cases a; cases b; simp[BEq.beq] rfl {a} := by cases a; simp [BEq.beq] -instance (ty : ScalarTy) : CoeOut (Scalar ty) Int where +instance (ty : UScalarTy) : CoeOut (UScalar ty) Nat where + coe := λ v => v.val + +instance (ty : IScalarTy) : CoeOut (IScalar ty) Int where coe := λ v => v.val /- Activate the ↑ notation -/ -attribute [coe] Scalar.val +attribute [coe] UScalar.val IScalar.val + +theorem UScalar.bound_suffices (ty : UScalarTy) (x : Nat) : + x ≤ UScalar.cMax ty -> x < 2^ty.numBits + := by + intro h + have := UScalar.rMax_eq_pow_numBits ty + have : 0 < 2^ty.numBits := by simp + have := cMax_le_rMax ty + omega + +theorem IScalar.bound_suffices (ty : IScalarTy) (x : Int) : + IScalar.cMin ty ≤ x ∧ x ≤ IScalar.cMax ty -> + -2^(ty.numBits - 1) ≤ x ∧ x < 2^(ty.numBits - 1) + := by + intro h + have := ty.cNumBits_nonzero + have := ty.numBits_nonzero + have := ty.cNumBits_le + have := IScalar.rMin_eq_pow_numBits ty + have := IScalar.rMax_eq_pow_numBits ty + have := rMin_le_cMin ty + have := cMax_le_rMax ty + omega + +/- TODO: remove? Having a check on the bounds is a good sanity check, and it allows to prove + nice theorems like `(ofIntCore x ..).val = x`. But on the other hand `BitVec` also has powerful + simplification lemmas. -/ +def UScalar.ofNatCore {ty : UScalarTy} (x : Nat) (_ : x < 2^ty.numBits) : UScalar ty := + { bv := BitVec.ofNat _ x } + +-- TODO: remove? +def IScalar.ofIntCore {ty : IScalarTy} (x : Int) (_ : -2^(ty.numBits-1) ≤ x ∧ x < 2^(ty.numBits - 1)) : IScalar ty := + { bv := BitVec.ofInt _ x } -theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : - Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty -> - Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty - := - λ h => by - apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> omega +@[reducible] def UScalar.ofNat {ty : UScalarTy} (x : Nat) + (hInBounds : x ≤ UScalar.cMax ty := by decide) : UScalar ty := + UScalar.ofNatCore x (UScalar.bound_suffices ty x hInBounds) -def Scalar.ofIntCore {ty : ScalarTy} (x : Int) - (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty := - { val := x, hmin := h.left, hmax := h.right } +@[reducible] def IScalar.ofInt {ty : IScalarTy} (x : Int) + (hInBounds : IScalar.cMin ty ≤ x ∧ x ≤ IScalar.cMax ty := by decide) : IScalar ty := + IScalar.ofIntCore x (IScalar.bound_suffices ty x hInBounds) -@[reducible] def Scalar.ofInt {ty : ScalarTy} (x : Int) - (hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty := by decide) : Scalar ty := - Scalar.ofIntCore x (Scalar.bound_suffices ty x hInBounds) +@[simp] abbrev UScalar.inBounds (ty : UScalarTy) (x : Nat) : Prop := + x < 2^ty.numBits -@[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop := - Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty +@[simp] abbrev IScalar.inBounds (ty : IScalarTy) (x : Int) : Prop := + - 2^(ty.numBits - 1) ≤ x ∧ x < 2^(ty.numBits - 1) -@[simp] abbrev Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := - (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) +@[simp] abbrev UScalar.check_bounds (ty : UScalarTy) (x : Nat) : Bool := + x < 2^ty.numBits + +@[simp] abbrev IScalar.check_bounds (ty : IScalarTy) (x : Int) : Bool := + -2^(ty.numBits - 1) ≤ x ∧ x < 2^(ty.numBits - 1) /- Discussion: This coercion can be slightly annoying at times, because if we write - something like `u = 3` (where `u` is, for instance, as `U32`), then instead of - coercing `u` to `Int`, Lean will lift `3` to `U32`). + something like `u = 3` (where `u` is, for instance, a `U32`), then instead of + coercing `u` to `Nat`, Lean will lift `3` to `U32`). For now we deactivate it. -- TODO(raitobezarius): the inbounds constraint is a bit ugly as we can pretty trivially @@ -345,326 +517,544 @@ instance {ty: ScalarTy} [InBounds ty (Int.ofNat n)]: OfNat (Scalar ty) (n: ℕ) ofNat := Scalar.ofInt n -/ -theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int} - (h: Scalar.check_bounds ty x) : - Scalar.in_bounds ty x := by - simp at * - have ⟨ hmin, hmax ⟩ := h - have hbmin := Scalar.cMin_bound ty - have hbmax := Scalar.cMax_bound ty - cases hmin <;> cases hmax <;> apply And.intro <;> omega - -theorem Scalar.check_bounds_eq_in_bounds (ty : ScalarTy) (x : Int) : - Scalar.check_bounds ty x ↔ Scalar.in_bounds ty x := by +theorem UScalar.check_bounds_imp_inBounds {ty : UScalarTy} {x : Nat} + (h: UScalar.check_bounds ty x) : + UScalar.inBounds ty x := by + simp at *; apply h + +theorem UScalar.check_bounds_eq_inBounds (ty : UScalarTy) (x : Nat) : + UScalar.check_bounds ty x ↔ UScalar.inBounds ty x := by constructor <;> intro h - . apply (check_bounds_imp_in_bounds h) + . apply (check_bounds_imp_inBounds h) . simp_all --- Further thoughts: look at what has been done here: --- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/Fin/Basic.lean --- and --- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/UInt.lean --- which both contain a fair amount of reasoning already! -def Scalar.tryMkOpt (ty : ScalarTy) (x : Int) : Option (Scalar ty) := - if h:Scalar.check_bounds ty x then - -- If we do: - -- ``` - -- let ⟨ hmin, hmax ⟩ := (Scalar.check_bounds_imp_in_bounds h) - -- Scalar.ofIntCore x hmin hmax - -- ``` - -- then normalization blocks (for instance, some proofs which use reflexivity fail). - -- However, the version below doesn't block reduction (TODO: investigate): - some (Scalar.ofIntCore x (Scalar.check_bounds_imp_in_bounds h)) +theorem IScalar.check_bounds_imp_inBounds {ty : IScalarTy} {x : Int} + (h: IScalar.check_bounds ty x) : + IScalar.inBounds ty x := by + simp at *; apply h + +theorem IScalar.check_bounds_eq_inBounds (ty : IScalarTy) (x : Int) : + IScalar.check_bounds ty x ↔ IScalar.inBounds ty x := by + constructor <;> intro h + . apply (check_bounds_imp_inBounds h) + . simp_all + +def UScalar.tryMkOpt (ty : UScalarTy) (x : Nat) : Option (UScalar ty) := + if h:UScalar.check_bounds ty x then + some (UScalar.ofNatCore x (UScalar.check_bounds_imp_inBounds h)) + else none + +def UScalar.tryMk (ty : UScalarTy) (x : Nat) : Result (UScalar ty) := + Result.ofOption (tryMkOpt ty x) integerOverflow + +def IScalar.tryMkOpt (ty : IScalarTy) (x : Int) : Option (IScalar ty) := + if h:IScalar.check_bounds ty x then + some (IScalar.ofIntCore x (IScalar.check_bounds_imp_inBounds h)) else none -def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := +def IScalar.tryMk (ty : IScalarTy) (x : Int) : Result (IScalar ty) := Result.ofOption (tryMkOpt ty x) integerOverflow -theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) : +theorem UScalar.tryMkOpt_eq (ty : UScalarTy) (x : Nat) : + match tryMkOpt ty x with + | some y => y.val = x ∧ inBounds ty x + | none => ¬ (inBounds ty x) := by + simp [tryMkOpt, ofNatCore] + have h := check_bounds_eq_inBounds ty x + split_ifs <;> simp_all + simp [UScalar.val, UScalarTy.numBits, max] at * + dcases ty <;> simp_all [U8.max, U16.max, U32.max, U64.max, U128.max, Usize.max, max] + cases h: System.Platform.numBits_eq <;> simp_all + +theorem UScalar.tryMk_eq (ty : UScalarTy) (x : Nat) : match tryMk ty x with - | ok y => y.val = x ∧ in_bounds ty x - | fail _ => ¬ (in_bounds ty x) + | ok y => y.val = x ∧ inBounds ty x + | fail _ => ¬ (inBounds ty x) | _ => False := by - simp [tryMk, ofOption, tryMkOpt, ofIntCore] - have h := check_bounds_eq_in_bounds ty x + have := UScalar.tryMkOpt_eq ty x + simp [tryMk, ofOption] + dcases h: tryMkOpt ty x <;> simp_all + +theorem IScalar.tryMkOpt_eq (ty : IScalarTy) (x : Int) : + match tryMkOpt ty x with + | some y => y.val = x ∧ inBounds ty x + | none => ¬ (inBounds ty x) := by + simp [tryMkOpt, ofIntCore] + have h := check_bounds_eq_inBounds ty x split_ifs <;> simp_all - -@[simp] theorem Scalar.tryMk_eq_div (ty : ScalarTy) (x : Int) : - tryMk ty x = div ↔ False := by - simp [tryMk, ofOption, tryMkOpt] - split_ifs <;> simp - -@[simp] theorem zero_in_cbounds {ty : ScalarTy} : Scalar.cMin ty ≤ 0 ∧ 0 ≤ Scalar.cMax ty := by - cases ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide - --- The scalar types --- We declare the definitions as reducible so that Lean can unfold them (useful --- for type class resolution for instance). -@[reducible] def Isize := Scalar .Isize -@[reducible] def I8 := Scalar .I8 -@[reducible] def I16 := Scalar .I16 -@[reducible] def I32 := Scalar .I32 -@[reducible] def I64 := Scalar .I64 -@[reducible] def I128 := Scalar .I128 -@[reducible] def Usize := Scalar .Usize -@[reducible] def U8 := Scalar .U8 -@[reducible] def U16 := Scalar .U16 -@[reducible] def U32 := Scalar .U32 -@[reducible] def U64 := Scalar .U64 -@[reducible] def U128 := Scalar .U128 - --- ofIntCore + simp [IScalar.val, IScalarTy.numBits, min, max] at * + dcases ty <;> + simp_all [I8.min, I16.min, I32.min, I64.min, I128.min, Isize.min, + I8.max, I16.max, I32.max, I64.max, I128.max, Isize.max, + min, max] <;> + simp [Int.bmod] <;> split <;> (try omega) <;> + cases h: System.Platform.numBits_eq <;> simp_all <;> omega + +theorem IScalar.tryMk_eq (ty : IScalarTy) (x : Int) : + match tryMk ty x with + | ok y => y.val = x ∧ inBounds ty x + | fail _ => ¬ (inBounds ty x) + | _ => False := by + have := tryMkOpt_eq ty x + simp [tryMk, ofIntCore] + dcases h : tryMkOpt ty x <;> simp_all + +@[simp] theorem UScalar.zero_in_cbounds {ty : UScalarTy} : 0 < 2^ty.numBits := by + simp + +@[simp] theorem IScalar.zero_in_cbounds {ty : IScalarTy} : + -2^(ty.numBits - 1) ≤ 0 ∧ 0 < 2^(ty.numBits - 1) := by + cases ty <;> simp + +/-! The scalar types. -/ +abbrev Usize := UScalar .Usize +abbrev U8 := UScalar .U8 +abbrev U16 := UScalar .U16 +abbrev U32 := UScalar .U32 +abbrev U64 := UScalar .U64 +abbrev U128 := UScalar .U128 +abbrev Isize := IScalar .Isize +abbrev I8 := IScalar .I8 +abbrev I16 := IScalar .I16 +abbrev I32 := IScalar .I32 +abbrev I64 := IScalar .I64 +abbrev I128 := IScalar .I128 + +/-! ofNatCore -/ -- TODO: typeclass? -def Isize.ofIntCore := @Scalar.ofIntCore .Isize -def I8.ofIntCore := @Scalar.ofIntCore .I8 -def I16.ofIntCore := @Scalar.ofIntCore .I16 -def I32.ofIntCore := @Scalar.ofIntCore .I32 -def I64.ofIntCore := @Scalar.ofIntCore .I64 -def I128.ofIntCore := @Scalar.ofIntCore .I128 -def Usize.ofIntCore := @Scalar.ofIntCore .Usize -def U8.ofIntCore := @Scalar.ofIntCore .U8 -def U16.ofIntCore := @Scalar.ofIntCore .U16 -def U32.ofIntCore := @Scalar.ofIntCore .U32 -def U64.ofIntCore := @Scalar.ofIntCore .U64 -def U128.ofIntCore := @Scalar.ofIntCore .U128 - --- ofInt +def Usize.ofNatCore := @UScalar.ofNatCore .Usize +def U8.ofNatCore := @UScalar.ofNatCore .U8 +def U16.ofNatCore := @UScalar.ofNatCore .U16 +def U32.ofNatCore := @UScalar.ofNatCore .U32 +def U64.ofNatCore := @UScalar.ofNatCore .U64 +def U128.ofNatCore := @UScalar.ofNatCore .U128 + +/-! ofIntCore -/ +def Isize.ofIntCore := @IScalar.ofIntCore .Isize +def I8.ofIntCore := @IScalar.ofIntCore .I8 +def I16.ofIntCore := @IScalar.ofIntCore .I16 +def I32.ofIntCore := @IScalar.ofIntCore .I32 +def I64.ofIntCore := @IScalar.ofIntCore .I64 +def I128.ofIntCore := @IScalar.ofIntCore .I128 + +/-! ofNat -/ -- TODO: typeclass? -abbrev Isize.ofInt := @Scalar.ofInt .Isize -abbrev I8.ofInt := @Scalar.ofInt .I8 -abbrev I16.ofInt := @Scalar.ofInt .I16 -abbrev I32.ofInt := @Scalar.ofInt .I32 -abbrev I64.ofInt := @Scalar.ofInt .I64 -abbrev I128.ofInt := @Scalar.ofInt .I128 -abbrev Usize.ofInt := @Scalar.ofInt .Usize -abbrev U8.ofInt := @Scalar.ofInt .U8 -abbrev U16.ofInt := @Scalar.ofInt .U16 -abbrev U32.ofInt := @Scalar.ofInt .U32 -abbrev U64.ofInt := @Scalar.ofInt .U64 -abbrev U128.ofInt := @Scalar.ofInt .U128 - --- TODO: factor those lemmas out -@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by - simp [Scalar.ofInt, Scalar.ofIntCore] - -@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofIntCore x h).val = x := by - apply Scalar.ofInt_val_eq h - -instance (ty : ScalarTy) : Inhabited (Scalar ty) := by - constructor; cases ty <;> apply (Scalar.ofInt 0) - --- TODO: reducible? -@[reducible] def core_isize_min : Isize := Scalar.ofIntCore Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) -@[reducible] def core_isize_max : Isize := Scalar.ofIntCore Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) -@[reducible] def core_i8_min : I8 := Scalar.ofInt I8.min -@[reducible] def core_i8_max : I8 := Scalar.ofInt I8.max -@[reducible] def core_i16_min : I16 := Scalar.ofInt I16.min -@[reducible] def core_i16_max : I16 := Scalar.ofInt I16.max -@[reducible] def core_i32_min : I32 := Scalar.ofInt I32.min -@[reducible] def core_i32_max : I32 := Scalar.ofInt I32.max -@[reducible] def core_i64_min : I64 := Scalar.ofInt I64.min -@[reducible] def core_i64_max : I64 := Scalar.ofInt I64.max -@[reducible] def core_i128_min : I128 := Scalar.ofInt I128.min -@[reducible] def core_i128_max : I128 := Scalar.ofInt I128.max - --- TODO: reducible? -@[reducible] def core_usize_min : Usize := Scalar.ofIntCore Usize.min (by simp [Scalar.min, Scalar.max]) -@[reducible] def core_usize_max : Usize := Scalar.ofIntCore Usize.max (by simp [Scalar.min, Scalar.max]) -@[reducible] def core_u8_min : U8 := Scalar.ofInt U8.min -@[reducible] def core_u8_max : U8 := Scalar.ofInt U8.max -@[reducible] def core_u16_min : U16 := Scalar.ofInt U16.min -@[reducible] def core_u16_max : U16 := Scalar.ofInt U16.max -@[reducible] def core_u32_min : U32 := Scalar.ofInt U32.min -@[reducible] def core_u32_max : U32 := Scalar.ofInt U32.max -@[reducible] def core_u64_min : U64 := Scalar.ofInt U64.min -@[reducible] def core_u64_max : U64 := Scalar.ofInt U64.max -@[reducible] def core_u128_min : U128 := Scalar.ofInt U128.min -@[reducible] def core_u128_max : U128 := Scalar.ofInt U128.max - --- Comparisons -instance {ty} : LT (Scalar ty) where +abbrev Usize.ofNat := @UScalar.ofNat .Usize +abbrev U8.ofNat := @UScalar.ofNat .U8 +abbrev U16.ofNat := @UScalar.ofNat .U16 +abbrev U32.ofNat := @UScalar.ofNat .U32 +abbrev U64.ofNat := @UScalar.ofNat .U64 +abbrev U128.ofNat := @UScalar.ofNat .U128 + +/-! ofInt -/ +abbrev Isize.ofInt := @IScalar.ofInt .Isize +abbrev I8.ofInt := @IScalar.ofInt .I8 +abbrev I16.ofInt := @IScalar.ofInt .I16 +abbrev I32.ofInt := @IScalar.ofInt .I32 +abbrev I64.ofInt := @IScalar.ofInt .I64 +abbrev I128.ofInt := @IScalar.ofInt .I128 + +@[simp, scalar_tac_simp] theorem UScalar.ofNat_val_eq {ty : UScalarTy} (h : x < 2^ty.numBits) : + (UScalar.ofNatCore x h).val = x := by + simp [UScalar.ofNat, UScalar.ofNatCore, UScalar.val, max] + dcases ty <;> simp_all + cases h: System.Platform.numBits_eq <;> simp_all + +@[simp, scalar_tac_simp] theorem U8.ofNat_val_eq (h : x < 2^UScalarTy.U8.numBits) : (U8.ofNatCore x h).val = x := by + apply UScalar.ofNat_val_eq h + +@[simp, scalar_tac_simp] theorem U16.ofNat_val_eq (h : x < 2^UScalarTy.U16.numBits) : (U16.ofNatCore x h).val = x := by + apply UScalar.ofNat_val_eq h + +@[simp, scalar_tac_simp] theorem U32.ofNat_val_eq (h : x < 2^UScalarTy.U32.numBits) : (U32.ofNatCore x h).val = x := by + apply UScalar.ofNat_val_eq h + +@[simp, scalar_tac_simp] theorem U64.ofNat_val_eq (h : x < 2^UScalarTy.U64.numBits) : (U64.ofNatCore x h).val = x := by + apply UScalar.ofNat_val_eq h + +@[simp, scalar_tac_simp] theorem U128.ofNat_val_eq (h : x < 2^UScalarTy.U128.numBits) : (U128.ofNatCore x h).val = x := by + apply UScalar.ofNat_val_eq h + +@[simp, scalar_tac_simp] theorem Usize.ofNat_val_eq (h : x < 2^UScalarTy.Usize.numBits) : (Usize.ofNatCore x h).val = x := by + apply UScalar.ofNat_val_eq h + +@[simp, scalar_tac_simp] theorem IScalar.ofInt_val_eq {ty : IScalarTy} (h : - 2^(ty.numBits - 1) ≤ x ∧ x < 2^(ty.numBits - 1)) : + (IScalar.ofIntCore x h).val = x := by + simp [IScalar.ofInt, IScalar.ofIntCore, IScalar.val] + dcases ty <;> + simp_all <;> + simp [Int.bmod] <;> split <;> (try omega) <;> + cases h: System.Platform.numBits_eq <;> simp_all <;> omega + +@[simp, scalar_tac_simp] theorem I8.ofInt_val_eq (h : -2^(IScalarTy.I8.numBits-1) ≤ x ∧ x < 2^(IScalarTy.I8.numBits-1)) : (I8.ofIntCore x h).val = x := by + apply IScalar.ofInt_val_eq + +@[simp, scalar_tac_simp] theorem I16.ofInt_val_eq (h : -2^(IScalarTy.I16.numBits-1) ≤ x ∧ x < 2^(IScalarTy.I16.numBits-1)) : (I16.ofIntCore x h).val = x := by + apply IScalar.ofInt_val_eq + +@[simp, scalar_tac_simp] theorem I32.ofInt_val_eq (h : -2^(IScalarTy.I32.numBits-1) ≤ x ∧ x < 2^(IScalarTy.I32.numBits-1)) : (I32.ofIntCore x h).val = x := by + apply IScalar.ofInt_val_eq + +@[simp, scalar_tac_simp] theorem I64.ofInt_val_eq (h : -2^(IScalarTy.I64.numBits-1) ≤ x ∧ x < 2^(IScalarTy.I64.numBits-1)) : (I64.ofIntCore x h).val = x := by + apply IScalar.ofInt_val_eq + +@[simp, scalar_tac_simp] theorem I128.ofInt_val_eq (h : -2^(IScalarTy.I128.numBits-1) ≤ x ∧ x < 2^(IScalarTy.I128.numBits-1)) : (I128.ofIntCore x h).val = x := by + apply IScalar.ofInt_val_eq + +@[simp, scalar_tac_simp] theorem Isize.ofInt_val_eq (h : -2^(IScalarTy.Isize.numBits-1) ≤ x ∧ x < 2^(IScalarTy.Isize.numBits-1)) : (Isize.ofIntCore x h).val = x := by + apply IScalar.ofInt_val_eq + +theorem UScalar.eq_equiv_bv_eq {ty : UScalarTy} (x y : UScalar ty) : + x = y ↔ x.bv = y.bv := by + cases x; cases y; simp + +theorem U8.eq_equiv_bv_eq (x y : U8) : x = y ↔ x.bv = y.bv := by apply UScalar.eq_equiv_bv_eq +theorem U16.eq_equiv_bv_eq (x y : U16) : x = y ↔ x.bv = y.bv := by apply UScalar.eq_equiv_bv_eq +theorem U32.eq_equiv_bv_eq (x y : U32) : x = y ↔ x.bv = y.bv := by apply UScalar.eq_equiv_bv_eq +theorem U64.eq_equiv_bv_eq (x y : U64) : x = y ↔ x.bv = y.bv := by apply UScalar.eq_equiv_bv_eq +theorem U128.eq_equiv_bv_eq (x y : U128) : x = y ↔ x.bv = y.bv := by apply UScalar.eq_equiv_bv_eq +theorem Usize.eq_equiv_bv_eq (x y : Usize) : x = y ↔ x.bv = y.bv := by apply UScalar.eq_equiv_bv_eq + +theorem UScalar.ofNatCore_bv {ty : UScalarTy} (x : Nat) h : + (@UScalar.ofNatCore ty x h).bv = BitVec.ofNat _ x := by + simp only [ofNatCore, bv] + +@[simp] theorem U8.ofNat_bv (x : Nat) h : (U8.ofNat x h).bv = BitVec.ofNat _ x := by apply UScalar.ofNatCore_bv +@[simp] theorem U16.ofNat_bv (x : Nat) h : (U16.ofNat x h).bv = BitVec.ofNat _ x := by apply UScalar.ofNatCore_bv +@[simp] theorem U32.ofNat_bv (x : Nat) h : (U32.ofNat x h).bv = BitVec.ofNat _ x := by apply UScalar.ofNatCore_bv +@[simp] theorem U64.ofNat_bv (x : Nat) h : (U64.ofNat x h).bv = BitVec.ofNat _ x := by apply UScalar.ofNatCore_bv +@[simp] theorem U128.ofNat_bv (x : Nat) h : (U128.ofNat x h).bv = BitVec.ofNat _ x := by apply UScalar.ofNatCore_bv +@[simp] theorem Usize.ofNat_bv (x : Nat) h : (Usize.ofNat x h).bv = BitVec.ofNat _ x := by apply UScalar.ofNatCore_bv + +@[simp] theorem UScalar.BitVec_ofNat_val {ty} (x : UScalar ty) : BitVec.ofNat ty.numBits x.val = x := by simp + +theorem IScalar.eq_equiv_bv_eq {ty : IScalarTy} (x y : IScalar ty) : + x = y ↔ x.bv = y.bv := by + cases x; cases y; simp + +theorem I8.eq_equiv_bv_eq (x y : I8) : x = y ↔ x.bv = y.bv := by apply IScalar.eq_equiv_bv_eq +theorem I16.eq_equiv_bv_eq (x y : I16) : x = y ↔ x.bv = y.bv := by apply IScalar.eq_equiv_bv_eq +theorem I32.eq_equiv_bv_eq (x y : I32) : x = y ↔ x.bv = y.bv := by apply IScalar.eq_equiv_bv_eq +theorem I64.eq_equiv_bv_eq (x y : I64) : x = y ↔ x.bv = y.bv := by apply IScalar.eq_equiv_bv_eq +theorem I128.eq_equiv_bv_eq (x y : I128) : x = y ↔ x.bv = y.bv := by apply IScalar.eq_equiv_bv_eq +theorem Isize.eq_equiv_bv_eq (x y : Isize) : x = y ↔ x.bv = y.bv := by apply IScalar.eq_equiv_bv_eq + +theorem IScalar.ofIntCore_bv {ty : IScalarTy} (x : Int) h : + (@IScalar.ofIntCore ty x h).bv = BitVec.ofInt _ x := by + simp only [ofIntCore, bv] + +@[simp] theorem I8.ofInt_bv (x : Int) h : (I8.ofInt x h).bv = BitVec.ofInt _ x := by apply IScalar.ofIntCore_bv +@[simp] theorem I16.ofInt_bv (x : Int) h : (I16.ofInt x h).bv = BitVec.ofInt _ x := by apply IScalar.ofIntCore_bv +@[simp] theorem I32.ofInt_bv (x : Int) h : (I32.ofInt x h).bv = BitVec.ofInt _ x := by apply IScalar.ofIntCore_bv +@[simp] theorem I64.ofInt_bv (x : Int) h : (I64.ofInt x h).bv = BitVec.ofInt _ x := by apply IScalar.ofIntCore_bv +@[simp] theorem I128.ofInt_bv (x : Int) h : (I128.ofInt x h).bv = BitVec.ofInt _ x := by apply IScalar.ofIntCore_bv +@[simp] theorem Isize.ofInt_bv (x : Int) h : (Isize.ofInt x h).bv = BitVec.ofInt _ x := by apply IScalar.ofIntCore_bv + +instance (ty : UScalarTy) : Inhabited (UScalar ty) := by + constructor; cases ty <;> apply (UScalar.ofNat 0 (by simp)) + +instance (ty : IScalarTy) : Inhabited (IScalar ty) := by + constructor; cases ty <;> apply (IScalar.ofInt 0 (by simp [IScalar.cMin, IScalar.cMax, IScalar.rMin, IScalar.rMax]; simp_iscalar_bounds)) + +theorem IScalar.min_lt_max (ty : IScalarTy) : IScalar.min ty < IScalar.max ty := by + cases ty <;> simp [IScalar.min, IScalar.max] <;> (try simp_iscalar_bounds) + have : (0 : Int) < 2 ^ (System.Platform.numBits - 1) := by simp + omega + +theorem IScalar.min_le_max (ty : IScalarTy) : IScalar.min ty ≤ IScalar.max ty := by + have := IScalar.min_lt_max ty + scalar_tac + +@[reducible] def core_u8_min : U8 := UScalar.ofNat 0 +@[reducible] def core_u8_max : U8 := UScalar.ofNat U8.rMax +@[reducible] def core_u16_min : U16 := UScalar.ofNat 0 +@[reducible] def core_u16_max : U16 := UScalar.ofNat U16.rMax +@[reducible] def core_u32_min : U32 := UScalar.ofNat 0 +@[reducible] def core_u32_max : U32 := UScalar.ofNat U32.rMax +@[reducible] def core_u64_min : U64 := UScalar.ofNat 0 +@[reducible] def core_u64_max : U64 := UScalar.ofNat U64.rMax +@[reducible] def core_u128_min : U128 := UScalar.ofNat 0 +@[reducible] def core_u128_max : U128 := UScalar.ofNat U128.rMax +@[reducible] def core_usize_min : Usize := UScalar.ofNatCore 0 (by simp) +@[reducible] def core_usize_max : Usize := UScalar.ofNatCore Usize.max (by simp [Usize.max, Usize.numBits, UScalar.rMax]) + +@[reducible] def core_i8_min : I8 := IScalar.ofInt I8.rMin +@[reducible] def core_i8_max : I8 := IScalar.ofInt I8.rMax +@[reducible] def core_i16_min : I16 := IScalar.ofInt I16.rMin +@[reducible] def core_i16_max : I16 := IScalar.ofInt I16.rMax +@[reducible] def core_i32_min : I32 := IScalar.ofInt I32.rMin +@[reducible] def core_i32_max : I32 := IScalar.ofInt I32.rMax +@[reducible] def core_i64_min : I64 := IScalar.ofInt I64.rMin +@[reducible] def core_i64_max : I64 := IScalar.ofInt I64.rMax +@[reducible] def core_i128_min : I128 := IScalar.ofInt I128.rMin +@[reducible] def core_i128_max : I128 := IScalar.ofInt I128.rMax +@[reducible] def core_isize_min : Isize := IScalar.ofIntCore Isize.min (by simp [Isize.min, Isize.numBits, Isize.rMin]) +@[reducible] def core_isize_max : Isize := IScalar.ofIntCore Isize.max (by simp [Isize.max, Isize.numBits, Isize.rMax]; (have : (0 : Int) < 2 ^ (System.Platform.numBits - 1) := by simp); omega) + + +/-! # Comparisons -/ +instance {ty} : LT (UScalar ty) where lt a b := LT.lt a.val b.val -instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val +instance {ty} : LE (UScalar ty) where le a b := LE.le a.val b.val + +instance {ty} : LT (IScalar ty) where + lt a b := LT.lt a.val b.val + +instance {ty} : LE (IScalar ty) where le a b := LE.le a.val b.val + +/- Not marking this one with @[simp] on purpose: if we have `x = y` somewhere in the context, + we may want to use it to substitute `y` with `x` somewhere. -/ +@[scalar_tac_simp] theorem UScalar.eq_equiv {ty : UScalarTy} (x y : UScalar ty) : + x = y ↔ (↑x : Nat) = ↑y := by + cases x; cases y; simp_all [UScalar.val, BitVec.toNat_eq] + +@[simp] theorem UScalar.eq_imp {ty : UScalarTy} (x y : UScalar ty) : + (↑x : Nat) = ↑y → x = y := (eq_equiv x y).mpr + +@[simp, scalar_tac_simp] theorem UScalar.lt_equiv {ty : UScalarTy} (x y : UScalar ty) : + x < y ↔ (↑x : Nat) < ↑y := by + rw [LT.lt, instLTUScalar] + +@[simp] theorem UScalar.lt_imp {ty : UScalarTy} (x y : UScalar ty) : + (↑x : Nat) < (↑y) → x < y := (lt_equiv x y).mpr + +@[simp, scalar_tac_simp] theorem UScalar.le_equiv {ty : UScalarTy} (x y : UScalar ty) : + x ≤ y ↔ (↑x : Nat) ≤ ↑y := by + rw [LE.le, instLEUScalar] + +@[simp] theorem UScalar.le_imp {ty : UScalarTy} (x y : UScalar ty) : + (↑x : Nat) ≤ ↑y → x ≤ y := (le_equiv x y).mpr --- Not marking this one with @[simp] on purpose: if we have `x = y` somewhere in the context, --- we may want to use it to substitute `y` with `x` somewhere. -theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) : +@[scalar_tac_simp] theorem IScalar.eq_equiv {ty : IScalarTy} (x y : IScalar ty) : x = y ↔ (↑x : Int) = ↑y := by - cases x; cases y; simp_all + cases x; cases y; simp_all [IScalar.val] + constructor <;> intro <;> + first | simp [*] | apply BitVec.eq_of_toInt_eq; simp [*] -@[simp] theorem Scalar.eq_imp {ty : ScalarTy} (x y : Scalar ty) : +@[simp] theorem IScalar.eq_imp {ty : IScalarTy} (x y : IScalar ty) : (↑x : Int) = ↑y → x = y := (eq_equiv x y).mpr -@[simp] theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) : - x < y ↔ (↑x : Int) < ↑y := by simp [LT.lt] +@[simp, scalar_tac_simp] theorem IScalar.lt_equiv {ty : IScalarTy} (x y : IScalar ty) : + x < y ↔ (↑x : Int) < ↑y := by + rw [LT.lt, instLTIScalar] -@[simp] theorem Scalar.lt_imp {ty : ScalarTy} (x y : Scalar ty) : +@[simp, scalar_tac_simp] theorem IScalar.lt_imp {ty : IScalarTy} (x y : IScalar ty) : (↑x : Int) < (↑y) → x < y := (lt_equiv x y).mpr -@[simp] theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) : +@[simp] theorem IScalar.le_equiv {ty : IScalarTy} (x y : IScalar ty) : x ≤ y ↔ (↑x : Int) ≤ ↑y := by simp [LE.le] -@[simp] theorem Scalar.le_imp {ty : ScalarTy} (x y : Scalar ty) : +@[simp] theorem IScalar.le_imp {ty : IScalarTy} (x y : IScalar ty) : (↑x : Int) ≤ ↑y → x ≤ y := (le_equiv x y).mpr -instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt .. -instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe .. +instance UScalar.decLt {ty} (a b : UScalar ty) : Decidable (LT.lt a b) := Nat.decLt .. +instance UScalar.decLe {ty} (a b : UScalar ty) : Decidable (LE.le a b) := Nat.decLe .. +instance IScalar.decLt {ty} (a b : IScalar ty) : Decidable (LT.lt a b) := Int.decLt .. +instance IScalar.decLe {ty} (a b : IScalar ty) : Decidable (LE.le a b) := Int.decLe .. -theorem Scalar.eq_of_val_eq {ty} : ∀ {i j : Scalar ty}, Eq i.val j.val → Eq i j - | ⟨_, _, _⟩, ⟨_, _, _⟩, rfl => rfl +theorem UScalar.eq_of_val_eq {ty} : ∀ {i j : UScalar ty}, Eq i.val j.val → Eq i j + | ⟨_, _⟩, ⟨_, _⟩, rfl => rfl -theorem Scalar.val_eq_of_eq {ty} {i j : Scalar ty} (h : Eq i j) : Eq i.val j.val := - h ▸ rfl +theorem IScalar.eq_of_val_eq {ty} : ∀ {i j : IScalar ty}, Eq i.val j.val → Eq i j := by + intro i j hEq + dcases i; dcases j + simp [IScalar.val] at hEq; simp + apply BitVec.eq_of_toInt_eq; assumption -theorem Scalar.ne_of_val_ne {ty} {i j : Scalar ty} (h : Not (Eq i.val j.val)) : Not (Eq i j) := +theorem UScalar.val_eq_of_eq {ty} {i j : UScalar ty} (h : Eq i j) : Eq i.val j.val := h ▸ rfl +theorem IScalar.val_eq_of_eq {ty} {i j : IScalar ty} (h : Eq i j) : Eq i.val j.val := h ▸ rfl + +theorem UScalar.ne_of_val_ne {ty} {i j : UScalar ty} (h : Not (Eq i.val j.val)) : Not (Eq i j) := + fun h' => absurd (val_eq_of_eq h') h + +theorem IScalar.ne_of_val_ne {ty} {i j : IScalar ty} (h : Not (Eq i.val j.val)) : Not (Eq i j) := fun h' => absurd (val_eq_of_eq h') h -instance (ty : ScalarTy) : DecidableEq (Scalar ty) := +instance (ty : UScalarTy) : DecidableEq (UScalar ty) := fun i j => match decEq i.val j.val with - | isTrue h => isTrue (Scalar.eq_of_val_eq h) - | isFalse h => isFalse (Scalar.ne_of_val_ne h) + | isTrue h => isTrue (UScalar.eq_of_val_eq h) + | isFalse h => isFalse (UScalar.ne_of_val_ne h) -@[simp] theorem Scalar.neq_to_neq_val {ty} : ∀ {i j : Scalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by +instance (ty : IScalarTy) : DecidableEq (IScalar ty) := + fun i j => + match decEq i.val j.val with + | isTrue h => isTrue (IScalar.eq_of_val_eq h) + | isFalse h => isFalse (IScalar.ne_of_val_ne h) + +@[simp, scalar_tac_simp] theorem UScalar.neq_to_neq_val {ty} : ∀ {i j : UScalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by + simp [eq_equiv] + +@[simp, scalar_tac_simp] theorem IScalar.neq_to_neq_val {ty} : ∀ {i j : IScalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by simp [eq_equiv] @[simp] -theorem Scalar.val_not_eq_imp_not_eq (x y : Scalar ty) (h : ScalarTac.Int.not_eq x.val y.val) : +theorem UScalar.val_not_eq_imp_not_eq (x y : UScalar ty) (h : ScalarTac.Nat.not_eq x.val y.val) : + ¬ x = y := by + simp_all; scalar_tac + +@[simp] +theorem IScalar.val_not_eq_imp_not_eq (x y : IScalar ty) (h : ScalarTac.Int.not_eq x.val y.val) : ¬ x = y := by - simp_all; int_tac + simp_all; scalar_tac -instance (ty: ScalarTy) : Preorder (Scalar ty) where +instance (ty: UScalarTy) : Preorder (UScalar ty) where le_refl := fun a => by simp le_trans := fun a b c => by intro Hab Hbc - exact (le_trans ((Scalar.le_equiv _ _).1 Hab) ((Scalar.le_equiv _ _).1 Hbc)) + exact (le_trans ((UScalar.le_equiv _ _).1 Hab) ((UScalar.le_equiv _ _).1 Hbc)) lt_iff_le_not_le := fun a b => by - trans (a: Int) < (b: Int); exact (Scalar.lt_equiv _ _) + trans (a: Nat) < (b: Nat); exact (UScalar.lt_equiv _ _) + trans (a: Nat) ≤ (b: Nat) ∧ ¬ (b: Nat) ≤ (a: Nat); exact lt_iff_le_not_le + repeat rewrite [← UScalar.le_equiv]; rfl + +instance (ty: IScalarTy) : Preorder (IScalar ty) where + le_refl := fun a => by simp + le_trans := fun a b c => by + intro Hab Hbc + exact (le_trans ((IScalar.le_equiv _ _).1 Hab) ((IScalar.le_equiv _ _).1 Hbc)) + lt_iff_le_not_le := fun a b => by + trans (a: Int) < (b: Int); exact (IScalar.lt_equiv _ _) trans (a: Int) ≤ (b: Int) ∧ ¬ (b: Int) ≤ (a: Int); exact lt_iff_le_not_le - repeat rewrite [← Scalar.le_equiv]; rfl + repeat rewrite [← IScalar.le_equiv]; rfl + +instance (ty: UScalarTy) : PartialOrder (UScalar ty) where + le_antisymm := fun a b Hab Hba => + UScalar.eq_imp _ _ ((@le_antisymm Nat _ _ _ ((UScalar.le_equiv a b).1 Hab) ((UScalar.le_equiv b a).1 Hba))) + +instance (ty: IScalarTy) : PartialOrder (IScalar ty) where + le_antisymm := fun a b Hab Hba => + IScalar.eq_imp _ _ ((@le_antisymm Int _ _ _ ((IScalar.le_equiv a b).1 Hab) ((IScalar.le_equiv b a).1 Hba))) -instance (ty: ScalarTy) : PartialOrder (Scalar ty) where - le_antisymm := fun a b Hab Hba => Scalar.eq_imp _ _ ((@le_antisymm Int _ _ _ ((Scalar.le_equiv a b).1 Hab) ((Scalar.le_equiv b a).1 Hba))) +instance UScalarDecidableLE (ty: UScalarTy) : DecidableRel (· ≤ · : UScalar ty -> UScalar ty -> Prop) := by + simp [instLEUScalar] + -- Lift this to the decidability of the Int version. + infer_instance -instance ScalarDecidableLE (ty: ScalarTy) : DecidableRel (· ≤ · : Scalar ty -> Scalar ty -> Prop) := by - simp [instLEScalar] +instance IScalarDecidableLE (ty: IScalarTy) : DecidableRel (· ≤ · : IScalar ty -> IScalar ty -> Prop) := by + simp [instLEIScalar] -- Lift this to the decidability of the Int version. infer_instance -instance (ty: ScalarTy) : LinearOrder (Scalar ty) where +instance (ty: UScalarTy) : LinearOrder (UScalar ty) where + le_total := fun a b => by + rcases (Nat.le_total a b) with H | H + left; exact (UScalar.le_equiv _ _).2 H + right; exact (UScalar.le_equiv _ _).2 H + decidableLE := UScalarDecidableLE ty + +instance (ty: IScalarTy) : LinearOrder (IScalar ty) where le_total := fun a b => by rcases (Int.le_total a b) with H | H - left; exact (Scalar.le_equiv _ _).2 H - right; exact (Scalar.le_equiv _ _).2 H - decidableLE := ScalarDecidableLE ty + left; exact (IScalar.le_equiv _ _).2 H + right; exact (IScalar.le_equiv _ _).2 H + decidableLE := IScalarDecidableLE ty + +/-! # Coercion Theorems --- Coercion theorems --- This is helpful whenever you want to "push" casts to the innermost nodes --- and make the cast normalization happen more magically. + This is helpful whenever you want to "push" casts to the innermost nodes + and make the cast normalization happen more magically. -/ @[simp, norm_cast] -theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max (↑a) (↑b): ℤ) := by - -- TODO: there should be a shorter way to prove this. - rw [max_def, max_def] +theorem UScalar.coe_max {ty: UScalarTy} (a b: UScalar ty): ↑(Max.max a b) = (Max.max (↑a) (↑b): ℕ) := by + rw[_root_.max_def, _root_.max_def] split_ifs <;> simp_all --- Max theory --- TODO: do the min theory later on. - -theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 (by simp) ≤ x := by - apply (Scalar.le_equiv _ _).2 - convert x.hmin - cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min] - -@[simp] -theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty): - Max.max (Scalar.ofInt 0 (by simp)) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x) - -@[simp] -theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty): - Max.max x (Scalar.ofInt 0 (by simp)) = x := max_eq_left (Scalar.zero_le_unsigned s.out x) - --- Some conversions -@[simp] abbrev Scalar.toNat {ty} (x : Scalar ty) : Nat := x.val.toNat -@[simp] abbrev U8.toNat (x : U8) : Nat := x.val.toNat -@[simp] abbrev U16.toNat (x : U16) : Nat := x.val.toNat -@[simp] abbrev U32.toNat (x : U32) : Nat := x.val.toNat -@[simp] abbrev U64.toNat (x : U64) : Nat := x.val.toNat -@[simp] abbrev U128.toNat (x : U128) : Nat := x.val.toNat -@[simp] abbrev Usize.toNat (x : Usize) : Nat := x.val.toNat -@[simp] abbrev I8.toNat (x : I8) : Nat := x.val.toNat -@[simp] abbrev I16.toNat (x : I16) : Nat := x.val.toNat -@[simp] abbrev I32.toNat (x : I32) : Nat := x.val.toNat -@[simp] abbrev I64.toNat (x : I64) : Nat := x.val.toNat -@[simp] abbrev I128.toNat (x : I128) : Nat := x.val.toNat -@[simp] abbrev Isize.toNat (x : Isize) : Nat := x.val.toNat +@[simp, norm_cast] +theorem IScalar.coe_max {ty: IScalarTy} (a b: IScalar ty): ↑(Max.max a b) = (Max.max (↑a) (↑b): ℤ) := by + rw[_root_.max_def, _root_.max_def] + split_ifs <;> simp_all -@[simp] -theorem Scalar.unsigned_ofNat_toNat (x : Scalar ty) (h : ¬ ty.isSigned := by decide) : - (x.toNat : Int) = x.val := by - have := x.hmin - simp; cases ty <;> simp_all [min] - -@[scalar_tac x.toNat] theorem U8.unsigned_ofNat_toNat (x : U8) : (x.toNat : Int) = x.val := Scalar.unsigned_ofNat_toNat x -@[scalar_tac x.toNat] theorem U16.unsigned_ofNat_toNat (x : U16) : (x.toNat : Int) = x.val := Scalar.unsigned_ofNat_toNat x -@[scalar_tac x.toNat] theorem U32.unsigned_ofNat_toNat (x : U32) : (x.toNat : Int) = x.val := Scalar.unsigned_ofNat_toNat x -@[scalar_tac x.toNat] theorem U64.unsigned_ofNat_toNat (x : U64) : (x.toNat : Int) = x.val := Scalar.unsigned_ofNat_toNat x -@[scalar_tac x.toNat] theorem U128.unsigned_ofNat_toNat (x : U128) : (x.toNat : Int) = x.val := Scalar.unsigned_ofNat_toNat x -@[scalar_tac x.toNat] theorem Usize.unsigned_ofNat_toNat (x : Usize) : (x.toNat : Int) = x.val := Scalar.unsigned_ofNat_toNat x +/-! Max theory -/ +-- TODO: do the min theory later on. -@[simp] -theorem Scalar.unsigned_add_nat_toNat (h : ¬ ty.isSigned) (x : Scalar ty) (n : Nat) : - (x.val + n).toNat = x.val.toNat + n := by - cases ty <;> simp_all <;> int_tac +theorem UScalar.zero_le {ty} (x: UScalar ty): UScalar.ofNat 0 (by simp) ≤ x := by simp @[simp] -theorem Scalar.unsigned_nat_add_toNat (h : ¬ ty.isSigned) (x : Scalar ty) (n : Nat) : - (n + x.val).toNat = x.val.toNat + n := by - cases ty <;> simp_all <;> int_tac +theorem UScalar.max_left_zero_eq {ty} (x: UScalar ty): + Max.max (UScalar.ofNat 0 (by simp)) x = x := max_eq_right (UScalar.zero_le x) @[simp] -theorem Scalar.unsigned_add_pos_toNat (h : ¬ ty.isSigned) (x : Scalar ty) (n : Int) (h' : 0 ≤ n) : - (x.val + n).toNat = x.val.toNat + n.toNat := by - cases ty <;> simp_all <;> int_tac +theorem UScalar.max_right_zero_eq {ty} (x: UScalar ty): + Max.max x (UScalar.ofNat 0 (by simp)) = x := max_eq_left (UScalar.zero_le x) + +/-! Some conversions -/ +@[simp] abbrev IScalar.toNat (x : IScalar ty) : Nat := x.val.toNat +@[simp] abbrev I8.toNat (x : I8) : Nat := x.val.toNat +@[simp] abbrev I16.toNat (x : I16) : Nat := x.val.toNat +@[simp] abbrev I32.toNat (x : I32) : Nat := x.val.toNat +@[simp] abbrev I64.toNat (x : I64) : Nat := x.val.toNat +@[simp] abbrev I128.toNat (x : I128) : Nat := x.val.toNat +@[simp] abbrev Isize.toNat (x : Isize) : Nat := x.val.toNat + +def U8.bv (x : U8) : BitVec 8 := UScalar.bv x +def U16.bv (x : U16) : BitVec 16 := UScalar.bv x +def U32.bv (x : U32) : BitVec 32 := UScalar.bv x +def U64.bv (x : U64) : BitVec 64 := UScalar.bv x +def U128.bv (x : U128) : BitVec 128 := UScalar.bv x +def Usize.bv (x : Usize) : BitVec System.Platform.numBits := UScalar.bv x + +def I8.bv (x : I8) : BitVec 8 := IScalar.bv x +def I16.bv (x : I16) : BitVec 16 := IScalar.bv x +def I32.bv (x : I32) : BitVec 32 := IScalar.bv x +def I64.bv (x : I64) : BitVec 64 := IScalar.bv x +def I128.bv (x : I128) : BitVec 128 := IScalar.bv x +def Isize.bv (x : Isize) : BitVec System.Platform.numBits := IScalar.bv x + +@[simp, scalar_tac_simp] theorem UScalar.bv_toNat_eq {ty : UScalarTy} (x : UScalar ty) : + (UScalar.bv x).toNat = x.val := by + simp [val] + +@[simp, scalar_tac_simp] theorem U8.bv_toNat_eq (x : U8) : x.bv.toNat = x.val := by apply UScalar.bv_toNat_eq +@[simp, scalar_tac_simp] theorem U16.bv_toNat_eq (x : U16) : x.bv.toNat = x.val := by apply UScalar.bv_toNat_eq +@[simp, scalar_tac_simp] theorem U32.bv_toNat_eq (x : U32) : x.bv.toNat = x.val := by apply UScalar.bv_toNat_eq +@[simp, scalar_tac_simp] theorem U64.bv_toNat_eq (x : U64) : x.bv.toNat = x.val := by apply UScalar.bv_toNat_eq +@[simp, scalar_tac_simp] theorem U128.bv_toNat_eq (x : U128) : x.bv.toNat = x.val := by apply UScalar.bv_toNat_eq +@[simp, scalar_tac_simp] theorem Usize.bv_toNat_eq (x : Usize) : x.bv.toNat = x.val := by apply UScalar.bv_toNat_eq + +@[simp, scalar_tac_simp] theorem IScalar.bv_toInt_eq {ty : IScalarTy} (x : IScalar ty) : + (IScalar.bv x).toInt = x.val := by + simp [val] + +@[simp, scalar_tac_simp] theorem I8.bv_toInt_eq (x : I8) : x.bv.toInt = x.val := by apply IScalar.bv_toInt_eq +@[simp, scalar_tac_simp] theorem I16.bv_toInt_eq (x : I16) : x.bv.toInt = x.val := by apply IScalar.bv_toInt_eq +@[simp, scalar_tac_simp] theorem I32.bv_toInt_eq (x : I32) : x.bv.toInt = x.val := by apply IScalar.bv_toInt_eq +@[simp, scalar_tac_simp] theorem I64.bv_toInt_eq (x : I64) : x.bv.toInt = x.val := by apply IScalar.bv_toInt_eq +@[simp, scalar_tac_simp] theorem I128.bv_toInt_eq (x : I128) : x.bv.toInt = x.val := by apply IScalar.bv_toInt_eq +@[simp, scalar_tac_simp] theorem Isize.bv_toInt_eq (x : Isize) : x.bv.toInt = x.val := by apply IScalar.bv_toInt_eq + +theorem U8.lt_succ_max (x: U8) : x.val < 256 := by have := x.hBounds; simp at this; omega +theorem U16.lt_succ_max (x: U16) : x.val < 65536 := by have := x.hBounds; simp at this; omega +theorem U32.lt_succ_max (x: U32) : x.val < 4294967296 := by have := x.hBounds; simp at this; omega +theorem U64.lt_succ_max (x: U64) : x.val < 18446744073709551616 := by have := x.hBounds; simp at this; omega +theorem U128.lt_succ_max (x: U128) : x.val < 340282366920938463463374607431768211456 := by have := x.hBounds; simp at this; omega + +theorem U8.le_max (x: U8) : x.val ≤ 255 := by have := x.hBounds; simp at this; omega +theorem U16.le_max (x: U16) : x.val ≤ 65535 := by have := x.hBounds; simp at this; omega +theorem U32.le_max (x: U32) : x.val ≤ 4294967295 := by have := x.hBounds; simp at this; omega +theorem U64.le_max (x: U64) : x.val ≤ 18446744073709551615 := by have := x.hBounds; simp at this; omega +theorem U128.le_max (x: U128) : x.val ≤ 340282366920938463463374607431768211455 := by have := x.hBounds; simp at this; omega + +@[simp] theorem UScalar.BitVec_ofNat_val_eq (x : UScalar ty) : BitVec.ofNat ty.numBits x.val = x.bv := by + cases x; simp only [val, BitVec.ofNat_toNat, BitVec.setWidth_eq] + +theorem U8.BitVec_ofNat_val_eq (x : U8) : BitVec.ofNat 8 x.val = x.bv := by apply UScalar.BitVec_ofNat_val_eq +theorem U16.BitVec_ofNat_val_eq (x : U16) : BitVec.ofNat 16 x.val = x.bv := by apply UScalar.BitVec_ofNat_val_eq +theorem U32.BitVec_ofNat_val_eq (x : U32) : BitVec.ofNat 32 x.val = x.bv := by apply UScalar.BitVec_ofNat_val_eq +theorem U64.BitVec_ofNat_val_eq (x : U64) : BitVec.ofNat 64 x.val = x.bv := by apply UScalar.BitVec_ofNat_val_eq +theorem U128.BitVec_ofNat_val_eq (x : U128) : BitVec.ofNat 128 x.val = x.bv := by apply UScalar.BitVec_ofNat_val_eq +theorem Usize.BitVec_ofNat_val_eq (x : Usize) : BitVec.ofNat System.Platform.numBits x.val = x.bv := by apply UScalar.BitVec_ofNat_val_eq + +/-! +Adding theorems to the `zify_simps` simplification set. +-/ +attribute [zify_simps] UScalar.eq_equiv IScalar.eq_equiv + UScalar.lt_equiv IScalar.lt_equiv + UScalar.le_equiv IScalar.le_equiv -@[simp] -theorem Scalar.unsigned_pos_add_toNat (h : ¬ ty.isSigned) (x : Scalar ty) (n : Int) (h' : 0 ≤ n) : - (n + x.val).toNat = n.toNat + x.val.toNat := by - cases ty <;> simp_all <;> int_tac +attribute [zify_simps] U8.bv_toNat_eq U16.bv_toNat_eq U32.bv_toNat_eq + U64.bv_toNat_eq U128.bv_toNat_eq Usize.bv_toNat_eq end Std diff --git a/backends/lean/Aeneas/Std/ScalarNotations.lean b/backends/lean/Aeneas/Std/ScalarNotations.lean index d9f01e3b..3e8e1d0c 100644 --- a/backends/lean/Aeneas/Std/ScalarNotations.lean +++ b/backends/lean/Aeneas/Std/ScalarNotations.lean @@ -14,39 +14,42 @@ open Lean Meta Elab Term PrettyPrinter sometimes leaves meta-variables in place, which then causes issues when type-checking functions. For instance, it happens when we have const-generics in the translation: the constants contain meta-variables, which are then - used in the types, which cause issues later. An example is given below: + used in the types, which cause issues later. For this reason we first try + solving the goal with `decide`, which often works in the cases which are + problematic for `scalar_tac`, then we try with `scalar_tac`. -/ -macro:max x:term:max noWs "#isize" : term => `(Isize.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u8" : term => `(U8.ofNat $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u16" : term => `(U16.ofNat $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u32" : term => `(U32.ofNat $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u64" : term => `(U64.ofNat $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u128" : term => `(U128.ofNat $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#usize" : term => `(Usize.ofNat $x (by first | decide | scalar_tac)) + macro:max x:term:max noWs "#i8" : term => `(I8.ofInt $x (by first | decide | scalar_tac)) macro:max x:term:max noWs "#i16" : term => `(I16.ofInt $x (by first | decide | scalar_tac)) macro:max x:term:max noWs "#i32" : term => `(I32.ofInt $x (by first | decide | scalar_tac)) macro:max x:term:max noWs "#i64" : term => `(I64.ofInt $x (by first | decide | scalar_tac)) macro:max x:term:max noWs "#i128" : term => `(I128.ofInt $x (by first | decide | scalar_tac)) -macro:max x:term:max noWs "#usize" : term => `(Usize.ofInt $x (by first | decide | scalar_tac)) -macro:max x:term:max noWs "#u8" : term => `(U8.ofInt $x (by first | decide | scalar_tac)) -macro:max x:term:max noWs "#u16" : term => `(U16.ofInt $x (by first | decide | scalar_tac)) -macro:max x:term:max noWs "#u32" : term => `(U32.ofInt $x (by first | decide | scalar_tac)) -macro:max x:term:max noWs "#u64" : term => `(U64.ofInt $x (by first | decide | scalar_tac)) -macro:max x:term:max noWs "#u128" : term => `(U128.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#isize" : term => `(Isize.ofInt $x (by first | decide | scalar_tac)) -- Some pretty printing (for the goals) -@[app_unexpander U8.ofInt] -def unexpU8ofInt : Unexpander | `($_ $n $_) => `($n#u8) | _ => throw () +@[app_unexpander U8.ofNat] +def unexpU8ofNat : Unexpander | `($_ $n $_) => `($n#u8) | _ => throw () -@[app_unexpander U16.ofInt] -def unexpU16ofInt : Unexpander | `($_ $n $_) => `($n#u16) | _ => throw () +@[app_unexpander U16.ofNat] +def unexpU16ofNat : Unexpander | `($_ $n $_) => `($n#u16) | _ => throw () -@[app_unexpander U32.ofInt] -def unexpU32ofInt : Unexpander | `($_ $n $_) => `($n#u32) | _ => throw () +@[app_unexpander U32.ofNat] +def unexpU32ofNat : Unexpander | `($_ $n $_) => `($n#u32) | _ => throw () -@[app_unexpander U64.ofInt] -def unexpU64ofInt : Unexpander | `($_ $n $_) => `($n#u64) | _ => throw () +@[app_unexpander U64.ofNat] +def unexpU64ofNat : Unexpander | `($_ $n $_) => `($n#u64) | _ => throw () -@[app_unexpander U128.ofInt] -def unexpU128ofInt : Unexpander | `($_ $n $_) => `($n#u128) | _ => throw () +@[app_unexpander U128.ofNat] +def unexpU128ofNat : Unexpander | `($_ $n $_) => `($n#u128) | _ => throw () -@[app_unexpander Usize.ofInt] -def unexpUsizeofInt : Unexpander | `($_ $n $_) => `($n#usize) | _ => throw () +@[app_unexpander Usize.ofNat] +def unexpUsizeofNat : Unexpander | `($_ $n $_) => `($n#usize) | _ => throw () @[app_unexpander I8.ofInt] def unexpI8ofInt : Unexpander | `($_ $n $_) => `($n#i8) | _ => throw () @@ -68,7 +71,8 @@ def unexpIsizeofInt : Unexpander | `($_ $n $_) => `($n#isize) | _ => throw () -- Notation for pattern matching -- We make the precedence looser than the negation. -notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ +notation:70 a:70 "#uscalar" => UScalar.mk (a) +notation:70 a:70 "#iscalar" => IScalar.mk (a) /- Testing the notations -/ example := 0#u32 @@ -99,38 +103,49 @@ example (x : I32) : Bool := example (x : U32) : Bool := match x with - | 0#scalar => true + | 0#uscalar => true | _ => false example (x : U32) : Bool := match x with - | 1#scalar => true + | 1#uscalar => true | _ => false example (x : I32) : Bool := match x with - | (-1)#scalar => true + | (-1)#iscalar => true | _ => false -example {ty} (x : Scalar ty) : ℤ := +/- +-- FIXME +example {ty} (x : UScalar ty) : ℕ := match x with - | v#scalar => v + | v#uscalar => v -example {ty} (x : Scalar ty) : Bool := +example {ty} (x : IScalar ty) : ℤ := match x with - | 1#scalar => true + | v#iscalar => v +-/ + +/- +-- FIXME +example {ty} (x : UScalar ty) : Bool := + match x with + | 1#uscalar => true | _ => false -example {ty} (x : Scalar ty) : Bool := +example {ty} (x : IScalar ty) : Bool := match x with - | -(1 : Int)#scalar => true + | -(1 : Int)#iscalar => true | _ => false +-/ -- Testing the notations example : Result Usize := 0#usize + 1#usize -- More complex expressions -example (x y : Int) (h : 0 ≤ x + y ∧ x + y ≤ 1000) : U32 := (x + y)#u32 +example (x y : Nat) (h : x + y ≤ 1000) : U32 := (x + y)#u32 +example (x y : Int) (h : 0 ≤ x + y ∧ x + y ≤ 1000) : I32 := (x + y)#i32 namespace Scalar.Examples diff --git a/backends/lean/Aeneas/Std/Vec.lean b/backends/lean/Aeneas/Std/Vec.lean index 331bb31a..7110ea85 100644 --- a/backends/lean/Aeneas/Std/Vec.lean +++ b/backends/lean/Aeneas/Std/Vec.lean @@ -28,8 +28,8 @@ instance [BEq α] : BEq (Vec α) := SubtypeBEq _ instance [BEq α] [LawfulBEq α] : LawfulBEq (Vec α) := SubtypeLawfulBEq _ @[scalar_tac v] -theorem Vec.len_ineq {α : Type u} (v : Vec α) : 0 ≤ v.val.length ∧ v.val.length ≤ Scalar.max ScalarTy.Usize := by - cases v; simp[Scalar.max, *] +theorem Vec.len_ineq {α : Type u} (v : Vec α) : v.val.length ≤ Usize.max := by + cases v; simp[*] -- TODO: move/remove? @[scalar_tac v] @@ -41,7 +41,7 @@ abbrev Vec.length {α : Type u} (v : Vec α) : Nat := v.val.length @[simp] abbrev Vec.v {α : Type u} (v : Vec α) : List α := v.val -example {a: Type u} (v : Vec a) : v.length ≤ Scalar.max ScalarTy.Usize := by +example {a: Type u} (v : Vec a) : v.length ≤ Usize.max := by scalar_tac abbrev Vec.new (α : Type u): Vec α := ⟨ [], by simp ⟩ @@ -52,26 +52,61 @@ instance (α : Type u) : Inhabited (Vec α) := by @[simp] abbrev Vec.len {α : Type u} (v : Vec α) : Usize := - Usize.ofIntCore v.val.length (by constructor <;> scalar_tac) + Usize.ofNatCore v.val.length (by scalar_tac) -@[simp] +@[simp, scalar_tac_simp] theorem Vec.len_val {α : Type u} (v : Vec α) : (Vec.len v).val = v.length := - by rfl + by simp + +@[reducible] instance {α : Type u} : GetElem (Vec α) Nat α (fun a i => i < a.val.length) where + getElem a i h := getElem a.val i h + +@[reducible] instance {α : Type u} : GetElem? (Vec α) Nat α (fun a i => i < a.val.length) where + getElem? a i := getElem? a.val i + getElem! a i := getElem! a.val i + +@[simp, scalar_tac_simp] theorem Vec.getElem?_Nat_eq {α : Type u} (v : Vec α) (i : Nat) : v[i]? = v.val[i]? := by rfl +@[simp, scalar_tac_simp] theorem Vec.getElem!_Nat_eq {α : Type u} [Inhabited α] (v : Vec α) (i : Nat) : v[i]! = v.val[i]! := by rfl + +@[reducible] instance {α : Type u} : GetElem (Vec α) Usize α (fun a i => i < a.val.length) where + getElem a i h := getElem a.val i.val h + +@[reducible] instance {α : Type u} : GetElem? (Vec α) Usize α (fun a i => i < a.val.length) where + getElem? a i := getElem? a.val i.val + getElem! a i := getElem! a.val i.val + +@[simp, scalar_tac_simp] theorem Vec.getElem?_Usize_eq {α : Type u} (v : Vec α) (i : Usize) : v[i]? = v.val[i.val]? := by rfl +@[simp, scalar_tac_simp] theorem Vec.getElem!_Usize_eq {α : Type u} [Inhabited α] (v : Vec α) (i : Usize) : v[i]! = v.val[i.val]! := by rfl + +@[simp, scalar_tac_simp] abbrev Vec.get? {α : Type u} (v : Vec α) (i : Nat) : Option α := getElem? v i +@[simp, scalar_tac_simp] abbrev Vec.get! {α : Type u} [Inhabited α] (v : Vec α) (i : Nat) : α := getElem! v i + +def Vec.set {α : Type u} (v: Vec α) (i: Usize) (x: α) : Vec α := + ⟨ v.val.set i.val x, by have := v.property; simp [*] ⟩ + +def Vec.set_opt {α : Type u} (v: Vec α) (i: Usize) (x: Option α) : Vec α := + ⟨ v.val.set_opt i.val x, by have := v.property; simp [*] ⟩ + +@[simp] +theorem Vec.set_val_eq {α : Type u} (v: Vec α) (i: Usize) (x: α) : + (v.set i x) = v.val.set i.val x := by + simp [set] + +@[simp] +theorem Vec.set_opt_val_eq {α : Type u} (v: Vec α) (i: Usize) (x: Option α) : + (v.set_opt i x) = v.val.set_opt i.val x := by + simp [set_opt] @[irreducible] def Vec.push {α : Type u} (v : Vec α) (x : α) : Result (Vec α) := let nlen := List.length v.val + 1 if h : nlen ≤ U32.max || nlen ≤ Usize.max then - have h : nlen ≤ Usize.max := by - simp [Usize.max] at * - have hm := Usize.refined_max.property - cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try omega - ok ⟨ List.concat v.val x, by simp at *; assumption ⟩ + ok ⟨ List.concat v.val x, by simp; scalar_tac ⟩ else fail maximumSizeExceeded -@[pspec] +@[progress] theorem Vec.push_spec {α : Type u} (v : Vec α) (x : α) (h : v.val.length < Usize.max) : ∃ v1, v.push x = ok v1 ∧ v1.val = v.val ++ [x] := by @@ -81,73 +116,62 @@ theorem Vec.push_spec {α : Type u} (v : Vec α) (x : α) (h : v.val.length < Us def Vec.insert {α : Type u} (v: Vec α) (i: Usize) (x: α) : Result (Vec α) := if i.val < v.length then - ok ⟨ v.val.update i.toNat x, by have := v.property; simp [*] ⟩ + ok ⟨ v.val.set i x, by have := v.property; simp [*] ⟩ else fail arrayOutOfBounds -@[pspec] +@[progress] theorem Vec.insert_spec {α : Type u} (v: Vec α) (i: Usize) (x: α) (hbound : i.val < v.length) : - ∃ nv, v.insert i x = ok nv ∧ nv.val = v.val.update i.toNat x := by + ∃ nv, v.insert i x = ok nv ∧ nv.val = v.val.set i x := by simp [insert, *] def Vec.index_usize {α : Type u} (v: Vec α) (i: Usize) : Result α := - match v.val.indexOpt i.toNat with + match v[i.val]? with | none => fail .arrayOutOfBounds | some x => ok x -@[pspec] +@[progress] theorem Vec.index_usize_spec {α : Type u} [Inhabited α] (v: Vec α) (i: Usize) (hbound : i.val < v.length) : - ∃ x, v.index_usize i = ok x ∧ x = v.val.index i.toNat := by + ∃ x, v.index_usize i = ok x ∧ x = v.val[i.val]! := by simp only [index_usize] - -- TODO: dependent rewrite - have h := List.indexOpt_eq_index v.val i.toNat (by scalar_tac) + simp at * simp [*] -def Vec.update {α : Type u} (v: Vec α) (i: Usize) (x: α) : Vec α := - ⟨ v.val.update i.toNat x, by have := v.property; simp [*] ⟩ - -@[simp] -theorem Vec.update_val_eq {α : Type u} (v: Vec α) (i: Usize) (x: α) : - (v.update i x).val = v.val.update i.toNat x := by - simp [update] - -def Vec.update_usize {α : Type u} (v: Vec α) (i: Usize) (x: α) : Result (Vec α) := - match v.val.indexOpt i.toNat with +def Vec.update {α : Type u} (v: Vec α) (i: Usize) (x: α) : Result (Vec α) := + match v.val[i.val]? with | none => fail .arrayOutOfBounds | some _ => - ok ⟨ v.val.update i.toNat x, by have := v.property; simp [*] ⟩ + ok ⟨ v.val.set i x, by have := v.property; simp [*] ⟩ -@[pspec] -theorem Vec.update_usize_spec {α : Type u} (v: Vec α) (i: Usize) (x : α) +@[progress] +theorem Vec.update_spec {α : Type u} (v: Vec α) (i: Usize) (x : α) (hbound : i.val < v.length) : - ∃ nv, v.update_usize i x = ok nv ∧ - nv = v.update i x + ∃ nv, v.update i x = ok nv ∧ + nv = v.set i x := by - simp only [update_usize] - have h := List.indexOpt_bounds v.val i.toNat - split - . simp_all [length]; scalar_tac - . simp [Vec.update] + simp only [update, set] + simp at * + split <;> simp_all -@[scalar_tac v.update i x] -theorem Vec.update_length {α : Type u} (v: Vec α) (i: Usize) (x: α) : - (v.update i x).length = v.length := by simp +@[scalar_tac_simp] +theorem Vec.set_length {α : Type u} (v: Vec α) (i: Usize) (x: α) : + (v.set i x).length = v.length := by simp def Vec.index_mut_usize {α : Type u} (v: Vec α) (i: Usize) : Result (α × (α → Vec α)) := match Vec.index_usize v i with | ok x => - ok (x, Vec.update v i) + ok (x, Vec.set v i) | fail e => fail e | div => div -@[pspec] +@[progress] theorem Vec.index_mut_usize_spec {α : Type u} [Inhabited α] (v: Vec α) (i: Usize) (hbound : i.val < v.length) : - ∃ x, v.index_mut_usize i = ok (x, v.update i) ∧ - x = v.val.index i.toNat + ∃ x, v.index_mut_usize i = ok (x, v.set i) ∧ + x = v.val[i.val]! := by simp only [index_mut_usize] have ⟨ x, h ⟩ := index_usize_spec v i hbound @@ -156,13 +180,13 @@ theorem Vec.index_mut_usize_spec {α : Type u} [Inhabited α] (v: Vec α) (i: Us /- [alloc::vec::Vec::index]: forward function -/ def Vec.index {T I : Type} (inst : core.slice.index.SliceIndex I (Slice T)) (self : Vec T) (i : I) : Result inst.Output := - sorry -- TODO + inst.index i self /- [alloc::vec::Vec::index_mut]: forward function -/ def Vec.index_mut {T I : Type} (inst : core.slice.index.SliceIndex I (Slice T)) (self : Vec T) (i : I) : Result (inst.Output × (inst.Output → Vec T)) := - sorry -- TODO + inst.index_mut i self /- Trait implementation: [alloc::vec::Vec] -/ @[reducible] @@ -185,14 +209,16 @@ def Vec.coreopsindexIndexMutInst {T I : Type} @[simp] theorem Vec.index_slice_index {α : Type} (v : Vec α) (i : Usize) : Vec.index (core.slice.index.SliceIndexUsizeSliceTInst α) v i = - Vec.index_usize v i := - sorry + Vec.index_usize v i := by + simp [Vec.index, Vec.index_usize, Slice.index_usize] + rfl @[simp] theorem Vec.index_mut_slice_index {α : Type} (v : Vec α) (i : Usize) : Vec.index_mut (core.slice.index.SliceIndexUsizeSliceTInst α) v i = - index_mut_usize v i := - sorry + index_mut_usize v i := by + simp [Vec.index_mut, Vec.index_mut_usize, Slice.index_mut_usize] + rfl end alloc.vec @@ -204,7 +230,7 @@ def alloc.slice.Slice.to_vec /-- [core::slice::{@Slice}::reverse] -/ def core.slice.Slice.reverse {T : Type} (s : Slice T) : Slice T := - ⟨ s.val.reverse, by sorry ⟩ + ⟨ s.val.reverse, by scalar_tac ⟩ def alloc.vec.Vec.with_capacity (T : Type) (_ : Usize) : alloc.vec.Vec T := Vec.new T @@ -245,26 +271,89 @@ def core.ops.deref.DerefMutVec {T : Type} : def alloc.vec.Vec.resize {T : Type} (cloneInst : core.clone.Clone T) (v : alloc.vec.Vec T) (new_len : Usize) (value : T) : Result (alloc.vec.Vec T) := do if new_len.val < v.length then - ok ⟨ v.val.resize new_len.toNat value, by scalar_tac ⟩ + ok ⟨ v.val.resize new_len value, by scalar_tac ⟩ else let value ← cloneInst.clone value - ok ⟨ v.val.resize new_len.toNat value, by scalar_tac ⟩ + ok ⟨ v.val.resize new_len value, by scalar_tac ⟩ -@[pspec] +@[progress] theorem alloc.vec.Vec.resize_spec {T} (cloneInst : core.clone.Clone T) (v : alloc.vec.Vec T) (new_len : Usize) (value : T) (hClone : cloneInst.clone value = ok value) : ∃ nv, alloc.vec.Vec.resize cloneInst v new_len value = ok nv ∧ - nv.val = v.val.resize new_len.toNat value := by + nv.val = v.val.resize new_len value := by rw [resize] split . simp . simp [*] -@[simp] -theorem alloc.vec.Vec.update_index_eq α [Inhabited α] (x : alloc.vec.Vec α) (i : Usize) : - x.update i (x.val.index i.toNat) = x := by - simp [Vec, Subtype.eq_iff] +@[simp, scalar_tac_simp] +theorem alloc.vec.Vec.set_getElem!_eq α [Inhabited α] (x : alloc.vec.Vec α) (i : Usize) : + x.set i x[i]! = x := by + simp only [getElem!_Usize_eq] + simp only [Vec, set_val_eq, Subtype.eq_iff, List.set_getElem!] + +namespace Tests + example + (α : Type) + (slots : alloc.vec.Vec (List α)) + (n : Usize) + (_ : ∀ i < slots.length, slots.val[i]! = .nil) + (Hlen : (↑slots.len : ℕ) + (↑n : ℕ) ≤ Usize.max) + (_ : 0 < (↑n : ℕ)) + (slots1 : alloc.vec.Vec (List α)) + (hEq : (↑slots1 : List (List α)) = (↑slots : List (List α)) ++ [.nil]) + (n1 : Usize) + (_ : (↑n : ℕ) = (↑n1 : ℕ) + 1) + (_ : ∀ i < slots1.length, slots.val[i]! = .nil) : + (↑slots1.len : ℕ) + (↑n1 : ℕ) ≤ Usize.max + := by + scalar_tac + + example + (α : Type) + (capacity : Usize) + (dividend divisor : Usize) + (Hfactor : 0 < (↑dividend : ℕ) ∧ + (↑dividend : ℕ) < (↑divisor : ℕ) ∧ + (↑capacity : ℕ) * (↑dividend : ℕ) ≤ Usize.max ∧ + (↑capacity : ℕ) * (↑dividend : ℕ) ≥ (↑divisor : ℕ)) + (slots : alloc.vec.Vec (List α)) + (h2 : (↑slots.len : ℕ) = (↑(alloc.vec.Vec.new (List α)).len : ℕ) + (↑capacity : ℕ)) + (i1 : Usize) + (i2 : Usize) : + (↑(↑divisor : ℕ) : ℤ) ≤ + (↑(↑slots : List (List α)).length : ℤ) * (↑(↑dividend : ℕ) : ℤ) + := by + scalar_tac + + example + (v : alloc.vec.Vec U32) + (i : Usize) + (x : U32) + (i1 : Usize) + (h : (↑i : ℕ) < v.val.length) + (_ : x = v[i]!) + (_ : (↑i1 : ℕ) = (↑i : ℕ) + 1) : + (↑i : ℕ) + 1 ≤ v.val.length + := by + scalar_tac + + attribute [-simp] List.getElem!_eq_getElem?_getD + example + (α : Type) + (slots : alloc.vec.Vec (List α)) + (Hslots : ∀ i < slots.length, slots[i]! = []) + (slots1 : alloc.vec.Vec (List α)) + (_ : (↑slots1 : List (List α)) = (↑slots : List (List α)) ++ [[]]) + (i : ℕ) + (hi : i < slots.length) : + (↑slots : List (List α))[i]! = [] + := by + simp at * -- TODO: being forced to do this is annoying + simp [*] + +end Tests end Std diff --git a/backends/lean/Aeneas/Termination.lean b/backends/lean/Aeneas/Termination.lean index 20b75fe3..3b2432b1 100644 --- a/backends/lean/Aeneas/Termination.lean +++ b/backends/lean/Aeneas/Termination.lean @@ -10,14 +10,6 @@ namespace Utils open Lean Lean.Elab Command Term Lean.Meta Tactic --- Inspired by the `clear` tactic -def clearFvarIds (fvarIds : Array FVarId) : TacticM Unit := do - let fvarIds ← withMainContext <| sortFVarIds fvarIds - for fvarId in fvarIds.reverse do - withMainContext do - let mvarId ← (← getMainGoal).clear fvarId - replaceMainGoal [mvarId] - /- Utility function for proofs of termination (i.e., inside `decreasing_by`). Clean up the local context by removing all assumptions containing occurrences @@ -42,8 +34,8 @@ def removeInvImageAssumptions : TacticM Unit := do | .const name _ => name == ``invImage | _ => false)) false (← inferType expr) let filtDecls ← liftM (decls.filterM containsInvertImage) - -- It can happen that other variables depend on the variables we want to clear: - -- filter them. + /- It can happen that other variables depend on the variables we want to clear: + filter them. -/ let allFVarsInTypes ← decls.foldlM (fun hs d => do let hs ← getFVarIds (← inferType d.toExpr) hs -- Explore the body if it is not opaque diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index d863074b..f37860ff 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -95,11 +95,6 @@ open Lean.Elab.Command let name := cs.constName! explore_decl name -private def test1 : Nat := 0 -private def test2 (x : Nat) : Nat := x -print_decl test1 -print_decl test2 - def printDecls (decls : List LocalDecl) : MetaM Unit := do let decls ← decls.foldrM (λ decl msg => do pure (m!"\n{decl.toExpr} : {← inferType decl.toExpr}" ++ msg)) m!"" @@ -382,10 +377,38 @@ def splitConjTarget : TacticM Unit := do setGoals (lmvar.mvarId! :: rmvar.mvarId! :: goals) -- Destruct an equaliy and return the two sides -def destEq (e : Expr) : MetaM (Expr × Expr) := do +def destEqOpt (e : Expr) : MetaM (Option (Expr × Expr)) := do e.consumeMData.withApp fun f args => - if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2) - else throwError "Not an equality: {e}" + if f.isConstOf ``Eq ∧ args.size = 3 then pure (some (args.get! 1, args.get! 2)) + else pure none + +-- Destruct an equaliy and return the two sides +def destEq (e : Expr) : MetaM (Expr × Expr) := do + match ← destEqOpt e with + | none => throwError "Not an equality: {e}" + | some e => pure e + +def destProdTypeOpt (ty : Expr) : Option (Expr × Expr) := do + ty.consumeMData.withApp fun fn args => + if fn.isConst ∧ fn.constName == ``Prod ∧ args.size = 2 then + some (args[0]!, args[1]!) + else none + +partial def destProdsType (ty : Expr) : List Expr := + match destProdTypeOpt ty with + | none => [ty] + | some (ty0, ty1) => ty0 :: destProdsType ty1 + +def destProdValOpt (x : Expr) : Option (Expr × Expr) := do + x.consumeMData.withApp fun f args => + if f.isConst ∧ f.constName = ``Prod.mk ∧ args.size = 4 then + some (args[2]!, args[3]!) + else none + +partial def destProdsVal (x : Expr) : List Expr := + match destProdValOpt x with + | none => [x] + | some (x0, x1) => x0 :: destProdsVal x1 -- Return the set of FVarIds in the expression -- TODO: this collects fvars introduced in the inner bindings @@ -406,13 +429,18 @@ def assumptionTac : TacticM Unit := -- List all the local declarations matching the goal def getAllMatchingAssumptions (type : Expr) : MetaM (List (LocalDecl × Name)) := do + let typeType ← inferType type let decls ← (← getLCtx).getAllDecls decls.filterMapM fun localDecl => do -- Make sure we revert the meta-variables instantiations by saving the state and restoring it let s ← saveState - let x := - if (← isDefEq type localDecl.type) then (some (localDecl, localDecl.userName)) - else none + let x ← do + /- First check if the type can be matched (some assumptions are actually *variables*)-/ + if (← isDefEq typeType (← inferType localDecl.type)) then + if (← isDefEq type localDecl.type) then + pure (some (localDecl, localDecl.userName)) + else pure none + else pure none restoreState s pure x @@ -433,11 +461,17 @@ def singleAssumptionTac : TacticM Unit := do assumptionTac else trace[Utils] "The goal contains meta-variables" - -- There are meta-variables that + /- There are meta-variables that we need to instantiate + + Remark: at some point I tried using a discrimination tree to filter the assumptions, + in particular inside the `progress` tactic as may need to call the `singleAssumptionTac` + several times, but discrimination trees don't work if the expression we match over + contains meta-variables. + -/ match ← (getAllMatchingAssumptions goal) with | [(localDecl, _)] => - -- There is a single assumption which matches the goal: use it - -- Note that we need to call isDefEq again to properly instantiate the meta-variables + /- There is a single assumption which matches the goal: use it + Note that we need to call isDefEq again to properly instantiate the meta-variables -/ let _ ← isDefEq goal localDecl.type mvarId.assign (mkFVar localDecl.fvarId) | [] => @@ -685,7 +719,7 @@ example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by initialize a simp context without doing an elaboration - as a consequence we write our own here. -/ def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (kind : SimpKind) - (simprocs : List Name) (declsToUnfold : List Name) + (simprocs : List Name) (addSimpThms : List SimpTheorems) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : Tactic.TacticM (Simp.Context × Simp.SimprocsArray) := do -- Initialize either with the builtin simp theorems or with all the simp theorems @@ -720,7 +754,7 @@ def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (kind : SimpKind) let congrTheorems ← getSimpCongrTheorems let defaultSimprocs ← if simpOnly then pure {} else Simp.getSimprocs let simprocs ← simprocs.foldlM (fun simprocs name => simprocs.add name true) defaultSimprocs - let ctx ← Simp.mkContext config (simpTheorems := #[simpThms]) congrTheorems + let ctx ← Simp.mkContext config (simpTheorems := ⟨ simpThms :: addSimpThms ⟩) congrTheorems pure (ctx, #[simprocs]) inductive Location where @@ -733,7 +767,8 @@ inductive Location where | targets (hypotheses : Array Syntax) (type : Bool) -- Adapted from Tactic.simpLocation -def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (discharge? : Option Simp.Discharge := none) +def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) + (discharge? : Option Simp.Discharge := none) (loc : Location) : TacticM Simp.Stats := do match loc with | Location.targets hyps simplifyTarget => @@ -753,28 +788,29 @@ def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (dis simpLocation.go ctx simprocs discharge? tgts (simplifyTarget := true) /- Call the simp tactic. -/ -def simpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : List Name) +def simpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : List Name) (simpThms : List SimpTheorems) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Location) : Tactic.TacticM Unit := do -- Initialize the simp context - let (ctx, simprocs) ← mkSimpCtx simpOnly config .simp simprocs declsToUnfold thms hypsToUse + let (ctx, simprocs) ← mkSimpCtx simpOnly config .simp simprocs simpThms declsToUnfold thms hypsToUse -- Apply the simplifier let _ ← customSimpLocation ctx simprocs (discharge? := .none) loc /- Call the dsimp tactic. -/ -def dsimpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : List Name) +def dsimpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : List Name) (simpThms : List SimpTheorems) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Tactic.Location) : Tactic.TacticM Unit := do -- Initialize the simp context - let (ctx, simprocs) ← mkSimpCtx simpOnly config .dsimp simprocs declsToUnfold thms hypsToUse + let (ctx, simprocs) ← mkSimpCtx simpOnly config .dsimp simprocs simpThms declsToUnfold thms hypsToUse -- Apply the simplifier dsimpLocation ctx simprocs loc -- Call the simpAll tactic -def simpAll (config : Simp.Config) (simpOnly : Bool) (simprocs : List Name) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : +def simpAll (config : Simp.Config) (simpOnly : Bool) (simprocs : List Name) (simpThms : List SimpTheorems) + (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : Tactic.TacticM Unit := do -- Initialize the simp context - let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll simprocs declsToUnfold thms hypsToUse + let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll simprocs simpThms declsToUnfold thms hypsToUse -- Apply the simplifier let (result?, _) ← Lean.Meta.simpAll (← getMainGoal) ctx (simprocs := simprocs) match result? with @@ -868,30 +904,35 @@ def evalAesopSaturate (options : Aesop.Options') (ruleSets : Array Name) : Tacti def normalizeLetBindings (e : Expr) : MetaM Expr := zetaReduce e -/-- For the attributes - - If we apply an attribute to a definition in a group of mutually recursive definitions - (say, to `foo` in the group [`foo`, `bar`]), the attribute gets applied to `foo` but also to - the recursive definition which encodes `foo` and `bar` (Lean encodes mutually recursive - definitions in one recursive definition, e.g., `foo._mutual`, before deriving the individual - definitions, e.g., `foo` and `bar`, from this one). This definition should be named `foo._mutual` - or `bar._mutual`, and we generally want to ignore it. - - TODO: same problem happens if we use decreases clauses, etc. - - Below, we implement a small utility to do so. - -/ -def attrIgnoreAuxDef (name : Name) (default : AttrM α) (x : AttrM α) : AttrM α := do - -- TODO: this is a hack - if let .str _ "_mutual" := name then - trace[Utils] "Ignoring a mutually recursive definition: {name}" - default - else if let .str _ "_unary" := name then - trace[Utils] "Ignoring a unary def: {name}" - default - else - -- Normal execution - x +section + variable [Monad m] [MonadOptions m] [MonadTrace m] [MonadLiftT IO m] [AddMessageContext m] [MonadError m] + variable {α : Type} + + /-- For the attributes + + If we apply an attribute to a definition in a group of mutually recursive definitions + (say, to `foo` in the group [`foo`, `bar`]), the attribute gets applied to `foo` but also to + the recursive definition which encodes `foo` and `bar` (Lean encodes mutually recursive + definitions in one recursive definition, e.g., `foo._mutual`, before deriving the individual + definitions, e.g., `foo` and `bar`, from this one). This definition should be named `foo._mutual` + or `bar._mutual`, and we generally want to ignore it. + + TODO: same problem happens if we use decreases clauses, etc. + + Below, we implement a small utility to do so. + -/ + def attrIgnoreAuxDef (name : Name) (default : m α) (x : m α) : m α := do + -- TODO: this is a hack + if let .str _ "_mutual" := name then + trace[Utils] "Ignoring a mutually recursive definition: {name}" + default + else if let .str _ "_unary" := name then + trace[Utils] "Ignoring a unary def: {name}" + default + else + -- Normal execution + x +end /-- Split anything in the context, and return the resulting set of subgoals. Raise an exception if we couldn't split. @@ -985,33 +1026,544 @@ example (x y : Int) : True := by example (x y : Int) : True := by dcases h: x = y <;> simp -def extractGoal : TacticM Unit := do +/-- Inspired by the `clear` tactic -/ +def clearFvarIds (fvarIds : Array FVarId) : TacticM Unit := do + let fvarIds ← withMainContext <| sortFVarIds fvarIds + for fvarId in fvarIds.reverse do + withMainContext do + let mvarId ← (← getMainGoal).clear fvarId + replaceMainGoal [mvarId] + +/-- Minimize the goal by removing all the unnecessary variables and assumptions -/ +partial def minimizeGoal : TacticM Unit := do withMainContext do + /- Retrieve the goal -/ + let goal ← getMainGoal + let goalFVarIds ← getFVarIds (← goal.getType) + /- Explore the local declarations to check which ones are need. + We do this recursively until we reach a fixed-point. -/ let ctx ← Lean.MonadLCtx.getLCtx let decls ← ctx.getDecls + let declsFVarIds := Std.HashSet.ofList (decls.map (fun d => d.fvarId)) + /- -/ + let mut changed := true + let mut neededIds := goalFVarIds + -- We need to filter the variables: some of them might come from quantifiers + neededIds := neededIds.filter (fun x => x ∈ declsFVarIds) + let mut exploredIds : Std.HashSet FVarId := Std.HashSet.empty + while changed do + changed := false + for decl in decls do + /- Shortcut: do not re-explore the already explored ids -/ + if decl.fvarId ∉ exploredIds then + trace[Utils] "Exploring: {decl.userName}" + exploredIds := exploredIds.insert decl.fvarId + /- Explore the type and the body: if they contain needed ids, add it -/ + let mut declIds ← getFVarIds decl.type + match decl.value? with + | none => pure () + | some value => + declIds := declIds.union (← getFVarIds value) + declIds := declIds.filter (fun x => x ∈ declsFVarIds) + trace[Utils] "declIds: {← declIds.toArray.mapM (fun x => x.getUserName)}" + let mut inter := false + for x in declIds do + if x ∈ neededIds then + inter := true + break + /- Check if there is an intersection -/ + if inter then + neededIds := neededIds.insert decl.fvarId + neededIds := neededIds.union declIds + changed := true + trace[Utils] "Done exploring the context" + /- Clear all the fvars which were not listed -/ + trace[Utils] "neededIds: {← neededIds.toArray.mapM (fun x => x.getUserName)}" + let allIds ← getFVarIdsAt goal + let allIds := allIds.filter (fun x => x ∉ neededIds) + clearFvarIds allIds + +elab "minimize_goal" : tactic => do + withMainContext do + minimizeGoal + +/-- Print the goal as an auxiliary lemma that can be copy-pasted by the user -/ +def extractGoal (ref : Syntax) (fullGoal : Bool) : TacticM Unit := do + /- First minimize the goal, if necessary -/ + if ¬ fullGoal then + minimizeGoal + withMainContext do + /- Rename the local declarations to avoid collisions -/ + let mut ctx ← Lean.MonadLCtx.getLCtx + let rec stripHygieneAux (n : Name) : MetaM (Bool × Name) := do + trace[Utils] "stripping: {n.toString}" + match n with + | .str pre str => + let (strip, pre) ← stripHygieneAux pre + if strip ∨ str == "_@" ∨ str == "_hyg" then + pure (true, pre) + else pure (false, .str pre str) + | .anonymous => pure (false, .anonymous) + | .num pre i => + let (strip, pre) ← stripHygieneAux pre + if strip then pure (true, pre) else pure (false, .num pre i) + let stripHygiene n : MetaM Name := do pure (← stripHygieneAux n).snd + + let rec renameDecls (allNames : Std.HashSet Name) (decls : List LocalDecl) : MetaM LocalContext := do + match decls with + | [] => Lean.MonadLCtx.getLCtx + | decl :: decls => + trace[Utils] "declName: {decl.userName.toString}" + let userName ← stripHygiene decl.userName + trace[Utils] "declName after stripping hygiene parts: {userName.toString}" + if userName ∈ allNames then + let lctx ← Lean.MonadLCtx.getLCtx + let newName := lctx.getUnusedName userName + let lctx := lctx.setUserName decl.fvarId newName + let allNames := allNames.insert newName + withLCtx' lctx do + renameDecls allNames decls + else + let allNames := allNames.insert userName + renameDecls allNames decls + let lctx ← renameDecls Std.HashSet.empty (← (← Lean.MonadLCtx.getLCtx).getDecls).reverse + withLCtx' lctx do + /- Extract the goal -/ + let decls ← ctx.getDecls let assumptions : List Format ← decls.mapM fun decl => do let ty ← Meta.ppExprWithInfos decl.type - /- TODO: we might want to update the names of the local - declarations, to use proper names for the variables - which are shadowed/have been introduced automatically - by the tactics/elaboration -/ let name ← Meta.ppExprWithInfos (Expr.fvar decl.fvarId) pure ("\n (" ++ name.fmt ++ " : " ++ ty.fmt ++ ")") let assumptions := Format.joinSep assumptions "" let mgoal ← getMainGoal let goal ← Meta.ppExprWithInfos (← mgoal.getType) - let msg := "example " ++ assumptions ++ " :\n " ++ goal.fmt ++ "\n := sorry" - println! msg + let msg := "example" ++ assumptions ++ " :\n " ++ goal.fmt ++ "\n := by sorry" + logInfoAt ref m!"{msg}" -elab "extract_goal" : tactic => do +elab ref:"extract_goal0" full:"full"? : tactic => do withMainContext do - extractGoal + extractGoal ref full.isSome + +syntax "extract_goal" ("full")? : tactic +macro_rules +| `(tactic|extract_goal) => + `(tactic|set_option pp.coercions.types true in extract_goal0) +| `(tactic|extract_goal full) => + `(tactic|set_option pp.coercions.types true in extract_goal0 full) + +/-- +info: example + (x : Nat) + (y : Nat) + (h_1 : x ≤ y) + (h : y ≤ y) : + x ≤ y + := by sorry +-/ +#guard_msgs in +set_option linter.unusedVariables false in +example (x x y : Nat) (h : x ≤ y) (h : y ≤ y) : x ≤ y := by + extract_goal + omega + +/-- +info: example + (x : Nat) + (y : Nat) + (h : x ≤ y) : + y ≥ x + := by sorry +-/ +#guard_msgs in example (x : Nat) (y : Nat) (_ : Nat) (h : x ≤ y) : y ≥ x := by - set_option linter.unusedTactic false in extract_goal omega +/-- +info: example + (v : List Nat) + (i : Nat) + (x_3 : Nat) + (v1 : List Nat) + (h_1 : i ≤ v.length) + (h : i < v.length) + (x_2 : x_3 = v.get! i) + (x_1 : i = i + 1) + (x✝ : v1.length = v.length) : + v1.length = v.length + := by sorry +-/ +#guard_msgs in +set_option linter.unusedVariables false in +example + (v : List Nat) + (i : Nat) + (x : Nat) + (i1 : Usize) + (v1 : List Nat) + (h : i ≤ v.length) + (h : i < v.length) + (_ : x = v.get! i) + (_ : i = i + 1) + (_ : v1.length = v.length) : + v1.length = v.length + := by + extract_goal + simp [*] + +/-- Introduce an auxiliary assertion for the goal -/ +def extractAssert (ref : Syntax) : TacticM Unit := do + withMainContext do + let goal ← (← getMainGoal).getType + let goal ← Lean.Meta.Tactic.TryThis.delabToRefinableSyntax goal + let tac : TSyntax `tactic ← `(tactic|have : $goal := by sorry) + /- Remark: there exists addHaveSuggestion -/ + Meta.Tactic.TryThis.addSuggestion ref tac (origSpan? := ← getRef) + +elab tk:"extract_assert" : tactic => do + withMainContext do + extractAssert tk + +/-- +info: Try this: have : y ≥ x := by sorry +-/ +#guard_msgs in +set_option linter.unusedTactic false in +example (x : Nat) (y : Nat) (_ : Nat) (h : x ≤ y) : y ≥ x := by + extract_assert + omega + +/- Group a list of expressions into a (non-dependent) tuple -/ +def mkProdsVal (xl : List Expr) : MetaM Expr := + match xl with + | [] => + pure (Expr.const ``PUnit.unit [Level.succ .zero]) + | [x] => do + pure x + | x :: xl => do + let xl ← mkProdsVal xl + mkAppM ``Prod.mk #[x, xl] + +def mkProdType (x y : Expr) : MetaM Expr := + mkAppM ``Prod #[x, y] + +def mkProd (x y : Expr) : MetaM Expr := + mkAppM ``Prod.mk #[x, y] + +/- Deconstruct a sigma type. + + For instance, deconstructs `(a : Type) × List a` into + `Type` and `λ a => List a`. + -/ +def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do + ty.withApp fun f args => do + if ¬ f.isConstOf ``Sigma ∨ args.size ≠ 2 then + throwError "Invalid argument to getSigmaTypes: {ty}" + else + pure (args.get! 0, args.get! 1) + +/- Make a sigma type. + + `x` should be a variable, and `ty` and type which (might) uses `x` + -/ +def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do + trace[Utils] "mkSigmaType: {x} {sty}" + let alpha ← inferType x + let beta ← mkLambdaFVars #[x] sty + trace[Utils] "mkSigmaType: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + +/- Generate a Sigma type from a list of *variables* (all the expressions + must be variables). + + Example: + - xl = [(a:Type), (ls:List a), (i:Int)] + + Generates: + `(a:Type) × (ls:List a) × (i:Int)` + + -/ +def mkSigmasType (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Utils] "mkSigmasType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) + | [x] => do + trace[Utils] "mkSigmasType: [{x}]" + let ty ← inferType x + pure ty + | x :: xl => do + trace[Utils] "mkSigmasType: [{x}::{xl}]" + let sty ← mkSigmasType xl + mkSigmaType x sty + +/- Generate a product type from a list of *variables*. + + Example: + - xl = [(ls:List a), (i:Int)] + + Generates: + `List a × Int` + -/ +def mkProdsType (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Utils] "mkProdsType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) + | [x] => do + trace[Utils] "mkProdsType: [{x}]" + let ty ← inferType x + pure ty + | x :: xl => do + trace[Utils] "mkProdsType: [{x}::{xl}]" + let ty ← inferType x + let xl_ty ← mkProdsType xl + mkAppM ``Prod #[ty, xl_ty] + +/- Split the input arguments between the types and the "regular" arguments. + + We do something simple: we treat an input argument as an + input type iff it appears in the type of the following arguments. + + Note that what really matters is that we find the arguments which appear + in the output type. + + Also, we stop at the first input that we treat as an + input type. + -/ +def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × Array Expr) := do + -- Look for the first parameter which appears in the subsequent parameters + let rec splitAux (in_tys : List Expr) : MetaM (Std.HashSet FVarId × List Expr × List Expr) := + match in_tys with + | [] => do + let fvars ← getFVarIds (← inferType out_ty) + pure (fvars, [], []) + | ty :: in_tys => do + let (fvars, in_tys, in_args) ← splitAux in_tys + -- Have we already found where to split between type variables/regular + -- variables? + if ¬ in_tys.isEmpty then + -- The fvars set is now useless: no need to update it anymore + pure (fvars, ty :: in_tys, in_args) + else + -- Check if ty appears in the set of free variables: + let ty_id := ty.fvarId! + if fvars.contains ty_id then + -- We must split here. Note that we don't need to update the fvars + -- set: it is not useful anymore + pure (fvars, [ty], in_args) + else + -- We must split later: update the fvars set + let fvars := fvars.insertMany (← getFVarIds (← inferType ty)) + pure (fvars, [], ty :: in_args) + let (_, in_tys, in_args) ← splitAux in_tys.toList + pure (Array.mk in_tys, Array.mk in_args) + +/- Apply a lambda expression to some arguments, simplifying the lambdas -/ +def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do + lambdaTelescopeN e xs.size fun vars body => + -- Create the substitution + let s : Std.HashMap FVarId Expr := Std.HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList) + -- Substitute in the body + pure (body.replace fun e => + match e with + | Expr.fvar fvarId => match s.get? fvarId with + | none => e + | some v => v + | _ => none) + +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + + We need the type argument because as the elements in the tuple are + "concrete", we can't in all generality figure out the type of the tuple. + + Example: + `⟨ True, 3 ⟩ : (x : Bool) × (if x then Int else Unit)` + -/ +def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Utils] "mkSigmasVal: []" + pure (Expr.const ``PUnit.unit [Level.succ .zero]) + | [x] => do + trace[Utils] "mkSigmasVal: [{x}]" + pure x + | fst :: xl => do + trace[Utils] "mkSigmasVal: [{fst}::{xl}]" + -- Deconstruct the type + let (alpha, beta) ← getSigmaTypes ty + -- Compute the "second" field + -- Specialize beta for fst + let nty ← applyLambdaToArgs beta #[fst] + -- Recursive call + let snd ← mkSigmasVal nty xl + -- Put everything together + trace[Utils] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + +def mkAnonymous (s : String) (i : Nat) : Name := + .num (.str .anonymous s) i + +/- Given a list of values `[x0:ty0, ..., xn:ty1]`, where every `xi` might use the previous + `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following + expression: + ``` + fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple + match x with + | (x0, ..., xn) => out + ``` + + The `index` parameter is used for naming purposes: we use it to numerotate the + bound variables that we introduce. + + We use this function to currify functions (the function bodies given to the + fixed-point operator must be unary functions). + + Example: + ======== + - xl = `[a:Type, ls:List a, i:Int]` + - out = `a` + - index = 0 + + generates (getting rid of most of the syntactic sugar): + ``` + λ scrut0 => match scrut0 with + | Sigma.mk x scrut1 => + match scrut1 with + | Sigma.mk ls i => + a + ``` +-/ +partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkSigmasMatch: empty list of input parameters" + | [x] => do + -- In the example given for the explanations: this is the inner match case + trace[Utils] "mkSigmasMatch: [{x}]" + mkLambdaFVars #[x] out + | fst :: xl => do + /- In the example given for the explanations: this is the outer match case + Remark: for the naming purposes, we use the same convention as for the + fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + those definitions might help) + + We want to build the match expression: + ``` + λ scrut => + match scrut with + | Sigma.mk x ... -- the hole is given by a recursive call on the tail + ``` -/ + trace[Utils] "mkSigmasMatch: [{fst}::{xl}]" + let alpha ← inferType fst + let snd_ty ← mkSigmasType xl + let beta ← mkLambdaFVars #[fst] snd_ty + let snd ← mkSigmasMatch xl out (index + 1) + let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkSigmaType fst snd_ty + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do + trace[Utils] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + match out_ty with + | .sort _ | .lit _ | .const .. => + -- The type of the motive doesn't depend on the scrutinee + mkLambdaFVars #[scrut] out_ty + | _ => + /- The type of the motive *may* depend on the scrutinee + TODO: make this more efficient (we could change the output type of + mkSigmasMatch -/ + mkSigmasMatch (fst :: xl) out_ty + -- The final expression: putting everything together + trace[Utils] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable + let sm ← mkLambdaFVars #[scrut] sm + trace[Utils] "mkSigmasMatch: sm: {sm}" + pure sm + +/- This is similar to `mkSigmasMatch`, but with non-dependent tuples + + Remark: factor out with `mkSigmasMatch`? This is extremely similar. +-/ +partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkProdsMatch: empty list of input parameters" + | [x] => do + -- In the example given for the explanations: this is the inner match case + trace[Utils] "mkProdsMatch: [{x}]" + mkLambdaFVars #[x] out + | fst :: xl => do + trace[Utils] "mkProdsMatch: [{fst}::{xl}]" + let alpha ← inferType fst + let beta ← mkProdsType xl + let snd ← mkProdsMatch xl out (index + 1) + let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkProdType alpha beta + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do + trace[Utils] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + match out_ty with + | .sort _ | .lit _ | .const .. => + -- The type of the motive doesn't depend on the scrutinee + mkLambdaFVars #[scrut] out_ty + | _ => + /- The type of the motive *may* depend on the scrutinee + TODO: make this more efficient (we could change the output type of + mkProdsMatch) -/ + mkProdsMatch (fst :: xl) out_ty + /-let motive ← do + let out_ty ← inferType out + mkLambdaFVars #[scrut] out_ty-/ + -- The final expression: putting everything together + trace[Utils] "mkProdsMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Prod.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable + let sm ← mkLambdaFVars #[scrut] sm + trace[Utils] "mkProdsMatch: sm: {sm}" + pure sm + +/- Same as `mkSigmasMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkSigmasMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkSigmasMatch xl out + +/- Same as `mkProdsMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkProdsMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkProdsMatch xl out + + end Utils end Aeneas diff --git a/backends/lean/Aeneas/Zify.lean b/backends/lean/Aeneas/Zify.lean new file mode 100644 index 00000000..8be2826f --- /dev/null +++ b/backends/lean/Aeneas/Zify.lean @@ -0,0 +1,4 @@ +import Aeneas.Arith.Lemmas +open Aeneas.Arith + +attribute [zify_simps] ZMod.val_intCast ZMod.castInt_val_sub diff --git a/src/Config.ml b/src/Config.ml index d990986b..bfb24417 100644 --- a/src/Config.ml +++ b/src/Config.ml @@ -284,6 +284,37 @@ let decompose_nested_let_patterns = ref false *) let unfold_monadic_let_bindings = ref false +(** Perform the following transformation: + + {[ + let y <-- f x (* Must be an application, is not necessarily monadic *) + let (a, b) := y (* Tuple decomposition *) + ]} + + becomes: + + {[ + let (a, b) <-- f x + ]} + *) +let merge_let_app_decompose_tuple = ref false + +(** Perform the following transformation: + + {[ + let y <-- ok e + + ~~> + + let y <-- toResult e + ]} + + We only do this on a specific set of pure functions calls - those + functions are identified in the "builtin" information about external + function calls. + *) +let lift_pure_function_calls = ref false + (** Introduce calls to [massert] (monadic assertion). The pattern below is very frequent especially as it is introduced by diff --git a/src/Logging.ml b/src/Logging.ml index 12b5ec79..866cacd4 100644 --- a/src/Logging.ml +++ b/src/Logging.ml @@ -46,7 +46,7 @@ let simplify_aggregates_unchanged_fields_log = create_logger "PureMicroPasses.simplify_aggregates_unchanged_fields" (** Logger for ExtractBase *) -let extract_log = create_logger "ExtractBase" +let extract_log = create_logger "Extract" (** Logger for ExtractBuiltin *) let builtin_log = create_logger "Builtin" diff --git a/src/Main.ml b/src/Main.ml index 47f74a29..50d98794 100644 --- a/src/Main.ml +++ b/src/Main.ml @@ -364,7 +364,9 @@ let () = record_fields_short_names := true; (* We exploit the fact that the variant name should always be prefixed with the type name to prevent collisions *) - variant_concatenate_type_name := false + variant_concatenate_type_name := false; + (* *) merge_let_app_decompose_tuple := true; + lift_pure_function_calls := true | HOL4 -> (* We don't support fuel for the HOL4 backend *) if !use_fuel then ( diff --git a/src/NameMatcher.ml b/src/NameMatcher.ml new file mode 100644 index 00000000..1d3d4f37 --- /dev/null +++ b/src/NameMatcher.ml @@ -0,0 +1 @@ +include Charon.NameMatcher diff --git a/src/Translate.ml b/src/Translate.ml index c674dab7..35bc2853 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -977,7 +977,7 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info) (* Add the custom includes *) List.iter (fun m -> Printf.fprintf out "import %s\n" m) fi.custom_includes; (* Always open the Primitives namespace *) - Printf.fprintf out "open Aeneas.Std\n"; + Printf.fprintf out "open Aeneas.Std Result Error\n"; (* It happens that we generate duplicated namespaces, like `betree.betree`. We deactivate the linter for this, because otherwise it leads to too much noise. *) @@ -1206,8 +1206,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : ^ " because of previous error\nName pattern: '" ^ name_pattern ^ "'" ); ctx) - ctx - (GlobalDeclId.Map.values crate.global_decls) + ctx trans_globals in let ctx = diff --git a/src/TranslateCore.ml b/src/TranslateCore.ml index 84539ad1..8c6be995 100644 --- a/src/TranslateCore.ml +++ b/src/TranslateCore.ml @@ -20,22 +20,21 @@ let trans_ctx_to_pure_fmt_env (ctx : trans_ctx) : PrintPure.fmt_env = let name_to_string (ctx : trans_ctx) = Print.Types.name_to_string (trans_ctx_to_fmt_env ctx) -let match_name_find_opt (ctx : trans_ctx) (name : Types.name) - (m : 'a NameMatcherMap.t) : 'a option = - let mctx = Charon.NameMatcher.ctx_from_crate ctx.crate in - let open ExtractBuiltin in - NameMatcherMap.find_opt mctx name m - -let match_name_with_generics_find_opt (ctx : trans_ctx) (name : Types.name) - (generics : Types.generic_args) (m : 'a NameMatcherMap.t) : 'a option = - let mctx = Charon.NameMatcher.ctx_from_crate ctx.crate in - let open ExtractBuiltin in - NameMatcherMap.find_with_generics_opt mctx name generics m - let name_to_simple_name (ctx : trans_ctx) (n : Types.name) : string list = let mctx = Charon.NameMatcher.ctx_from_crate ctx.crate in name_to_simple_name mctx n +let match_name_find_opt (ctx : trans_ctx) (name : Types.name) + (m : 'a NameMatcher.NameMatcherMap.t) : 'a option = + let mctx = NameMatcher.ctx_from_crate ctx.crate in + ExtractBuiltin.NameMatcherMap.find_opt mctx name m + +let match_name_with_generics_find_opt (ctx : trans_ctx) (name : Types.name) + (generics : Types.generic_args) (m : 'a NameMatcher.NameMatcherMap.t) : + 'a option = + let mctx = NameMatcher.ctx_from_crate ctx.crate in + ExtractBuiltin.NameMatcherMap.find_with_generics_opt mctx name generics m + let trait_name_with_generics_to_simple_name (ctx : trans_ctx) ?(prefix : Types.name option = None) (n : Types.name) (p : Types.generic_params) (g : Types.generic_args) : string list = diff --git a/src/dune b/src/dune index f563565f..b2caf8ee 100644 --- a/src/dune +++ b/src/dune @@ -57,6 +57,7 @@ LlbcOfJson Logging Meta + NameMatcher PrePasses Print PrintPure diff --git a/src/extract/Extract.ml b/src/extract/Extract.ml index 74796b5d..ed327450 100644 --- a/src/extract/Extract.ml +++ b/src/extract/Extract.ml @@ -32,18 +32,12 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) declarations *) ctx | _ -> ( - (* Check if the function is builtin *) - let builtin = - let open ExtractBuiltin in - let funs_map = builtin_funs_map () in - match_name_find_opt ctx.trans_ctx def.f.item_meta.name funs_map - in (* Use the builtin names if necessary *) - match builtin with - | Some (filter_info, fun_info) -> + match def.f.builtin_info with + | Some info -> (* Builtin function: register the filtering information, if there is *) let ctx = - match filter_info with + match info.filter_params with | Some keep -> { ctx with @@ -54,10 +48,9 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) | _ -> ctx in let f = def.f in - let open ExtractBuiltin in let fun_id = (Pure.FunId (FRegular f.def_id), f.loop_id) in - ctx_add f.item_meta.span (FunId (FromLlbc fun_id)) - fun_info.extract_name ctx + ctx_add f.item_meta.span (FunId (FromLlbc fun_id)) info.extract_name + ctx | None -> (* Not builtin *) (* If this is a trait method implementation, we prefix the name with the @@ -87,7 +80,7 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (** Simply add the global name to the context. *) let extract_global_decl_register_names (ctx : extraction_ctx) - (def : A.global_decl) : extraction_ctx = + (def : global_decl) : extraction_ctx = ctx_add_global_decl_and_body def ctx (** The following function factorizes the extraction of ADT values. @@ -590,8 +583,19 @@ and extract_function_call (span : Meta.span) (ctx : extraction_ctx) } | Pure (UpdateAtIndex Slice) -> Some { explicit_types = [ Implicit ]; explicit_const_generics = [] } + | Pure ToResult -> + Some { explicit_types = [ Implicit ]; explicit_const_generics = [] } | Pure _ -> None in + (* Special case for [ToResult]: we don't want to print a space between the + coercion symbol and the expression - TODO: this is a bit ad-hoc *) + let print_first_space = + if Config.backend () = Lean then + match fun_id with + | Pure ToResult -> false + | _ -> true + else true + in (* Filter the generics. We might need to filter some of the type arguments, if the type @@ -616,9 +620,10 @@ and extract_function_call (span : Meta.span) (ctx : extraction_ctx) "(\"ERROR: ill-formed builtin: invalid number of filtering \ arguments\")"); (* Print the arguments *) + let print_space = ref print_first_space in List.iter (fun ve -> - F.pp_print_space fmt (); + if !print_space then F.pp_print_space fmt () else print_space := true; extract_texpression span ctx fmt true ve) args; (* Close the box for the function call *) @@ -824,6 +829,17 @@ and extract_lets (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter) (re : texpression) : extraction_ctx = (* Open a box for the let-binding *) F.pp_open_hovbox fmt ctx.indent_incr; + (* Should we add a type annotation? It is necessary when the bound expression + has a coercion. *) + let extract_type_annot = + monadic + && backend () = Lean + && + match re.e with + | App ({ e = Qualif { id = FunOrOp (Fun (Pure ToResult)); _ }; _ }, _) -> + true + | _ -> false + in let ctx = (* There are two cases: - do we use a notation like [x <-- y;] @@ -871,7 +887,14 @@ and extract_lets (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter) else ( F.pp_print_string fmt "let"; F.pp_print_space fmt ()); + if extract_type_annot then F.pp_print_string fmt "("; let ctx = extract_typed_pattern span ctx fmt true true lv in + if extract_type_annot then ( + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty span ctx fmt TypeDeclId.Set.empty false lv.ty; + F.pp_print_string fmt ")"); F.pp_print_space fmt (); let eq = match backend () with @@ -2184,8 +2207,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (** Similar to {!extract_trait_decl_register_names} *) let extract_trait_decl_register_parent_clause_names (ctx : extraction_ctx) (trait_decl : trait_decl) - (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : - extraction_ctx = + (builtin_info : Pure.builtin_trait_decl_info option) : extraction_ctx = (* Compute the clause names *) let clause_names = match builtin_info with @@ -2226,8 +2248,7 @@ let extract_trait_decl_register_parent_clause_names (ctx : extraction_ctx) (** Similar to {!extract_trait_decl_register_names} *) let extract_trait_decl_register_constant_names (ctx : extraction_ctx) (trait_decl : trait_decl) - (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : - extraction_ctx = + (builtin_info : Pure.builtin_trait_decl_info option) : extraction_ctx = let consts = trait_decl.consts in (* Compute the names *) let constant_names = @@ -2261,8 +2282,7 @@ let extract_trait_decl_register_constant_names (ctx : extraction_ctx) (** Similar to {!extract_trait_decl_register_names} *) let extract_trait_decl_type_names (ctx : extraction_ctx) (trait_decl : trait_decl) - (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : - extraction_ctx = + (builtin_info : Pure.builtin_trait_decl_info option) : extraction_ctx = let types = trait_decl.types in (* Compute the names *) let type_names = @@ -2300,8 +2320,8 @@ let extract_trait_decl_type_names (ctx : extraction_ctx) (** Similar to {!extract_trait_decl_register_names} *) let extract_trait_decl_method_names (ctx : extraction_ctx) (trait_decl : trait_decl) - (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : - extraction_ctx = + (builtin_info : Pure.builtin_trait_decl_info option) : extraction_ctx = + log#ltrace (lazy (__FUNCTION__ ^ ": " ^ trait_decl.name)); let methods = trait_decl.methods in (* Compute the names *) let method_names = @@ -2310,6 +2330,10 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) (* Not a builtin function *) let compute_item_name (item_name : string) (id : fun_decl_id) : string * string = + log#ldebug + (lazy + (__FUNCTION__ ^ "(" ^ trait_decl.name ^ "): compute_item_name: " + ^ item_name)); let trans : pure_fun_translation = match FunDeclId.Map.find_opt id ctx.trans_funs with | Some decl -> decl @@ -2329,6 +2353,10 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) let f = { f with item_meta = { f.item_meta with name = llbc_name } } in + log#ldebug + (lazy + (__FUNCTION__ ^ ": compute_item_name: llbc_name=" + ^ name_to_string ctx f.item_meta.name)); let name = ctx_compute_fun_name f true ctx in (* Add a prefix if necessary *) let name = @@ -2346,7 +2374,6 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) let funs_map = StringMap.of_list info.methods in List.map (fun (item_name, _) -> - let open ExtractBuiltin in let info = StringMap.find item_name funs_map in let fun_name = info.extract_name in (item_name, fun_name)) @@ -2363,15 +2390,9 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) (** Similar to {!extract_type_decl_register_names} *) let extract_trait_decl_register_names (ctx : extraction_ctx) (trait_decl : trait_decl) : extraction_ctx = - (* Lookup the information if this is a builtin trait *) - let open ExtractBuiltin in - let builtin_info = - match_name_find_opt ctx.trans_ctx trait_decl.item_meta.name - (builtin_trait_decls_map ()) - in let ctx = let trait_name, trait_constructor = - match builtin_info with + match trait_decl.builtin_info with | None -> ( ctx_compute_trait_decl_name ctx trait_decl, ctx_compute_trait_decl_constructor ctx trait_decl ) @@ -2384,6 +2405,7 @@ let extract_trait_decl_register_names (ctx : extraction_ctx) ctx_add trait_decl.item_meta.span (TraitDeclConstructorId trait_decl.def_id) trait_constructor ctx in + let builtin_info = trait_decl.builtin_info in (* Parent clauses *) let ctx = extract_trait_decl_register_parent_clause_names ctx trait_decl builtin_info @@ -2403,26 +2425,13 @@ let extract_trait_impl_register_names (ctx : extraction_ctx) (trait_impl : trait_impl) : extraction_ctx = let decl_id = trait_impl.impl_trait.trait_decl_id in let trait_decl = TraitDeclId.Map.find decl_id ctx.trans_trait_decls in - (* Check if the trait implementation is builtin *) - let builtin_info = - let open ExtractBuiltin in - (* Lookup the original Rust impl to retrieve the original trait ref (we - use it to derive the name)*) - let trait_impl = - TraitImplId.Map.find trait_impl.def_id ctx.crate.trait_impls - in - let decl_ref = trait_impl.impl_trait in - match_name_with_generics_find_opt ctx.trans_ctx trait_decl.item_meta.name - decl_ref.decl_generics - (builtin_trait_impls_map ()) - in (* Register some builtin information (if necessary) *) let ctx, builtin_info = - match builtin_info with + match trait_impl.builtin_info with | None -> (ctx, None) - | Some (filter, info) -> + | Some builtin_info -> let ctx = - match filter with + match builtin_info.filter_params with | None -> ctx | Some filter -> { @@ -2432,7 +2441,7 @@ let extract_trait_impl_register_names (ctx : extraction_ctx) ctx.trait_impls_filter_type_args_map; } in - (ctx, Some info) + (ctx, Some builtin_info) in (* Everything is taken care of by {!extract_trait_decl_register_names} *but* @@ -2441,7 +2450,7 @@ let extract_trait_impl_register_names (ctx : extraction_ctx) let name = match builtin_info with | None -> ctx_compute_trait_impl_name ctx trait_decl trait_impl - | Some name -> name + | Some info -> info.impl_name in ctx_add trait_decl.item_meta.span (TraitImplId trait_impl.def_id) name ctx diff --git a/src/extract/ExtractBase.ml b/src/extract/ExtractBase.ml index 78d8a948..c788da76 100644 --- a/src/extract/ExtractBase.ml +++ b/src/extract/ExtractBase.ml @@ -842,7 +842,11 @@ let unop_name (unop : unop) : string = match ty with | None -> "not" | Some int_ty -> int_name int_ty ^ "_not") - | Lean -> "¬" + | Lean -> begin + match ty with + | None -> "¬" + | Some _ -> "~~~" + end | Coq -> if Option.is_none ty then "negb" else "scalar_not" | HOL4 -> "~") | Neg (int_ty : integer_type) -> ( @@ -1163,12 +1167,9 @@ let builtin_variants () : (builtin_ty * VariantId.id * string) list = ] | Lean -> [ - (TResult, result_ok_id, "Result.ok"); - (TResult, result_fail_id, "Result.fail"); - (* For panic: we omit the prefix "Error." because the type is always - clear from the context. Also, "Error" is often used by user-defined - types (when we omit the crate as a prefix). *) - (TError, error_failure_id, ".panic"); + (TResult, result_ok_id, "ok"); + (TResult, result_fail_id, "fail"); + (TError, error_failure_id, "panic"); (* No Fuel::Zero on purpose *) (* No Fuel::Succ on purpose *) ] @@ -1223,6 +1224,7 @@ let builtin_pure_functions () : (pure_builtin_fun_id * string) list = (FuelEqZero, "is_zero"); (UpdateAtIndex Slice, "slice_update_usize"); (UpdateAtIndex Array, "array_update_usize"); + (ToResult, "return"); ] | Coq -> (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) @@ -1232,6 +1234,7 @@ let builtin_pure_functions () : (pure_builtin_fun_id * string) list = (Assert, "massert"); (UpdateAtIndex Slice, "slice_update_usize"); (UpdateAtIndex Array, "array_update_usize"); + (ToResult, "return_"); ] | Lean -> (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) @@ -1239,8 +1242,9 @@ let builtin_pure_functions () : (pure_builtin_fun_id * string) list = (Return, "return"); (Fail, "fail_"); (Assert, "massert"); - (UpdateAtIndex Slice, "Slice.update_usize"); - (UpdateAtIndex Array, "Array.update_usize"); + (UpdateAtIndex Slice, "Slice.update"); + (UpdateAtIndex Array, "Array.update"); + (ToResult, "↑"); ] | HOL4 -> (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) @@ -1250,6 +1254,7 @@ let builtin_pure_functions () : (pure_builtin_fun_id * string) list = (Assert, "massert"); (UpdateAtIndex Slice, "slice_update_usize"); (UpdateAtIndex Array, "array_update_usize"); + (ToResult, "return"); ] let names_map_init () : names_map_init = @@ -1447,22 +1452,22 @@ let name_last_elem_as_ident (span : Meta.span) (n : llbc_name) : string = we remove it. We ignore disambiguators (there may be collisions, but we check if there are). *) -let ctx_prepare_name (span : Meta.span) (ctx : extraction_ctx) +let ctx_prepare_name (meta : T.item_meta) (ctx : extraction_ctx) (name : llbc_name) : llbc_name = (* Rmk.: initially we only filtered the disambiguators equal to 0 *) match name with | (PeIdent (crate, _) as id) :: name -> if crate = ctx.crate.name then name else id :: name | _ -> - craise __FILE__ __LINE__ span + craise __FILE__ __LINE__ meta.span ("Unexpected name shape: " ^ TranslateCore.name_to_string ctx.trans_ctx name) (** Helper *) -let ctx_compute_simple_name (span : Meta.span) (ctx : extraction_ctx) +let ctx_compute_simple_name (meta : T.item_meta) (ctx : extraction_ctx) (name : llbc_name) : string list = (* Rmk.: initially we only filtered the disambiguators equal to 0 *) - let name = ctx_prepare_name span ctx name in + let name = if meta.is_local then ctx_prepare_name meta ctx name else name in name_to_simple_name ctx.trans_ctx name (** Helper *) @@ -1472,7 +1477,7 @@ let ctx_compute_simple_type_name = ctx_compute_simple_name let ctx_compute_type_name_no_suffix (ctx : extraction_ctx) (item_meta : Types.item_meta) (name : llbc_name) : string = let name = rename_llbc_name item_meta.attr_info name in - flatten_name (ctx_compute_simple_type_name item_meta.span ctx name) + flatten_name (ctx_compute_simple_type_name item_meta ctx name) (** Provided a basename, compute a type name. @@ -1564,9 +1569,9 @@ let ctx_compute_struct_constructor (def : type_decl) (ctx : extraction_ctx) let tname = ctx_compute_type_name def.item_meta ctx basename in ExtractBuiltin.mk_struct_constructor tname -let ctx_compute_fun_name_no_suffix (span : Meta.span) (ctx : extraction_ctx) +let ctx_compute_fun_name_no_suffix (meta : T.item_meta) (ctx : extraction_ctx) (fname : llbc_name) : string = - let fname = ctx_compute_simple_name span ctx fname in + let fname = ctx_compute_simple_name meta ctx fname in (* TODO: don't convert to snake case for Coq, HOL4, F* *) let fname = flatten_name fname in match backend () with @@ -1574,15 +1579,14 @@ let ctx_compute_fun_name_no_suffix (span : Meta.span) (ctx : extraction_ctx) | Lean -> fname (** Provided a basename, compute the name of a global declaration. *) -let ctx_compute_global_name (span : Meta.span) (ctx : extraction_ctx) +let ctx_compute_global_name (meta : T.item_meta) (ctx : extraction_ctx) (name : llbc_name) : string = + let name = ctx_compute_simple_name meta ctx name in match Config.backend () with | Coq | FStar | HOL4 -> - let parts = - List.map to_snake_case (ctx_compute_simple_name span ctx name) - in + let parts = List.map to_snake_case name in String.concat "_" parts - | Lean -> flatten_name (ctx_compute_simple_name span ctx name) + | Lean -> flatten_name name (** Helper function: generate a suffix for a function name, i.e., generates a suffix like "_loop", "loop1", etc. to append to a function name. @@ -1613,10 +1617,10 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) : string = - loop id (if pertinent) TODO: use the fun id for the builtin functions. *) -let ctx_compute_fun_name (span : Meta.span) (ctx : extraction_ctx) +let ctx_compute_fun_name (meta : T.item_meta) (ctx : extraction_ctx) (fname : llbc_name) (num_loops : int) (loop_id : LoopId.id option) : string = - let fname = ctx_compute_fun_name_no_suffix span ctx fname in + let fname = ctx_compute_fun_name_no_suffix meta ctx fname in (* Compute the suffix *) let suffix = default_fun_suffix num_loops loop_id in (* Concatenate *) @@ -1646,8 +1650,7 @@ let ctx_compute_trait_impl_name (ctx : extraction_ctx) (trait_decl : trait_decl) let params = trait_impl.llbc_generics in let args = trait_impl.llbc_impl_trait.decl_generics in let name = - ctx_prepare_name trait_impl.item_meta.span ctx - trait_decl.item_meta.name + ctx_prepare_name trait_impl.item_meta ctx trait_decl.item_meta.name in let name = rename_llbc_name trait_impl.item_meta.attr_info name in trait_name_with_generics_to_simple_name ctx.trans_ctx name params args @@ -1799,17 +1802,17 @@ let ctx_compute_trait_type_clause_name (ctx : extraction_ctx) the same purpose as in [llbc_name]. - loop identifier, if this is for a loop *) -let ctx_compute_termination_measure_name (span : Meta.span) +let ctx_compute_termination_measure_name (meta : T.item_meta) (ctx : extraction_ctx) (_fid : A.FunDeclId.id) (fname : llbc_name) (num_loops : int) (loop_id : LoopId.id option) : string = - let fname = ctx_compute_fun_name_no_suffix span ctx fname in + let fname = ctx_compute_fun_name_no_suffix meta ctx fname in let lp_suffix = default_fun_loop_suffix num_loops loop_id in (* Compute the suffix *) let suffix = match Config.backend () with | FStar -> "_decreases" | Lean -> "_terminates" - | Coq | HOL4 -> craise __FILE__ __LINE__ span "Unexpected" + | Coq | HOL4 -> craise __FILE__ __LINE__ meta.span "Unexpected" in (* Concatenate *) fname ^ lp_suffix ^ suffix @@ -1828,16 +1831,16 @@ let ctx_compute_termination_measure_name (span : Meta.span) the same purpose as in [llbc_name]. - loop identifier, if this is for a loop *) -let ctx_compute_decreases_proof_name (span : Meta.span) (ctx : extraction_ctx) +let ctx_compute_decreases_proof_name (meta : T.item_meta) (ctx : extraction_ctx) (_fid : A.FunDeclId.id) (fname : llbc_name) (num_loops : int) (loop_id : LoopId.id option) : string = - let fname = ctx_compute_fun_name_no_suffix span ctx fname in + let fname = ctx_compute_fun_name_no_suffix meta ctx fname in let lp_suffix = default_fun_loop_suffix num_loops loop_id in (* Compute the suffix *) let suffix = match Config.backend () with | Lean -> "_decreases" - | FStar | Coq | HOL4 -> craise __FILE__ __LINE__ span "Unexpected" + | FStar | Coq | HOL4 -> craise __FILE__ __LINE__ meta.span "Unexpected" in (* Concatenate *) fname ^ lp_suffix ^ suffix @@ -2107,7 +2110,7 @@ let ctx_add_decreases_proof (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = let name = rename_llbc_name def.item_meta.attr_info def.item_meta.name in let name = - ctx_compute_decreases_proof_name def.item_meta.span ctx def.def_id name + ctx_compute_decreases_proof_name def.item_meta ctx def.def_id name def.num_loops def.loop_id in ctx_add def.item_meta.span @@ -2118,14 +2121,14 @@ let ctx_add_termination_measure (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = let name = rename_llbc_name def.item_meta.attr_info def.item_meta.name in let name = - ctx_compute_termination_measure_name def.item_meta.span ctx def.def_id name + ctx_compute_termination_measure_name def.item_meta ctx def.def_id name def.num_loops def.loop_id in ctx_add def.item_meta.span (TerminationMeasureId (FRegular def.def_id, def.loop_id)) name ctx -let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : +let ctx_add_global_decl_and_body (def : global_decl) (ctx : extraction_ctx) : extraction_ctx = (* TODO: update once the body id can be an option *) let decl = GlobalId def.def_id in @@ -2133,18 +2136,16 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : (* Check if the global corresponds to an builtin global that we should map to a custom definition in our standard library (for instance, happens with "core::num::usize::MAX") *) - match - match_name_find_opt ctx.trans_ctx def.item_meta.name builtin_globals_map - with - | Some name -> + match def.builtin_info with + | Some info -> (* Yes: register the custom binding *) - ctx_add def.item_meta.span decl name ctx + ctx_add def.item_meta.span decl info.global_name ctx | None -> (* Not the case: "standard" registration *) let name = rename_llbc_name def.item_meta.attr_info def.item_meta.name in - let name = ctx_compute_global_name def.item_meta.span ctx name in + let name = ctx_compute_global_name def.item_meta ctx name in - let body = FunId (FromLlbc (FunId (FRegular def.body), None)) in + let body = FunId (FromLlbc (FunId (FRegular def.body_id), None)) in (* If this is a provided constant (i.e., the default value for a constant in a trait declaration) we add a suffix. Otherwise there is a clash between the name for the default constant and the name for the field @@ -2202,6 +2203,10 @@ let ctx_compute_fun_name (def : fun_decl) (is_trait_decl_field : bool) | _ -> def.item_meta in let llbc_name = rename_llbc_name item_meta.attr_info def.item_meta.name in + log#ldebug + (lazy + (__FUNCTION__ ^ ": llbc_name after renaming: " + ^ name_to_string ctx llbc_name)); (* When a trait method has a default implementation, this becomes a [fun_decl] that we may want to extract. By default, its name is [Trait::method], which for lean creates a name clash with the method name as a field in the trait @@ -2216,8 +2221,12 @@ let ctx_compute_fun_name (def : fun_decl) (is_trait_decl_field : bool) llbc_name @ [ PeIdent ("default", Disambiguator.zero) ] | _ -> llbc_name in - ctx_compute_fun_name def.item_meta.span ctx llbc_name def.num_loops - def.loop_id + log#ldebug + (lazy + (__FUNCTION__ + ^ ": llbc_name after adding 'default' suffix (for default methods): " + ^ name_to_string ctx llbc_name)); + ctx_compute_fun_name def.item_meta ctx llbc_name def.num_loops def.loop_id (* TODO: move to Extract *) let ctx_add_fun_decl (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = diff --git a/src/extract/ExtractBuiltin.ml b/src/extract/ExtractBuiltin.ml index 08a173e0..7c30abd7 100644 --- a/src/extract/ExtractBuiltin.ml +++ b/src/extract/ExtractBuiltin.ml @@ -5,7 +5,7 @@ *) open Config -open Charon.NameMatcher (* TODO: include? *) +open NameMatcher (* TODO: include? *) include ExtractName (* TODO: only open? *) let log = Logging.builtin_log @@ -81,39 +81,11 @@ let builtin_globals : (string * string) list = ("core::num::{i128}::MAX", "core_i128_max"); ] -let builtin_globals_map : string NameMatcherMap.t = +let builtin_globals_map : Pure.builtin_global_info NameMatcherMap.t = NameMatcherMap.of_list - (List.map (fun (x, y) -> (parse_pattern x, y)) builtin_globals) - -type builtin_variant_info = { fields : (string * string) list } -[@@deriving show] - -type builtin_enum_variant_info = { - rust_variant_name : string; - extract_variant_name : string; - fields : string list option; -} -[@@deriving show] - -type builtin_type_body_info = - | Struct of string * (string * string) list - (* The constructor name and the map for the field names *) - | Enum of builtin_enum_variant_info list -(* For every variant, a map for the field names *) -[@@deriving show] - -type builtin_type_info = { - rust_name : pattern; - extract_name : string; - keep_params : bool list option; - (** We might want to filter some of the type parameters. - - For instance, `Vec` type takes a type parameter for the allocator, - which we want to ignore. - *) - body_info : builtin_type_body_info option; -} -[@@deriving show] + (List.map + (fun (x, y) -> (parse_pattern x, { Pure.global_name = y })) + builtin_globals) type type_variant_kind = | KOpaque @@ -141,17 +113,17 @@ let mk_struct_constructor (type_name : string) : string = parameters. For instance, in the case of the `Vec` functions, there is a type parameter for the allocator to use, which we want to filter. *) -let builtin_types () : builtin_type_info list = +let builtin_types () : Pure.builtin_type_info list = let mk_type (rust_name : string) ?(custom_name : string option = None) ?(keep_params : bool list option = None) - ?(kind : type_variant_kind = KOpaque) () : builtin_type_info = + ?(kind : type_variant_kind = KOpaque) () : Pure.builtin_type_info = let rust_name = parse_pattern rust_name in let extract_name = match custom_name with | None -> flatten_name (pattern_to_type_extract_name rust_name) | Some name -> flatten_name (split_on_separator name) in - let body_info : builtin_type_body_info option = + let body_info : Pure.builtin_type_body_info option = match kind with | KOpaque -> None | KStruct fields -> @@ -181,11 +153,12 @@ let builtin_types () : builtin_type_info list = | Lean -> extract_name ^ "." ^ variant | HOL4 -> extract_name ^ variant in - { - rust_variant_name = variant; - extract_variant_name; - fields = None; - }) + ({ + rust_variant_name = variant; + extract_variant_name; + fields = None; + } + : Pure.builtin_enum_variant_info)) variants in Some (Enum variants) @@ -256,17 +229,12 @@ let builtin_types () : builtin_type_info list = let mk_builtin_types_map () = NameMatcherMap.of_list - (List.map (fun info -> (info.rust_name, info)) (builtin_types ())) + (List.map + (fun (info : Pure.builtin_type_info) -> (info.rust_name, info)) + (builtin_types ())) let builtin_types_map = mk_memoized mk_builtin_types_map -type builtin_fun_info = { - extract_name : string; - can_fail : bool; - stateful : bool; -} -[@@deriving show] - let int_and_smaller_list : (string * string) list = let uint_names = List.rev [ "u8"; "u16"; "u32"; "u64"; "u128" ] in let int_names = List.rev [ "i8"; "i16"; "i32"; "i64"; "i128" ] in @@ -295,12 +263,12 @@ let int_and_smaller_list : (string * string) list = parameters. For instance, in the case of the `Vec` functions, there is a type parameter for the allocator to use, which we want to filter. *) -let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = - (* Small utility *) +let mk_builtin_funs () : (pattern * Pure.builtin_fun_info) list = + (* Small utility. *) let mk_fun (rust_name : string) ?(filter : bool list option = None) - ?(can_fail = true) ?(stateful = false) + ?(can_fail = true) ?(stateful = false) ?(lift = true) ?(extract_name : string option = None) () : - pattern * bool list option * builtin_fun_info = + pattern * Pure.builtin_fun_info = let rust_name = try parse_pattern rust_name with Failure _ -> @@ -312,12 +280,20 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = | Some name -> split_on_separator name in let basename = flatten_name extract_name in - let f = { extract_name = basename; can_fail; stateful } in - (rust_name, filter, f) + let f : Pure.builtin_fun_info = + { + filter_params = filter; + extract_name = basename; + can_fail; + stateful; + lift; + } + in + (rust_name, f) in let mk_scalar_fun (rust_name : string -> string) (extract_name : string -> string) ?(can_fail = true) () : - (pattern * bool list option * builtin_fun_info) list = + (pattern * Pure.builtin_fun_info) list = List.map (fun ty -> mk_fun (rust_name ty) @@ -326,13 +302,13 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = all_int_names in [ - mk_fun "core::mem::replace" ~can_fail:false (); - mk_fun "core::mem::take" ~can_fail:false (); + mk_fun "core::mem::replace" ~can_fail:false ~lift:false (); + mk_fun "core::mem::take" ~can_fail:false ~lift:false (); mk_fun "core::slice::{[@T]}::len" ~extract_name:(Some (backend_choice "slice::len" "Slice::len")) - ~can_fail:false (); + ~can_fail:false ~lift:false (); mk_fun "alloc::vec::{alloc::vec::Vec<@T, alloc::alloc::Global>}::new" - ~extract_name:(Some "alloc::vec::Vec::new") ~can_fail:false (); + ~extract_name:(Some "alloc::vec::Vec::new") ~can_fail:false ~lift:false (); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::push" ~filter:(Some [ true; false ]) (); @@ -341,7 +317,7 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = (); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::len" ~filter:(Some [ true; false ]) - ~can_fail:false (); + ~can_fail:false ~lift:false (); mk_fun "alloc::vec::{core::ops::index::Index, \ @I>}::index" @@ -422,14 +398,15 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = ~extract_name:(Some "alloc.slice.Slice.to_vec") (); mk_fun "alloc::vec::{alloc::vec::Vec<@T, alloc::alloc::Global>}::with_capacity" - ~extract_name:(Some "alloc.vec.Vec.with_capacity") ~can_fail:false (); + ~extract_name:(Some "alloc.vec.Vec.with_capacity") ~can_fail:false + ~lift:false (); mk_fun "core::slice::{[@T]}::reverse" ~extract_name:(Some "core.slice.Slice.reverse") ~can_fail:false (); mk_fun "alloc::vec::{core::ops::deref::Deref>}::deref" ~extract_name:(Some "alloc.vec.DerefVec.deref") ~filter:(Some [ true; false ]) - ~can_fail:false (); + ~can_fail:false ~lift:false (); mk_fun "alloc::vec::{core::ops::deref::DerefMut>}::deref_mut" @@ -521,15 +498,15 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::resize" ~filter:(Some [ true; false ]) (); - mk_fun "core::mem::swap" ~can_fail:false (); + mk_fun "core::mem::swap" ~can_fail:false ~lift:false (); mk_fun "core::option::{core::option::Option<@T>}::take" ~extract_name: (backend_choice None (Some "core::option::Option::take")) - ~can_fail:false (); + ~can_fail:false ~lift:false (); mk_fun "core::option::{core::option::Option<@T>}::is_none" ~extract_name: (backend_choice None (Some "core::option::Option::is_none")) - ~can_fail:false (); + ~can_fail:false ~lift:false (); mk_fun "core::clone::Clone::clone_from" (); (* Into> *) mk_fun "core::convert::{core::convert::Into<@T, @U>}::into" @@ -568,6 +545,11 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = "core::slice::index::{core::slice::index::SliceIndex, \ [@T]>}::index_mut" (); + (* *) + mk_fun "alloc::boxed::{core::convert::AsMut, @T>}::as_mut" + ~can_fail:false + ~filter:(Some [ true; false ]) + (); ] @ List.flatten (List.map @@ -587,22 +569,19 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = (false, "saturating_sub"); (false, "wrapping_add"); (false, "wrapping_sub"); - (true, "overflowing_add"); + (false, "overflowing_add"); (false, "rotate_left"); (false, "rotate_right"); ]) all_int_names)) -let builtin_funs : unit -> (pattern * bool list option * builtin_fun_info) list - = +let builtin_funs : unit -> (pattern * Pure.builtin_fun_info) list = mk_memoized mk_builtin_funs let mk_builtin_funs_map () = let m = NameMatcherMap.of_list - (List.map - (fun (name, filter, info) -> (name, (filter, info))) - (builtin_funs ())) + (List.map (fun (name, info) -> (name, info)) (builtin_funs ())) in log#ltrace (lazy ("builtin_funs_map:\n" ^ NameMatcherMap.to_string (fun _ -> "...") m)); @@ -613,11 +592,9 @@ let builtin_funs_map = mk_memoized mk_builtin_funs_map type effect_info = { can_fail : bool; stateful : bool } let mk_builtin_fun_effects () : (pattern * effect_info) list = - let builtin_funs : (pattern * bool list option * builtin_fun_info) list = - builtin_funs () - in + let builtin_funs : (pattern * Pure.builtin_fun_info) list = builtin_funs () in List.map - (fun ((pattern, _, info) : _ * _ * builtin_fun_info) -> + (fun ((pattern, info) : _ * Pure.builtin_fun_info) -> let info = { can_fail = info.can_fail; stateful = info.stateful } in (pattern, info)) builtin_funs @@ -627,26 +604,12 @@ let mk_builtin_fun_effects_map () = let builtin_fun_effects_map = mk_memoized mk_builtin_fun_effects_map -type builtin_trait_decl_info = { - rust_name : pattern; - extract_name : string; - constructor : string; - parent_clauses : string list; - consts : (string * string) list; - types : (string * string) list; - (** Every type has: - - a Rust name - - an extraction name *) - methods : (string * builtin_fun_info) list; -} -[@@deriving show] - let builtin_trait_decls_info () = let mk_trait (rust_name : string) ?(extract_name : string option = None) ?(parent_clauses : string list = []) ?(types : string list = []) ?(methods : string list = []) ?(methods_with_extract : (string * string) list option = None) () : - builtin_trait_decl_info = + Pure.builtin_trait_decl_info = let rust_name = parse_pattern rust_name in let extract_name = match extract_name with @@ -681,8 +644,14 @@ let builtin_trait_decls_info () = if !record_fields_short_names then item_name else extract_name ^ "_" ^ item_name in - let fwd = - { extract_name = basename; can_fail = true; stateful = false } + let fwd : Pure.builtin_fun_info = + { + filter_params = None; + extract_name = basename; + can_fail = true; + stateful = false; + lift = true; + } in (item_name, fwd) in @@ -690,7 +659,15 @@ let builtin_trait_decls_info () = | Some methods -> List.map (fun (item_name, extract_name) -> - (item_name, { extract_name; can_fail = true; stateful = false })) + ( item_name, + ({ + filter_params = None; + extract_name; + can_fail = true; + stateful = false; + lift = true; + } + : Pure.builtin_fun_info) )) methods in { @@ -747,22 +724,25 @@ let builtin_trait_decls_info () = (); (* Debug *) mk_trait "core::fmt::Debug" ~types:[ "T" ] ~methods:[ "fmt" ] (); + (* *) mk_trait "core::convert::TryFrom" ~methods:[ "try_from" ] (); mk_trait "core::convert::TryInto" ~methods:[ "try_into" ] (); + mk_trait "core::convert::AsMut" ~methods:[ "as_mut" ] (); ] let mk_builtin_trait_decls_map () = NameMatcherMap.of_list (List.map - (fun info -> (info.rust_name, info)) + (fun (info : Pure.builtin_trait_decl_info) -> (info.rust_name, info)) (builtin_trait_decls_info ())) let builtin_trait_decls_map = mk_memoized mk_builtin_trait_decls_map -let builtin_trait_impls_info () : (pattern * (bool list option * string)) list = +let builtin_trait_impls_info () : (pattern * Pure.builtin_trait_impl_info) list + = let fmt (rust_name : string) ?(extract_name : string option = None) ?(filter : bool list option = None) () : - pattern * (bool list option * string) = + pattern * Pure.builtin_trait_impl_info = let rust_name = parse_pattern rust_name in let name = let name = @@ -772,7 +752,7 @@ let builtin_trait_impls_info () : (pattern * (bool list option * string)) list = in flatten_name name in - (rust_name, (filter, name)) + (rust_name, { filter_params = filter; impl_name = name }) in [ (* core::ops::Deref> *) @@ -849,6 +829,9 @@ let builtin_trait_impls_info () : (pattern * (bool list option * string)) list = "core::slice::index::SliceIndex, \ [@Self]>" (); + fmt "core::convert::AsMut, @Self>" + ~filter:(Some [ true; false ]) + (); ] (* From *) @ List.map diff --git a/src/extract/ExtractTypes.ml b/src/extract/ExtractTypes.ml index 8874113c..383d3527 100644 --- a/src/extract/ExtractTypes.ml +++ b/src/extract/ExtractTypes.ml @@ -42,7 +42,10 @@ let extract_literal (span : Meta.span) (fmt : F.formatter) (is_pattern : bool) F.pp_print_string fmt ("%" ^ iname) | Lean -> (* We don't use the same notation for patterns and regular literals *) - if is_pattern then F.pp_print_string fmt "#scalar" + if is_pattern then + if Scalars.integer_type_is_signed sv.int_ty then + F.pp_print_string fmt "#iscalar" + else F.pp_print_string fmt "#uscalar" else let iname = String.lowercase_ascii (int_name sv.int_ty) in F.pp_print_string fmt ("#" ^ iname) @@ -136,7 +139,13 @@ let extract_unop (span : Meta.span) (extract_expr : bool -> texpression -> unit) let cast_str = match backend () with | Coq | FStar -> "scalar_cast" - | Lean -> "Scalar.cast" + | Lean -> + let signed_src = Scalars.integer_type_is_signed src in + let signed_tgt = Scalars.integer_type_is_signed tgt in + if signed_src = signed_tgt then + if signed_src then "IScalar.cast" else "UScalar.cast" + else if signed_src then "IScalar.hcast" + else "UScalar.hcast" | HOL4 -> admit_string __FILE__ __LINE__ span "Unreachable" in let src = @@ -149,7 +158,10 @@ let extract_unop (span : Meta.span) (extract_expr : bool -> texpression -> unit) let cast_str = match backend () with | Coq | FStar -> "scalar_cast_bool" - | Lean -> "Scalar.cast_bool" + | Lean -> + if Scalars.integer_type_is_signed tgt then + "IScalar.cast_fromBool" + else "UScalar.cast_fromBool" | HOL4 -> admit_string __FILE__ __LINE__ span "Unreachable" in let tgt = integer_type_to_string tgt in @@ -796,14 +808,9 @@ and extract_trait_instance_id (span : Meta.span) (ctx : extraction_ctx) *) let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : extraction_ctx = - (* Lookup the builtin information, if there is *) - let open ExtractBuiltin in - let info = - match_name_find_opt ctx.trans_ctx def.item_meta.name (builtin_types_map ()) - in - (* Register the filtering information, if there is *) + (* Register the filtering information, if the type has builtin information *) let ctx = - match info with + match def.builtin_info with | Some { keep_params = Some keep; _ } -> { ctx with @@ -814,7 +821,7 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : in (* Compute and register the type decl name *) let def_name = - match info with + match def.builtin_info with | None -> ctx_compute_type_decl_name ctx def | Some info -> info.extract_name in @@ -836,7 +843,7 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : | Struct fields -> (* Compute the names *) let field_names, cons_name = - match info with + match def.builtin_info with | None | Some { body_info = None; _ } -> let field_names = FieldId.mapi @@ -884,7 +891,7 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : ctx | Enum variants -> let variant_names = - match info with + match def.builtin_info with | None -> VariantId.mapi (fun variant_id (variant : variant) -> diff --git a/src/pure/PrintPure.ml b/src/pure/PrintPure.ml index 033bd443..5d519c4d 100644 --- a/src/pure/PrintPure.ml +++ b/src/pure/PrintPure.ml @@ -593,6 +593,7 @@ let pure_builtin_fun_id_to_string (fid : pure_builtin_fun_id) : string = | Return -> "@return" | Fail -> "@fail" | Assert -> "@assert" + | ToResult -> "@toResult" | FuelDecrease -> "@fuel_decrease" | FuelEqZero -> "@fuel_eq_zero" | UpdateAtIndex array_or_slice -> begin diff --git a/src/pure/Pure.ml b/src/pure/Pure.ml index 1afcaf17..343a455a 100644 --- a/src/pure/Pure.ml +++ b/src/pure/Pure.ml @@ -72,8 +72,6 @@ type 'a de_bruijn_var = 'a Types.de_bruijn_var [@@deriving show, ord] - [State]: the type of the state, when using state-error monads. Note that this state is opaque to Aeneas (the user can define it, or leave it as builtin) - - TODO: add a prefix "T" *) type builtin_ty = | TState @@ -93,6 +91,104 @@ type builtin_ty = *) [@@deriving show, ord] +type array_or_slice = Array | Slice [@@deriving show, ord] + +(** Identifiers of builtin functions that we use only in the pure translation *) +type pure_builtin_fun_id = + | Return (** The monadic return *) + | Fail (** The monadic fail *) + | Assert (** Assertion *) + | FuelDecrease + (** Decrease fuel, provided it is non zero (used for F* ) - TODO: this is ugly *) + | FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *) + | UpdateAtIndex of array_or_slice + (** Update an array or a slice at a given index. + + Note that in LLBC we only use an index function: if we want to + modify an element in an array/slice, we create a mutable borrow + to this element, then use the borrow to perform the update. The + update functions are introduced in the pure code by a micro-pass. + *) + | ToResult + (** Lifts a pure expression to a monadic expression. + + We use this when using `ok ...` would result in let-bindings getting + simplified away (in a backend like Lean). *) +[@@deriving show, ord] + +(* Builtin declarations coming from external libraries. + + Those are not too be understood as the builtin definitions like `U32`: + the builtin declarations decribed with, e.g., `builtin_type_info`, are + declarations coming from external libraries and which we should thus not + extract (for instance: `std::vec::Vec`, `std::option::Option`, etc.). +*) + +type builtin_variant_info = { fields : (string * string) list } +[@@deriving show, ord] + +type builtin_enum_variant_info = { + rust_variant_name : string; + extract_variant_name : string; + fields : string list option; +} +[@@deriving show, ord] + +type builtin_type_body_info = + | Struct of string * (string * string) list + (* The constructor name and the map for the field names *) + | Enum of builtin_enum_variant_info list +(* For every variant, a map for the field names *) +[@@deriving show, ord] + +type builtin_type_info = { + rust_name : NameMatcher.pattern; + extract_name : string; + keep_params : bool list option; + (** We might want to filter some of the type parameters. + + For instance, `Vec` type takes a type parameter for the allocator, + which we want to ignore. + *) + body_info : builtin_type_body_info option; +} +[@@deriving show, ord] + +type builtin_global_info = { global_name : string } [@@deriving show] + +type builtin_fun_info = { + filter_params : bool list option; + extract_name : string; + can_fail : bool; + stateful : bool; + lift : bool; + (** If the function can not fail, should we still lift it to the [Result] + monad? This can make reasonings easier, as we can then use [progress] + to do proofs in a Hoare-Logic style, rather than doing equational + reasonings. *) +} +[@@deriving show] + +type builtin_trait_decl_info = { + rust_name : NameMatcher.pattern; + extract_name : string; + constructor : string; + parent_clauses : string list; + consts : (string * string) list; + types : (string * string) list; + (** Every type has: + - a Rust name + - an extraction name *) + methods : (string * builtin_fun_info) list; +} +[@@deriving show] + +type builtin_trait_impl_info = { + filter_params : bool list option; + impl_name : string; +} +[@@deriving show] + (* TODO: we should never directly manipulate [Return] and [Fail], but rather * the monadic functions [return] and [fail] (makes treatment of error and * state-error monads more uniform) *) @@ -343,6 +439,9 @@ class ['self] iter_type_decl_base = self#visit_literal_type e var.ty method visit_item_meta : 'env -> T.item_meta -> unit = fun _ _ -> () + + method visit_builtin_type_info : 'env -> builtin_type_info -> unit = + fun _ _ -> () end (** Ancestor for map visitor for [type_decl] *) @@ -367,6 +466,10 @@ class ['self] map_type_decl_base = } method visit_item_meta : 'env -> T.item_meta -> T.item_meta = fun _ x -> x + + method visit_builtin_type_info + : 'env -> builtin_type_info -> builtin_type_info = + fun _ x -> x end (** Ancestor for reduce visitor for [type_decl] *) @@ -388,6 +491,9 @@ class virtual ['self] reduce_type_decl_base = self#plus (self#plus x0 x1) x2 method visit_item_meta : 'env -> T.item_meta -> 'a = fun _ _ -> self#zero + + method visit_builtin_type_info : 'env -> builtin_type_info -> 'a = + fun _ _ -> self#zero end (** Ancestor for mapreduce visitor for [type_decl] *) @@ -411,6 +517,10 @@ class virtual ['self] mapreduce_type_decl_base = method visit_item_meta : 'env -> T.item_meta -> T.item_meta * 'a = fun _ x -> (x, self#zero) + + method visit_builtin_type_info + : 'env -> builtin_type_info -> builtin_type_info * 'a = + fun _ x -> (x, self#zero) end type field = { @@ -510,6 +620,7 @@ and type_decl = { to derive them from the original LLBC types from before the simplification of types like boxes and references. *) kind : type_decl_kind; + builtin_info : builtin_type_info option; preds : predicates; } [@@deriving @@ -720,26 +831,6 @@ type unop = | Cast of literal_type * literal_type [@@deriving show, ord] -type array_or_slice = Array | Slice [@@deriving show, ord] - -(** Identifiers of builtin functions that we use only in the pure translation *) -type pure_builtin_fun_id = - | Return (** The monadic return *) - | Fail (** The monadic fail *) - | Assert (** Assertion *) - | FuelDecrease - (** Decrease fuel, provided it is non zero (used for F* ) - TODO: this is ugly *) - | FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *) - | UpdateAtIndex of array_or_slice - (** Update an array or a slice at a given index. - - Note that in LLBC we only use an index function: if we want to - modify an element in an array/slice, we create a mutable borrow - to this element, then use the borrow to perform the update. The - update functions are introduced in the pure code by a micro-pass. - *) -[@@deriving show, ord] - type fun_id_or_trait_method_ref = | FunId of A.fun_id | TraitMethod of trait_ref * string * fun_decl_id @@ -765,11 +856,13 @@ type fun_id = (** A function only used in the pure translation *) [@@deriving show, ord] +type binop = E.binop [@@deriving show, ord] + (** A function or an operation id *) type fun_or_op_id = | Fun of fun_id | Unop of unop - | Binop of E.binop * integer_type + | Binop of binop * integer_type [@@deriving show, ord] (** An identifier for an ADT constructor *) @@ -1294,6 +1387,7 @@ type 'a binder = { type fun_decl = { def_id : FunDeclId.id; item_meta : T.item_meta; + builtin_info : builtin_fun_info option; kind : item_kind; backend_attributes : backend_attributes; num_loops : int; @@ -1318,6 +1412,7 @@ type global_decl = { def_id : GlobalDeclId.id; span : span; item_meta : T.item_meta; + builtin_info : builtin_global_info option; name : string; (** We use the name only for printing purposes (for debugging): the name used at extraction time will be derived from the @@ -1339,6 +1434,7 @@ type trait_decl = { def_id : trait_decl_id; name : string; item_meta : T.item_meta; + builtin_info : builtin_trait_decl_info option; generics : generic_params; explicit_info : explicit_info; (** Information about which inputs parameters are explicit/implicit *) @@ -1361,6 +1457,7 @@ type trait_impl = { def_id : trait_impl_id; name : string; item_meta : T.item_meta; + builtin_info : builtin_trait_impl_info option; impl_trait : trait_decl_ref; llbc_impl_trait : Types.trait_decl_ref; (** Same remark as for {!field:llbc_generics}. *) diff --git a/src/pure/PureMicroPasses.ml b/src/pure/PureMicroPasses.ml index 5d7db57f..2dd09725 100644 --- a/src/pure/PureMicroPasses.ml +++ b/src/pure/PureMicroPasses.ml @@ -8,32 +8,34 @@ open Errors (** The local logger *) let log = Logging.pure_micro_passes_log -let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = - let fmt = trans_ctx_to_pure_fmt_env ctx in +type ctx = { fun_decls : fun_decl FunDeclId.Map.t; trans_ctx : trans_ctx } + +let fun_decl_to_string (ctx : ctx) (def : Pure.fun_decl) : string = + let fmt = trans_ctx_to_pure_fmt_env ctx.trans_ctx in PrintPure.fun_decl_to_string fmt def -let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = - let fmt = trans_ctx_to_pure_fmt_env ctx in +let fun_sig_to_string (ctx : ctx) (sg : Pure.fun_sig) : string = + let fmt = trans_ctx_to_pure_fmt_env ctx.trans_ctx in PrintPure.fun_sig_to_string fmt sg -let var_to_string (ctx : trans_ctx) (v : var) : string = - let fmt = trans_ctx_to_pure_fmt_env ctx in +let var_to_string (ctx : ctx) (v : var) : string = + let fmt = trans_ctx_to_pure_fmt_env ctx.trans_ctx in PrintPure.var_to_string fmt v -let texpression_to_string (ctx : trans_ctx) (x : texpression) : string = - let fmt = trans_ctx_to_pure_fmt_env ctx in +let texpression_to_string (ctx : ctx) (x : texpression) : string = + let fmt = trans_ctx_to_pure_fmt_env ctx.trans_ctx in PrintPure.texpression_to_string fmt false "" " " x -let switch_to_string (ctx : trans_ctx) scrut (x : switch_body) : string = - let fmt = trans_ctx_to_pure_fmt_env ctx in +let switch_to_string (ctx : ctx) scrut (x : switch_body) : string = + let fmt = trans_ctx_to_pure_fmt_env ctx.trans_ctx in PrintPure.switch_to_string fmt "" " " scrut x -let struct_update_to_string (ctx : trans_ctx) supd : string = - let fmt = trans_ctx_to_pure_fmt_env ctx in +let struct_update_to_string (ctx : ctx) supd : string = + let fmt = trans_ctx_to_pure_fmt_env ctx.trans_ctx in PrintPure.struct_update_to_string fmt "" " " supd -let typed_pattern_to_string (ctx : trans_ctx) pat : string = - let fmt = trans_ctx_to_pure_fmt_env ctx in +let typed_pattern_to_string (ctx : ctx) pat : string = + let fmt = trans_ctx_to_pure_fmt_env ctx.trans_ctx in PrintPure.typed_pattern_to_string fmt pat (** Small utility. @@ -624,7 +626,7 @@ let remove_span (def : fun_decl) : fun_decl = e ]} *) -let intro_massert (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let intro_massert (_ctx : ctx) (def : fun_decl) : fun_decl = let span = def.item_meta.span in let visitor = object @@ -695,7 +697,7 @@ let intro_massert (_ctx : trans_ctx) (def : fun_decl) : fun_decl = The subsequent passes, in particular the ones which inline the useless assignments, simplify this further. *) -let simplify_decompose_struct (ctx : trans_ctx) (def : fun_decl) : fun_decl = +let simplify_decompose_struct (ctx : ctx) (def : fun_decl) : fun_decl = let span = def.item_meta.span in let visitor = object @@ -706,7 +708,9 @@ let simplify_decompose_struct (ctx : trans_ctx) (def : fun_decl) : fun_decl = match (lv.value, lv.ty) with | PatAdt adt_pat, TAdt (TAdtId adt_id, generics) -> (* Detect if this is an enumeration or not *) - let tdef = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in + let tdef = + TypeDeclId.Map.find adt_id ctx.trans_ctx.type_ctx.type_decls + in let is_enum = TypesUtils.type_decl_is_enum tdef in (* We deconstruct the ADT with a single let-binding in two situations: - if the ADT is an enumeration (which must have exactly one branch) @@ -721,7 +725,7 @@ let simplify_decompose_struct (ctx : trans_ctx) (def : fun_decl) : fun_decl = like Coq don't, in which case we have to deconstruct the whole ADT at once (`let (a, b, c) = x in`) *) || TypesUtils.type_decl_from_type_id_is_tuple_struct - ctx.type_ctx.type_infos (T.TAdtId adt_id) + ctx.trans_ctx.type_ctx.type_infos (T.TAdtId adt_id) && not !Config.use_tuple_projectors in if use_let_with_cons then @@ -779,7 +783,7 @@ let simplify_decompose_struct (ctx : trans_ctx) (def : fun_decl) : fun_decl = Note however that we do not apply this transformation if the structure is to be extracted as a tuple. *) -let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = +let intro_struct_updates (ctx : ctx) (def : fun_decl) : fun_decl = let visitor = object (self) inherit [_] map_expression as super @@ -800,11 +804,13 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = generics = _; } -> (* Lookup the def *) - let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in + let decl = + TypeDeclId.Map.find adt_id ctx.trans_ctx.type_ctx.type_decls + in (* Check if the def will be extracted as a tuple *) if TypesUtils.type_decl_from_decl_id_is_tuple_struct - ctx.type_ctx.type_infos adt_id + ctx.trans_ctx.type_ctx.type_infos adt_id then ignore () else (* Check that there are as many arguments as there are fields - note @@ -816,7 +822,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let is_rec = match TypeDeclId.Map.find adt_id - ctx.type_ctx.type_decls_groups + ctx.trans_ctx.type_ctx.type_decls_groups with | NonRecGroup _ -> false | RecGroup _ -> true @@ -898,7 +904,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = ... ]} *) -let simplify_let_bindings (ctx : trans_ctx) (def : fun_decl) : fun_decl = +let simplify_let_bindings (ctx : ctx) (def : fun_decl) : fun_decl = let obj = object (self) inherit [_] map_expression as super @@ -973,7 +979,7 @@ let simplify_let_bindings (ctx : trans_ctx) (def : fun_decl) : fun_decl = | Qualif { id = AdtCons { adt_id; _ }; _ } -> not (PureUtils.type_decl_from_type_id_is_tuple_struct - ctx.type_ctx.type_infos adt_id) + ctx.trans_ctx.type_ctx.type_infos adt_id) | Qualif { id = Proj _; _ } -> false | _ -> true in @@ -1014,7 +1020,7 @@ let simplify_let_bindings (ctx : trans_ctx) (def : fun_decl) : fun_decl = This micro pass removes those duplicate function calls. *) -let simplify_duplicate_calls (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let simplify_duplicate_calls (_ctx : ctx) (def : fun_decl) : fun_decl = let visitor = object (self) inherit [_] map_expression as super @@ -1053,6 +1059,54 @@ let simplify_duplicate_calls (_ctx : trans_ctx) (def : fun_decl) : fun_decl = in { def with body = Some body } +(** A helper predicate *) +let lift_unop (unop : unop) : bool = + match unop with + | Not None -> false + | Not (Some _) | Neg _ | Cast _ -> true + +(** A helper predicate *) +let inline_unop unop = not (lift_unop unop) + +(** A helper predicate *) +let lift_binop (binop : binop) : bool = + match binop with + | Eq | Lt | Le | Ne | Ge | Gt -> false + | BitXor + | BitAnd + | BitOr + | Div + | Rem + | Add + | Sub + | Mul + | CheckedAdd + | CheckedSub + | CheckedMul + | Shl + | Shr -> true + +(** A helper predicate *) +let inline_binop binop = not (lift_binop binop) + +(** A helper predicate *) +let lift_fun (ctx : ctx) (fun_id : fun_id) : bool = + (* Lookup if the function is builtin: we only lift builtin functions + which were explictly marked to be lifted. *) + match fun_id with + | FromLlbc (FunId (FRegular fid), _) -> begin + match FunDeclId.Map.find_opt fid ctx.fun_decls with + | None -> false + | Some def -> ( + match def.builtin_info with + | None -> false + | Some info -> info.lift) + end + | _ -> false + +(** A helper predicate *) +let inline_fun (_ : fun_id) : bool = false + (** Inline the useless variable (re-)assignments: A lot of intermediate variable assignments are introduced through the @@ -1071,8 +1125,8 @@ let simplify_duplicate_calls (_ctx : trans_ctx) (def : fun_decl) : fun_decl = [inline_pure]: if [true], inline all the pure assignments where the variable on the left is anonymous, but the assignments where the r-expression is - a non-primitive function call (i.e.: inline the binops, ADT constructions, - etc.). + a function call (i.e.: ADT constructions, etc.), except certain cases of + function calls. [inline_identity]: if [true], inline the identity functions (i.e., lambda functions of the shape [fun x -> x]). @@ -1083,8 +1137,8 @@ let simplify_duplicate_calls (_ctx : trans_ctx) (def : fun_decl) : fun_decl = pass (if they are useless). *) let inline_useless_var_assignments ~(inline_named : bool) ~(inline_const : bool) - ~(inline_pure : bool) ~(inline_identity : bool) (ctx : trans_ctx) - (def : fun_decl) : fun_decl = + ~(inline_pure : bool) ~(inline_identity : bool) (ctx : ctx) (def : fun_decl) + : fun_decl = let obj = object (self) inherit [_] map_expression as super @@ -1141,9 +1195,9 @@ let inline_useless_var_assignments ~(inline_named : bool) ~(inline_const : bool) match qualif.id with | AdtCons _ -> true (* ADT constructor *) | Proj _ -> true (* Projector *) - | FunOrOp (Unop _ | Binop _) -> - true (* primitive function call *) - | FunOrOp (Fun _) -> false (* non-primitive function call *) + | FunOrOp (Unop unop) -> inline_unop unop + | FunOrOp (Binop (binop, _)) -> inline_binop binop + | FunOrOp (Fun fun_id) -> inline_fun fun_id | _ -> false) | StructUpdate _ -> true (* ADT constructor *) | _ -> false @@ -1186,8 +1240,8 @@ let inline_useless_var_assignments ~(inline_named : bool) ~(inline_const : bool) let re = self#visit_texpression env re in if PureUtils.is_var re - && type_decl_from_type_id_is_tuple_struct ctx.type_ctx.type_infos - adt_id + && type_decl_from_type_id_is_tuple_struct + ctx.trans_ctx.type_ctx.type_infos adt_id then (* Update the substitution environment *) let env = VarId.Map.add lv_var.id re env in @@ -1222,7 +1276,7 @@ let inline_useless_var_assignments ~(inline_named : bool) ~(inline_const : bool) (** Filter the useless assignments (removes the useless variables, filters the function calls) *) -let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let filter_useless (_ctx : ctx) (def : fun_decl) : fun_decl = (* We first need a transformation on *left-values*, which filters the useless variables and tells us whether the value contains any variable which has not been replaced by [_] (in which case we need to keep the assignment, @@ -1442,7 +1496,7 @@ let simplify_let_then_ok _ctx (def : fun_decl) = Mkstruct x.f0 x.f1 x.f2 ~~> x ]} *) -let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = +let simplify_aggregates (ctx : ctx) (def : fun_decl) : fun_decl = let expr_visitor = object inherit [_] map_expression as super @@ -1465,7 +1519,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* This is a struct *) (* Retrieve the definiton, to find how many fields there are *) let adt_decl = - TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls + TypeDeclId.Map.find adt_id ctx.trans_ctx.type_ctx.type_decls in let fields = match adt_decl.kind with @@ -1649,8 +1703,8 @@ type simp_aggr_env = { else x ]} *) -let simplify_aggregates_unchanged_fields (ctx : trans_ctx) (def : fun_decl) : - fun_decl = +let simplify_aggregates_unchanged_fields (ctx : ctx) (def : fun_decl) : fun_decl + = let log = Logging.simplify_aggregates_unchanged_fields_log in let span = def.item_meta.span in (* Some helpers *) @@ -1861,8 +1915,7 @@ let simplify_aggregates_unchanged_fields (ctx : trans_ctx) (def : fun_decl) : those function bodies into independent definitions while removing occurrences of the {!Pure.Loop} node. *) -let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : - fun_decl * fun_decl list = +let decompose_loops (_ctx : ctx) (def : fun_decl) : fun_decl * fun_decl list = match def.body with | None -> (def, []) | Some body -> @@ -2007,10 +2060,14 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : *but* replace its span with the span of the loop *) let item_meta = { def.item_meta with span = loop.span } in + sanity_check __FILE__ __LINE__ (def.builtin_info = None) + def.item_meta.span; + let loop_def : fun_decl = { def_id = def.def_id; item_meta; + builtin_info = def.builtin_info; kind = def.kind; backend_attributes = def.backend_attributes; num_loops; @@ -2076,7 +2133,7 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl = function calls, and when translating end abstractions. Here, we can do something simpler, in one micro-pass. *) -let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let eliminate_box_functions (_ctx : ctx) (def : fun_decl) : fun_decl = (* The map visitor *) let obj = object @@ -2109,7 +2166,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = { def with body } (** Simplify the lambdas by applying beta-reduction *) -let apply_beta_reduction (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let apply_beta_reduction (_ctx : ctx) (def : fun_decl) : fun_decl = (* The map visitor *) let visitor = object (self) @@ -2187,7 +2244,7 @@ let apply_beta_reduction (_ctx : trans_ctx) (def : fun_decl) : fun_decl = Array.update a i x ]} *) -let simplify_array_slice_update (ctx : trans_ctx) (def : fun_decl) : fun_decl = +let simplify_array_slice_update (ctx : ctx) (def : fun_decl) : fun_decl = let span = def.item_meta.span in (* The difficulty is that the let-binding which uses the backward function @@ -2452,8 +2509,7 @@ let simplify_array_slice_update (ctx : trans_ctx) (def : fun_decl) : fun_decl = [decompose_nested_pats]: decompose the nested patterns *) let decompose_let_bindings (decompose_monadic : bool) - (decompose_nested_pats : bool) (_ctx : trans_ctx) (def : fun_decl) : - fun_decl = + (decompose_nested_pats : bool) (_ctx : ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def | Some body -> @@ -2573,20 +2629,18 @@ let decompose_let_bindings (decompose_monadic : bool) See the explanations in {!val:Config.decompose_monadic_let_bindings} *) -let decompose_monadic_let_bindings (ctx : trans_ctx) (def : fun_decl) : fun_decl - = +let decompose_monadic_let_bindings (ctx : ctx) (def : fun_decl) : fun_decl = decompose_let_bindings true false ctx def (** Decompose the nested let patterns. See the explanations in {!val:Config.decompose_nested_let_patterns} *) -let decompose_nested_let_patterns (ctx : trans_ctx) (def : fun_decl) : fun_decl - = +let decompose_nested_let_patterns (ctx : ctx) (def : fun_decl) : fun_decl = decompose_let_bindings false true ctx def (** Unfold the monadic let-bindings to explicit matches. *) -let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let unfold_monadic_let_bindings (_ctx : ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def | Some body -> @@ -2653,8 +2707,170 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Return *) { def with body = Some body } -let end_passes : - (bool ref option * string * (trans_ctx -> fun_decl -> fun_decl)) list = +(** Perform the following transformation: + + {[ + let y <-- ok e + + ~~> + + let y <-- toResult e + ]} + + We only do this on a specific set of pure functions calls - those + functions are identified in the "builtin" information about external + function calls. + *) +let lift_pure_function_calls (ctx : ctx) (def : fun_decl) : fun_decl = + let span = def.item_meta.span in + + let try_lift_expr (super_visit_e : texpression -> texpression) + (visit_e : texpression -> texpression) (app : texpression) : + bool * texpression = + (* Check if the function should be lifted *) + let f, args = destruct_apps app in + let f = super_visit_e f in + let args = List.map visit_e args in + (* *) + let lift = + match f.e with + | Qualif { id = FunOrOp (Unop unop); _ } -> lift_unop unop + | Qualif { id = FunOrOp (Binop (binop, _)); _ } -> lift_binop binop + | Qualif { id = FunOrOp (Fun fun_id); _ } -> lift_fun ctx fun_id + | _ -> false + in + let app = mk_apps span f args in + if lift then (true, mk_to_result_texpression span app) else (false, app) + in + + (* The map visitor *) + let visitor = + object (self) + inherit [_] map_expression as super + + method! visit_texpression env e0 = + (* Check if this is an expression of the shape: [ok (f ...)] where + `f` has been identified as a function which should be lifted. *) + match destruct_apps e0 with + | ( ({ e = Qualif { id = FunOrOp (Fun (Pure ToResult)); _ }; _ } as + to_result_expr), + [ app ] ) -> + (* Attempt to lift the expression *) + let lifted, app = + try_lift_expr + (super#visit_texpression env) + (self#visit_texpression env) + app + in + + if lifted then app else mk_app span to_result_expr app + | { e = Let (monadic, pat, bound, next); ty }, [] -> + let next = self#visit_texpression env next in + (* Attempt to lift only if the let-expression is not already monadic *) + let lifted, bound = + if monadic then (true, self#visit_texpression env bound) + else + try_lift_expr + (super#visit_texpression env) + (self#visit_texpression env) + bound + in + { e = Let (lifted, pat, bound, next); ty } + | f, args -> + let f = super#visit_texpression env f in + let args = List.map (self#visit_texpression env) args in + mk_apps span f args + end + in + (* Update the body *) + match def.body with + | None -> def + | Some body -> + let body = + Some + { + body with + body = visitor#visit_texpression VarId.Map.empty body.body; + } + in + { def with body } + +(** Perform the following transformation: + + {[ + let y <-- f x (* Must be an application, is not necessarily monadic *) + let (a, b) := y (* Tuple decomposition *) + ... + ]} + + becomes: + + {[ + let (a, b) <-- f x + ... + ]} + *) +let merge_let_app_then_decompose_tuple (_ctx : ctx) (def : fun_decl) : fun_decl + = + let span = def.item_meta.span in + (* We may need to introduce fresh variables *) + let var_cnt = get_opt_body_min_var_counter def.body in + let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in + + let visitor = + object (self) + inherit [_] map_expression + + method! visit_Let env monadic0 pat0 bound0 next0 = + let bound0 = self#visit_texpression env bound0 in + (* Check if we need to merge two let-bindings *) + if is_pat_var pat0 then + let var0, _ = as_pat_var span pat0 in + match next0.e with + | Let (false, pat1, { e = Var var_id; _ }, next1) + when var_id = var0.id -> begin + (* Check if we are decomposing a tuple *) + if is_pat_tuple pat1 then + (* Introduce fresh variables for all the dummy variables + to make sure we can turn the pattern into an expression *) + let pat1 = typed_pattern_replace_dummy_vars fresh_var_id pat1 in + let pat1_expr = + Option.get (typed_pattern_to_texpression span pat1) + in + (* Register the mapping from the variable we remove to the expression *) + let env = VarId.Map.add var0.id pat1_expr env in + (* Continue *) + let next1 = self#visit_texpression env next1 in + Let (monadic0, pat1, bound0, next1) + else + let next0 = self#visit_texpression env next0 in + Let (monadic0, pat0, bound0, next0) + end + | _ -> + let next0 = self#visit_texpression env next0 in + Let (monadic0, pat0, bound0, next0) + else + let next0 = self#visit_texpression env next0 in + Let (monadic0, pat0, bound0, next0) + + (* Replace the variables *) + method! visit_Var env var_id = + match VarId.Map.find_opt var_id env with + | None -> Var var_id + | Some e -> e.e + end + in + + match def.body with + | None -> def + | Some body -> + let body = + { body with body = visitor#visit_texpression VarId.Map.empty body.body } + in + { def with body = Some body } + +let end_passes : (bool ref option * string * (ctx -> fun_decl -> fun_decl)) list + = [ (* Convert the unit variables to [()] if they are used as right-values or * [_] if they are used as left values. *) @@ -2682,6 +2898,10 @@ let end_passes : (None, "eliminate_box_functions", eliminate_box_functions); (* Remove the duplicated function calls *) (None, "simplify_duplicate_calls", simplify_duplicate_calls); + (* Merge let bindings which bind an expression then decompose a tuple *) + ( Some Config.merge_let_app_decompose_tuple, + "merge_let_app_then_decompose_tuple", + merge_let_app_then_decompose_tuple ); (* Filter the useless variables, assignments, function calls, etc. *) (None, "filter_useless", filter_useless); (* Simplify the lets immediately followed by a return. @@ -2744,10 +2964,15 @@ let end_passes : ( Some Config.unfold_monadic_let_bindings, "unfold_monadic_let_bindings", unfold_monadic_let_bindings ); + (* Introduce calls to [toResult] to lift pure function calls to monadic + function calls *) + ( Some Config.lift_pure_function_calls, + "lift_pure_function_calls", + lift_pure_function_calls ); ] (** Auxiliary function for {!apply_passes_to_def} *) -let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = +let apply_end_passes_to_def (ctx : ctx) (def : fun_decl) : fun_decl = List.fold_left (fun def (option, pass_name, pass) -> let apply = @@ -2791,7 +3016,7 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) (** Filter the useless loop input parameters. *) -let filter_loop_inputs (ctx : trans_ctx) (transl : pure_fun_translation list) : +let filter_loop_inputs (ctx : ctx) (transl : pure_fun_translation list) : pure_fun_translation list = (* We need to explore groups of mutually recursive functions. In order to compute which parameters are useless, we need to explore the @@ -3113,18 +3338,13 @@ let filter_loop_inputs (ctx : trans_ctx) (transl : pure_fun_translation list) : the function as reducible, we allow tactics like [simp] or [progress] to see through the definition. *) -let compute_reducible (_ctx : trans_ctx) (transl : pure_fun_translation list) : +let compute_reducible (_ctx : ctx) (transl : pure_fun_translation list) : pure_fun_translation list = let update_one (trans : pure_fun_translation) : pure_fun_translation = match trans.f.body with | None -> trans | Some body -> ( - (* Check if the body is exactly a call to a loop function. - Note that we check that the arguments are exactly the input - variables - otherwise we may not want the call to be reducible; - for instance when using the [progress] tactic we might want to - use a more specialized specification theorem. *) - let app, args = destruct_apps body.body in + let app, _ = destruct_apps body.body in match app.e with | Qualif { @@ -3132,20 +3352,10 @@ let compute_reducible (_ctx : trans_ctx) (transl : pure_fun_translation list) : generics = _; } when fid = FRegular trans.f.def_id -> - if - List.length body.inputs = List.length args - && List.for_all - (fun ((var, arg) : var * texpression) -> - match arg.e with - | Var var_id -> var_id = var.id - | _ -> false) - (List.combine body.inputs args) - then - let f = - { trans.f with backend_attributes = { reducible = true } } - in - { trans with f } - else trans + let f = + { trans.f with backend_attributes = { reducible = true } } + in + { trans with f } | _ -> trans) in List.map update_one transl @@ -3159,7 +3369,7 @@ let compute_reducible (_ctx : trans_ctx) (transl : pure_fun_translation list) : [ctx]: used only for printing. *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops = +let apply_passes_to_def (ctx : ctx) (def : fun_decl) : fun_and_loops = (* Debug *) log#ltrace (lazy ("PureMicroPasses.apply_passes_to_def: " ^ def.name)); @@ -3202,8 +3412,14 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops = functions. Note that here, keeping the forward function it is not *necessary* but convenient. *) -let apply_passes_to_pure_fun_translations (ctx : trans_ctx) +let apply_passes_to_pure_fun_translations (trans_ctx : trans_ctx) (transl : fun_decl list) : pure_fun_translation list = + let fun_decls = + FunDeclId.Map.of_list + (List.map (fun (f : fun_decl) -> (f.def_id, f)) transl) + in + let ctx = { trans_ctx; fun_decls } in + (* Apply the micro-passes *) let transl = List.map (apply_passes_to_def ctx) transl in diff --git a/src/pure/PureUtils.ml b/src/pure/PureUtils.ml index ac2a515e..e4d70315 100644 --- a/src/pure/PureUtils.ml +++ b/src/pure/PureUtils.ml @@ -295,10 +295,45 @@ let is_cvar (e : texpression) : bool = | CVar _ -> true | _ -> false -let as_pat_var (span : Meta.span) (p : typed_pattern) : var * mplace option = +let as_opt_pat_var (p : typed_pattern) : (var * mplace option) option = match p.value with - | PatVar (v, mp) -> (v, mp) - | _ -> craise __FILE__ __LINE__ span "Not a var" + | PatVar (v, mp) -> Some (v, mp) + | _ -> None + +let as_pat_var (span : Meta.span) (p : typed_pattern) : var * mplace option = + match as_opt_pat_var p with + | None -> craise __FILE__ __LINE__ span "Not a var" + | Some (v, mp) -> (v, mp) + +let is_pat_var (p : typed_pattern) : bool = Option.is_some (as_opt_pat_var p) + +let as_opt_pat_tuple (p : typed_pattern) : typed_pattern list option = + match p with + | { + value = PatAdt { variant_id = None; field_values }; + ty = TAdt (TTuple, _); + } -> Some field_values + | _ -> None + +(** Replace all the dummy variables in a pattern with fresh variables *) +let typed_pattern_replace_dummy_vars (fresh_var_id : unit -> VarId.id) + (p : typed_pattern) : typed_pattern = + let visitor = + object + inherit [_] map_typed_pattern as super + + method! visit_typed_pattern env p = + match p.value with + | PatDummy -> + let id = fresh_var_id () in + { p with value = PatVar ({ id; basename = None; ty = p.ty }, None) } + | _ -> super#visit_typed_pattern env p + end + in + visitor#visit_typed_pattern () p + +let is_pat_tuple (p : typed_pattern) : bool = + Option.is_some (as_opt_pat_tuple p) let is_global (e : texpression) : bool = match e.e with @@ -695,11 +730,12 @@ let unwrap_result_ty (span : Meta.span) (ty : ty) : ty = let mk_result_fail_texpression (span : Meta.span) (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = TAdt (TBuiltin TResult, mk_generic_args_from_types type_args) in + let generics = mk_generic_args_from_types type_args in + let ty = TAdt (TBuiltin TResult, generics) in let id = AdtCons { adt_id = TBuiltin TResult; variant_id = Some result_fail_id } in - let qualif = { id; generics = mk_generic_args_from_types type_args } in + let qualif = { id; generics } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -713,11 +749,12 @@ let mk_result_fail_texpression_with_error_id (span : Meta.span) let mk_result_ok_texpression (span : Meta.span) (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = TAdt (TBuiltin TResult, mk_generic_args_from_types type_args) in + let generics = mk_generic_args_from_types type_args in + let ty = TAdt (TBuiltin TResult, generics) in let id = AdtCons { adt_id = TBuiltin TResult; variant_id = Some result_ok_id } in - let qualif = { id; generics = mk_generic_args_from_types type_args } in + let qualif = { id; generics } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -799,6 +836,7 @@ let trait_decl_is_empty (trait_decl : trait_decl) : bool = def_id = _; name = _; item_meta = _; + builtin_info = _; generics = _; explicit_info = _; llbc_generics = _; @@ -818,6 +856,7 @@ let trait_impl_is_empty (trait_impl : trait_impl) : bool = def_id = _; name = _; item_meta = _; + builtin_info = _; impl_trait = _; llbc_impl_trait = _; generics = _; @@ -929,3 +968,15 @@ let typed_pattern_get_vars (pat : typed_pattern) : VarId.Set.t = in visitor#visit_typed_pattern () pat; !vars + +let mk_to_result_texpression (span : Meta.span) (e : texpression) : texpression + = + let type_args = [ e.ty ] in + let generics = mk_generic_args_from_types type_args in + let ty = TAdt (TBuiltin TResult, generics) in + let id = FunOrOp (Fun (Pure ToResult)) in + let qualif = { id; generics } in + let qualif = Qualif qualif in + let qualif_ty = mk_arrow e.ty ty in + let qualif = { e = qualif; ty = qualif_ty } in + mk_app span qualif e diff --git a/src/symbolic/SymbolicToPure.ml b/src/symbolic/SymbolicToPure.ml index bcd83ae3..ec291333 100644 --- a/src/symbolic/SymbolicToPure.ml +++ b/src/symbolic/SymbolicToPure.ml @@ -15,6 +15,11 @@ module S = SymbolicAst (** The local logger *) let log = Logging.symbolic_to_pure_log +let match_name_find_opt = TranslateCore.match_name_find_opt + +let match_name_with_generics_find_opt = + TranslateCore.match_name_with_generics_find_opt + type type_ctx = { llbc_type_decls : T.type_decl TypeDeclId.Map.t; type_decls : type_decl TypeDeclId.Map.t; @@ -722,10 +727,16 @@ let translate_type_decl (ctx : Contexts.decls_ctx) (def : T.type_decl) : let explicit_info = compute_explicit_info generics [] in let kind = translate_type_decl_kind span def.T.kind in let item_meta = def.item_meta in + (* Lookup the builtin information, if there is *) + let builtin_info = + match_name_find_opt ctx def.item_meta.name + (ExtractBuiltin.builtin_types_map ()) + in { def_id; name; item_meta; + builtin_info; generics; explicit_info; llbc_generics = def.generics; @@ -3092,7 +3103,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Note that cast can fail *) let effect_info = { - can_fail = true; + can_fail = not (Config.backend () = Lean); stateful_group = false; stateful = false; can_diverge = false; @@ -4728,10 +4739,16 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Assemble the declaration *) let backend_attributes = { reducible = false } in + (* Check if the function is builtin *) + let builtin_info = + let funs_map = ExtractBuiltin.builtin_funs_map () in + match_name_find_opt ctx.decls_ctx def.item_meta.name funs_map + in let def : fun_decl = { def_id; item_meta = def.item_meta; + builtin_info; kind = def.kind; backend_attributes; num_loops; @@ -4815,10 +4832,16 @@ let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl) (name, translate_trait_method span translate_ty bound_fn)) methods in + (* Lookup the builtin information, if there is *) + let builtin_info = + match_name_find_opt ctx trait_decl.item_meta.name + (ExtractBuiltin.builtin_trait_decls_map ()) + in { def_id; name; item_meta; + builtin_info; generics; explicit_info; llbc_generics; @@ -4873,10 +4896,19 @@ let translate_trait_impl (ctx : Contexts.decls_ctx) (trait_impl : A.trait_impl) (name, translate_trait_method span translate_ty bound_fn)) methods in + (* Lookup the builtin information, if there is *) + let builtin_info = + let decl_id = trait_impl.impl_trait.trait_decl_id in + let trait_decl = TraitDeclId.Map.find decl_id ctx.crate.trait_decls in + match_name_with_generics_find_opt ctx trait_decl.item_meta.name + llbc_impl_trait.decl_generics + (ExtractBuiltin.builtin_trait_impls_map ()) + in { def_id; name; item_meta; + builtin_info; impl_trait; llbc_impl_trait; generics; @@ -4913,10 +4945,14 @@ let translate_global (ctx : Contexts.decls_ctx) (decl : A.global_decl) : let ty = translate_fwd_ty (Some decl.item_meta.span) ctx.type_ctx.type_infos ty in + let builtin_info = + match_name_find_opt ctx item_meta.name ExtractBuiltin.builtin_globals_map + in { span = item_meta.span; def_id; item_meta; + builtin_info; name; llbc_generics; generics; diff --git a/tests/lean/AdtBorrows.lean b/tests/lean/AdtBorrows.lean index e306521b..19dcc2c3 100644 --- a/tests/lean/AdtBorrows.lean +++ b/tests/lean/AdtBorrows.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [adt_borrows] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -15,12 +15,12 @@ namespace adt_borrows /- [adt_borrows::{adt_borrows::SharedWrapper<'a, T>}::create]: Source: 'tests/src/adt-borrows.rs', lines 10:4-12:5 -/ def SharedWrapper.create {T : Type} (x : T) : Result (SharedWrapper T) := - Result.ok x + ok x /- [adt_borrows::{adt_borrows::SharedWrapper<'a, T>}::unwrap]: Source: 'tests/src/adt-borrows.rs', lines 14:4-16:5 -/ def SharedWrapper.unwrap {T : Type} (self : SharedWrapper T) : Result T := - Result.ok self + ok self /- [adt_borrows::use_shared_wrapper]: Source: 'tests/src/adt-borrows.rs', lines 19:0-24:1 -/ @@ -38,12 +38,12 @@ structure SharedWrapper1 (T : Type) where /- [adt_borrows::{adt_borrows::SharedWrapper1<'a, T>}#1::create]: Source: 'tests/src/adt-borrows.rs', lines 31:4-33:5 -/ def SharedWrapper1.create {T : Type} (x : T) : Result (SharedWrapper1 T) := - Result.ok { x } + ok { x } /- [adt_borrows::{adt_borrows::SharedWrapper1<'a, T>}#1::unwrap]: Source: 'tests/src/adt-borrows.rs', lines 35:4-37:5 -/ def SharedWrapper1.unwrap {T : Type} (self : SharedWrapper1 T) : Result T := - Result.ok self.x + ok self.x /- [adt_borrows::use_shared_wrapper1]: Source: 'tests/src/adt-borrows.rs', lines 40:0-45:1 -/ @@ -63,21 +63,20 @@ structure SharedWrapper2 (T : Type) where Source: 'tests/src/adt-borrows.rs', lines 53:4-55:5 -/ def SharedWrapper2.create {T : Type} (x : T) (y : T) : Result (SharedWrapper2 T) := - Result.ok { x, y } + ok { x, y } /- [adt_borrows::{adt_borrows::SharedWrapper2<'a, 'b, T>}#2::unwrap]: Source: 'tests/src/adt-borrows.rs', lines 57:4-59:5 -/ def SharedWrapper2.unwrap {T : Type} (self : SharedWrapper2 T) : Result (T × T) := - Result.ok (self.x, self.y) + ok (self.x, self.y) /- [adt_borrows::use_shared_wrapper2]: Source: 'tests/src/adt-borrows.rs', lines 62:0-69:1 -/ def use_shared_wrapper2 : Result Unit := do let w ← SharedWrapper2.create 0#i32 1#i32 - let p ← SharedWrapper2.unwrap w - let (px, py) := p + let (px, py) ← SharedWrapper2.unwrap w massert (0#i32 = px) massert (1#i32 = py) @@ -89,14 +88,14 @@ def use_shared_wrapper2 : Result Unit := Source: 'tests/src/adt-borrows.rs', lines 74:4-76:5 -/ def MutWrapper.create {T : Type} (x : T) : Result ((MutWrapper T) × (MutWrapper T → T)) := - Result.ok (x, fun ret => ret) + ok (x, fun ret => ret) /- [adt_borrows::{adt_borrows::MutWrapper<'a, T>}#3::unwrap]: Source: 'tests/src/adt-borrows.rs', lines 78:4-80:5 -/ def MutWrapper.unwrap {T : Type} (self : MutWrapper T) : Result (T × (T → MutWrapper T)) := let back := fun ret => ret - Result.ok (self, back) + ok (self, back) /- [adt_borrows::{adt_borrows::MutWrapper<'a, T>}#3::id]: Source: 'tests/src/adt-borrows.rs', lines 82:4-84:5 -/ @@ -105,7 +104,7 @@ def MutWrapper.id Result ((MutWrapper T) × (MutWrapper T → MutWrapper T)) := let back := fun ret => ret - Result.ok (self, back) + ok (self, back) /- [adt_borrows::use_mut_wrapper]: Source: 'tests/src/adt-borrows.rs', lines 87:0-93:1 -/ @@ -126,7 +125,7 @@ def use_mut_wrapper_id do let (mw, id_back) ← MutWrapper.id x let back := fun ret => id_back ret - Result.ok (mw, back) + ok (mw, back) /- [adt_borrows::MutWrapper1] Source: 'tests/src/adt-borrows.rs', lines 99:0-101:1 -/ @@ -138,14 +137,14 @@ structure MutWrapper1 (T : Type) where def MutWrapper1.create {T : Type} (x : T) : Result ((MutWrapper1 T) × (MutWrapper1 T → T)) := let back := fun ret => ret.x - Result.ok ({ x }, back) + ok ({ x }, back) /- [adt_borrows::{adt_borrows::MutWrapper1<'a, T>}#4::unwrap]: Source: 'tests/src/adt-borrows.rs', lines 108:4-110:5 -/ def MutWrapper1.unwrap {T : Type} (self : MutWrapper1 T) : Result (T × (T → MutWrapper1 T)) := let back := fun ret => { x := ret } - Result.ok (self.x, back) + ok (self.x, back) /- [adt_borrows::{adt_borrows::MutWrapper1<'a, T>}#4::id]: Source: 'tests/src/adt-borrows.rs', lines 112:4-114:5 -/ @@ -153,7 +152,7 @@ def MutWrapper1.id {T : Type} (self : MutWrapper1 T) : Result ((MutWrapper1 T) × (MutWrapper1 T → MutWrapper1 T)) := - Result.ok (self, fun ret => ret) + ok (self, fun ret => ret) /- [adt_borrows::use_mut_wrapper1]: Source: 'tests/src/adt-borrows.rs', lines 117:0-123:1 -/ @@ -187,7 +186,7 @@ def MutWrapper2.create := let back'a := fun ret => ret.x let back'b := fun ret => ret.y - Result.ok ({ x, y }, back'a, back'b) + ok ({ x, y }, back'a, back'b) /- [adt_borrows::{adt_borrows::MutWrapper2<'a, 'b, T>}#5::unwrap]: Source: 'tests/src/adt-borrows.rs', lines 139:4-141:5 -/ @@ -197,7 +196,7 @@ def MutWrapper2.unwrap := let back'a := fun ret => { self with x := ret } let back'b := fun ret => { self with y := ret } - Result.ok ((self.x, self.y), back'a, back'b) + ok ((self.x, self.y), back'a, back'b) /- [adt_borrows::{adt_borrows::MutWrapper2<'a, 'b, T>}#5::id]: Source: 'tests/src/adt-borrows.rs', lines 143:4-145:5 -/ @@ -208,7 +207,7 @@ def MutWrapper2.id := let back'a := fun ret => { self with x := ret.x } let back'b := fun ret => { self with y := ret.y } - Result.ok (self, back'a, back'b) + ok (self, back'a, back'b) /- [adt_borrows::use_mut_wrapper2]: Source: 'tests/src/adt-borrows.rs', lines 148:0-157:1 -/ @@ -235,12 +234,12 @@ def use_mut_wrapper2_id let (mw, id_back, id_back1) ← MutWrapper2.id x let back'a := fun ret => { x with x := (id_back { mw with x := ret.x }).x } let back'b := fun ret => { x with y := (id_back1 { mw with y := ret.y }).y } - Result.ok (mw, back'a, back'b) + ok (mw, back'a, back'b) /- [adt_borrows::array_shared_borrow]: Source: 'tests/src/adt-borrows.rs', lines 170:0-172:1 -/ def array_shared_borrow {N : Usize} (x : Array U32 N) : Result (Array U32 N) := - Result.ok x + ok x /- [adt_borrows::array_mut_borrow]: Source: 'tests/src/adt-borrows.rs', lines 174:0-176:1 -/ @@ -248,7 +247,7 @@ def array_mut_borrow {N : Usize} (x : Array U32 N) : Result ((Array U32 N) × (Array U32 N → Array U32 N)) := - Result.ok (x, fun ret => ret) + ok (x, fun ret => ret) /- [adt_borrows::use_array_mut_borrow1]: Source: 'tests/src/adt-borrows.rs', lines 178:0-180:1 -/ @@ -270,18 +269,18 @@ def use_array_mut_borrow2 let back := fun ret => let x2 := array_mut_borrow_back1 ret array_mut_borrow_back x2 - Result.ok (a, back) + ok (a, back) /- [adt_borrows::boxed_slice_shared_borrow]: Source: 'tests/src/adt-borrows.rs', lines 187:0-189:1 -/ def boxed_slice_shared_borrow (x : Slice U32) : Result (Slice U32) := - Result.ok x + ok x /- [adt_borrows::boxed_slice_mut_borrow]: Source: 'tests/src/adt-borrows.rs', lines 191:0-193:1 -/ def boxed_slice_mut_borrow (x : Slice U32) : Result ((Slice U32) × (Slice U32 → Slice U32)) := - Result.ok (x, fun ret => ret) + ok (x, fun ret => ret) /- [adt_borrows::use_boxed_slice_mut_borrow1]: Source: 'tests/src/adt-borrows.rs', lines 195:0-197:1 -/ @@ -300,7 +299,7 @@ def use_boxed_slice_mut_borrow2 fun ret => let s1 := boxed_slice_mut_borrow_back1 ret boxed_slice_mut_borrow_back s1 - Result.ok (s, back) + ok (s, back) /- [adt_borrows::SharedList] Source: 'tests/src/adt-borrows.rs', lines 207:0-210:1 -/ @@ -312,15 +311,15 @@ inductive SharedList (T : Type) where Source: 'tests/src/adt-borrows.rs', lines 214:4-216:5 -/ def SharedList.push {T : Type} (self : SharedList T) (x : T) : Result (SharedList T) := - Result.ok (SharedList.Cons x self) + ok (SharedList.Cons x self) /- [adt_borrows::{adt_borrows::SharedList<'a, T>}#6::pop]: Source: 'tests/src/adt-borrows.rs', lines 218:4-224:5 -/ def SharedList.pop {T : Type} (self : SharedList T) : Result (T × (SharedList T)) := match self with - | SharedList.Nil => Result.fail .panic - | SharedList.Cons hd tl => Result.ok (hd, tl) + | SharedList.Nil => fail panic + | SharedList.Cons hd tl => ok (hd, tl) /- [adt_borrows::MutList] Source: 'tests/src/adt-borrows.rs', lines 227:0-230:1 -/ @@ -341,7 +340,7 @@ def MutList.push | MutList.Cons t ml1 => (t, ml1) | _ => (x, self) (ml, x1) - Result.ok (MutList.Cons x self, back) + ok (MutList.Cons x self, back) /- [adt_borrows::{adt_borrows::MutList<'a, T>}#7::pop]: Source: 'tests/src/adt-borrows.rs', lines 238:4-244:5 -/ @@ -350,16 +349,16 @@ def MutList.pop Result ((T × (MutList T)) × ((T × (MutList T)) → MutList T)) := match self with - | MutList.Nil => Result.fail .panic + | MutList.Nil => fail panic | MutList.Cons hd tl => let back := fun ret => let (t, ml) := ret MutList.Cons t ml - Result.ok ((hd, tl), back) + ok ((hd, tl), back) /- [adt_borrows::wrap_shared_in_option]: Source: 'tests/src/adt-borrows.rs', lines 247:0-249:1 -/ def wrap_shared_in_option {T : Type} (x : T) : Result (Option T) := - Result.ok (some x) + ok (some x) /- [adt_borrows::wrap_mut_in_option]: Source: 'tests/src/adt-borrows.rs', lines 251:0-253:1 -/ @@ -368,7 +367,7 @@ def wrap_mut_in_option let back := fun ret => match ret with | some t => t | _ => x - Result.ok (some x, back) + ok (some x, back) /- [adt_borrows::List] Source: 'tests/src/adt-borrows.rs', lines 255:0-258:1 -/ @@ -383,11 +382,11 @@ divergent def nth_shared_loop match ls with | List.Cons x tl => if i = 0#u32 - then Result.ok (some x) + then ok (some x) else do let i1 ← i - 1#u32 nth_shared_loop tl i1 - | List.Nil => Result.ok none + | List.Nil => ok none /- [adt_borrows::nth_shared]: Source: 'tests/src/adt-borrows.rs', lines 260:0-270:1 -/ @@ -411,16 +410,16 @@ divergent def nth_mut_loop | some t1 => t1 | _ => x List.Cons t tl - Result.ok (some x, back) + ok (some x, back) else do let i1 ← i - 1#u32 let (o, back) ← nth_mut_loop tl i1 let back1 := fun ret => let tl1 := back ret List.Cons x tl1 - Result.ok (o, back1) + ok (o, back1) | List.Nil => let back := fun ret => List.Nil - Result.ok (none, back) + ok (none, back) /- [adt_borrows::nth_mut]: Source: 'tests/src/adt-borrows.rs', lines 272:0-282:1 -/ @@ -437,7 +436,7 @@ def update_array_mut_borrow (a : Array U32 32#usize) : Result ((Array U32 32#usize) × (Array U32 32#usize → Array U32 32#usize)) := - Result.ok (a, fun ret => ret) + ok (a, fun ret => ret) /- [adt_borrows::array_mut_borrow_loop1]: loop 0: Source: 'tests/src/adt-borrows.rs', lines 289:4-291:5 -/ @@ -448,8 +447,8 @@ divergent def array_mut_borrow_loop1_loop do let (a1, update_array_mut_borrow_back) ← update_array_mut_borrow a let a2 ← array_mut_borrow_loop1_loop true a1 - Result.ok (update_array_mut_borrow_back a2) - else Result.ok a + ok (update_array_mut_borrow_back a2) + else ok a /- [adt_borrows::array_mut_borrow_loop1]: Source: 'tests/src/adt-borrows.rs', lines 288:0-292:1 -/ @@ -471,8 +470,8 @@ divergent def array_mut_borrow_loop2_loop let (a2, back) ← array_mut_borrow_loop2_loop true a1 let back1 := fun ret => let a3 := back ret update_array_mut_borrow_back a3 - Result.ok (a2, back1) - else Result.ok (a, fun ret => ret) + ok (a2, back1) + else ok (a, fun ret => ret) /- [adt_borrows::array_mut_borrow_loop2]: Source: 'tests/src/adt-borrows.rs', lines 294:0-299:1 -/ @@ -486,7 +485,7 @@ def array_mut_borrow_loop2 /- [adt_borrows::copy_shared_array]: Source: 'tests/src/adt-borrows.rs', lines 301:0-303:1 -/ def copy_shared_array (a : Array U32 32#usize) : Result (Array U32 32#usize) := - Result.ok a + ok a /- [adt_borrows::array_shared_borrow_loop1]: loop 0: Source: 'tests/src/adt-borrows.rs', lines 306:4-308:5 -/ @@ -496,7 +495,7 @@ divergent def array_shared_borrow_loop1_loop then do let a1 ← copy_shared_array a array_shared_borrow_loop1_loop true a1 - else Result.ok () + else ok () /- [adt_borrows::array_shared_borrow_loop1]: Source: 'tests/src/adt-borrows.rs', lines 305:0-309:1 -/ @@ -513,7 +512,7 @@ divergent def array_shared_borrow_loop2_loop then do let a1 ← copy_shared_array a array_shared_borrow_loop2_loop true a1 - else Result.ok a + else ok a /- [adt_borrows::array_shared_borrow_loop2]: Source: 'tests/src/adt-borrows.rs', lines 311:0-316:1 -/ diff --git a/tests/lean/Arrays.lean b/tests/lean/Arrays.lean index b98c3b89..26fd0246 100644 --- a/tests/lean/Arrays.lean +++ b/tests/lean/Arrays.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [arrays] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -38,19 +38,19 @@ def array_to_mut_slice_ def array_len {T : Type} (s : Array T 32#usize) : Result Usize := do let s1 ← Array.to_slice s - Result.ok (Slice.len s1) + ok (Slice.len s1) /- [arrays::shared_array_len]: Source: 'tests/src/arrays.rs', lines 32:0-34:1 -/ def shared_array_len {T : Type} (s : Array T 32#usize) : Result Usize := do let s1 ← Array.to_slice s - Result.ok (Slice.len s1) + ok (Slice.len s1) /- [arrays::shared_slice_len]: Source: 'tests/src/arrays.rs', lines 36:0-38:1 -/ def shared_slice_len {T : Type} (s : Slice T) : Result Usize := - Result.ok (Slice.len s) + ok (Slice.len s) /- [arrays::index_array_shared]: Source: 'tests/src/arrays.rs', lines 40:0-42:1 -/ @@ -163,13 +163,13 @@ def update_update_array := do let (a, index_mut_back) ← Array.index_mut_usize s i - let a1 ← Array.update_usize a j 0#u32 - Result.ok (index_mut_back a1) + let a1 ← Array.update a j 0#u32 + ok (index_mut_back a1) /- [arrays::array_local_deep_copy]: Source: 'tests/src/arrays.rs', lines 121:0-123:1 -/ def array_local_deep_copy (x : Array U32 32#usize) : Result Unit := - Result.ok () + ok () /- [arrays::array_update1]: Source: 'tests/src/arrays.rs', lines 126:0-128:1 -/ @@ -177,58 +177,58 @@ def array_update1 (a : Slice U32) (i : Usize) (x : U32) : Result (Slice U32) := do let i1 ← i + 1#usize let i2 ← x + 1#u32 - Slice.update_usize a i1 i2 + Slice.update a i1 i2 /- [arrays::array_update2]: Source: 'tests/src/arrays.rs', lines 131:0-134:1 -/ def array_update2 (a : Slice U32) (i : Usize) (x : U32) : Result (Slice U32) := do let i1 ← x + 1#u32 - let a1 ← Slice.update_usize a i i1 + let a1 ← Slice.update a i i1 let i2 ← i + 1#usize - Slice.update_usize a1 i2 i1 + Slice.update a1 i2 i1 /- [arrays::array_update3]: Source: 'tests/src/arrays.rs', lines 136:0-140:1 -/ def array_update3 (a : Slice U32) (i : Usize) (x : U32) : Result (Slice U32) := do - let a1 ← Slice.update_usize a i x + let a1 ← Slice.update a i x let i1 ← i + 1#usize - let a2 ← Slice.update_usize a1 i1 x + let a2 ← Slice.update a1 i1 x let i2 ← i + 2#usize - Slice.update_usize a2 i2 x + Slice.update a2 i2 x /- [arrays::take_array]: Source: 'tests/src/arrays.rs', lines 142:0-142:33 -/ def take_array (a : Array U32 2#usize) : Result Unit := - Result.ok () + ok () /- [arrays::take_array_borrow]: Source: 'tests/src/arrays.rs', lines 143:0-143:41 -/ def take_array_borrow (a : Array U32 2#usize) : Result Unit := - Result.ok () + ok () /- [arrays::take_slice]: Source: 'tests/src/arrays.rs', lines 144:0-144:31 -/ def take_slice (s : Slice U32) : Result Unit := - Result.ok () + ok () /- [arrays::take_mut_slice]: Source: 'tests/src/arrays.rs', lines 145:0-145:39 -/ def take_mut_slice (s : Slice U32) : Result (Slice U32) := - Result.ok s + ok s /- [arrays::const_array]: Source: 'tests/src/arrays.rs', lines 147:0-149:1 -/ def const_array : Result (Array U32 2#usize) := - Result.ok (Array.make 2#usize [ 0#u32, 0#u32 ]) + ok (Array.make 2#usize [ 0#u32, 0#u32 ]) /- [arrays::const_slice]: Source: 'tests/src/arrays.rs', lines 151:0-153:1 -/ def const_slice : Result Unit := do let _ ← Array.to_slice (Array.make 2#usize [ 0#u32, 0#u32 ]) - Result.ok () + ok () /- [arrays::take_all]: Source: 'tests/src/arrays.rs', lines 161:0-173:1 -/ @@ -241,7 +241,7 @@ def take_all : Result Unit := take_slice s let (s1, _) ← Array.to_slice_mut (Array.make 2#usize [ 0#u32, 0#u32 ]) let _ ← take_mut_slice s1 - Result.ok () + ok () /- [arrays::index_array]: Source: 'tests/src/arrays.rs', lines 175:0-177:1 -/ @@ -263,7 +263,7 @@ def index_slice_u32_0 (x : Slice U32) : Result U32 := def index_mut_slice_u32_0 (x : Slice U32) : Result (U32 × (Slice U32)) := do let i ← Slice.index_usize x 0#usize - Result.ok (i, x) + ok (i, x) /- [arrays::index_all]: Source: 'tests/src/arrays.rs', lines 190:0-202:1 -/ @@ -285,18 +285,18 @@ def index_all : Result U32 := def update_array (x : Array U32 2#usize) : Result Unit := do let _ ← Array.index_mut_usize x 0#usize - Result.ok () + ok () /- [arrays::update_array_mut_borrow]: Source: 'tests/src/arrays.rs', lines 207:0-209:1 -/ def update_array_mut_borrow (x : Array U32 2#usize) : Result (Array U32 2#usize) := - Array.update_usize x 0#usize 1#u32 + Array.update x 0#usize 1#u32 /- [arrays::update_mut_slice]: Source: 'tests/src/arrays.rs', lines 210:0-212:1 -/ def update_mut_slice (x : Slice U32) : Result (Slice U32) := - Slice.update_usize x 0#usize 1#u32 + Slice.update x 0#usize 1#u32 /- [arrays::update_all]: Source: 'tests/src/arrays.rs', lines 214:0-220:1 -/ @@ -307,7 +307,7 @@ def update_all : Result Unit := let x ← update_array_mut_borrow (Array.make 2#usize [ 0#u32, 0#u32 ]) let (s, _) ← Array.to_slice_mut x let _ ← update_mut_slice s - Result.ok () + ok () /- [arrays::incr_array]: Source: 'tests/src/arrays.rs', lines 222:0-224:1 -/ @@ -315,7 +315,7 @@ def incr_array (x : Array U32 2#usize) : Result (Array U32 2#usize) := do let i ← Array.index_usize x 0#usize let i1 ← i + 1#u32 - Array.update_usize x 0#usize i1 + Array.update x 0#usize i1 /- [arrays::incr_slice]: Source: 'tests/src/arrays.rs', lines 226:0-228:1 -/ @@ -323,7 +323,7 @@ def incr_slice (x : Slice U32) : Result (Slice U32) := do let i ← Slice.index_usize x 0#usize let i1 ← i + 1#u32 - Slice.update_usize x 0#usize i1 + Slice.update x 0#usize i1 /- [arrays::range_all]: Source: 'tests/src/arrays.rs', lines 233:0-237:1 -/ @@ -335,7 +335,7 @@ def range_all : Result Unit := (Array.make 4#usize [ 0#u32, 0#u32, 0#u32, 0#u32 ]) { start := 1#usize, end_ := 3#usize } let _ ← update_mut_slice s - Result.ok () + ok () /- [arrays::deref_array_borrow]: Source: 'tests/src/arrays.rs', lines 242:0-245:1 -/ @@ -348,12 +348,12 @@ def deref_array_mut_borrow (x : Array U32 2#usize) : Result (U32 × (Array U32 2#usize)) := do let i ← Array.index_usize x 0#usize - Result.ok (i, x) + ok (i, x) /- [arrays::take_array_t]: Source: 'tests/src/arrays.rs', lines 255:0-255:34 -/ def take_array_t (a : Array AB 2#usize) : Result Unit := - Result.ok () + ok () /- [arrays::non_copyable_array]: Source: 'tests/src/arrays.rs', lines 257:0-265:1 -/ @@ -371,12 +371,12 @@ divergent def sum_loop (s : Slice U32) (sum1 : U32) (i : Usize) : Result U32 := let sum3 ← sum1 + i2 let i3 ← i + 1#usize sum_loop s sum3 i3 - else Result.ok sum1 + else ok sum1 /- [arrays::sum]: Source: 'tests/src/arrays.rs', lines 270:0-278:1 -/ -def sum (s : Slice U32) : Result U32 := - sum_loop s 0#u32 0#usize +@[reducible] def sum (s : Slice U32) : Result U32 := + sum_loop s 0#u32 0#usize /- [arrays::sum2]: loop 0: Source: 'tests/src/arrays.rs', lines 284:4-287:5 -/ @@ -392,7 +392,7 @@ divergent def sum2_loop let sum3 ← sum1 + i4 let i5 ← i + 1#usize sum2_loop s s2 sum3 i5 - else Result.ok sum1 + else ok sum1 /- [arrays::sum2]: Source: 'tests/src/arrays.rs', lines 280:0-289:1 -/ @@ -409,19 +409,19 @@ def f0 : Result Unit := do let (s, _) ← Array.to_slice_mut (Array.make 2#usize [ 1#u32, 2#u32 ]) let _ ← Slice.index_mut_usize s 0#usize - Result.ok () + ok () /- [arrays::f1]: Source: 'tests/src/arrays.rs', lines 296:0-299:1 -/ def f1 : Result Unit := do let _ ← Array.index_mut_usize (Array.make 2#usize [ 1#u32, 2#u32 ]) 0#usize - Result.ok () + ok () /- [arrays::f2]: Source: 'tests/src/arrays.rs', lines 301:0-301:20 -/ def f2 (i : U32) : Result Unit := - Result.ok () + ok () /- [arrays::f4]: Source: 'tests/src/arrays.rs', lines 310:0-312:1 -/ @@ -443,7 +443,7 @@ def f3 : Result U32 := /- [arrays::SZ] Source: 'tests/src/arrays.rs', lines 314:0-314:25 -/ -def SZ_body : Result Usize := Result.ok 32#usize +def SZ_body : Result Usize := ok 32#usize def SZ : Usize := eval_global SZ_body /- [arrays::f5]: @@ -458,7 +458,7 @@ def ite : Result Unit := let (s, _) ← Array.to_slice_mut (Array.make 2#usize [ 0#u32, 0#u32 ]) let _ ← index_mut_slice_u32_0 s let _ ← index_mut_slice_u32_0 s - Result.ok () + ok () /- [arrays::zero_slice]: loop 0: Source: 'tests/src/arrays.rs', lines 334:4-337:5 -/ @@ -467,10 +467,10 @@ divergent def zero_slice_loop if i < len then do - let a1 ← Slice.update_usize a i 0#u8 + let a1 ← Slice.update a i 0#u8 let i1 ← i + 1#usize zero_slice_loop a1 i1 len - else Result.ok a + else ok a /- [arrays::zero_slice]: Source: 'tests/src/arrays.rs', lines 331:0-338:1 -/ @@ -485,7 +485,7 @@ divergent def iter_mut_slice_loop (len : Usize) (i : Usize) : Result Unit := then do let i1 ← i + 1#usize iter_mut_slice_loop len i1 - else Result.ok () + else ok () /- [arrays::iter_mut_slice]: Source: 'tests/src/arrays.rs', lines 340:0-346:1 -/ @@ -493,7 +493,7 @@ def iter_mut_slice (a : Slice U8) : Result (Slice U8) := do let len := Slice.len a iter_mut_slice_loop len 0#usize - Result.ok a + ok a /- [arrays::sum_mut_slice]: loop 0: Source: 'tests/src/arrays.rs', lines 351:4-354:5 -/ @@ -507,13 +507,13 @@ divergent def sum_mut_slice_loop let s1 ← s + i2 let i3 ← i + 1#usize sum_mut_slice_loop a i3 s1 - else Result.ok s + else ok s /- [arrays::sum_mut_slice]: Source: 'tests/src/arrays.rs', lines 348:0-356:1 -/ def sum_mut_slice (a : Slice U32) : Result (U32 × (Slice U32)) := do let i ← sum_mut_slice_loop a 0#usize 0#u32 - Result.ok (i, a) + ok (i, a) end arrays diff --git a/tests/lean/AsMut.lean b/tests/lean/AsMut.lean new file mode 100644 index 00000000..b213b241 --- /dev/null +++ b/tests/lean/AsMut.lean @@ -0,0 +1,25 @@ +-- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS +-- [as_mut] +import Aeneas +open Aeneas.Std Result Error +set_option linter.dupNamespace false +set_option linter.hashCommand false +set_option linter.unusedVariables false + +namespace as_mut + +/- [as_mut::use_box_as_mut]: + Source: 'tests/src/as_mut.rs', lines 2:0-4:1 -/ +def use_box_as_mut {T : Type} (x : T) : Result (T × (T → T)) := + ok (alloc.boxed.AsMutBoxT.as_mut x) + +/- [as_mut::use_as_mut]: + Source: 'tests/src/as_mut.rs', lines 6:0-8:1 -/ +def use_as_mut + {S : Type} {T : Type} (coreconvertAsMutInst : core.convert.AsMut T S) + (x : T) : + Result (S × (S → T)) + := + coreconvertAsMutInst.as_mut x + +end as_mut diff --git a/tests/lean/Avl/Funs.lean b/tests/lean/Avl/Funs.lean index 2eb7e02d..1c151737 100644 --- a/tests/lean/Avl/Funs.lean +++ b/tests/lean/Avl/Funs.lean @@ -2,7 +2,7 @@ -- [avl]: function definitions import Aeneas import Avl.Types -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -13,11 +13,10 @@ namespace avl Source: 'src/avl.rs', lines 8:4-16:5 -/ def OrdI32.cmp (self : I32) (other : I32) : Result Ordering := if self < other - then Result.ok Ordering.Less - else - if self = other - then Result.ok Ordering.Equal - else Result.ok Ordering.Greater + then ok Ordering.Less + else if self = other + then ok Ordering.Equal + else ok Ordering.Greater /- Trait implementation: [avl::{avl::Ord for i32}] Source: 'src/avl.rs', lines 7:0-17:1 -/ @@ -36,10 +35,10 @@ def Node.rotate_left (Node.mk z.value o z.right z.balance_factor) if root1.balance_factor = 0#i8 then - Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 1#i8)) + ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 1#i8)) root1.right (-1)#i8) else - Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 0#i8)) + ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 0#i8)) root1.right 0#i8) /- [avl::{avl::Node}#1::rotate_right]: @@ -52,11 +51,11 @@ def Node.rotate_right (Node.mk z.value z.left o z.balance_factor) if root1.balance_factor = 0#i8 then - Result.ok (Node.mk root1.value root1.left (some (Node.mk x.value - x.left x.right (-1)#i8)) 1#i8) + ok (Node.mk root1.value root1.left (some (Node.mk x.value x.left + x.right (-1)#i8)) 1#i8) else - Result.ok (Node.mk root1.value root1.left (some (Node.mk x.value - x.left x.right 0#i8)) 0#i8) + ok (Node.mk root1.value root1.left (some (Node.mk x.value x.left + x.right 0#i8)) 0#i8) /- [avl::{avl::Node}#1::rotate_left_right]: Source: 'src/avl.rs', lines 138:4-186:5 -/ @@ -72,16 +71,16 @@ def Node.rotate_left_right (Node.mk y.value o1 o2 y.balance_factor) if root1.balance_factor = 0#i8 then - Result.ok (Node.mk root1.value (some (Node.mk z.value z.left a 0#i8)) (some + ok (Node.mk root1.value (some (Node.mk z.value z.left a 0#i8)) (some (Node.mk x.value x.left x.right 0#i8)) 0#i8) else if root1.balance_factor < 0#i8 then - Result.ok (Node.mk root1.value (some (Node.mk z.value z.left a 0#i8)) - (some (Node.mk x.value x.left x.right 1#i8)) 0#i8) + ok (Node.mk root1.value (some (Node.mk z.value z.left a 0#i8)) (some + (Node.mk x.value x.left x.right 1#i8)) 0#i8) else - Result.ok (Node.mk root1.value (some (Node.mk z.value z.left a (-1)#i8)) - (some (Node.mk x.value x.left x.right 0#i8)) 0#i8) + ok (Node.mk root1.value (some (Node.mk z.value z.left a (-1)#i8)) (some + (Node.mk x.value x.left x.right 0#i8)) 0#i8) /- [avl::{avl::Node}#1::rotate_right_left]: Source: 'src/avl.rs', lines 188:4-236:5 -/ @@ -97,16 +96,16 @@ def Node.rotate_right_left (Node.mk y.value o1 o2 y.balance_factor) if root1.balance_factor = 0#i8 then - Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 0#i8)) - (some (Node.mk z.value a z.right 0#i8)) 0#i8) + ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 0#i8)) (some + (Node.mk z.value a z.right 0#i8)) 0#i8) else if root1.balance_factor > 0#i8 then - Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right - (-1)#i8)) (some (Node.mk z.value a z.right 0#i8)) 0#i8) + ok (Node.mk root1.value (some (Node.mk x.value x.left x.right (-1)#i8)) + (some (Node.mk z.value a z.right 0#i8)) 0#i8) else - Result.ok (Node.mk root1.value (some (Node.mk x.value x.left x.right - 0#i8)) (some (Node.mk z.value a z.right 1#i8)) 0#i8) + ok (Node.mk root1.value (some (Node.mk x.value x.left x.right 0#i8)) + (some (Node.mk z.value a z.right 1#i8)) 0#i8) /- [avl::{avl::Node}#2::insert_in_left]: Source: 'src/avl.rs', lines 240:4-275:5 -/ @@ -130,14 +129,14 @@ mutual divergent def Node.insert_in_left do let node1 ← Node.rotate_right (Node.mk node.value o2 node.right i) left - Result.ok (false, node1) + ok (false, node1) else do let node1 ← Node.rotate_left_right (Node.mk node.value o2 node.right i) left - Result.ok (false, node1) - else Result.ok (i != 0#i8, Node.mk node.value o node.right i) - else Result.ok (false, Node.mk node.value o node.right node.balance_factor) + ok (false, node1) + else ok (i != 0#i8, Node.mk node.value o node.right i) + else ok (false, Node.mk node.value o node.right node.balance_factor) /- [avl::{avl::Tree}#3::insert_in_opt_node]: Source: 'src/avl.rs', lines 356:4-371:5 -/ @@ -147,11 +146,11 @@ divergent def Tree.insert_in_opt_node := match node with | none => let n := Node.mk value none none 0#i8 - Result.ok (true, some n) + ok (true, some n) | some node1 => do let (b, node2) ← Node.insert OrdInst node1 value - Result.ok (b, some node2) + ok (b, some node2) /- [avl::{avl::Node}#2::insert_in_right]: Source: 'src/avl.rs', lines 277:4-315:5 -/ @@ -175,14 +174,14 @@ divergent def Node.insert_in_right do let node1 ← Node.rotate_left (Node.mk node.value node.left o2 i) right - Result.ok (false, node1) + ok (false, node1) else do let node1 ← Node.rotate_right_left (Node.mk node.value node.left o2 i) right - Result.ok (false, node1) - else Result.ok (i != 0#i8, Node.mk node.value node.left o i) - else Result.ok (false, Node.mk node.value node.left o node.balance_factor) + ok (false, node1) + else ok (i != 0#i8, Node.mk node.value node.left o i) + else ok (false, Node.mk node.value node.left o node.balance_factor) /- [avl::{avl::Node}#2::insert]: Source: 'src/avl.rs', lines 318:4-334:5 -/ @@ -194,7 +193,7 @@ divergent def Node.insert let ordering ← OrdInst.cmp value node.value match ordering with | Ordering.Less => Node.insert_in_left OrdInst node value - | Ordering.Equal => Result.ok (false, node) + | Ordering.Equal => ok (false, node) | Ordering.Greater => Node.insert_in_right OrdInst node value end @@ -202,7 +201,7 @@ end /- [avl::{avl::Tree}#3::new]: Source: 'src/avl.rs', lines 338:4-340:5 -/ def Tree.new {T : Type} (OrdInst : Ord T) : Result (Tree T) := - Result.ok { root := none } + ok { root := none } /- [avl::{avl::Tree}#3::find]: loop 0: Source: 'src/avl.rs', lines 345:8-354:5 -/ @@ -211,17 +210,18 @@ divergent def Tree.find_loop Result Bool := match current_tree with - | none => Result.ok false + | none => ok false | some current_node => do let o ← OrdInst.cmp current_node.value value match o with | Ordering.Less => Tree.find_loop OrdInst value current_node.right - | Ordering.Equal => Result.ok true + | Ordering.Equal => ok true | Ordering.Greater => Tree.find_loop OrdInst value current_node.left /- [avl::{avl::Tree}#3::find]: Source: 'src/avl.rs', lines 342:4-354:5 -/ +@[reducible] def Tree.find {T : Type} (OrdInst : Ord T) (self : Tree T) (value : T) : Result Bool := Tree.find_loop OrdInst value self.root @@ -234,6 +234,6 @@ def Tree.insert := do let (b, o) ← Tree.insert_in_opt_node OrdInst self.root value - Result.ok (b, { root := o }) + ok (b, { root := o }) end avl diff --git a/tests/lean/Avl/Properties.lean b/tests/lean/Avl/Properties.lean index 6d3f9ed7..7e616773 100644 --- a/tests/lean/Avl/Properties.lean +++ b/tests/lean/Avl/Properties.lean @@ -75,16 +75,16 @@ def Node.forall (p: Node T -> Prop) (node : Node T) : Prop := p node ∧ Subtree.forall p node.left ∧ Subtree.forall p node.right termination_by Node.size node -decreasing_by all_goals (simp_wf; simp [Node.left, Node.right]; split <;> simp <;> scalar_tac) +decreasing_by all_goals (simp_wf; fsimp [Node.left, Node.right]; split; fsimp <;> scalar_tac) end @[simp] theorem Subtree.forall_left {p: Node T -> Prop} {t: Node T}: Node.forall p t -> Subtree.forall p t.left := by - cases t; simp_all (config := {maxDischargeDepth := 1}) [Node.forall] + cases t; fsimp_all [Node.forall] @[simp] theorem Subtree.forall_right {p: Node T -> Prop} {t: Node T}: Subtree.forall p t -> Subtree.forall p t.right := by - cases t; simp_all (config := {maxDischargeDepth := 1}) [Node.forall] + cases t; fsimp_all [Node.forall] mutual theorem Subtree.forall_imp {p q: Node T -> Prop} {t: Subtree T}: (∀ x, p x -> q x) -> Subtree.forall p t -> Subtree.forall q t @@ -99,9 +99,9 @@ theorem Subtree.forall_imp {p q: Node T -> Prop} {t: Subtree T}: (∀ x, p x -> theorem Node.forall_imp {p q: Node T -> Prop} {t: Node T}: (∀ x, p x -> q x) -> Node.forall p t -> Node.forall q t := by match t with | .mk x left right height => - simp [Node.forall] + fsimp [Node.forall] intros Himp Hleft Hright Hx - simp [*] + fsimp [*] split_conjs <;> apply @Subtree.forall_imp T p q <;> tauto end @@ -117,7 +117,7 @@ def Subtree.balanceFactor (t: Subtree T): ℤ := @[simp] theorem Subtree.some_balanceFactor (t: Node T) : Subtree.balanceFactor (some t) = t.balanceFactor := by - simp [balanceFactor] + fsimp [balanceFactor] @[simp, reducible] def Node.invAuxNotBalanced [LinearOrder T] (node : Node T) : Prop := @@ -140,7 +140,7 @@ theorem Node.inv_imp_current [LinearOrder T] {node : Node T} (hInv : node.inv) : (∀ x ∈ Subtree.v node.left, x < node.value) ∧ (∀ x ∈ Subtree.v node.right, node.value < x) ∧ -1 ≤ node.balanceFactor ∧ node.balanceFactor ≤ 1 := by - simp_all (config := {maxDischargeDepth := 1}) [Node.inv, Node.forall, Node.invAux] + fsimp_all [Node.inv, Node.forall, Node.invAux] @[reducible] def Subtree.inv [LinearOrder T] (st : Subtree T) : Prop := @@ -165,26 +165,26 @@ theorem Node.inv_mk [LinearOrder T] (value : T) (left right : Option (Node T)) ( Subtree.inv left ∧ Subtree.inv right) := by have : ∀ (n : Option (Node T)), Subtree.forall invAux n = Subtree.inv n := by unfold Subtree.forall - simp [Subtree.inv] + fsimp [Subtree.inv] constructor <;> - simp [*, Node.inv, Node.forall] + fsimp [*, Node.inv, Node.forall] @[simp] theorem Node.inv_left [LinearOrder T] {t: Node T}: t.inv -> Subtree.inv t.left := by - simp [Node.inv] + fsimp [Node.inv] intro - cases t; simp_all (config := {maxDischargeDepth := 1}) + cases t; fsimp_all @[simp] theorem Node.inv_right [LinearOrder T] {t: Node T}: t.inv -> Subtree.inv t.right := by - simp [Node.inv] + fsimp [Node.inv] intro - cases t; simp_all (config := {maxDischargeDepth := 1}) + cases t; fsimp_all theorem Node.inv_imp_balance_factor_eq [LinearOrder T] {t: Node T} (hInv : t.inv) : t.balance_factor.val = t.balanceFactor := by - simp [inv, Node.forall, invAux] at hInv - cases t; simp_all (config := {maxDischargeDepth := 1}) + fsimp [inv, Node.forall, invAux] at hInv + cases t; fsimp_all @[simp] theorem Node.lt_imp_not_in_right [LinearOrder T] (node : Node T) (hInv : node.inv) (x : T) @@ -224,7 +224,7 @@ theorem Node.value_not_in_left [LinearOrder T] (node : Node T) (hInv : node.inv) have := ne_of_lt this tauto -@[pspec] +@[progress] theorem Tree.find_loop_spec {T : Type} (OrdInst : Ord T) [DecidableEq T] [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] @@ -235,24 +235,24 @@ theorem Tree.find_loop_spec match t with | none => simp | some (.mk v left right height) => - dsimp only + fsimp only have hCmp := Ospec.infallible -- TODO progress keep Hordering as ⟨ ordering ⟩; clear hCmp have hInvLeft := Node.inv_left hInv have hInvRight := Node.inv_right hInv - cases ordering <;> dsimp only + cases ordering <;> fsimp only . /- node.value < value -/ progress have hNotIn := Node.lt_imp_not_in_left _ hInv - simp_all (config := {maxDischargeDepth := 1}) - intro; simp_all (config := {maxDischargeDepth := 1}) + fsimp_all + intro; fsimp_all . /- node.value = value -/ - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all . /- node.value > value -/ progress have hNotIn := Node.lt_imp_not_in_right _ hInv - simp_all (config := {maxDischargeDepth := 1}) - intro; simp_all (config := {maxDischargeDepth := 1}) + fsimp_all + intro; fsimp_all theorem Tree.find_spec {T : Type} (OrdInst : Ord T) @@ -262,12 +262,12 @@ theorem Tree.find_spec (b ↔ value ∈ t.v) := by rw [find] progress - simp [Tree.v]; assumption + fsimp [Tree.v]; assumption -- TODO: move set_option maxHeartbeats 5000000 -@[pspec] +@[progress] theorem Node.rotate_left_spec {T : Type} [LinearOrder T] (x z : T) (a b c : Option (Node T)) (bf_x bf_z : I8) @@ -293,68 +293,68 @@ theorem Node.rotate_left_spec Node.height ntree = 2 + Subtree.height b := by rw [rotate_left] - simp [core.mem.replace] + fsimp [core.mem.replace] -- Some proofs common to both cases -- Elements in the left subtree are < z have : ∀ (y : T), (y = x ∨ y ∈ Subtree.v a) ∨ y ∈ Subtree.v b → y < z := by - simp [invAux] at hInvZ + fsimp [invAux] at hInvZ intro y hIn -- TODO: automate that cases hIn . rename _ => hIn cases hIn - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all . -- Proving: y ∈ a → y < z -- Using: y < x ∧ x < z rename _ => hIn have hInv1 : y < x := by tauto have hInv2 := hInvX.right.right z - simp at hInv2 + fsimp at hInv2 apply lt_trans hInv1 hInv2 . tauto -- Elements in the right subtree are < z have : ∀ y ∈ Subtree.v c, z < y := by - simp [invAux] at hInvZ + fsimp [invAux] at hInvZ tauto -- Two cases depending on whether the BF of Z is 0 or 1 split . -- BF(Z) == 0 - simp at * - simp [*] + fsimp at * + fsimp [*] -- TODO: scalar_tac should succeed below have hHeightEq : Subtree.height b = Subtree.height c := by - simp_all (config := {maxDischargeDepth := 1}) [balanceFactor, Node.invAux] + fsimp_all [balanceFactor, Node.invAux] scalar_tac -- TODO: scalar_tac should succeed below have : 1 + Subtree.height c = Subtree.height a + 2 := by - simp_all (config := {maxDischargeDepth := 1}) [balanceFactor, Node.invAux] + fsimp_all [balanceFactor, Node.invAux] scalar_tac - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all split_conjs . -- Partial invariant for the final tree starting at z - simp [Node.invAux, balanceFactor, *] + fsimp [Node.invAux, balanceFactor, *] split_conjs <;> (try omega) <;> tauto . -- Partial invariant for the subtree x - simp [Node.invAux, balanceFactor, *] - split_conjs <;> (try omega) <;> simp_all (config := {maxDischargeDepth := 1}) + fsimp [Node.invAux, balanceFactor, *] + split_conjs <;> (try omega) <;> fsimp_all . -- The sets are the same apply Set.ext; simp; tauto . -- The height didn't change - simp [balanceFactor] at * + fsimp [balanceFactor] at * replace hInvX := hInvX.left - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all scalar_tac . -- BF(Z) == 1 rename _ => hNotEq - simp at * - simp [*] - simp_all (config := {maxDischargeDepth := 1}) + fsimp at * + fsimp [*] + fsimp_all have : bf_z.val = 1 := by - simp [Node.invAux] at hInvZ + fsimp [Node.invAux] at hInvZ omega clear hNotEq hBfZ have : Subtree.height c = 1 + Subtree.height b := by - simp [balanceFactor, Node.invAux] at * + fsimp [balanceFactor, Node.invAux] at * replace hInvZ := hInvZ.left omega have : max (Subtree.height c) (Subtree.height b) = Subtree.height c := by @@ -362,26 +362,26 @@ theorem Node.rotate_left_spec -- TODO: we shouldn't need this: scalar_tac should succeed have : Subtree.height c = 1 + Subtree.height a := by -- TODO: scalar_tac fails here (conversion int/nat) - simp_all (config := {maxDischargeDepth := 1}) [balanceFactor, Node.invAux] + fsimp_all [balanceFactor, Node.invAux] omega have : Subtree.height a = Subtree.height b := by - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all split_conjs . -- Invariant for whole tree (starting at z) - simp [invAux, balanceFactor] + fsimp [invAux, balanceFactor] split_conjs <;> (try omega) <;> tauto . -- Invariant for subtree x - simp [invAux, balanceFactor] - split_conjs <;> (try omega) <;> simp_all (config := {maxDischargeDepth := 1}) + fsimp [invAux, balanceFactor] + split_conjs <;> (try omega) <;> fsimp_all . -- The sets are the same apply Set.ext; simp; tauto . -- The height didn't change - simp [balanceFactor] at * + fsimp [balanceFactor] at * replace hInvX := hInvX.left - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all scalar_tac -@[pspec] +@[progress] theorem Node.rotate_right_spec {T : Type} [LinearOrder T] (x z : T) (a b c : Option (Node T)) (bf_x bf_z : I8) @@ -407,17 +407,17 @@ theorem Node.rotate_right_spec Node.height ntree = 2 + Subtree.height b := by rw [rotate_right] - simp [core.mem.replace] + fsimp [core.mem.replace] -- Some proofs common to both cases -- Elements in the right subtree are > z have : ∀ (y : T), (y = x ∨ y ∈ Subtree.v b) ∨ y ∈ Subtree.v c → z < y := by - simp [invAux] at * + fsimp [invAux] at * intro y hIn -- TODO: automate that cases hIn . rename _ => hIn cases hIn - . simp [*] + . fsimp [*] . tauto . -- Proving: y ∈ c → z < y -- Using: z < x ∧ x < y @@ -426,47 +426,47 @@ theorem Node.rotate_right_spec apply lt_trans <;> tauto -- Elements in the left subtree are < z have : ∀ y ∈ Subtree.v a, y < z := by - simp_all (config := {maxDischargeDepth := 1}) [invAux] + fsimp_all [invAux] -- Two cases depending on whether the BF of Z is 0 or 1 split . -- BF(Z) == 0 - simp at * - simp [*] + fsimp at * + fsimp [*] have hHeightEq : Subtree.height a = Subtree.height b := by - simp_all (config := {maxDischargeDepth := 1}) [balanceFactor, Node.invAux] + fsimp_all [balanceFactor, Node.invAux] -- TODO: scalar_tac fails here (conversion int/nat) omega -- TODO: we shouldn't need this: scalar_tac should succeed have : 1 + Subtree.height a = Subtree.height c + 2 := by -- TODO: scalar_tac fails here (conversion int/nat) - simp_all (config := {maxDischargeDepth := 1}) [balanceFactor, Node.invAux] + fsimp_all [balanceFactor, Node.invAux] omega - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all split_conjs . -- Partial invariant for the final tree starting at z - simp [Node.invAux, balanceFactor, *] + fsimp [Node.invAux, balanceFactor, *] split_conjs <;> (try omega) <;> tauto . -- Partial invariant for the subtree x - simp [Node.invAux, balanceFactor, *] - split_conjs <;> (try omega) <;> simp_all (config := {maxDischargeDepth := 1}) + fsimp [Node.invAux, balanceFactor, *] + split_conjs <;> (try omega) <;> fsimp_all . -- The sets are the same apply Set.ext; simp; tauto . -- The height didn't change - simp [balanceFactor] at * + fsimp [balanceFactor] at * replace hInvX := hInvX.left - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all scalar_tac . -- BF(Z) == -1 rename _ => hNotEq - simp at * - simp [*] - simp_all (config := {maxDischargeDepth := 1}) + fsimp at * + fsimp [*] + fsimp_all have : bf_z.val = -1 := by - simp [Node.invAux] at hInvZ + fsimp [Node.invAux] at hInvZ omega clear hNotEq hBfZ have : Subtree.height a = 1 + Subtree.height b := by - simp [balanceFactor, Node.invAux] at * + fsimp [balanceFactor, Node.invAux] at * replace hInvZ := hInvZ.left omega have : max (Subtree.height a) (Subtree.height b) = Subtree.height a := by @@ -474,26 +474,26 @@ theorem Node.rotate_right_spec -- TODO: we shouldn't need this: scalar_tac should succeed have : Subtree.height a = 1 + Subtree.height c := by -- TODO: scalar_tac fails here (conversion int/nat) - simp_all (config := {maxDischargeDepth := 1}) [balanceFactor, Node.invAux] + fsimp_all [balanceFactor, Node.invAux] omega have : Subtree.height c = Subtree.height b := by - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all split_conjs . -- Invariant for whole tree (starting at z) - simp [invAux, balanceFactor] + fsimp [invAux, balanceFactor] split_conjs <;> (try omega) <;> tauto . -- Invariant for subtree x - simp [invAux, balanceFactor] - split_conjs <;> (try omega) <;> simp_all (config := {maxDischargeDepth := 1}) + fsimp [invAux, balanceFactor] + split_conjs <;> (try omega) <;> fsimp_all . -- The sets are the same apply Set.ext; simp; tauto . -- The height didn't change - simp [balanceFactor] at * + fsimp [balanceFactor] at * replace hInvX := hInvX.left - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all scalar_tac -@[pspec] +@[progress] theorem Node.rotate_left_right_spec {T : Type} [LinearOrder T] (x y z : T) (bf_x bf_y bf_z : I8) @@ -522,36 +522,36 @@ theorem Node.rotate_left_right_spec Node.height ntree = 2 + Subtree.height t0 := by intro x_tree y_tree z_tree tree - simp [rotate_left_right] -- TODO: this inlines the local decls + fsimp [rotate_left_right] -- TODO: this inlines the local decls -- Some facts about the heights and the balance factors -- TODO: automate that have : Node.height z_tree = Subtree.height t1 + 2 := by - simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + fsimp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega have : Node.height y_tree = Subtree.height t0 + 1 := by - simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + fsimp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega have : bf_y.val + Subtree.height a = Subtree.height b := by - simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega - simp [x_tree, y_tree, z_tree] at * + fsimp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + fsimp [x_tree, y_tree, z_tree] at * -- TODO: automate the < proofs -- Auxiliary proofs for invAux for y have : ∀ (e : T), (e = z ∨ e ∈ Subtree.v t0) ∨ e ∈ Subtree.v a → e < y := by intro e hIn - simp [invAux] at * + fsimp [invAux] at * cases hIn . rename _ => hIn -- TODO: those cases are cumbersome cases hIn - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all . have : e < z := by tauto have : z < y := by tauto apply lt_trans <;> tauto . tauto have : ∀ (e : T), (e = x ∨ e ∈ Subtree.v b) ∨ e ∈ Subtree.v t1 → y < e := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * cases hIn . rename _ => hIn cases hIn - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all . tauto . have : y < x := by replace hInvX := hInvX.right.left y @@ -560,65 +560,65 @@ theorem Node.rotate_left_right_spec apply lt_trans <;> tauto -- Auxiliary proofs for invAux for z have : ∀ e ∈ Subtree.v t0, e < z := by - intro x hIn; simp [invAux] at * + intro x hIn; fsimp [invAux] at * tauto have : ∀ e ∈ Subtree.v a, z < e := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * replace hInvZ := hInvZ.right.right.left e tauto -- Auxiliary proofs for invAux for x have : ∀ e ∈ Subtree.v b, e < x := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * replace hInvX := hInvX.right.left e tauto have : ∀ e ∈ Subtree.v t1, x < e := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * tauto -- Case disjunction on the balance factor of Y split . -- BF(Y) = 0 - simp [balanceFactor] at * - split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + fsimp [balanceFactor] at * + split_conjs <;> (try fsimp [Node.invAux, balanceFactor, *]) . -- invAux for y split_conjs <;> (try omega) <;> (try tauto) . -- invAux for z - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- invAux for x - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- The sets are the same - apply Set.ext; simp [tree, z_tree, y_tree]; tauto + apply Set.ext; fsimp [tree, z_tree, y_tree]; tauto . -- Height scalar_tac . split <;> simp . -- BF(Y) < 0 - have : bf_y.val = -1 := by simp [Node.invAux] at *; omega - simp [balanceFactor] at * - split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + have : bf_y.val = -1 := by fsimp [Node.invAux] at *; omega + fsimp [balanceFactor] at * + split_conjs <;> (try fsimp [Node.invAux, balanceFactor, *]) . -- invAux for y split_conjs <;> (try omega) <;> (try tauto) . -- invAux for z - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- invAux for x - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- The sets are the same - apply Set.ext; simp [tree, z_tree, y_tree]; tauto + apply Set.ext; fsimp [tree, z_tree, y_tree]; tauto . -- Height scalar_tac . -- BF(Y) > 0 - have : bf_y.val = 1 := by simp [Node.invAux] at *; omega - split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + have : bf_y.val = 1 := by fsimp [Node.invAux] at *; omega + split_conjs <;> (try fsimp [Node.invAux, balanceFactor, *]) . -- invAux for y split_conjs <;> (try omega) <;> (try tauto) . -- invAux for z - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- invAux for x - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- The sets are the same - apply Set.ext; simp [tree, z_tree, y_tree]; tauto + apply Set.ext; fsimp [tree, z_tree, y_tree]; tauto . -- Height scalar_tac -@[pspec] +@[progress] theorem Node.rotate_right_left_spec {T : Type} [LinearOrder T] (x y z : T) (bf_x bf_y bf_z : I8) @@ -647,36 +647,36 @@ theorem Node.rotate_right_left_spec Node.height ntree = 2 + Subtree.height t1 := by intro x_tree y_tree z_tree tree - simp [rotate_right_left] -- TODO: this inlines the local decls + fsimp [rotate_right_left] -- TODO: this inlines the local decls -- Some facts about the heights and the balance factors -- TODO: automate that have : Node.height z_tree = Subtree.height t1 + 2 := by - simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + fsimp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega have : Node.height y_tree = Subtree.height t0 + 1 := by - simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + fsimp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega have : bf_y.val + Subtree.height b = Subtree.height a := by - simp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega - simp [x_tree, y_tree, z_tree] at * + fsimp [y_tree, z_tree, inv, invAux, balanceFactor] at *; omega + fsimp [x_tree, y_tree, z_tree] at * -- TODO: automate the < proofs -- Auxiliary proofs for invAux for y have : ∀ (e : T), (e = z ∨ e ∈ Subtree.v a) ∨ e ∈ Subtree.v t0 → y < e := by intro e hIn - simp [invAux] at * + fsimp [invAux] at * cases hIn . rename _ => hIn -- TODO: those cases are cumbersome cases hIn - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all . tauto . have : z < e := by tauto have : y < z := by tauto apply lt_trans <;> tauto have : ∀ (e : T), (e = x ∨ e ∈ Subtree.v t1) ∨ e ∈ Subtree.v b → e < y := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * cases hIn . rename _ => hIn cases hIn - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all . have : x < y := by replace hInvX := hInvX.right.right y tauto @@ -685,61 +685,61 @@ theorem Node.rotate_right_left_spec . tauto -- Auxiliary proofs for invAux for z have : ∀ e ∈ Subtree.v t0, z < e := by - intro x hIn; simp [invAux] at * + intro x hIn; fsimp [invAux] at * tauto have : ∀ e ∈ Subtree.v a, e < z := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * replace hInvZ := hInvZ.right.left e tauto -- Auxiliary proofs for invAux for x have : ∀ e ∈ Subtree.v b, x < e := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * replace hInvX := hInvX.right.right e tauto have : ∀ e ∈ Subtree.v t1, e < x := by - intro e hIn; simp [invAux] at * + intro e hIn; fsimp [invAux] at * tauto -- Case disjunction on the balance factor of Y split . -- BF(Y) = 0 - simp [balanceFactor] at * - split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + fsimp [balanceFactor] at * + split_conjs <;> (try fsimp [Node.invAux, balanceFactor, *]) . -- invAux for y split_conjs <;> (try omega) <;> (try tauto) . -- invAux for z - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- invAux for x - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- The sets are the same - apply Set.ext; simp [tree, z_tree, y_tree]; tauto + apply Set.ext; fsimp [tree, z_tree, y_tree]; tauto . -- Height scalar_tac . split <;> simp . -- BF(Y) > 0 - have : bf_y.val = 1 := by simp [Node.invAux] at *; omega - simp [balanceFactor] at * - split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + have : bf_y.val = 1 := by fsimp [Node.invAux] at *; omega + fsimp [balanceFactor] at * + split_conjs <;> (try fsimp [Node.invAux, balanceFactor, *]) . -- invAux for y split_conjs <;> (try omega) <;> (try tauto) . -- invAux for z - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- invAux for x - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- The sets are the same - apply Set.ext; simp [tree, z_tree, y_tree]; tauto + apply Set.ext; fsimp [tree, z_tree, y_tree]; tauto . -- Height scalar_tac . -- BF(Y) < 0 - have : bf_y.val = -1 := by simp [Node.invAux] at *; omega - split_conjs <;> (try simp [Node.invAux, balanceFactor, *]) + have : bf_y.val = -1 := by fsimp [Node.invAux] at *; omega + split_conjs <;> (try fsimp [Node.invAux, balanceFactor, *]) . -- invAux for y split_conjs <;> (try omega) <;> (try tauto) . -- invAux for z - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- invAux for x - split_conjs <;> (try scalar_tac) + split_conjs <;> (try assumption) <;> (try scalar_tac) . -- The sets are the same - apply Set.ext; simp [tree, z_tree, y_tree]; tauto + apply Set.ext; fsimp [tree, z_tree, y_tree]; tauto . -- Height scalar_tac @@ -756,7 +756,7 @@ theorem Node.right_height_lt_height (n : Node T) : mutual -@[pspec] +@[progress] theorem Node.insert_spec {T : Type} (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] (node : Node T) (value : T) @@ -770,19 +770,19 @@ theorem Node.insert_spec rw [Node.insert] have hCmp := Ospec.infallible -- TODO progress as ⟨ ordering ⟩ - split <;> rename _ => hEq <;> clear hCmp <;> simp at * + split <;> rename _ => hEq <;> clear hCmp <;> fsimp at * . -- value < node.value progress as ⟨ updt, node', h1, h2 ⟩ - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all . -- value = node.value - cases node; simp_all (config := {maxDischargeDepth := 1}) + cases node; fsimp_all . -- node.value < value progress as ⟨ updt, node', h1, h2 ⟩ - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all termination_by (node.height, 1) decreasing_by all_goals simp_wf -@[pspec] +@[progress] theorem Tree.insert_in_opt_node_spec {T : Type} (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] (tree : Option (Node T)) (value : T) @@ -794,19 +794,19 @@ theorem Tree.insert_in_opt_node_spec else Subtree.height tree' = Subtree.height tree) ∧ (b → Subtree.height tree > 0 → Subtree.balanceFactor tree' ≠ 0) := by rw [Tree.insert_in_opt_node] - cases hNode : tree <;> simp [hNode] + cases hNode : tree <;> fsimp [hNode] . -- tree = none - simp [Node.invAux, Node.balanceFactor] + fsimp [Node.invAux, Node.balanceFactor] . -- tree = some rename Node T => node - have hNodeInv : Node.inv node := by simp_all (config := {maxDischargeDepth := 1}) + have hNodeInv : Node.inv node := by fsimp_all progress as ⟨ updt, tree' ⟩ - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all termination_by (Subtree.height tree, 2) -decreasing_by simp_wf; simp [*] +decreasing_by simp_wf; fsimp [*] -- TODO: any modification triggers the replay of the whole proof -@[pspec] +@[progress] theorem Node.insert_in_left_spec {T : Type} (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] @@ -819,21 +819,21 @@ theorem Node.insert_in_left_spec (if b then node'.height = node.height + 1 else node'.height = node.height) ∧ (b → node'.balanceFactor ≠ 0) := by rw [Node.insert_in_left] - have hInvLeft : Subtree.inv node.left := by cases node; simp_all (config := {maxDischargeDepth := 1}) + have hInvLeft : Subtree.inv node.left := by cases node; fsimp_all progress as ⟨ updt, left_opt' ⟩ split . -- the height of the subtree changed have hBalanceFactor : node.balance_factor = node.balanceFactor ∧ -1 ≤ node.balanceFactor ∧ node.balanceFactor ≤ 1 := by - cases node; simp_all (config := {maxDischargeDepth := 1}) [Node.invAux] + cases node; fsimp_all [Node.invAux] progress as ⟨ i ⟩ split . -- i = -2 simp - cases h: left_opt' with - | none => simp_all (config := {maxDischargeDepth := 1}) -- absurd - | some left' => - simp [h] + cases h: left_opt' + . fsimp_all -- absurd + . rename_i left' + fsimp [h] cases node with | mk x left right balance_factor => split . -- rotate_right @@ -841,73 +841,71 @@ theorem Node.insert_in_left_spec cases h:left' with | mk z a b bf_z => progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ -- TODO: syntax for preconditions - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor] + . fsimp_all + . fsimp_all + . fsimp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor] scalar_tac - . simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] - . -- End of the proof - simp [*] + . fsimp [*] split_conjs . -- set reasoning - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all apply Set.ext; simp intro x; tauto . -- height - simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + fsimp_all [Node.invAux, Node.balanceFactor] -- This assertion is not necessary for the proof, but it is important that it holds. -- We can prove it because of the post-conditions `b → node'.balanceFactor ≠ 0` (see above) have : bf_z.val = -1 := by scalar_tac scalar_tac . -- rotate_left_right simp - cases h:left' with | mk z t0 y bf_z => - cases h: y with - | none => - -- Can't get there - simp_all (config := {maxDischargeDepth := 1}) [Node.balanceFactor, Node.invAux] - | some y => + cases h:left' + rename_i z t0 y bf_z + cases h: y + . -- Can't get there + fsimp_all [Node.balanceFactor, Node.invAux] + . rename_i y cases h: y with | mk y a b bf_y => progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ -- TODO: syntax for preconditions - . simp_all (config := {maxDischargeDepth := 1}) [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor]; scalar_tac + . fsimp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac + . fsimp_all + . fsimp_all + . fsimp_all [Node.invAux, Node.balanceFactor]; scalar_tac . -- End of the proof - simp [*] + fsimp [*] split_conjs - . apply Set.ext; simp_all (config := {maxDischargeDepth := 1}) + . apply Set.ext; fsimp_all intro x; tauto - . simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + . fsimp_all [Node.invAux, Node.balanceFactor] scalar_tac . -- i ≠ -2: the height of the tree did not change - simp [*] + fsimp [*] split_conjs - . cases node; simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + . cases node; fsimp_all [Node.invAux, Node.balanceFactor] split_conjs <;> scalar_tac . apply Set.ext; simp - cases node; simp_all (config := {maxDischargeDepth := 1}) + cases node; fsimp_all tauto - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all cases node with | mk node_value left right balance_factor => - split <;> simp [Node.balanceFactor] at * <;> scalar_tac - . simp_all (config := {maxDischargeDepth := 1}) [Node.balanceFactor] + split <;> fsimp [Node.balanceFactor] at * <;> scalar_tac + . fsimp_all [Node.balanceFactor] scalar_tac . -- the height of the subtree did not change - simp [*] + fsimp [*] split_conjs . cases node; - simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + fsimp_all [Node.invAux, Node.balanceFactor] . apply Set.ext; simp; intro x - cases node; simp_all (config := {maxDischargeDepth := 1}) + cases node; fsimp_all tauto - . simp_all (config := {maxDischargeDepth := 1}) - cases node; simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all + cases node; fsimp_all termination_by (node.height, 0) decreasing_by simp_wf -@[pspec] +@[progress] theorem Node.insert_in_right_spec {T : Type} (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] @@ -920,91 +918,94 @@ theorem Node.insert_in_right_spec (if b then node'.height = node.height + 1 else node'.height = node.height) ∧ (b → node'.balanceFactor ≠ 0) := by rw [Node.insert_in_right] - have hInvLeft : Subtree.inv node.right := by cases node; simp_all (config := {maxDischargeDepth := 1}) + have hInvLeft : Subtree.inv node.right := by cases node; fsimp_all progress as ⟨ updt, right_opt' ⟩ split . -- the height of the subtree changed have hBalanceFactor : node.balance_factor = node.balanceFactor ∧ -1 ≤ node.balanceFactor ∧ node.balanceFactor ≤ 1 := by - cases node; simp_all (config := {maxDischargeDepth := 1}) [Node.invAux] + cases node; fsimp_all [Node.invAux] progress as ⟨ i ⟩ split . -- i = 2 simp - cases h: right_opt' with - | none => simp_all (config := {maxDischargeDepth := 1}) -- absurd - | some right' => - simp [h] + cases h: right_opt' + . fsimp_all -- absurd + . rename_i right' + fsimp [h] split . -- rotate_left - cases node with | mk x a right balance_factor => + cases node + rename_i x a right balance_factor -- TODO: fix progress - cases h:right' with | mk z b c bf_z => + cases h:right' + rename_i z b c bf_z progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ -- TODO: syntax for preconditions - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac - . simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + . fsimp_all + . fsimp_all + . fsimp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac . -- End of the proof - simp [*] + fsimp [*] split_conjs . -- set reasoning - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all . -- height - simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + fsimp_all [Node.invAux, Node.balanceFactor] -- Remark: here we have: -- bf_z.val = -1 scalar_tac . -- rotate_right_left - cases node with | mk x t1 right balance_factor => + cases node + rename_i x t1 right balance_factor simp - cases h:right' with | mk z y t0 bf_z => - cases h: y with - | none => - -- Can't get there - simp_all (config := {maxDischargeDepth := 1}) [Node.balanceFactor, Node.invAux] - | some y => - cases h: y with | mk y b a bf_y => + cases h:right' + rename_i z y t0 bf_z + cases h: y + . -- Can't get there + fsimp_all [Node.balanceFactor, Node.invAux] + . rename_i y + cases h: y + rename_i y b a bf_y progress as ⟨ tree', hInv', hTree'Set, hTree'Height ⟩ -- TODO: syntax for preconditions - . simp_all (config := {maxDischargeDepth := 1}) [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor]; scalar_tac + . fsimp_all [Node.inv, Node.invAux, Node.invAuxNotBalanced, Node.balanceFactor]; scalar_tac + . fsimp_all + . fsimp_all + . fsimp_all [Node.invAux, Node.balanceFactor]; scalar_tac . -- End of the proof - simp [*] + fsimp [*] split_conjs - . apply Set.ext; simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + . apply Set.ext; fsimp_all + . fsimp (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] at * scalar_tac . -- i ≠ -2: the height of the tree did not change - simp [*] + fsimp [*] split_conjs - . cases node; simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + . cases node; fsimp_all [Node.invAux, Node.balanceFactor] split_conjs <;> scalar_tac . apply Set.ext; simp - cases node; simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) + cases node; fsimp_all + . fsimp_all cases node with | mk node_value left right balance_factor => - split <;> simp [Node.balanceFactor] at * <;> scalar_tac - . simp_all (config := {maxDischargeDepth := 1}) [Node.balanceFactor] + split <;> fsimp [Node.balanceFactor] at * <;> scalar_tac + . fsimp_all [Node.balanceFactor] scalar_tac . -- the height of the subtree did not change - simp [*] -- TODO: annoying to use this simp everytime: put this in progress + fsimp [*] -- TODO: annoying to use this fsimp everytime: put this in progress split_conjs . cases node; - simp_all (config := {maxDischargeDepth := 1}) [Node.invAux, Node.balanceFactor] + fsimp_all [Node.invAux, Node.balanceFactor] . apply Set.ext; simp; intro x - cases node; simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) - cases node; simp_all (config := {maxDischargeDepth := 1}) + cases node; fsimp_all + . fsimp_all + cases node; fsimp_all termination_by (node.height, 0) decreasing_by simp_wf end -@[pspec] +@[progress] theorem Tree.insert_spec {T : Type} (OrdInst : Ord T) [LinOrd : LinearOrder T] [Ospec: OrdSpecLinearOrderEq OrdInst] (tree : Tree T) (value : T) @@ -1015,11 +1016,11 @@ theorem Tree.insert_spec {T : Type} tree'.v = tree.v ∪ {value} := by rw [Tree.insert] progress as ⟨ updt, tree' ⟩ - simp [*] + fsimp [*] -@[pspec] +@[progress] theorem Tree.new_spec {T : Type} (OrdInst : Ord T) : ∃ t, Tree.new OrdInst = ok t ∧ t.v = ∅ ∧ t.height = 0 := by - simp [new, Tree.v, Tree.height] + fsimp [new, Tree.v, Tree.height] end avl diff --git a/tests/lean/Avl/Types.lean b/tests/lean/Avl/Types.lean index 77e9568e..b05d84ae 100644 --- a/tests/lean/Avl/Types.lean +++ b/tests/lean/Avl/Types.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [avl]: type definitions import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/BaseTutorial.lean b/tests/lean/BaseTutorial.lean index 995b1137..7e6a1ec3 100644 --- a/tests/lean/BaseTutorial.lean +++ b/tests/lean/BaseTutorial.lean @@ -74,9 +74,9 @@ def mul2_add1 (x : U32) : Result U32 := do results explicit. -/ def mul2_add1_desugared (x : U32) : Result U32 := - match Scalar.add x x with + match UScalar.add x x with | ok x1 => -- Success case - match Scalar.add x1 (U32.ofInt 1) with + match UScalar.add x1 (U32.ofNat 1) with | ok x2 => ok x2 | error => error | error => error -- Propagating the errors @@ -100,7 +100,7 @@ theorem mul2_add1_spec We simply state that [2 * x + 1] must not overflow. The `.val` notation is used to coerce values. Here, we coerce `x`, which is - a bounded machine integer, to an unbounded mathematical integer, which is + a bounded machine integer, to an unbounded natural number, which is easier to work with. Note that it is also possible to use the `↑x` notation to tell Lean to apply a coercion, though Lean may not always be able to figure out which coercion to apply (`x.val` is always more precise). @@ -108,7 +108,7 @@ theorem mul2_add1_spec (h : 2 * x.val + 1 ≤ U32.max) /- The postcondition -/ : ∃ y, mul2_add1 x = ok y ∧ -- The call succeeds - ↑ y = 2 * ↑x + (1 : Int) -- The output has the expected value + ↑ y = 2 * ↑x + (1 : Nat) -- The output has the expected value := by /- The proof -/ -- Start by a call to the rewriting tactic to reveal the body of [mul2_add1] @@ -149,12 +149,12 @@ theorem mul2_add1_spec For this reason, we provide the possibility of registering theorems in a database so that [progress] can automatically look them up. This is done by marking - theorems with custom attributes, like [pspec] below. + theorems with custom attributes, like [progress] below. Theorems in the standard library like [U32.add_spec] have already been marked with such attributes, meaning we don't need to tell [progress] to use them. -/ -@[pspec] -- the [pspec] attribute saves the theorem in a database, for [progress] to use it +@[progress] -- the [progress] attribute saves the theorem in a database, for [progress] to use it theorem mul2_add1_spec2 (x : U32) (h : 2 * x.val + 1 ≤ U32.max) : ∃ y, mul2_add1 x = ok y ∧ ↑ y = 2 * ↑x + (1 : Int) @@ -164,7 +164,7 @@ theorem mul2_add1_spec2 (x : U32) (h : 2 * x.val + 1 ≤ U32.max) progress as ⟨ x2 ⟩ -- same simp at *; scalar_tac -/- Because we marked [mul2_add1_spec2] theorem with [pspec], [progress] can +/- Because we marked [mul2_add1_spec2] theorem with [progress], [progress] can now automatically look it up. For instance, below: -/ -- A dummy function which uses [mul2_add1] @@ -172,7 +172,7 @@ def use_mul2_add1 (x : U32) (y : U32) : Result U32 := do let x1 ← mul2_add1 x x1 + y -@[pspec] +@[progress] theorem use_mul2_add1_spec (x : U32) (y : U32) (h : 2 * x.val + 1 + y.val ≤ U32.max) : ∃ z, use_mul2_add1 x y = ok z ∧ ↑z = 2 * ↑x + (1 : Int) + ↑y := by @@ -240,13 +240,13 @@ divergent def list_nth (T : Type) (l : CList T) (i : U32) : Result T := /- Conversion to Lean's standard list type. - Note that because we use the suffix "CList.", we can use the notation [l.to_list] + Note that because we use the suffix "CList.", we can use the notation [l.toList] if [l] has type [CList ...]. -/ -def CList.to_list {α : Type} (x : CList α) : List α := +def CList.toList {α : Type} (x : CList α) : List α := match x with | CNil => [] - | CCons hd tl => hd :: tl.to_list + | CCons hd tl => hd :: tl.toList /- Let's prove that [list_nth] indeed accesses the ith element of the list. @@ -263,11 +263,11 @@ def CList.to_list {α : Type} (x : CList α) : List α := -/ theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) -- Precondition: the index is in bounds - (h : i.val < l.to_list.length) + (h : i.val < l.toList.length) -- Postcondition : ∃ x, list_nth T l i = ok x ∧ -- [x] is the ith element of [l] after conversion to [List] - x = l.to_list.index i.toNat + x = l.toList[i.val] := by -- Here we have to be careful when unfolding the body of [list_nth]: we could -- use the [simp] tactic, but it will sometimes loop on recursive definitions. @@ -276,12 +276,10 @@ theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) match l with | CNil => -- We can't get there: we can derive a contradiction from the precondition: - -- we have that [i < 0] (because [i < CNil.to_list.len]) and at the same + -- we have that [i < 0] (because [i < CNil.toList.len]) and at the same -- time [0 ≤ i] (because [i] is a [U32] unsigned integer). - -- First, let's simplify [to_list CNil] to [0] - simp [CList.to_list] at h - -- Proving we have a contradiction - scalar_tac + -- First, let's simplify [toList CNil] to [0] + simp [CList.toList] at h | CCons hd tl => -- Simplify the match simp only [] @@ -296,7 +294,7 @@ theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) -- Simplify the condition and the [if then else] simp [hi] -- Prove the final equality - simp [CList.to_list] + simp [CList.toList] else -- The interesting branch -- Simplify the condition and the [if then else] @@ -304,10 +302,10 @@ theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) -- i0 := i - 1 progress as ⟨ i1, hi1 ⟩ -- [progress] can handle recursion - simp [CList.to_list] at h -- we need to simplify this inequality to prove the precondition + simp [CList.toList] at h -- we need to simplify this inequality to prove the precondition progress as ⟨ l1 ⟩ -- Proving the postcondition - -- We need this to trigger the simplification of [index to.to_list i.val] + -- We need this to trigger the simplification of [index to.toList i.val] -- -- Among other things, the call to [simp] below will apply the theorem -- [List.index_nzero_cons], which has the precondition [i.val ≠ 0]. [simp] @@ -316,8 +314,8 @@ theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) -- by giving it [*] as argument, we tell [simp] to use all the assumptions -- to perform rewritings. In particular, it will use [i.val ≠ 0] to -- apply [List.index_nzero_cons]. - have : i.toNat ≠ 0 := by scalar_tac -- Remark: [simp at hi] also works - simp [CList.to_list, *] + have : i.val ≠ 0 := by scalar_tac -- Remark: [simp at hi] also works + simp [CList.toList, *] /-#===========================================================================# # diff --git a/tests/lean/Betree/Funs.lean b/tests/lean/Betree/Funs.lean index fdf8c3bc..07e6a08f 100644 --- a/tests/lean/Betree/Funs.lean +++ b/tests/lean/Betree/Funs.lean @@ -3,7 +3,7 @@ import Aeneas import Betree.Types import Betree.FunsExternal -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -45,12 +45,12 @@ def betree.store_leaf_node def betree.fresh_node_id (counter : U64) : Result (U64 × U64) := do let counter1 ← counter + 1#u64 - Result.ok (counter, counter1) + ok (counter, counter1) /- [betree::betree::{betree::betree::NodeIdCounter}::new]: Source: 'src/betree.rs', lines 206:4-208:5 -/ def betree.NodeIdCounter.new : Result betree.NodeIdCounter := - Result.ok { next_node_id := 0#u64 } + ok { next_node_id := 0#u64 } /- [betree::betree::{betree::betree::NodeIdCounter}::fresh_id]: Source: 'src/betree.rs', lines 210:4-214:5 -/ @@ -58,7 +58,7 @@ def betree.NodeIdCounter.fresh_id (self : betree.NodeIdCounter) : Result (U64 × betree.NodeIdCounter) := do let i ← self.next_node_id + 1#u64 - Result.ok (self.next_node_id, { next_node_id := i }) + ok (self.next_node_id, { next_node_id := i }) /- [betree::betree::upsert_update]: Source: 'src/betree.rs', lines 234:0-273:1 -/ @@ -67,8 +67,8 @@ def betree.upsert_update match prev with | none => match st with - | betree.UpsertFunState.Add v => Result.ok v - | betree.UpsertFunState.Sub _ => Result.ok 0#u64 + | betree.UpsertFunState.Add v => ok v + | betree.UpsertFunState.Sub _ => ok 0#u64 | some prev1 => match st with | betree.UpsertFunState.Add v => @@ -76,11 +76,10 @@ def betree.upsert_update let margin ← core_u64_max - prev1 if margin >= v then prev1 + v - else Result.ok core_u64_max - | betree.UpsertFunState.Sub v => - if prev1 >= v - then prev1 - v - else Result.ok 0#u64 + else ok core_u64_max + | betree.UpsertFunState.Sub v => if prev1 >= v + then prev1 - v + else ok 0#u64 /- [betree::betree::{betree::betree::List}#1::len]: loop 0: Source: 'src/betree.rs', lines 279:8-282:9 -/ @@ -91,10 +90,11 @@ divergent def betree.List.len_loop do let len1 ← len + 1#u64 betree.List.len_loop tl len1 - | betree.List.Nil => Result.ok len + | betree.List.Nil => ok len /- [betree::betree::{betree::betree::List}#1::len]: Source: 'src/betree.rs', lines 276:4-284:5 -/ +@[reducible] def betree.List.len {T : Type} (self : betree.List T) : Result U64 := betree.List.len_loop self 0#u64 @@ -107,10 +107,11 @@ divergent def betree.List.reverse_loop match self with | betree.List.Cons hd tl => betree.List.reverse_loop tl (betree.List.Cons hd out) - | betree.List.Nil => Result.ok out + | betree.List.Nil => ok out /- [betree::betree::{betree::betree::List}#1::reverse]: Source: 'src/betree.rs', lines 304:4-312:5 -/ +@[reducible] def betree.List.reverse {T : Type} (self : betree.List T) : Result (betree.List T) := betree.List.reverse_loop self betree.List.Nil @@ -128,13 +129,14 @@ divergent def betree.List.split_at_loop do let n1 ← n - 1#u64 betree.List.split_at_loop n1 (betree.List.Cons hd beg) tl - | betree.List.Nil => Result.fail .panic + | betree.List.Nil => fail panic else do let l ← betree.List.reverse beg - Result.ok (l, self) + ok (l, self) /- [betree::betree::{betree::betree::List}#1::split_at]: Source: 'src/betree.rs', lines 287:4-302:5 -/ +@[reducible] def betree.List.split_at {T : Type} (self : betree.List T) (n : U64) : Result ((betree.List T) × (betree.List T)) @@ -146,7 +148,7 @@ def betree.List.split_at def betree.List.push_front {T : Type} (self : betree.List T) (x : T) : Result (betree.List T) := let (tl, _) := core.mem.replace self betree.List.Nil - Result.ok (betree.List.Cons x tl) + ok (betree.List.Cons x tl) /- [betree::betree::{betree::betree::List}#1::pop_front]: Source: 'src/betree.rs', lines 322:4-332:5 -/ @@ -154,15 +156,15 @@ def betree.List.pop_front {T : Type} (self : betree.List T) : Result (T × (betree.List T)) := let (ls, _) := core.mem.replace self betree.List.Nil match ls with - | betree.List.Cons x tl => Result.ok (x, tl) - | betree.List.Nil => Result.fail .panic + | betree.List.Cons x tl => ok (x, tl) + | betree.List.Nil => fail panic /- [betree::betree::{betree::betree::List}#1::hd]: Source: 'src/betree.rs', lines 334:4-339:5 -/ def betree.List.hd {T : Type} (self : betree.List T) : Result T := match self with - | betree.List.Cons hd _ => Result.ok hd - | betree.List.Nil => Result.fail .panic + | betree.List.Cons hd _ => ok hd + | betree.List.Nil => fail panic /- [betree::betree::{betree::betree::List<(u64, T)>}#2::head_has_key]: Source: 'src/betree.rs', lines 343:4-348:5 -/ @@ -170,8 +172,8 @@ def betree.ListPairU64T.head_has_key {T : Type} (self : betree.List (U64 × T)) (key : U64) : Result Bool := match self with | betree.List.Cons hd _ => let (i, _) := hd - Result.ok (i = key) - | betree.List.Nil => Result.ok false + ok (i = key) + | betree.List.Nil => ok false /- [betree::betree::{betree::betree::List<(u64, T)>}#2::partition_at_pivot]: loop 0: Source: 'src/betree.rs', lines 359:8-368:9 -/ @@ -194,10 +196,11 @@ divergent def betree.ListPairU64T.partition_at_pivot_loop do let l ← betree.List.reverse beg let l1 ← betree.List.reverse end1 - Result.ok (l, l1) + ok (l, l1) /- [betree::betree::{betree::betree::List<(u64, T)>}#2::partition_at_pivot]: Source: 'src/betree.rs', lines 355:4-370:5 -/ +@[reducible] def betree.ListPairU64T.partition_at_pivot {T : Type} (self : betree.List (U64 × T)) (pivot : U64) : Result ((betree.List (U64 × T)) × (betree.List (U64 × T))) @@ -213,17 +216,15 @@ def betree.Leaf.split Result (State × (betree.Internal × betree.NodeIdCounter)) := do - let p ← betree.List.split_at content params.split_size - let (content0, content1) := p - let p1 ← betree.List.hd content1 - let (pivot, _) := p1 + let (content0, content1) ← betree.List.split_at content params.split_size + let (pivot, _) ← betree.List.hd content1 let (id0, node_id_cnt1) ← betree.NodeIdCounter.fresh_id node_id_cnt let (id1, node_id_cnt2) ← betree.NodeIdCounter.fresh_id node_id_cnt1 let (st1, _) ← betree.store_leaf_node id0 content0 st let (st2, _) ← betree.store_leaf_node id1 content1 st1 let n := betree.Node.Leaf { id := id0, size := params.split_size } let n1 := betree.Node.Leaf { id := id1, size := params.split_size } - Result.ok (st2, (betree.Internal.mk self.id pivot n n1, node_id_cnt2)) + ok (st2, (betree.Internal.mk self.id pivot n n1, node_id_cnt2)) /- [betree::betree::{betree::betree::Node}#5::lookup_in_bindings]: loop 0: Source: 'src/betree.rs', lines 650:8-660:5 -/ @@ -233,12 +234,12 @@ divergent def betree.Node.lookup_in_bindings_loop | betree.List.Cons hd tl => let (i, i1) := hd if i = key - then Result.ok (some i1) + then ok (some i1) else if i > key - then Result.ok none + then ok none else betree.Node.lookup_in_bindings_loop key tl - | betree.List.Nil => Result.ok none + | betree.List.Nil => ok none /- [betree::betree::{betree::betree::Node}#5::lookup_in_bindings]: Source: 'src/betree.rs', lines 649:4-660:5 -/ @@ -258,7 +259,7 @@ divergent def betree.Node.lookup_first_message_for_key_loop | betree.List.Cons x next_msgs => let (i, _) := x if i >= key - then Result.ok (msgs, fun ret => ret) + then ok (msgs, fun ret => ret) else do let (l, back) ← @@ -266,8 +267,8 @@ divergent def betree.Node.lookup_first_message_for_key_loop let back1 := fun ret => let next_msgs1 := back ret betree.List.Cons x next_msgs1 - Result.ok (l, back1) - | betree.List.Nil => Result.ok (betree.List.Nil, fun ret => ret) + ok (l, back1) + | betree.List.Nil => ok (betree.List.Nil, fun ret => ret) /- [betree::betree::{betree::betree::Node}#5::lookup_first_message_for_key]: Source: 'src/betree.rs', lines 792:4-810:5 -/ @@ -294,8 +295,8 @@ divergent def betree.Node.apply_upserts_loop let (msg, msgs1) ← betree.List.pop_front msgs let (_, m) := msg match m with - | betree.Message.Insert _ => Result.fail .panic - | betree.Message.Delete => Result.fail .panic + | betree.Message.Insert _ => fail panic + | betree.Message.Delete => fail panic | betree.Message.Upsert s => do let v ← betree.upsert_update prev s @@ -304,7 +305,7 @@ divergent def betree.Node.apply_upserts_loop do let v ← core.option.Option.unwrap prev let msgs1 ← betree.List.push_front msgs (key, betree.Message.Insert v) - Result.ok (v, msgs1) + ok (v, msgs1) /- [betree::betree::{betree::betree::Node}#5::apply_upserts]: Source: 'src/betree.rs', lines 820:4-844:5 -/ @@ -326,11 +327,11 @@ mutual divergent def betree.Internal.lookup_in_children then do let (st1, (o, n)) ← betree.Node.lookup self.left key st - Result.ok (st1, (o, betree.Internal.mk self.id self.pivot n self.right)) + ok (st1, (o, betree.Internal.mk self.id self.pivot n self.right)) else do let (st1, (o, n)) ← betree.Node.lookup self.right key st - Result.ok (st1, (o, betree.Internal.mk self.id self.pivot self.left n)) + ok (st1, (o, betree.Internal.mk self.id self.pivot self.left n)) /- [betree::betree::{betree::betree::Node}#5::lookup]: Source: 'src/betree.rs', lines 712:4-785:5 -/ @@ -352,11 +353,11 @@ divergent def betree.Node.lookup do let (st2, (o, node1)) ← betree.Internal.lookup_in_children node key st1 - Result.ok (st2, (o, betree.Node.Internal node1)) + ok (st2, (o, betree.Node.Internal node1)) else match msg with - | betree.Message.Insert v => Result.ok (st1, (some v, self)) - | betree.Message.Delete => Result.ok (st1, (none, self)) + | betree.Message.Insert v => ok (st1, (some v, self)) + | betree.Message.Delete => ok (st1, (none, self)) | betree.Message.Upsert _ => do let (st2, (v, node1)) ← @@ -364,16 +365,16 @@ divergent def betree.Node.lookup let (v1, pending1) ← betree.Node.apply_upserts pending v key let msgs1 := lookup_first_message_for_key_back pending1 let (st3, _) ← betree.store_internal_node node1.id msgs1 st2 - Result.ok (st3, (some v1, betree.Node.Internal node1)) + ok (st3, (some v1, betree.Node.Internal node1)) | betree.List.Nil => do let (st2, (o, node1)) ← betree.Internal.lookup_in_children node key st1 - Result.ok (st2, (o, betree.Node.Internal node1)) + ok (st2, (o, betree.Node.Internal node1)) | betree.Node.Leaf node => do let (st1, bindings) ← betree.load_leaf_node node.id st let o ← betree.Node.lookup_in_bindings key bindings - Result.ok (st1, (o, self)) + ok (st1, (o, self)) end @@ -391,8 +392,8 @@ divergent def betree.Node.filter_messages_for_key_loop do let (_, msgs1) ← betree.List.pop_front msgs betree.Node.filter_messages_for_key_loop key msgs1 - else Result.ok msgs - | betree.List.Nil => Result.ok betree.List.Nil + else ok msgs + | betree.List.Nil => ok betree.List.Nil /- [betree::betree::{betree::betree::Node}#5::filter_messages_for_key]: Source: 'src/betree.rs', lines 683:4-692:5 -/ @@ -421,9 +422,9 @@ divergent def betree.Node.lookup_first_message_after_key_loop let back1 := fun ret => let next_msgs1 := back ret betree.List.Cons p next_msgs1 - Result.ok (l, back1) - else Result.ok (msgs, fun ret => ret) - | betree.List.Nil => Result.ok (betree.List.Nil, fun ret => ret) + ok (l, back1) + else ok (msgs, fun ret => ret) + | betree.List.Nil => ok (betree.List.Nil, fun ret => ret) /- [betree::betree::{betree::betree::Node}#5::lookup_first_message_after_key]: Source: 'src/betree.rs', lines 694:4-706:5 -/ @@ -453,16 +454,15 @@ def betree.Node.apply_to_internal do let msgs2 ← betree.Node.filter_messages_for_key key msgs1 let msgs3 ← betree.List.push_front msgs2 (key, new_msg) - Result.ok (lookup_first_message_for_key_back msgs3) + ok (lookup_first_message_for_key_back msgs3) | betree.Message.Delete => do let msgs2 ← betree.Node.filter_messages_for_key key msgs1 let msgs3 ← betree.List.push_front msgs2 (key, betree.Message.Delete) - Result.ok (lookup_first_message_for_key_back msgs3) + ok (lookup_first_message_for_key_back msgs3) | betree.Message.Upsert s => do - let p ← betree.List.hd msgs1 - let (_, m) := p + let (_, m) ← betree.List.hd msgs1 match m with | betree.Message.Insert prev => do @@ -470,25 +470,25 @@ def betree.Node.apply_to_internal let (_, msgs2) ← betree.List.pop_front msgs1 let msgs3 ← betree.List.push_front msgs2 (key, betree.Message.Insert v) - Result.ok (lookup_first_message_for_key_back msgs3) + ok (lookup_first_message_for_key_back msgs3) | betree.Message.Delete => do let (_, msgs2) ← betree.List.pop_front msgs1 let v ← betree.upsert_update none s let msgs3 ← betree.List.push_front msgs2 (key, betree.Message.Insert v) - Result.ok (lookup_first_message_for_key_back msgs3) + ok (lookup_first_message_for_key_back msgs3) | betree.Message.Upsert _ => do let (msgs2, lookup_first_message_after_key_back) ← betree.Node.lookup_first_message_after_key key msgs1 let msgs3 ← betree.List.push_front msgs2 (key, new_msg) let msgs4 := lookup_first_message_after_key_back msgs3 - Result.ok (lookup_first_message_for_key_back msgs4) + ok (lookup_first_message_for_key_back msgs4) else do let msgs2 ← betree.List.push_front msgs1 (key, new_msg) - Result.ok (lookup_first_message_for_key_back msgs2) + ok (lookup_first_message_for_key_back msgs2) /- [betree::betree::{betree::betree::Node}#5::apply_messages_to_internal]: loop 0: Source: 'src/betree.rs', lines 522:8-525:9 -/ @@ -503,7 +503,7 @@ divergent def betree.Node.apply_messages_to_internal_loop let (i, m) := new_msg let msgs1 ← betree.Node.apply_to_internal msgs i m betree.Node.apply_messages_to_internal_loop msgs1 new_msgs_tl - | betree.List.Nil => Result.ok msgs + | betree.List.Nil => ok msgs /- [betree::betree::{betree::betree::Node}#5::apply_messages_to_internal]: Source: 'src/betree.rs', lines 518:4-526:5 -/ @@ -526,14 +526,14 @@ divergent def betree.Node.lookup_mut_in_bindings_loop | betree.List.Cons hd tl => let (i, _) := hd if i >= key - then Result.ok (bindings, fun ret => ret) + then ok (bindings, fun ret => ret) else do let (l, back) ← betree.Node.lookup_mut_in_bindings_loop key tl let back1 := fun ret => let tl1 := back ret betree.List.Cons hd tl1 - Result.ok (l, back1) - | betree.List.Nil => Result.ok (betree.List.Nil, fun ret => ret) + ok (l, back1) + | betree.List.Nil => ok (betree.List.Nil, fun ret => ret) /- [betree::betree::{betree::betree::Node}#5::lookup_mut_in_bindings]: Source: 'src/betree.rs', lines 664:4-677:5 -/ @@ -564,28 +564,26 @@ def betree.Node.apply_to_leaf | betree.Message.Insert v => do let bindings3 ← betree.List.push_front bindings2 (key, v) - Result.ok (lookup_mut_in_bindings_back bindings3) - | betree.Message.Delete => - Result.ok (lookup_mut_in_bindings_back bindings2) + ok (lookup_mut_in_bindings_back bindings3) + | betree.Message.Delete => ok (lookup_mut_in_bindings_back bindings2) | betree.Message.Upsert s => do let (_, i) := hd let v ← betree.upsert_update (some i) s let bindings3 ← betree.List.push_front bindings2 (key, v) - Result.ok (lookup_mut_in_bindings_back bindings3) + ok (lookup_mut_in_bindings_back bindings3) else match new_msg with | betree.Message.Insert v => do let bindings2 ← betree.List.push_front bindings1 (key, v) - Result.ok (lookup_mut_in_bindings_back bindings2) - | betree.Message.Delete => - Result.ok (lookup_mut_in_bindings_back bindings1) + ok (lookup_mut_in_bindings_back bindings2) + | betree.Message.Delete => ok (lookup_mut_in_bindings_back bindings1) | betree.Message.Upsert s => do let v ← betree.upsert_update none s let bindings2 ← betree.List.push_front bindings1 (key, v) - Result.ok (lookup_mut_in_bindings_back bindings2) + ok (lookup_mut_in_bindings_back bindings2) /- [betree::betree::{betree::betree::Node}#5::apply_messages_to_leaf]: loop 0: Source: 'src/betree.rs', lines 467:8-470:9 -/ @@ -600,7 +598,7 @@ divergent def betree.Node.apply_messages_to_leaf_loop let (i, m) := new_msg let bindings1 ← betree.Node.apply_to_leaf bindings i m betree.Node.apply_messages_to_leaf_loop bindings1 new_msgs_tl - | betree.List.Nil => Result.ok bindings + | betree.List.Nil => ok bindings /- [betree::betree::{betree::betree::Node}#5::apply_messages_to_leaf]: Source: 'src/betree.rs', lines 463:4-471:5 -/ @@ -622,35 +620,35 @@ mutual divergent def betree.Internal.flush × betree.NodeIdCounter))) := do - let p ← betree.ListPairU64T.partition_at_pivot content self.pivot - let (msgs_left, msgs_right) := p + let (msgs_left, msgs_right) ← + betree.ListPairU64T.partition_at_pivot content self.pivot let len_left ← betree.List.len msgs_left if len_left >= params.min_flush_size then do - let (st1, p1) ← + let (st1, p) ← betree.Node.apply_messages self.left params node_id_cnt msgs_left st - let (n, node_id_cnt1) := p1 + let (n, node_id_cnt1) := p let len_right ← betree.List.len msgs_right if len_right >= params.min_flush_size then do - let (st2, p2) ← + let (st2, p1) ← betree.Node.apply_messages self.right params node_id_cnt1 msgs_right st1 - let (n1, node_id_cnt2) := p2 - Result.ok (st2, (betree.List.Nil, (betree.Internal.mk self.id self.pivot - n n1, node_id_cnt2))) + let (n1, node_id_cnt2) := p1 + ok (st2, (betree.List.Nil, (betree.Internal.mk self.id self.pivot n n1, + node_id_cnt2))) else - Result.ok (st1, (msgs_right, (betree.Internal.mk self.id self.pivot n + ok (st1, (msgs_right, (betree.Internal.mk self.id self.pivot n self.right, node_id_cnt1))) else do - let (st1, p1) ← + let (st1, p) ← betree.Node.apply_messages self.right params node_id_cnt msgs_right st - let (n, node_id_cnt1) := p1 - Result.ok (st1, (msgs_left, (betree.Internal.mk self.id self.pivot - self.left n, node_id_cnt1))) + let (n, node_id_cnt1) := p + ok (st1, (msgs_left, (betree.Internal.mk self.id self.pivot self.left n, + node_id_cnt1))) /- [betree::betree::{betree::betree::Node}#5::apply_messages]: Source: 'src/betree.rs', lines 601:4-645:5 -/ @@ -673,11 +671,11 @@ divergent def betree.Node.apply_messages betree.Internal.flush node params node_id_cnt content1 st1 let (node1, node_id_cnt1) := p let (st3, _) ← betree.store_internal_node node1.id content2 st2 - Result.ok (st3, (betree.Node.Internal node1, node_id_cnt1)) + ok (st3, (betree.Node.Internal node1, node_id_cnt1)) else do let (st2, _) ← betree.store_internal_node node.id content1 st1 - Result.ok (st2, (self, node_id_cnt)) + ok (st2, (self, node_id_cnt)) | betree.Node.Leaf node => do let (st1, content) ← betree.load_leaf_node node.id st @@ -690,12 +688,11 @@ divergent def betree.Node.apply_messages let (st2, (new_node, node_id_cnt1)) ← betree.Leaf.split node content1 params node_id_cnt st1 let (st3, _) ← betree.store_leaf_node node.id betree.List.Nil st2 - Result.ok (st3, (betree.Node.Internal new_node, node_id_cnt1)) + ok (st3, (betree.Node.Internal new_node, node_id_cnt1)) else do let (st2, _) ← betree.store_leaf_node node.id content1 st1 - Result.ok (st2, (betree.Node.Leaf { node with size := len }, - node_id_cnt)) + ok (st2, (betree.Node.Leaf { node with size := len }, node_id_cnt)) end @@ -720,7 +717,7 @@ def betree.BeTree.new let node_id_cnt ← betree.NodeIdCounter.new let (id, node_id_cnt1) ← betree.NodeIdCounter.fresh_id node_id_cnt let (st1, _) ← betree.store_leaf_node id betree.List.Nil st - Result.ok (st1, + ok (st1, { params := { min_flush_size, split_size }, node_id_cnt := node_id_cnt1, @@ -737,7 +734,7 @@ def betree.BeTree.apply let (st1, p) ← betree.Node.apply self.root self.params self.node_id_cnt key msg st let (n, nic) := p - Result.ok (st1, { self with node_id_cnt := nic, root := n }) + ok (st1, { self with node_id_cnt := nic, root := n }) /- [betree::betree::{betree::betree::BeTree}#6::insert]: Source: 'src/betree.rs', lines 873:4-876:5 -/ @@ -772,14 +769,14 @@ def betree.BeTree.lookup := do let (st1, (o, n)) ← betree.Node.lookup self.root key st - Result.ok (st1, (o, { self with root := n })) + ok (st1, (o, { self with root := n })) /- [betree::main]: Source: 'src/main.rs', lines 4:0-4:12 -/ def main : Result Unit := - Result.ok () + ok () /- Unit test for [betree::main] -/ -#assert (main == Result.ok ()) +#assert (main == ok ()) end betree diff --git a/tests/lean/Betree/FunsExternal_Template.lean b/tests/lean/Betree/FunsExternal_Template.lean index 4469c7c1..a8fbb0e5 100644 --- a/tests/lean/Betree/FunsExternal_Template.lean +++ b/tests/lean/Betree/FunsExternal_Template.lean @@ -3,7 +3,7 @@ -- This is a template file: rename it to "FunsExternal.lean" and fill the holes. import Aeneas import Betree.Types -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/Betree/Types.lean b/tests/lean/Betree/Types.lean index e111c4cf..ec232eba 100644 --- a/tests/lean/Betree/Types.lean +++ b/tests/lean/Betree/Types.lean @@ -2,7 +2,7 @@ -- [betree]: type definitions import Aeneas import Betree.TypesExternal -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/Betree/TypesExternal_Template.lean b/tests/lean/Betree/TypesExternal_Template.lean index c59e5386..63ef6ead 100644 --- a/tests/lean/Betree/TypesExternal_Template.lean +++ b/tests/lean/Betree/TypesExternal_Template.lean @@ -2,7 +2,7 @@ -- [betree]: external types. -- This is a template file: rename it to "TypesExternal.lean" and fill the holes. import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/Bitwise.lean b/tests/lean/Bitwise.lean index 577629ca..1fd4d6c5 100644 --- a/tests/lean/Bitwise.lean +++ b/tests/lean/Bitwise.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [bitwise] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -25,16 +25,16 @@ def shift_i32 (a : I32) : Result I32 := /- [bitwise::xor_u32]: Source: 'tests/src/bitwise.rs', lines 19:0-21:1 -/ def xor_u32 (a : U32) (b : U32) : Result U32 := - Result.ok (a ^^^ b) + ok (a ^^^ b) /- [bitwise::or_u32]: Source: 'tests/src/bitwise.rs', lines 23:0-25:1 -/ def or_u32 (a : U32) (b : U32) : Result U32 := - Result.ok (a ||| b) + ok (a ||| b) /- [bitwise::and_u32]: Source: 'tests/src/bitwise.rs', lines 27:0-29:1 -/ def and_u32 (a : U32) (b : U32) : Result U32 := - Result.ok (a &&& b) + ok (a &&& b) end bitwise diff --git a/tests/lean/Bst/Funs.lean b/tests/lean/Bst/Funs.lean index 8b2e5aab..ddd05e8a 100644 --- a/tests/lean/Bst/Funs.lean +++ b/tests/lean/Bst/Funs.lean @@ -2,7 +2,7 @@ -- [bst]: function definitions import Aeneas import Bst.Types -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -12,7 +12,7 @@ namespace bst /- [bst::{bst::TreeSet}::new]: Source: 'src/bst.rs', lines 28:4-30:5 -/ def TreeSet.new {T : Type} (OrdInst : Ord T) : Result (TreeSet T) := - Result.ok { root := none } + ok { root := none } /- [bst::{bst::TreeSet}::find]: loop 0: Source: 'src/bst.rs', lines 35:8-44:5 -/ @@ -21,17 +21,18 @@ divergent def TreeSet.find_loop Result Bool := match current_tree with - | none => Result.ok false + | none => ok false | some current_node => do let o ← OrdInst.cmp current_node.value value match o with | Ordering.Less => TreeSet.find_loop OrdInst value current_node.right - | Ordering.Equal => Result.ok true + | Ordering.Equal => ok true | Ordering.Greater => TreeSet.find_loop OrdInst value current_node.left /- [bst::{bst::TreeSet}::find]: Source: 'src/bst.rs', lines 32:4-44:5 -/ +@[reducible] def TreeSet.find {T : Type} (OrdInst : Ord T) (self : TreeSet T) (value : T) : Result Bool := TreeSet.find_loop OrdInst value self.root @@ -44,7 +45,7 @@ divergent def TreeSet.insert_loop := match current_tree with | none => let n := Node.mk value none none - Result.ok (true, some n) + ok (true, some n) | some current_node => do let o ← OrdInst.cmp current_node.value value @@ -53,14 +54,13 @@ divergent def TreeSet.insert_loop do let (b, current_tree1) ← TreeSet.insert_loop OrdInst value current_node.right - Result.ok (b, some (Node.mk current_node.value current_node.left - current_tree1)) - | Ordering.Equal => Result.ok (false, current_tree) + ok (b, some (Node.mk current_node.value current_node.left current_tree1)) + | Ordering.Equal => ok (false, current_tree) | Ordering.Greater => do let (b, current_tree1) ← TreeSet.insert_loop OrdInst value current_node.left - Result.ok (b, some (Node.mk current_node.value current_tree1 + ok (b, some (Node.mk current_node.value current_tree1 current_node.right)) /- [bst::{bst::TreeSet}::insert]: @@ -71,6 +71,6 @@ def TreeSet.insert := do let (b, ts) ← TreeSet.insert_loop OrdInst value self.root - Result.ok (b, { root := ts }) + ok (b, { root := ts }) end bst diff --git a/tests/lean/Bst/Types.lean b/tests/lean/Bst/Types.lean index 8e633753..35c37339 100644 --- a/tests/lean/Bst/Types.lean +++ b/tests/lean/Bst/Types.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [bst]: type definitions import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/Builtin.lean b/tests/lean/Builtin.lean index 588f525e..fe1ddfef 100644 --- a/tests/lean/Builtin.lean +++ b/tests/lean/Builtin.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [builtin] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -11,12 +11,12 @@ namespace builtin /- [builtin::clone_bool]: Source: 'tests/src/builtin.rs', lines 6:0-8:1 -/ def clone_bool (x : Bool) : Result Bool := - Result.ok (core.clone.impls.CloneBool.clone x) + ok (core.clone.impls.CloneBool.clone x) /- [builtin::clone_u32]: Source: 'tests/src/builtin.rs', lines 10:0-12:1 -/ def clone_u32 (x : U32) : Result U32 := - Result.ok (core.clone.impls.CloneU32.clone x) + ok (core.clone.impls.CloneU32.clone x) /- [builtin::into_from]: Source: 'tests/src/builtin.rs', lines 14:0-16:1 -/ @@ -34,28 +34,28 @@ def into_same {T : Type} (x : T) : Result T := /- [builtin::from_same]: Source: 'tests/src/builtin.rs', lines 22:0-24:1 -/ def from_same {T : Type} (x : T) : Result T := - Result.ok (core.convert.FromSame.from_ x) + ok (core.convert.FromSame.from_ x) /- [builtin::copy]: Source: 'tests/src/builtin.rs', lines 26:0-28:1 -/ def copy {T : Type} (coremarkerCopyInst : core.marker.Copy T) (x : T) : Result T := - Result.ok x + ok x /- [builtin::u32_from_le_bytes]: Source: 'tests/src/builtin.rs', lines 30:0-32:1 -/ def u32_from_le_bytes (x : Array U8 4#usize) : Result U32 := - Result.ok (core.num.U32.from_le_bytes x) + ok (core.num.U32.from_le_bytes x) /- [builtin::u32_to_le_bytes]: Source: 'tests/src/builtin.rs', lines 34:0-36:1 -/ def u32_to_le_bytes (x : U32) : Result (Array U8 4#usize) := - Result.ok (core.num.U32.to_le_bytes x) + ok (core.num.U32.to_le_bytes x) /- [builtin::use_debug_clause]: Source: 'tests/src/builtin.rs', lines 38:0-38:49 -/ def use_debug_clause {T : Type} (corefmtDebugInst : core.fmt.Debug T) (t : T) : Result Unit := - Result.ok () + ok () end builtin diff --git a/tests/lean/Constants.lean b/tests/lean/Constants.lean index 74e6603d..550da8bd 100644 --- a/tests/lean/Constants.lean +++ b/tests/lean/Constants.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [constants] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -10,17 +10,17 @@ namespace constants /- [constants::X0] Source: 'tests/src/constants.rs', lines 8:0-8:22 -/ -def X0_body : Result U32 := Result.ok 0#u32 +def X0_body : Result U32 := ok 0#u32 def X0 : U32 := eval_global X0_body /- [constants::X1] Source: 'tests/src/constants.rs', lines 10:0-10:29 -/ -def X1_body : Result U32 := Result.ok core_u32_max +def X1_body : Result U32 := ok core_u32_max def X1 : U32 := eval_global X1_body /- [constants::X2] Source: 'tests/src/constants.rs', lines 13:0-16:2 -/ -def X2_body : Result U32 := Result.ok 3#u32 +def X2_body : Result U32 := ok 3#u32 def X2 : U32 := eval_global X2_body /- [constants::incr]: @@ -36,7 +36,7 @@ def X3 : U32 := eval_global X3_body /- [constants::mk_pair0]: Source: 'tests/src/constants.rs', lines 26:0-28:1 -/ def mk_pair0 (x : U32) (y : U32) : Result (U32 × U32) := - Result.ok (x, y) + ok (x, y) /- [constants::Pair] Source: 'tests/src/constants.rs', lines 39:0-42:1 -/ @@ -47,7 +47,7 @@ structure Pair (T1 : Type) (T2 : Type) where /- [constants::mk_pair1]: Source: 'tests/src/constants.rs', lines 30:0-32:1 -/ def mk_pair1 (x : U32) (y : U32) : Result (Pair U32 U32) := - Result.ok { x, y } + ok { x, y } /- [constants::P0] Source: 'tests/src/constants.rs', lines 34:0-34:42 -/ @@ -61,12 +61,12 @@ def P1 : Pair U32 U32 := eval_global P1_body /- [constants::P2] Source: 'tests/src/constants.rs', lines 36:0-36:34 -/ -def P2_body : Result (U32 × U32) := Result.ok (0#u32, 1#u32) +def P2_body : Result (U32 × U32) := ok (0#u32, 1#u32) def P2 : (U32 × U32) := eval_global P2_body /- [constants::P3] Source: 'tests/src/constants.rs', lines 37:0-37:51 -/ -def P3_body : Result (Pair U32 U32) := Result.ok { x := 0#u32, y := 1#u32 } +def P3_body : Result (Pair U32 U32) := ok { x := 0#u32, y := 1#u32 } def P3 : Pair U32 U32 := eval_global P3_body /- [constants::Wrap] @@ -77,7 +77,7 @@ structure Wrap (T : Type) where /- [constants::{constants::Wrap}::new]: Source: 'tests/src/constants.rs', lines 57:4-59:5 -/ def Wrap.new {T : Type} (value : T) : Result (Wrap T) := - Result.ok { value } + ok { value } /- [constants::Y] Source: 'tests/src/constants.rs', lines 44:0-44:38 -/ @@ -87,7 +87,7 @@ def Y : Wrap I32 := eval_global Y_body /- [constants::unwrap_y]: Source: 'tests/src/constants.rs', lines 46:0-48:1 -/ def unwrap_y : Result I32 := - Result.ok Y.value + ok Y.value /- [constants::YVAL] Source: 'tests/src/constants.rs', lines 50:0-50:33 -/ @@ -96,13 +96,13 @@ def YVAL : I32 := eval_global YVAL_body /- [constants::get_z1::Z1] Source: 'tests/src/constants.rs', lines 65:4-65:22 -/ -def get_z1.Z1_body : Result I32 := Result.ok 3#i32 +def get_z1.Z1_body : Result I32 := ok 3#i32 def get_z1.Z1 : I32 := eval_global get_z1.Z1_body /- [constants::get_z1]: Source: 'tests/src/constants.rs', lines 64:0-67:1 -/ def get_z1 : Result I32 := - Result.ok get_z1.Z1 + ok get_z1.Z1 /- [constants::add]: Source: 'tests/src/constants.rs', lines 69:0-71:1 -/ @@ -111,12 +111,12 @@ def add (a : I32) (b : I32) : Result I32 := /- [constants::Q1] Source: 'tests/src/constants.rs', lines 77:0-77:22 -/ -def Q1_body : Result I32 := Result.ok 5#i32 +def Q1_body : Result I32 := ok 5#i32 def Q1 : I32 := eval_global Q1_body /- [constants::Q2] Source: 'tests/src/constants.rs', lines 78:0-78:23 -/ -def Q2_body : Result I32 := Result.ok Q1 +def Q2_body : Result I32 := ok Q1 def Q2 : I32 := eval_global Q2_body /- [constants::Q3] @@ -134,7 +134,7 @@ def get_z2 : Result I32 := /- [constants::S1] Source: 'tests/src/constants.rs', lines 83:0-83:23 -/ -def S1_body : Result U32 := Result.ok 6#u32 +def S1_body : Result U32 := ok 6#u32 def S1 : U32 := eval_global S1_body /- [constants::S2] @@ -144,7 +144,7 @@ def S2 : U32 := eval_global S2_body /- [constants::S3] Source: 'tests/src/constants.rs', lines 85:0-85:35 -/ -def S3_body : Result (Pair U32 U32) := Result.ok P3 +def S3_body : Result (Pair U32 U32) := ok P3 def S3 : Pair U32 U32 := eval_global S3_body /- [constants::S4] @@ -159,12 +159,12 @@ structure V (T : Type) (N : Usize) where /- [constants::{constants::V}#1::LEN] Source: 'tests/src/constants.rs', lines 94:4-94:29 -/ -def V.LEN_body (T : Type) (N : Usize) : Result Usize := Result.ok N +def V.LEN_body (T : Type) (N : Usize) : Result Usize := ok N def V.LEN (T : Type) (N : Usize) : Usize := eval_global (V.LEN_body T N) /- [constants::use_v]: Source: 'tests/src/constants.rs', lines 97:0-99:1 -/ def use_v (T : Type) (N : Usize) : Result Usize := - Result.ok (V.LEN T N) + ok (V.LEN T N) end constants diff --git a/tests/lean/Demo/Demo.lean b/tests/lean/Demo/Demo.lean index 270bf34a..ebdbe2e8 100644 --- a/tests/lean/Demo/Demo.lean +++ b/tests/lean/Demo/Demo.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [demo] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -14,9 +14,9 @@ def choose {T : Type} (b : Bool) (x : T) (y : T) : Result (T × (T → (T × T))) := if b then let back := fun ret => (ret, y) - Result.ok (x, back) + ok (x, back) else let back := fun ret => (x, ret) - Result.ok (y, back) + ok (y, back) /- [demo::mul2_add1]: Source: 'tests/src/demo.rs', lines 15:0-17:1 -/ @@ -44,7 +44,7 @@ def use_incr : Result Unit := let x ← incr 0#u32 let x1 ← incr x let _ ← incr x1 - Result.ok () + ok () /- [demo::CList] Source: 'tests/src/demo.rs', lines 36:0-39:1 -/ @@ -58,11 +58,11 @@ divergent def list_nth {T : Type} (l : CList T) (i : U32) : Result T := match l with | CList.CCons x tl => if i = 0#u32 - then Result.ok x + then ok x else do let i1 ← i - 1#u32 list_nth tl i1 - | CList.CNil => Result.fail .panic + | CList.CNil => fail panic /- [demo::list_nth1]: loop 0: Source: 'tests/src/demo.rs', lines 57:4-65:1 -/ @@ -70,11 +70,11 @@ divergent def list_nth1_loop {T : Type} (l : CList T) (i : U32) : Result T := match l with | CList.CCons x tl => if i = 0#u32 - then Result.ok x + then ok x else do let i1 ← i - 1#u32 list_nth1_loop tl i1 - | CList.CNil => Result.fail .panic + | CList.CNil => fail panic /- [demo::list_nth1]: Source: 'tests/src/demo.rs', lines 56:0-65:1 -/ @@ -90,21 +90,21 @@ divergent def list_nth_mut | CList.CCons x tl => if i = 0#u32 then let back := fun ret => CList.CCons ret tl - Result.ok (x, back) + ok (x, back) else do let i1 ← i - 1#u32 let (t, list_nth_mut_back) ← list_nth_mut tl i1 let back := fun ret => let tl1 := list_nth_mut_back ret CList.CCons x tl1 - Result.ok (t, back) - | CList.CNil => Result.fail .panic + ok (t, back) + | CList.CNil => fail panic /- [demo::i32_id]: Source: 'tests/src/demo.rs', lines 82:0-88:1 -/ divergent def i32_id (i : I32) : Result I32 := if i = 0#i32 - then Result.ok 0#i32 + then ok 0#i32 else do let i1 ← i - 1#i32 let i2 ← i32_id i1 @@ -120,8 +120,8 @@ divergent def list_tail let (c, list_tail_back) ← list_tail tl let back := fun ret => let tl1 := list_tail_back ret CList.CCons t tl1 - Result.ok (c, back) - | CList.CNil => Result.ok (CList.CNil, fun ret => ret) + ok (c, back) + | CList.CNil => ok (CList.CNil, fun ret => ret) /- Trait declaration: [demo::Counter] Source: 'tests/src/demo.rs', lines 99:0-101:1 -/ @@ -133,7 +133,7 @@ structure Counter (Self : Type) where def CounterUsize.incr (self : Usize) : Result (Usize × Usize) := do let self1 ← self + 1#usize - Result.ok (self, self1) + ok (self, self1) /- Trait implementation: [demo::{demo::Counter for usize}] Source: 'tests/src/demo.rs', lines 103:0-109:1 -/ diff --git a/tests/lean/Demo/Properties.lean b/tests/lean/Demo/Properties.lean index b6e7468f..705a5577 100644 --- a/tests/lean/Demo/Properties.lean +++ b/tests/lean/Demo/Properties.lean @@ -7,7 +7,7 @@ namespace demo #check U32.add_spec --- @[pspec] +-- @[progress] theorem mul2_add1_spec (x : U32) (h : 2 * x + 1 ≤ (U32.max : Int)) : ∃ y, mul2_add1 x = ok y ∧ ↑y = 2 * ↑x + (1 : Int) @@ -17,9 +17,9 @@ theorem mul2_add1_spec (x : U32) (h : 2 * x + 1 ≤ (U32.max : Int)) progress as ⟨ i' ⟩ scalar_tac -theorem use_mul2_add1_spec (x : U32) (y : U32) (h : 2 * x + 1 + y ≤ (U32.max : Int)) : +theorem use_mul2_add1_spec (x : U32) (y : U32) (h : 2 * x + 1 + y ≤ (U32.max : Nat)) : ∃ z, use_mul2_add1 x y = ok z ∧ - ↑z = 2 * ↑x + (1 : Int) + ↑y := by + ↑z = 2 * ↑x + (1 : Nat) + ↑y := by rw [use_mul2_add1] progress with mul2_add1_spec as ⟨ i ⟩ progress as ⟨ i' ⟩ @@ -27,20 +27,20 @@ theorem use_mul2_add1_spec (x : U32) (y : U32) (h : 2 * x + 1 + y ≤ (U32.max : open CList -@[simp] def CList.to_list {α : Type} (x : CList α) : List α := +@[simp] def CList.toList {α : Type} (x : CList α) : List α := match x with | CNil => [] - | CCons hd tl => hd :: tl.to_list + | CCons hd tl => hd :: tl.toList theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) - (h : i.val < l.to_list.length) : + (h : i.val < l.toList.length) : ∃ x, list_nth l i = ok x ∧ - x = l.to_list.index i.toNat + x = l.toList[i.val]! := by rw [list_nth] match l with | CNil => - simp_all; scalar_tac + simp_all | CCons hd tl => simp_all if hi: i = 0#u32 then @@ -65,7 +65,7 @@ decreasing_by simp_wf; scalar_tac theorem list_tail_spec {T : Type} [Inhabited T] (l : CList T) : ∃ back, list_tail l = ok (CList.CNil, back) ∧ - ∀ tl', ∃ l', back tl' = l' ∧ l'.to_list = l.to_list ++ tl'.to_list := by + ∀ tl', ∃ l', back tl' = l' ∧ l'.toList = l.toList ++ tl'.toList := by rw [list_tail] match l with | CNil => diff --git a/tests/lean/External/Funs.lean b/tests/lean/External/Funs.lean index e95dd7b0..b557053a 100644 --- a/tests/lean/External/Funs.lean +++ b/tests/lean/External/Funs.lean @@ -3,7 +3,7 @@ import Aeneas import External.Types import External.FunsExternal -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -25,6 +25,6 @@ def incr let (st1, (i, get_mut_back)) ← core.cell.Cell.get_mut rc st let i1 ← i + 1#u32 let (_, rc1) := get_mut_back i1 st1 - Result.ok (st1, rc1) + ok (st1, rc1) end external diff --git a/tests/lean/External/FunsExternal_Template.lean b/tests/lean/External/FunsExternal_Template.lean index 8c1d3dd1..5348acc3 100644 --- a/tests/lean/External/FunsExternal_Template.lean +++ b/tests/lean/External/FunsExternal_Template.lean @@ -3,7 +3,7 @@ -- This is a template file: rename it to "FunsExternal.lean" and fill the holes. import Aeneas import External.Types -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/External/Types.lean b/tests/lean/External/Types.lean index efd5d665..3ae433f4 100644 --- a/tests/lean/External/Types.lean +++ b/tests/lean/External/Types.lean @@ -2,7 +2,7 @@ -- [external]: type definitions import Aeneas import External.TypesExternal -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/External/TypesExternal_Template.lean b/tests/lean/External/TypesExternal_Template.lean index b0ffcdcb..6c2567cd 100644 --- a/tests/lean/External/TypesExternal_Template.lean +++ b/tests/lean/External/TypesExternal_Template.lean @@ -2,7 +2,7 @@ -- [external]: external types. -- This is a template file: rename it to "TypesExternal.lean" and fill the holes. import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/Hashmap/Funs.lean b/tests/lean/Hashmap/Funs.lean index 1f69f751..d1147a69 100644 --- a/tests/lean/Hashmap/Funs.lean +++ b/tests/lean/Hashmap/Funs.lean @@ -3,7 +3,7 @@ import Aeneas import Hashmap.Types import Hashmap.FunsExternal -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -13,12 +13,12 @@ namespace hashmap /- [hashmap::hash_key]: Source: 'tests/src/hashmap.rs', lines 38:0-43:1 -/ def hash_key (k : Usize) : Result Usize := - Result.ok k + ok k /- [hashmap::{core::clone::Clone for hashmap::Fraction}#1::clone]: Source: 'tests/src/hashmap.rs', lines 45:9-45:14 -/ def ClonehashmapFraction.clone (self : Fraction) : Result Fraction := - Result.ok self + ok self /- Trait implementation: [hashmap::{core::clone::Clone for hashmap::Fraction}#1] Source: 'tests/src/hashmap.rs', lines 45:9-45:14 -/ @@ -46,7 +46,7 @@ divergent def HashMap.allocate_slots_loop let slots1 ← alloc.vec.Vec.push slots AList.Nil let n1 ← n - 1#usize HashMap.allocate_slots_loop slots1 n1 - else Result.ok slots + else ok slots /- [hashmap::{hashmap::HashMap}::allocate_slots]: Source: 'tests/src/hashmap.rs', lines 69:4-75:5 -/ @@ -67,7 +67,7 @@ def HashMap.new_with_capacity let slots ← HashMap.allocate_slots (alloc.vec.Vec.new (AList T)) capacity let i ← capacity * max_load_factor.dividend let i1 ← i / max_load_factor.divisor - Result.ok + ok { num_entries := 0#usize, max_load_factor, @@ -98,19 +98,19 @@ divergent def HashMap.clear_loop let i2 ← i + 1#usize let slots1 := index_mut_back AList.Nil HashMap.clear_loop slots1 i2 - else Result.ok slots + else ok slots /- [hashmap::{hashmap::HashMap}::clear]: Source: 'tests/src/hashmap.rs', lines 102:4-110:5 -/ def HashMap.clear {T : Type} (self : HashMap T) : Result (HashMap T) := do let hm ← HashMap.clear_loop self.slots 0#usize - Result.ok { self with num_entries := 0#usize, slots := hm } + ok { self with num_entries := 0#usize, slots := hm } /- [hashmap::{hashmap::HashMap}::len]: Source: 'tests/src/hashmap.rs', lines 112:4-114:5 -/ def HashMap.len {T : Type} (self : HashMap T) : Result Usize := - Result.ok self.num_entries + ok self.num_entries /- [hashmap::{hashmap::HashMap}::insert_in_list]: loop 0: Source: 'tests/src/hashmap.rs', lines 1:0-135:9 -/ @@ -121,12 +121,12 @@ divergent def HashMap.insert_in_list_loop match ls with | AList.Cons ckey cvalue tl => if ckey = key - then Result.ok (false, AList.Cons ckey value tl) + then ok (false, AList.Cons ckey value tl) else do let (b, tl1) ← HashMap.insert_in_list_loop key value tl - Result.ok (b, AList.Cons ckey cvalue tl1) - | AList.Nil => Result.ok (true, AList.Cons key value AList.Nil) + ok (b, AList.Cons ckey cvalue tl1) + | AList.Nil => ok (true, AList.Cons key value AList.Nil) /- [hashmap::{hashmap::HashMap}::insert_in_list]: Source: 'tests/src/hashmap.rs', lines 119:4-136:5 -/ @@ -156,9 +156,9 @@ def HashMap.insert_no_resize do let i1 ← self.num_entries + 1#usize let v := index_mut_back a1 - Result.ok { self with num_entries := i1, slots := v } + ok { self with num_entries := i1, slots := v } else let v := index_mut_back a1 - Result.ok { self with slots := v } + ok { self with slots := v } /- [hashmap::{hashmap::HashMap}::move_elements_from_list]: loop 0: Source: 'tests/src/hashmap.rs', lines 201:12-208:17 -/ @@ -169,7 +169,7 @@ divergent def HashMap.move_elements_from_list_loop do let ntable1 ← HashMap.insert_no_resize ntable k v HashMap.move_elements_from_list_loop ntable1 tl - | AList.Nil => Result.ok ntable + | AList.Nil => ok ntable /- [hashmap::{hashmap::HashMap}::move_elements_from_list]: Source: 'tests/src/hashmap.rs', lines 198:4-211:5 -/ @@ -197,10 +197,11 @@ divergent def HashMap.move_elements_loop let i2 ← i + 1#usize let slots1 := index_mut_back a1 HashMap.move_elements_loop ntable1 slots1 i2 - else Result.ok (ntable, slots) + else ok (ntable, slots) /- [hashmap::{hashmap::HashMap}::move_elements]: Source: 'tests/src/hashmap.rs', lines 185:4-195:5 -/ +@[reducible] def HashMap.move_elements {T : Type} (ntable : HashMap T) (slots : alloc.vec.Vec (AList T)) : Result ((HashMap T) × (alloc.vec.Vec (AList T))) @@ -219,11 +220,9 @@ def HashMap.try_resize {T : Type} (self : HashMap T) : Result (HashMap T) := do let i1 ← capacity * 2#usize let ntable ← HashMap.new_with_capacity T i1 self.max_load_factor - let p ← HashMap.move_elements ntable self.slots - let (ntable1, _) := p - Result.ok - { self with max_load := ntable1.max_load, slots := ntable1.slots } - else Result.ok { self with saturated := true } + let (ntable1, _) ← HashMap.move_elements ntable self.slots + ok { self with max_load := ntable1.max_load, slots := ntable1.slots } + else ok { self with saturated := true } /- [hashmap::{hashmap::HashMap}::insert]: Source: 'tests/src/hashmap.rs', lines 151:4-158:5 -/ @@ -236,9 +235,9 @@ def HashMap.insert let i ← HashMap.len self1 if i > self1.max_load then if self1.saturated - then Result.ok self1 + then ok self1 else HashMap.try_resize self1 - else Result.ok self1 + else ok self1 /- [hashmap::{hashmap::HashMap}::contains_key_in_list]: loop 0: Source: 'tests/src/hashmap.rs', lines 1:0-233:9 -/ @@ -247,9 +246,9 @@ divergent def HashMap.contains_key_in_list_loop match ls with | AList.Cons ckey _ tl => if ckey = key - then Result.ok true + then ok true else HashMap.contains_key_in_list_loop key tl - | AList.Nil => Result.ok false + | AList.Nil => ok false /- [hashmap::{hashmap::HashMap}::contains_key_in_list]: Source: 'tests/src/hashmap.rs', lines 221:4-234:5 -/ @@ -278,9 +277,9 @@ divergent def HashMap.get_in_list_loop match ls with | AList.Cons ckey cvalue tl => if ckey = key - then Result.ok (some cvalue) + then ok (some cvalue) else HashMap.get_in_list_loop key tl - | AList.Nil => Result.ok none + | AList.Nil => ok none /- [hashmap::{hashmap::HashMap}::get_in_list]: Source: 'tests/src/hashmap.rs', lines 239:4-248:5 -/ @@ -318,15 +317,15 @@ divergent def HashMap.get_mut_in_list_loop | some t1 => t1 | _ => cvalue AList.Cons ckey t tl - Result.ok (some cvalue, back) + ok (some cvalue, back) else do let (o, back) ← HashMap.get_mut_in_list_loop tl key let back1 := fun ret => let tl1 := back ret AList.Cons ckey cvalue tl1 - Result.ok (o, back1) + ok (o, back1) | AList.Nil => let back := fun ret => AList.Nil - Result.ok (none, back) + ok (none, back) /- [hashmap::{hashmap::HashMap}::get_mut_in_list]: Source: 'tests/src/hashmap.rs', lines 256:4-265:5 -/ @@ -356,7 +355,7 @@ def HashMap.get_mut let a1 := get_mut_in_list_back ret let v := index_mut_back a1 { self with slots := v } - Result.ok (o, back) + ok (o, back) /- [hashmap::{hashmap::HashMap}::remove_from_list]: loop 0: Source: 'tests/src/hashmap.rs', lines 1:0-299:17 -/ @@ -368,13 +367,13 @@ divergent def HashMap.remove_from_list_loop then let (mv_ls, _) := core.mem.replace ls AList.Nil match mv_ls with - | AList.Cons _ cvalue tl1 => Result.ok (some cvalue, tl1) - | AList.Nil => Result.fail .panic + | AList.Cons _ cvalue tl1 => ok (some cvalue, tl1) + | AList.Nil => fail panic else do let (o, tl1) ← HashMap.remove_from_list_loop key tl - Result.ok (o, AList.Cons ckey t tl1) - | AList.Nil => Result.ok (none, AList.Nil) + ok (o, AList.Cons ckey t tl1) + | AList.Nil => ok (none, AList.Nil) /- [hashmap::{hashmap::HashMap}::remove_from_list]: Source: 'tests/src/hashmap.rs', lines 276:4-302:5 -/ @@ -398,14 +397,13 @@ def HashMap.remove T)) self.slots hash_mod let (x, a1) ← HashMap.remove_from_list key a match x with - | none => - let v := index_mut_back a1 - Result.ok (none, { self with slots := v }) + | none => let v := index_mut_back a1 + ok (none, { self with slots := v }) | some _ => do let i1 ← self.num_entries - 1#usize let v := index_mut_back a1 - Result.ok (x, { self with num_entries := i1, slots := v }) + ok (x, { self with num_entries := i1, slots := v }) /- [hashmap::insert_on_disk]: Source: 'tests/src/hashmap.rs', lines 336:0-343:1 -/ diff --git a/tests/lean/Hashmap/FunsExternal_Template.lean b/tests/lean/Hashmap/FunsExternal_Template.lean index 2d400eec..2ee452a5 100644 --- a/tests/lean/Hashmap/FunsExternal_Template.lean +++ b/tests/lean/Hashmap/FunsExternal_Template.lean @@ -3,7 +3,7 @@ -- This is a template file: rename it to "FunsExternal.lean" and fill the holes. import Aeneas import Hashmap.Types -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/Hashmap/Properties.lean b/tests/lean/Hashmap/Properties.lean index 852f317c..73f2ac96 100644 --- a/tests/lean/Hashmap/Properties.lean +++ b/tests/lean/Hashmap/Properties.lean @@ -3,9 +3,6 @@ import Hashmap.Funs open Aeneas.Std open Result ---set_option profiler true ---set_option profiler.threshold 1 - namespace hashmap namespace AList @@ -29,42 +26,42 @@ namespace HashMap def distinct_keys (ls : List (Usize × α)) := ls.pairwise_rel (λ x y => x.fst ≠ y.fst) -def hash_mod_key (k : Usize) (l : Int) : Int := +def hash_mod_key (k : Usize) (l : Nat) : Nat := match hash_key k with | .ok k => k.val % l | _ => 0 @[simp] theorem hash_mod_key_eq : hash_mod_key k l = k.val % l := by - simp [hash_mod_key, hash_key] + fsimp [hash_mod_key, hash_key] -def slot_s_inv_hash (l i : Int) (ls : List (Usize × α)) : Prop := +def slot_s_inv_hash (l i : Nat) (ls : List (Usize × α)) : Prop := ls.allP (λ (k, _) => hash_mod_key k l = i) -def slot_s_inv (l i : Int) (ls : List (Usize × α)) : Prop := +def slot_s_inv (l i : Nat) (ls : List (Usize × α)) : Prop := distinct_keys ls ∧ slot_s_inv_hash l i ls -def slot_t_inv (l i : Int) (s : AList α) : Prop := slot_s_inv l i s.v +def slot_t_inv (l i : Nat) (s : AList α) : Prop := slot_s_inv l i s.v -@[simp] theorem distinct_keys_nil : @distinct_keys α [] := by simp [distinct_keys] -@[simp] theorem slot_s_inv_hash_nil : @slot_s_inv_hash l i α [] := by simp [slot_s_inv_hash] -@[simp] theorem slot_s_inv_nil : @slot_s_inv α l i [] := by simp [slot_s_inv] -@[simp] theorem slot_t_inv_nil : @slot_t_inv α l i .Nil := by simp [slot_t_inv] +@[simp] theorem distinct_keys_nil : @distinct_keys α [] := by fsimp [distinct_keys] +@[simp] theorem slot_s_inv_hash_nil : @slot_s_inv_hash l i α [] := by fsimp [slot_s_inv_hash] +@[simp] theorem slot_s_inv_nil : @slot_s_inv α l i [] := by fsimp [slot_s_inv] +@[simp] theorem slot_t_inv_nil : @slot_t_inv α l i .Nil := by fsimp [slot_t_inv] @[simp] theorem distinct_keys_cons (kv : Usize × α) (tl : List (Usize × α)) : - distinct_keys (kv :: tl) ↔ ((tl.allP fun (k', _) => ¬↑kv.1 = ↑k') ∧ distinct_keys tl) := by simp [distinct_keys] + distinct_keys (kv :: tl) ↔ ((tl.allP fun (k', _) => ¬↑kv.1 = ↑k') ∧ distinct_keys tl) := by fsimp [distinct_keys] @[simp] theorem slot_s_inv_hash_cons (kv : Usize × α) (tl : List (Usize × α)) : slot_s_inv_hash l i (kv :: tl) ↔ (hash_mod_key kv.1 l = i ∧ tl.allP (λ (k, _) => hash_mod_key k l = i) ∧ slot_s_inv_hash l i tl) := - by simp [slot_s_inv_hash] + by fsimp [slot_s_inv_hash] @[simp] theorem slot_s_inv_cons (kv : Usize × α) (tl : List (Usize × α)) : slot_s_inv l i (kv :: tl) ↔ ((tl.allP fun (k', _) => ¬↑kv.1 = ↑k') ∧ distinct_keys tl ∧ hash_mod_key kv.1 l = i ∧ tl.allP (λ (k, _) => hash_mod_key k l = i) ∧ slot_s_inv l i tl) := by - simp [slot_s_inv]; tauto + fsimp [slot_s_inv]; tauto -- Interpret the hashmap as a list of lists def v (hm : HashMap α) : List (List (Usize × α)) := @@ -80,7 +77,7 @@ instance : Inhabited (AList α) where @[simp] def slots_s_inv (s : List (AList α)) : Prop := - ∀ (i : Nat), i < s.length → slot_t_inv s.length i (s.index i) + ∀ (i : Nat), i < s.length → slot_t_inv s.length i s[i]! def slots_t_inv (s : alloc.vec.Vec (AList α)) : Prop := slots_s_inv s.v @@ -88,7 +85,7 @@ def slots_t_inv (s : alloc.vec.Vec (AList α)) : Prop := @[simp] def slots_s_lookup (s : List (AList α)) (k : Usize) : Option α := let i := hash_mod_key k s.length - let slot := s.index i.toNat + let slot := s[i]! slot.lookup k abbrev Slots α := alloc.vec.Vec (AList α) @@ -138,19 +135,15 @@ def frame_load (hm nhm : HashMap α) : Prop := nhm.max_load = hm.max_load ∧ nhm.saturated = hm.saturated --- This rewriting lemma is problematic below -attribute [-simp] Bool.exists_bool +-- Those rewriting lemmas are problematic +attribute [-simp] Bool.exists_bool List.getElem!_eq_getElem?_getD --- These simp lemmas were introduced by upstream changes and are problematic +-- These fsimp lemmas were introduced by upstream changes and are problematic attribute [-simp] List.length_flatten List.flatten_eq_nil_iff List.lookup_eq_none_iff attribute [local simp] List.lookup /- Adding some theorems for `scalar_tac` -/ --- We first activate the rule set for non linear arithmetic - this is needed because of the modulo operations -set_option scalarTac.nonLin true - --- Custom, local rule @[local scalar_tac h] theorem inv_imp_eqs_ineqs {hm : HashMap α} (h : hm.inv) : 0 < hm.slots.length ∧ hm.num_entries.val = hm.al_v.length := by @@ -161,59 +154,60 @@ set_option maxHeartbeats 0 open AList -@[pspec] +@[progress] theorem allocate_slots_spec {α : Type} (slots : alloc.vec.Vec (AList α)) (n : Usize) - (Hslots : ∀ (i : Nat), i < slots.length → slots.val.index i = Nil) + (Hslots : ∀ (i : Nat), i < slots.length → slots[i]! = Nil) (Hlen : slots.len + n.val ≤ Usize.max) : ∃ slots1, allocate_slots slots n = ok slots1 ∧ - (∀ (i : Nat), i < slots1.length → slots1.val.index i = Nil) ∧ + (∀ (i : Nat), i < slots1.length → slots1[i]! = Nil) ∧ slots1.len = slots.len + n.val := by rw [allocate_slots] rw [allocate_slots_loop] + fsimp at * if h: 0 < n.val then - simp [h] + fsimp [h] progress as ⟨ slots1 ⟩ progress as ⟨ n1 ⟩ have Hslots1Nil : - ∀ (i : Nat), i < slots1.length → slots1.val.index i = Nil := by + ∀ (i : Nat), i < slots1.length → slots1[i]! = Nil := by intro i h0 - simp [*] - if hi : i < slots.val.length then - simp [*] + fsimp [*] + if hi : i < slots.length then + fsimp [*] else - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all have : i - slots.val.length = 0 := by scalar_tac - simp [*] + fsimp [*] have Hslots1Len : alloc.vec.Vec.len slots1 + n1.val ≤ Usize.max := by - simp_all (config := {maxDischargeDepth := 1}) + scalar_tac progress as ⟨ slots2 ⟩ - constructor + split_conjs . intro i h0 - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all . simp_all + scalar_tac else - simp [h] - simp_all (config := {maxDischargeDepth := 1}) - scalar_tac -termination_by n.val.toNat + fsimp [h] + fsimp_all +termination_by n.val decreasing_by scalar_decr_tac theorem forall_nil_imp_flatten_len_zero (slots : List (List α)) - (Hnil : ∀ i, i < slots.length → slots.index i = []) : + (Hnil : ∀ i, i < slots.length → slots[i]! = []) : slots.flatten = [] := by - induction slots <;> simp_all (config := {maxDischargeDepth := 1}) + induction slots <;> fsimp_all have Hhead := Hnil 0 (by simp) - simp at Hhead - simp_all (config := {maxDischargeDepth := 1}) + fsimp at Hhead + fsimp_all rename _ → _ => Hind apply Hind intros i h0 have := Hnil (i + 1) (by scalar_tac) have : 0 < i + 1 := by scalar_tac - -- TODO: simp_all (config := {maxDischargeDepth := 1}) eliminates Hnil - simp at *; simp_all (config := {maxDischargeDepth := 1}) + -- TODO: fsimp_all eliminates Hnil + fsimp at *; fsimp_all -@[pspec] +@[progress] theorem new_with_capacity_spec (capacity : Usize) (max_load_factor : Fraction) (Hcapa : 0 < capacity.val) @@ -227,49 +221,51 @@ theorem new_with_capacity_spec ∀ k, hm.lookup k = none := by rw [new_with_capacity] progress as ⟨ slots, Hnil ⟩ - . simp [alloc.vec.Vec.new] at *; scalar_tac - . progress as ⟨ i1 ⟩ - progress as ⟨ i2 ⟩ - simp [inv, inv_load] - have : (Slots.al_v slots).length = 0 := by - have := forall_nil_imp_flatten_len_zero (slots.val.map AList.v) - (by intro i h0 - -- TODO: simp_all (config := {maxDischargeDepth := 1}) eliminates Hnil !? - simp at * - simp_all (config := {maxDischargeDepth := 1})) - simp_all (config := {maxDischargeDepth := 1}) - have : 0 < slots.val.length := by simp_all (config := {maxDischargeDepth := 1}) [alloc.vec.Vec.len, alloc.vec.Vec.new]; scalar_tac - have : slots_t_inv slots := by - simp [slots_t_inv, slot_t_inv] - intro i h0 - simp_all (config := {maxDischargeDepth := 1}) - split_conjs - . simp_all (config := {maxDischargeDepth := 1}) [al_v, Slots.al_v, v] - . assumption - . scalar_tac - . simp_all (config := {maxDischargeDepth := 1}) [alloc.vec.Vec.len, alloc.vec.Vec.new] - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) [alloc.vec.Vec.len, alloc.vec.Vec.new] - . simp_all (config := {maxDischargeDepth := 1}) [alloc.vec.Vec.len, alloc.vec.Vec.new] - . simp_all (config := {maxDischargeDepth := 1}) [al_v, Slots.al_v, v] - . simp_all [HashMap.v, length] - . simp [lookup] - intro k - have : 0 ≤ k.val % slots.val.length ∧ k.val % slots.val.length < slots.val.length := by scalar_tac - simp [*] - -@[pspec] + progress as ⟨ i1 ⟩ + progress as ⟨ i2 ⟩ + fsimp [inv, inv_load] + have : (Slots.al_v slots).length = 0 := by + have := forall_nil_imp_flatten_len_zero (slots.val.map AList.v) + (by intro i h0 + -- TODO: fsimp_all eliminates Hnil !? + fsimp at * + fsimp_all) + fsimp_all + have : 0 < slots.val.length := by scalar_tac + have : slots_t_inv slots := by + fsimp [slots_t_inv, slot_t_inv] + intro i h0 + fsimp_all + split_conjs + . fsimp_all [al_v, Slots.al_v, v] + . assumption + . scalar_tac + . fsimp_all [alloc.vec.Vec.len, alloc.vec.Vec.new] + . fsimp_all + . scalar_tac + . fsimp_all [alloc.vec.Vec.len, alloc.vec.Vec.new] + . fsimp_all [al_v, Slots.al_v, v] + . simp_all [HashMap.v, length] + . fsimp [lookup] + intro k + have : k.val % slots.val.length < slots.val.length := by + scalar_tac +nonLin + simp at Hnil + have := Hnil _ this + fsimp [this] + +@[progress] theorem new_spec (α : Type) : ∃ hm, new α = ok hm ∧ hm.inv ∧ hm.len_s = 0 ∧ ∀ k, hm.lookup k = none := by rw [new] progress as ⟨ hm ⟩ - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all --set_option pp.all true -example (key : Usize) : key == key := by simp [beq_iff_eq] +example (key : Usize) : key == key := by fsimp [beq_iff_eq] -theorem insert_in_list_spec_aux {α : Type} (l : Int) (key: Usize) (value: α) (l0: AList α) +theorem insert_in_list_spec_aux {α : Type} (l : Nat) (key: Usize) (value: α) (l0: AList α) (hinv : slot_s_inv_hash l (hash_mod_key key l) l0.v) (hdk : distinct_keys l0.v) : ∃ b l1, @@ -290,40 +286,37 @@ theorem insert_in_list_spec_aux {α : Type} (l : Int) (key: Usize) (value: α) ( -- We need this auxiliary property to prove that the keys distinct properties is preserved (∀ k, k ≠ key → l0.v.allP (λ (k1, _) => k ≠ k1) → l1.v.allP (λ (k1, _) => k ≠ k1)) := by - cases l0 with - | Nil => - exists true -- TODO: why do we need to do this? - simp [insert_in_list] + cases l0; swap + . fsimp [insert_in_list] rw [insert_in_list_loop] - simp (config := {contextual := true}) [AList.v] - | Cons k v tl0 => - if h: k = key then - rw [insert_in_list] - rw [insert_in_list_loop] - simp [h] - split_conjs <;> simp_all (config := {maxDischargeDepth := 1}) [slot_s_inv_hash] - else - rw [insert_in_list] - rw [insert_in_list_loop] - simp [h] - have : slot_s_inv_hash l (hash_mod_key key l) (AList.v tl0) := by - simp_all (config := {maxDischargeDepth := 1}) [AList.v, slot_s_inv_hash] - have : distinct_keys (AList.v tl0) := by - simp [distinct_keys] at hdk - simp [hdk, distinct_keys] - progress as ⟨ b, tl1 ⟩ - have : slot_s_inv_hash l (hash_mod_key key l) (AList.v (AList.Cons k v tl1)) := by - simp [AList.v, slot_s_inv_hash] at * - simp [*] - have : distinct_keys ((k, v) :: AList.v tl1) := by - simp [distinct_keys] at * - simp [*] - -- TODO: canonize addition by default? - exists b - simp_all (config := {maxDischargeDepth := 2}) [Int.add_assoc, Int.add_comm, Int.add_left_comm] - -@[pspec] -theorem insert_in_list_spec {α : Type} (l : Int) (key: Usize) (value: α) (l0: AList α) + fsimp (config := {contextual := true}) [AList.v] + . rename_i k v tl0 + if h: k = key then + rw [insert_in_list] + rw [insert_in_list_loop] + fsimp [h] + split_conjs <;> fsimp_all [slot_s_inv_hash] + else + rw [insert_in_list] + rw [insert_in_list_loop] + fsimp [h] + have : slot_s_inv_hash l (hash_mod_key key l) (AList.v tl0) := by + fsimp_all [AList.v, slot_s_inv_hash] + have : distinct_keys (AList.v tl0) := by + fsimp [distinct_keys] at hdk + fsimp [hdk, distinct_keys] + progress as ⟨ b, tl1 ⟩ + have : slot_s_inv_hash l (hash_mod_key key l) (AList.v (AList.Cons k v tl1)) := by + fsimp [AList.v, slot_s_inv_hash] at * + fsimp [*] + have : distinct_keys ((k, v) :: AList.v tl1) := by + fsimp [distinct_keys] at * + fsimp [*] + -- TODO: canonize addition by default? + simp_all [Int.add_assoc, Int.add_comm, Int.add_left_comm] + +@[progress] +theorem insert_in_list_spec {α : Type} (l : Nat) (key: Usize) (value: α) (l0: AList α) (hinv : slot_s_inv_hash l (hash_mod_key key l) l0.v) (hdk : distinct_keys l0.v) : ∃ b l1, @@ -351,7 +344,7 @@ theorem if_update_eq {α β : Type u} (b : Bool) (y : α) (e : Result α) (f : α → Result β) : (if b then Bind.bind e f else f y) = Bind.bind (if b then e else pure y) f := by - split <;> simp [Pure.pure] + split <;> fsimp [Pure.pure] def frame_slots_params (hm1 hm2 : HashMap α) := -- The max load factor is the same @@ -359,7 +352,7 @@ def frame_slots_params (hm1 hm2 : HashMap α) := -- The number of slots is the same hm2.slots.val.length = hm1.slots.val.length -@[pspec] +@[progress] theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value : α) (hinv : hm.inv) (hnsat : hm.lookup key = none → hm.len_s < Usize.max) : ∃ nhm, hm.insert_no_resize key value = ok nhm ∧ @@ -377,36 +370,38 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value | some _ => nhm.len_s = hm.len_s) := by rw [insert_no_resize] -- Simplify. Note that this also simplifies some function calls, like array index - simp [hash_key, bind_tc_ok] + fsimp [hash_key, bind_tc_ok] progress as ⟨ hash_mod, hhm ⟩ - have _ : 0 ≤ hash_mod.val ∧ hash_mod.val < alloc.vec.Vec.length hm.slots := by scalar_tac + fsimp at hhm + have _ : hash_mod.val < alloc.vec.Vec.length hm.slots := by + scalar_tac +nonLin progress as ⟨ l, h_leq ⟩ have h_slot : slot_s_inv_hash hm.slots.length (hash_mod_key key hm.slots.length) l.v := by - simp [inv, slots_t_inv] at hinv - have h := (hinv.right.left hash_mod.toNat (by scalar_tac)).right - simp [slot_t_inv, hhm] at h - simp_all (config := {maxDischargeDepth := 1}) + fsimp [inv, slots_t_inv] at hinv + have h := (hinv.right.left hash_mod.val (by scalar_tac)).right + fsimp [slot_t_inv, hhm] at h + fsimp_all progress as ⟨ inserted, l0, _, _, _, _, hlen ⟩ - . simp [inv, slots_t_inv, slot_t_inv, slot_s_inv] at hinv - have h := hinv.right.left hash_mod.toNat (by scalar_tac) - simp [h, h_leq] + . fsimp [inv, slots_t_inv, slot_t_inv, slot_s_inv] at hinv + have h := hinv.right.left hash_mod.val (by scalar_tac) + fsimp [h, h_leq] rw [if_update_eq] -- TODO: necessary because we don't have a join -- TODO: progress to ... have hipost : - ∃ i0, (if inserted = true then hm.num_entries + Usize.ofInt 1 else pure hm.num_entries) = ok i0 ∧ + ∃ i0, (if inserted = true then hm.num_entries + Usize.ofNat 1 else pure hm.num_entries) = ok i0 ∧ i0.val = if inserted then hm.num_entries.val + 1 else hm.num_entries.val := by if inserted then - simp [*] - have hbounds : hm.num_entries.val + (Usize.ofInt 1).val ≤ Usize.max := by - simp [lookup] at hnsat - simp_all (config := {maxDischargeDepth := 1}) [] + fsimp [*] + have hbounds : hm.num_entries.val + (Usize.ofNat 1).val ≤ Usize.max := by + fsimp [lookup] at hnsat + fsimp_all [] scalar_tac progress as ⟨ z, hp ⟩ - simp [hp] + fsimp [hp] else - simp [*, Pure.pure] + fsimp [*, Pure.pure] progress as ⟨ i0 ⟩ -- TODO: hide the variables and only keep the props -- TODO: allow providing terms to progress to instantiate the meta variables @@ -418,69 +413,68 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value max_load_factor := hm.max_load_factor, max_load := hm.max_load, saturated := hm.saturated, - slots := hm.slots.update hash_mod l0 } + slots := hm.slots.set hash_mod l0 } have _ : match hm.lookup key with | none => nhm.len_s = hm.len_s + 1 | some _ => nhm.len_s = hm.len_s := by - simp only [lookup, len_s, al_v, HashMap.v, slots_s_lookup] at * + fsimp only [lookup, len_s, al_v, HashMap.v, slots_s_lookup] at * -- We have to do a case disjunction - simp_all (config := {maxDischargeDepth := 1}) [List.map_update_eq] + fsimp_all -- TODO: dependent rewrites - have _ : (key.val % hm.slots.val.length).toNat < (List.map AList.v hm.slots.val).length := by - simp [*] + have _ : key.val % hm.slots.val.length < (List.map AList.v hm.slots.val).length := by + fsimp [*] split <;> rename_i heq <;> - simp [heq] at hlen <;> + fsimp [heq] at hlen <;> -- TODO: canonize addition by default? We need a tactic to simplify arithmetic equalities -- with addition and substractions ((ℤ, +) is a group or something - there should exist a tactic -- somewhere in mathlib?) - simp [List.length_flatten_update_as_int_eq, nhm, *] - int_tac + fsimp [List.length_flatten_set_as_int_eq, nhm, *] + scalar_tac split_conjs - . simp [inv] at * + . fsimp [inv] at * split_conjs . match h: lookup hm key with | none => - simp [h, lookup, nhm] at * - simp_all (config := {maxDischargeDepth := 1}) + fsimp [h, lookup, nhm] at * + fsimp_all | some _ => - simp_all (config := {maxDischargeDepth := 1}) [lookup, nhm] - . simp [slots_t_inv, slot_t_inv] at * + fsimp_all [lookup, nhm] + . fsimp [slots_t_inv, slot_t_inv] at * intro i _ - have _ := hinv.right.left i (by simp_all (config := {maxDischargeDepth := 1})) + have _ := hinv.right.left i (by fsimp_all) -- We need a case disjunction - cases h_ieq : key.val % List.length hm.slots.val == i <;> simp_all (config := {maxDischargeDepth := 2}) [slot_s_inv] - . simp [hinv] - . simp_all (config := {maxDischargeDepth := 1}) [frame_load, inv_base, inv_load] + cases h_ieq : key.val % List.length hm.slots.val == i <;> simp_all [slot_s_inv] + . fsimp [hinv] + . fsimp_all [frame_load, inv_base, inv_load] . simp_all [frame_slots_params] - . simp [lookup] at * - simp_all (config := {maxDischargeDepth := 2}) - . simp [lookup] at * + . fsimp [lookup] at * + simp_all + . fsimp [lookup] at * intro k hk -- We have to make a case disjunction: either the hashes are different, -- in which case we don't even lookup the same slots, or the hashes -- are the same, in which case we have to reason about what happens -- in one slot let k_hash_mod := k.val % hm.slots.length - have _ : 0 ≤ k_hash_mod ∧ k_hash_mod < alloc.vec.Vec.length hm.slots := by - simp_all (config := {maxDischargeDepth := 1}) [k_hash_mod] -- TODO: shouldn't need to do this - scalar_tac - cases h_hm: k_hash_mod == hash_mod.val <;> simp_all (config := {zetaDelta := true, maxDischargeDepth := 2}) + have _ : k_hash_mod < alloc.vec.Vec.length hm.slots := by + scalar_tac +nonLin + cases h_hm: k_hash_mod == hash_mod.val <;> simp_all (config := {zetaDelta := true}) - simp_all (config := {maxDischargeDepth := 1}) [nhm] + fsimp_all [nhm] private theorem slot_allP_not_key_lookup (slot : AList α) (h : slot.v.allP fun (k', _) => ¬k = k') : slot.lookup k = none := by - induction slot <;> simp_all (config := {maxDischargeDepth := 1}) + induction slot <;> fsimp_all -@[pspec] +@[progress] theorem move_elements_from_list_spec {T : Type} (ntable : HashMap T) (slot : AList T) (hinv : ntable.inv) - {l i : Int} (hSlotInv : slot_t_inv l i slot) + {l i : Nat} (hSlotInv : slot_t_inv l i slot) (hDisjoint1 : ∀ key v, ntable.lookup key = some v → slot.lookup key = none) (hDisjoint2 : ∀ key v, slot.lookup key = some v → ntable.lookup key = none) (hLen : ntable.al_v.length + slot.v.length ≤ Usize.max) @@ -496,47 +490,47 @@ theorem move_elements_from_list_spec rw [move_elements_from_list]; rw [move_elements_from_list_loop] cases slot with | Nil => - simp [hinv, frame_slots_params] + fsimp [hinv, frame_slots_params] | Cons key value slot1 => simp have hLookupKey : ntable.lookup key = none := by by_contra - cases h: ntable.lookup key <;> simp_all (config := {maxDischargeDepth := 1}) - have : ntable.lookup key = none → ntable.len_s < Usize.max := by simp_all (config := {maxDischargeDepth := 1}); scalar_tac + cases h: ntable.lookup key <;> fsimp_all + have : ntable.lookup key = none → ntable.len_s < Usize.max := by fsimp_all; scalar_tac progress as ⟨ ntable1, _, _, hLookup11, hLookup12, hLength1 ⟩ - simp [hLookupKey] at hLength1 + fsimp [hLookupKey] at hLength1 have hTable1LookupImp : ∀ (key : Usize) (v : T), ntable1.lookup key = some v → slot1.lookup key = none := by intro key' v hLookup if h: key = key' then - simp_all (config := {maxDischargeDepth := 1}) [slot_t_inv] + fsimp_all [slot_t_inv] apply slot_allP_not_key_lookup - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all else - simp_all (config := {maxDischargeDepth := 1}) - cases h: ntable.lookup key' <;> simp_all (config := {maxDischargeDepth := 2}) + fsimp_all + cases h: ntable.lookup key' <;> simp_all have := hDisjoint1 _ _ h - simp_all (config := {maxDischargeDepth := 2}) + simp_all have hSlot1LookupImp : ∀ (key : Usize) (v : T), slot1.lookup key = some v → ntable1.lookup key = none := by intro key' v hLookup if h: key' = key then by_contra rename _ => hNtable1NotNone - cases h: ntable1.lookup key' <;> simp [h] at hNtable1NotNone + cases h: ntable1.lookup key' <;> fsimp [h] at hNtable1NotNone have := hTable1LookupImp _ _ h - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all else have := hLookup12 key' h have := hDisjoint2 key' v - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all have : slot_t_inv l i slot1 := by - simp [slot_t_inv] at hSlotInv - simp [slot_t_inv, hSlotInv] + fsimp [slot_t_inv] at hSlotInv + fsimp [slot_t_inv, hSlotInv] progress as ⟨ ntable2, hInv2, _, hLookup21, hLookup22, hLookup23, hLen1 ⟩ -- TODO: allow progress to receive instantiation hints -- The conclusion -- TODO: use aesop here split_conjs - . simp [*] + . fsimp [*] . simp_all [frame_slots_params] . intro key' v hLookup have := hLookup21 key' v @@ -547,50 +541,51 @@ theorem move_elements_from_list_spec have := hDisjoint2 key' v have := hTable1LookupImp key' v have := hSlot1LookupImp key' v - simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] + fsimp_all [Slots.lookup] else have := hLookup12 key' - simp_all (config := {maxDischargeDepth := 2}) + simp_all . intro key' v hLookup1 if h: key' = key then - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all else have := hLookup12 key' h have := hLookup22 key' v - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all . intro key' v hLookup1 if h: key' = key then have := hLookup22 key' v - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all else have := hLookup23 key' v - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all . scalar_tac private theorem slots_forall_nil_imp_lookup_none (slots : Slots T) (hLen : slots.val.length ≠ 0) - (hEmpty : ∀ (j : Nat), j < slots.val.length → slots.val.index j = AList.Nil) : + (hEmpty : ∀ (j : Nat), j < slots.val.length → slots[j]! = AList.Nil) : ∀ key, slots.lookup key = none := by intro key - simp [Slots.lookup] + fsimp [Slots.lookup] -- TODO: simplify - have : 0 ≤ key.val % slots.val.length ∧ key.val % slots.val.length < slots.val.length := by - scalar_tac - have := hEmpty (key.val % (slots.val.length : Int)).toNat (by simp [*]) - simp [*] + have : key.val % slots.val.length < slots.val.length := by + scalar_tac +nonLin + have := hEmpty (key.val % slots.val.length) (by fsimp [*]) + fsimp at * + fsimp [*] private theorem slots_index_len_le_flatten_len (slots : List (AList α)) (i : Nat) (h : i < slots.length) : - (slots.index i).length ≤ (List.map AList.v slots).flatten.length := by + (slots[i]!).length ≤ (List.map AList.v slots).flatten.length := by match slots with | [] => - simp at * + fsimp at * | slot :: slots' => - simp at * + fsimp at * if hi : i = 0 then - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all else have := slots_index_len_le_flatten_len slots' (i - 1) (by scalar_tac) - simp [*] + fsimp [*] scalar_tac /- If we successfully lookup a key from a slot, the hash of the key modulo the number of slots must @@ -599,7 +594,7 @@ private theorem slots_index_len_le_flatten_len -/ private theorem slots_inv_lookup_imp_eq (slots : Slots α) (hInv : slots_t_inv slots) (i : Nat) (hi : i < slots.val.length) (key : Usize) : - (slots.val.index i).lookup key ≠ none → i = (key.val % slots.val.length).toNat := by + (slots[i]!).lookup key ≠ none → i = key.val % slots.val.length := by suffices hSlot : ∀ (slot : List (Usize × α)), slot_s_inv slots.val.length i slot → slot.lookup key ≠ none → @@ -607,21 +602,21 @@ private theorem slots_inv_lookup_imp_eq (slots : Slots α) (hInv : slots_t_inv s from by rw [slots_t_inv, slots_s_inv] at hInv replace hInv := hInv i hi - simp [slot_t_inv] at hInv + fsimp [slot_t_inv] at hInv have := hSlot _ hInv - scalar_tac + apply this intro slot - induction slot <;> simp_all (config := {maxDischargeDepth := 1}) - intros; simp_all (config := {maxDischargeDepth := 1}) - split at * <;> simp_all (config := {maxDischargeDepth := 1}) + induction slot <;> fsimp_all + intros; fsimp_all + split at * <;> fsimp_all private theorem move_slots_updated_table_lookup_imp (i : Nat) (ntable ntable1 ntable2 : HashMap α) (slots slots1 : Slots α) (slot : AList α) (hi : i < slots.val.length) (hSlotsInv : slots_t_inv slots) - (hSlotEq : slot = slots.val.index i) - (hSlotsEq : slots1.val = slots.val.update i .Nil) + (hSlotEq : slot = slots[i]!) + (hSlotsEq : slots1.val = slots.val.set i .Nil) (hTableLookup : ∀ (key : Usize) (v : α), ntable1.lookup key = some v → ntable.lookup key = some v ∨ slot.lookup key = some v) (hTable1Lookup : ∀ (key : Usize) (v : α), ntable2.lookup key = some v → @@ -634,21 +629,21 @@ private theorem move_slots_updated_table_lookup_imp cases hTable1Lookup with | inl hTable1Lookup => replace hTableLookup := hTableLookup hTable1Lookup - cases hTableLookup <;> try simp [*] + cases hTableLookup <;> try fsimp [*] right - have := slots_inv_lookup_imp_eq slots hSlotsInv i hi key (by simp_all (config := {maxDischargeDepth := 1})) - simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] + have := slots_inv_lookup_imp_eq slots hSlotsInv i hi key (by fsimp_all) + fsimp_all [Slots.lookup] | inr hTable1Lookup => right -- The key can't be for the slot we replaced - cases heq : (key.val % slots.val.length).toNat == i <;> simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] + cases heq : key.val % slots.val.length == i <;> fsimp_all [Slots.lookup] private theorem move_one_slot_lookup_equiv {α : Type} (ntable ntable1 ntable2 : HashMap α) (slot : AList α) (slots slots1 : Slots α) - (i : Nat) (h1 : i < slots.length) - (hSlotEq : slot = slots.val.index i) - (hSlots1Eq : slots1.val = slots.val.update i .Nil) + (i : Nat) + (hSlotEq : slot = slots[i]!) + (hSlots1Eq : slots1.val = slots.val.set i .Nil) (hLookup1 : ∀ (key : Usize) (v : α), ntable.lookup key = some v → ntable1.lookup key = some v) (hLookup2 : ∀ (key : Usize) (v : α), slot.lookup key = some v → ntable1.lookup key = some v) (hLookup3 : ∀ (key : Usize) (v : α), ntable1.lookup key = some v → ntable2.lookup key = some v) @@ -656,60 +651,60 @@ private theorem move_one_slot_lookup_equiv {α : Type} (ntable ntable1 ntable2 : (∀ key v, slots.lookup key = some v → ntable2.lookup key = some v) ∧ (∀ key v, ntable.lookup key = some v → ntable2.lookup key = some v) := by constructor <;> intro key v hLookup - . if hi: (key.val % slots.val.length).toNat = i then + . if hi: key.val % slots.val.length = i then -- We lookup in slot have := hLookup2 key v - simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] + fsimp_all [Slots.lookup] have := hLookup3 key v - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all else -- We lookup in slots have := hLookup4 key v - simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] + fsimp_all [Slots.lookup] . have := hLookup1 key v have := hLookup3 key v - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all private theorem slots_lookup_none_imp_slot_lookup_none (slots : Slots α) (hInv : slots_t_inv slots) (i : Nat) (hi : i < slots.val.length) : - ∀ (key : Usize), slots.lookup key = none → (slots.val.index i).lookup key = none := by + ∀ (key : Usize), slots.lookup key = none → (slots[i]!).lookup key = none := by intro key hLookup - if heq : (key.val % slots.val.length).toNat = i then - simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] + if heq : key.val % slots.val.length = i then + fsimp_all [Slots.lookup] else have := slots_inv_lookup_imp_eq slots hInv i (by scalar_tac) key by_contra - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all private theorem slot_lookup_not_none_imp_slots_lookup_not_none (slots : Slots α) (hInv : slots_t_inv slots) (i : Nat) (hi : i < slots.val.length) : - ∀ (key : Usize), (slots.val.index i).lookup key ≠ none → slots.lookup key ≠ none := by + ∀ (key : Usize), (slots[i]!).lookup key ≠ none → slots.lookup key ≠ none := by intro key hLookup hNone have := slots_lookup_none_imp_slot_lookup_none slots hInv i hi key hNone apply hLookup this private theorem slots_forall_nil_imp_al_v_nil (slots : Slots α) - (hEmpty : ∀ (i : Nat), i < slots.val.length → slots.val.index i = AList.Nil) : + (hEmpty : ∀ (i : Nat), i < slots.val.length → slots[i]! = AList.Nil) : slots.al_v = [] := by suffices h : ∀ (slots : List (AList α)), - (∀ (i : Nat), i < slots.length → slots.index i = Nil) → + (∀ (i : Nat), i < slots.length → slots[i]! = Nil) → (slots.map AList.v).flatten = [] from by replace h := h slots.val (by intro i h0; exact hEmpty i h0) - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all clear slots hEmpty intro slots hEmpty - induction slots <;> simp_all (config := {maxDischargeDepth := 1}) + induction slots <;> fsimp_all have hHead := hEmpty 0 (by scalar_tac) - simp at hHead - simp [hHead] + fsimp at hHead + fsimp [hHead] rename (_ → _) => ih apply ih; intro i h0 replace hEmpty := hEmpty (i + 1) (by omega) - -- TODO: simp at hEmpty + -- TODO: fsimp at hEmpty have : 0 < i + 1 := by omega - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all theorem move_elements_loop_spec {α : Type} (ntable : HashMap α) (slots : Slots α) @@ -718,9 +713,8 @@ theorem move_elements_loop_spec (hinv : ntable.inv) (hSlotsNonZero : slots.val.length ≠ 0) (hSlotsInv : slots_t_inv slots) - (hEmpty : ∀ j, j < i.toNat → slots.val.index j = AList.Nil) + (hEmpty : ∀ j, j < i.val → slots[j]! = AList.Nil) (hDisjoint1 : ∀ key v, ntable.lookup key = some v → slots.lookup key = none) - (hDisjoint2 : ∀ key v, slots.lookup key = some v → ntable.lookup key = none) (hLen : ntable.al_v.length + slots.al_v.length ≤ Usize.max) : ∃ ntable1 slots1, ntable.move_elements_loop slots i = ok (ntable1, slots1) ∧ @@ -730,75 +724,68 @@ theorem move_elements_loop_spec (∀ key v, ntable1.lookup key = some v → ntable.lookup key = some v ∨ slots.lookup key = some v) ∧ (∀ key v, slots.lookup key = some v → ntable1.lookup key = some v) ∧ (∀ key v, ntable.lookup key = some v → ntable1.lookup key = some v) ∧ - (∀ (j : Nat), j < slots1.length → slots1.val.index j = AList.Nil) + (∀ (j : Nat), j < slots1.length → slots1[j]! = AList.Nil) := by rw [move_elements_loop] simp dcases hi: i.val < slots.val.length . -- Continue the proof - have hIneq : 0 ≤ i.val ∧ i.val < slots.val.length := by scalar_tac - simp [hi] + have hIneq : i.val < slots.val.length := by scalar_tac + fsimp [hi] progress as ⟨ slot, hSlotEq ⟩ have hInvSlot : slot_t_inv slots.val.length i.val slot := by - simp [slots_t_inv] at hSlotsInv - simp [*] - have := hSlotsInv i.toNat - simp_all (config := {maxDischargeDepth := 1}) + fsimp [slots_t_inv] at hSlotsInv + fsimp [*] have ntableLookupImpSlot : ∀ (key : Usize) (v : α), ntable.lookup key = some v → slot.lookup key = none := by intro key v hLookup by_contra - have : i.toNat = (key.val % slots.val.length).toNat := by - have := slots_inv_lookup_imp_eq slots hSlotsInv i.toNat (by scalar_tac) key - simp_all (config := {maxDischargeDepth := 1}) - cases h: slot.lookup key <;> simp_all (config := {maxDischargeDepth := 1}) + have : i.val = key.val % slots.val.length := by + have := slots_inv_lookup_imp_eq slots hSlotsInv i.val (by scalar_tac) key + fsimp_all + cases h: slot.lookup key <;> fsimp_all have : ntable.al_v.length + slot.v.length ≤ Usize.max := by - have := slots_index_len_le_flatten_len slots.val i.toNat (by scalar_tac) - simp_all (config := {maxDischargeDepth := 1}) [Slots.al_v]; scalar_tac + have := slots_index_len_le_flatten_len slots.val i.val (by scalar_tac) + fsimp_all [Slots.al_v]; scalar_tac progress as ⟨ ntable1, _, _, hDisjointNtable1, hLookup11, hLookup12, hLen1 ⟩ . intro key v hLookup by_contra - cases h : ntable.lookup key <;> simp_all (config := {maxDischargeDepth := 1}) + cases h : ntable.lookup key <;> fsimp_all progress as ⟨ i' ⟩ - have : i' ≤ alloc.vec.Vec.len (alloc.vec.Vec.update slots i Nil) := by - simp_all (config := {maxDischargeDepth := 1}) [alloc.vec.Vec.len]; scalar_tac - have : slots_t_inv (alloc.vec.Vec.update slots i Nil) := by - simp [slots_t_inv] at * + have : i' ≤ alloc.vec.Vec.len (alloc.vec.Vec.set slots i Nil) := by + fsimp_all [alloc.vec.Vec.len]; scalar_tac + have : slots_t_inv (alloc.vec.Vec.set slots i Nil) := by + fsimp [slots_t_inv] at * intro j h0 - cases h: j == i.toNat <;> simp_all (config := {maxDischargeDepth := 2}) + cases h: j == i.val <;> simp_all have ntable1LookupImpSlots1 : ∀ (key : Usize) (v : α), ntable1.lookup key = some v → - Slots.lookup (alloc.vec.Vec.update slots i Nil) key = none := by + Slots.lookup (alloc.vec.Vec.set slots i Nil) key = none := by intro key v hLookup cases hDisjointNtable1 _ _ hLookup with | inl h => have := ntableLookupImpSlot _ _ h have := hDisjoint1 _ _ h - cases heq : i.toNat == (key.val % slots.val.length).toNat <;> simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] - rw [eq_comm] at heq - simp [*] + cases heq : i.val == key.val % slots.val.length <;> fsimp_all [Slots.lookup] | inr h => have heq : i = key.val % slots.val.length := by - have := slots_inv_lookup_imp_eq slots hSlotsInv i.toNat (by scalar_tac) key (by simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup]) + have := slots_inv_lookup_imp_eq slots hSlotsInv i.val (by scalar_tac) key (by fsimp_all [Slots.lookup]) scalar_tac - simp_all (config := {maxDischargeDepth := 2}) [Slots.lookup] + simp_all [Slots.lookup] progress as ⟨ ntable2, slots2, _, _, _, hLookup2Rev, hLookup21, hLookup22, hIndexNil ⟩ . intro j h0 - if h : j = i.toNat then - simp_all (config := {maxDischargeDepth := 2}) + if h : j = i.val then + simp_all else have := hEmpty j (by scalar_tac) - simp_all (config := {maxDischargeDepth := 1}) - . intro key v hLookup - by_contra h - cases h : ntable1.lookup key <;> simp_all (config := {maxDischargeDepth := 1}) + fsimp_all . have : i.val < (List.map AList.v slots.val).length := by simp; scalar_tac - simp_all (config := {maxDischargeDepth := 2}) [Slots.al_v, List.length_flatten_update_eq, List.map_update_eq, List.length_flatten_update_as_int_eq] + simp_all [Slots.al_v, List.length_flatten_set_eq, List.length_flatten_set_as_int_eq] scalar_tac simp @@ -806,46 +793,46 @@ theorem move_elements_loop_spec (∀ key v, slots.lookup key = some v → ntable2.lookup key = some v) ∧ (∀ key v, ntable.lookup key = some v → ntable2.lookup key = some v) := by exact move_one_slot_lookup_equiv ntable ntable1 ntable2 slot slots - (alloc.vec.Vec.update slots i Nil) i.toNat - (by scalar_tac) (by assumption) (by simp) + (alloc.vec.Vec.set slots i Nil) i.val + (by assumption) (by simp) (by assumption) (by assumption) (by assumption) (by assumption) split_conjs - . simp [*] - . simp_all (config := {maxDischargeDepth := 1}) [frame_slots_params] - . simp_all (config := {maxDischargeDepth := 1}) [Slots.al_v] + . fsimp [*] + . fsimp_all [frame_slots_params] + . fsimp_all [Slots.al_v] -- TODO scalar_tac_preprocess - have : i.toNat < slots.length := by scalar_tac - simp_all (config := {maxDischargeDepth := 2}) [List.length_flatten_update_as_int_eq] + have : i.val < slots.length := by scalar_tac + simp_all [List.length_flatten_set_as_int_eq] scalar_tac . intro key v hLookup - apply move_slots_updated_table_lookup_imp i.toNat ntable ntable1 ntable2 slots (alloc.vec.Vec.update slots i Nil) slot (by scalar_tac) <;> + apply move_slots_updated_table_lookup_imp i.val ntable ntable1 ntable2 slots (alloc.vec.Vec.set slots i Nil) slot (by scalar_tac) <;> first | assumption | simp . apply hLookupPreserve.left . apply hLookupPreserve.right . intro j h0 apply hIndexNil j h0 - . simp [hi, *] - -- TODO: simp_all (config := {maxDischargeDepth := 1}) removes hEmpty!! + . fsimp [hi, *] + -- TODO: fsimp_all removes hEmpty!! have hi : i = alloc.vec.Vec.len slots := by scalar_tac - have hEmpty : ∀ (j : Nat), j < slots.val.length → slots.val.index j = AList.Nil := by - simp [hi] at hEmpty + have hEmpty : ∀ (j : Nat), j < slots.val.length → slots[j]! = AList.Nil := by + fsimp [hi] at hEmpty exact hEmpty have hNil : slots.al_v = [] := slots_forall_nil_imp_al_v_nil slots hEmpty - have hLenNonZero : slots.val.length ≠ 0 := by simp [*] + have hLenNonZero : slots.val.length ≠ 0 := by fsimp [*] have hLookupEmpty := slots_forall_nil_imp_lookup_none slots hLenNonZero hEmpty - simp [hNil, hLookupEmpty, frame_slots_params] + fsimp [hNil, hLookupEmpty, frame_slots_params] split_conjs . intros - simp [*] + fsimp [*] . intros simp_all . apply hEmpty -termination_by (slots.val.length - i.val).toNat +termination_by slots.val.length - i.val decreasing_by scalar_decr_tac -- TODO: this is expensive -@[pspec] +@[progress] theorem move_elements_spec {α : Type} (ntable : HashMap α) (slots : Slots α) (hinv : ntable.inv) @@ -863,15 +850,11 @@ theorem move_elements_spec (∀ key v, ntable1.lookup key = some v ↔ slots.lookup key = some v) := by rw [move_elements] - have ⟨ ntable1, slots1, hEq, _, _, _, ntable1Lookup, slotsLookup, _, _ ⟩ := - move_elements_loop_spec ntable slots 0#usize (by scalar_tac) hinv - (by scalar_tac) - hSlotsInv - (by intro j h0; scalar_tac) - (by simp [*]) - (by simp [*]) - (by scalar_tac) - simp [hEq]; clear hEq + progress with move_elements_loop_spec as ⟨ ntable1, slots1, _, _, _, ntable1Lookup, slotsLookup ⟩ + . -- Remaining precondition + fsimp [*] + -- Postcondition + fsimp have : frame_slots_params ntable ntable1 := by simp_all [frame_slots_params] split_conjs <;> try assumption @@ -879,9 +862,9 @@ theorem move_elements_spec intro key v have := ntable1Lookup key v have := slotsLookup key v - constructor <;> simp_all (config := {maxDischargeDepth := 1}) + constructor <;> fsimp_all -@[pspec] +@[progress] theorem try_resize_spec {α : Type} (hm : HashMap α) (hInv : hm.inv): ∃ hm', hm.try_resize = ok hm' ∧ hm'.inv ∧ @@ -891,65 +874,63 @@ theorem try_resize_spec {α : Type} (hm : HashMap α) (hInv : hm.inv): simp progress as ⟨ n1 ⟩ -- TODO: simplify (Usize.ofInt (OfNat.ofNat 2) try_resize.proof_1).val have : hm.2.1.val ≠ 0 := by - simp [inv, inv_load] at hInv + fsimp [inv, inv_load] at hInv -- TODO: why does hm.max_load_factor appears as hm.2?? -- Can we deactivate field notations? omega progress as ⟨ n2 ⟩ if hSmaller : hm.slots.val.length ≤ n2.val then - simp [hSmaller] + fsimp [hSmaller] have : (alloc.vec.Vec.len hm.slots).val * 2 ≤ Usize.max := by - simp [alloc.vec.Vec.len, inv, inv_load] at * + fsimp [alloc.vec.Vec.len, inv, inv_load] at * -- TODO: this should be automated - have hIneq1 : n1.val ≤ Usize.max / 2 := by simp [*] - simp [Int.le_ediv_iff_mul_le] at hIneq1 + have hIneq1 : n1.val ≤ Usize.max / 2 := by fsimp [*] + fsimp [Int.le_ediv_iff_mul_le] at hIneq1 -- TODO: this should be automated - have hIneq2 : n2.val ≤ n1.val / hm.2.1.val := by simp [*] - rw [Int.le_ediv_iff_mul_le] at hIneq2 <;> try simp [*] + have hIneq2 : n2.val ≤ n1.val / hm.2.1.val := by fsimp [*] + rw [Nat.le_div_iff_mul_le] at hIneq2 <;> try fsimp [*] have : n2.val * 1 ≤ n2.val * hm.max_load_factor.1.val := by - apply Int.mul_le_mul <;> scalar_tac + apply Nat.mul_le_mul <;> scalar_tac scalar_tac progress as ⟨ newLength ⟩ have : 0 < newLength.val := by - simp_all (config := {maxDischargeDepth := 1}) [inv, inv_load] + fsimp_all [inv, inv_load] progress as ⟨ ntable1 ⟩ -- TODO: introduce nice notation to take care of preconditions . -- Pre 1 - simp_all (config := {maxDischargeDepth := 1}) [inv, inv_load] + fsimp_all [inv, inv_load] split_conjs at hInv -- - apply Int.mul_le_of_le_ediv at hSmaller <;> try simp [*] - apply Int.mul_le_of_le_ediv at hSmaller <;> try simp + apply Nat.mul_le_of_le_div at hSmaller; try fsimp [*] + apply Nat.mul_le_of_le_div at hSmaller; try simp -- have : (hm.slots.val.length * hm.2.1.val) * 1 ≤ (hm.slots.val.length * hm.2.1.val) * 2 := by - apply Int.mul_le_mul <;> (try simp [*]); scalar_tac + apply Nat.mul_le_mul <;> (try fsimp [*]) -- ring_nf at * - simp [*] - unfold max_load max_load_factor at * - omega + scalar_tac . -- Pre 2 - simp_all (config := {maxDischargeDepth := 1}) [inv, inv_load] + fsimp_all [inv, inv_load] unfold max_load_factor at * -- TODO: this is really annoying omega . -- End of the proof - have : slots_t_inv hm.slots := by simp_all (config := {maxDischargeDepth := 1}) [inv] -- TODO - have : (Slots.al_v hm.slots).length ≤ Usize.max := by simp_all (config := {maxDischargeDepth := 1}) [inv, al_v, v, Slots.al_v]; scalar_tac + have : slots_t_inv hm.slots := by fsimp_all [inv] -- TODO + have : (Slots.al_v hm.slots).length ≤ Usize.max := by fsimp_all [inv, al_v, v, Slots.al_v]; scalar_tac progress as ⟨ ntable2, slots1, _, _, _, hLookup ⟩ -- TODO: assumption is not powerful enough - simp_all (config := {maxDischargeDepth := 1}) [lookup, al_v, v, alloc.vec.Vec.len] + fsimp_all [lookup, al_v, v, alloc.vec.Vec.len] split_conjs - . simp_all (config := {maxDischargeDepth := 1}) [inv, al_v, HashMap.v] + . fsimp_all [inv, al_v, HashMap.v] -- load invariant simp_all [inv_load, frame_slots_params] . intro key replace hLookup := hLookup key - cases h1: (ntable2.slots.val.index (key.val % ntable2.slots.val.length).toNat).v.lookup key <;> - cases h2: (hm.slots.val.index (key.val % hm.slots.val.length).toNat).v.lookup key <;> - simp_all (config := {maxDischargeDepth := 1}) [Slots.lookup] + cases h1: (ntable2.slots.val[key.val % ntable2.slots.val.length]!).v.lookup key <;> + cases h2: (hm.slots.val[key.val % hm.slots.val.length]!).v.lookup key <;> + fsimp_all [Slots.lookup] else - simp [hSmaller] + fsimp [hSmaller] tauto -@[pspec] +@[progress] theorem insert_spec {α} (hm : HashMap α) (key : Usize) (value : α) (hInv : hm.inv) (hNotSat : hm.lookup key = none → hm.len_s < Usize.max) : @@ -965,36 +946,37 @@ theorem insert_spec {α} (hm : HashMap α) (key : Usize) (value : α) := by rw [insert] progress as ⟨ hm1 ⟩ - simp [len] + fsimp [len] split . split - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all . progress as ⟨ hm2 ⟩ - simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) + fsimp_all + . fsimp_all -@[pspec] +@[progress] theorem get_in_list_spec {α} (key : Usize) (slot : AList α) : ∃ opt_v, get_in_list key slot = ok opt_v ∧ slot.lookup key = opt_v := by induction slot <;> rw [get_in_list, get_in_list_loop] <;> - simp_all (config := {maxDischargeDepth := 1}) - split <;> simp_all (config := {maxDischargeDepth := 2}) + fsimp_all + split <;> simp_all -@[pspec] +@[progress] theorem get_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : ∃ opt_v, get hm key = ok opt_v ∧ hm.lookup key = opt_v := by rw [get] - simp [hash_key, alloc.vec.Vec.len] + fsimp [hash_key, alloc.vec.Vec.len] progress as ⟨ hash_mod ⟩ -- TODO: decompose post by default - simp at * + fsimp at * + have : hash_mod.val < hm.slots.length := by scalar_tac +nonLin progress as ⟨ slot ⟩ progress as ⟨ v ⟩ - simp_all (config := {maxDischargeDepth := 1}) [lookup] + fsimp_all [lookup] -@[pspec] +@[progress] theorem get_mut_in_list_spec {α} (key : Usize) (slot : AList α) - {l i : Int} + {l i : Nat} (hInv : slot_t_inv l i slot) : ∃ opt_v back, get_mut_in_list slot key = ok (opt_v, back) ∧ slot.lookup key = opt_v ∧ @@ -1013,28 +995,28 @@ theorem get_mut_in_list_spec {α} (key : Usize) (slot : AList α) := by induction slot <;> rw [get_mut_in_list, get_mut_in_list_loop] <;> - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all split . -- Non-recursive case - simp_all (config := {maxDischargeDepth := 1}) [slot_t_inv] + fsimp_all [slot_t_inv] . -- Recursive case -- TODO: progress by progress as ⟨ opt_v, back, _, hBackNone, hBackSome ⟩ - . simp_all (config := {maxDischargeDepth := 1}) [slot_t_inv] + . fsimp_all [slot_t_inv] . simp [*] -- Proving the post-condition about back -- Case disjunction on v split_conjs - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all . intro v v' heq have := hBackSome v v' split_conjs - . simp_all (config := {maxDischargeDepth := 1}) [slot_t_inv, slot_s_inv, slot_s_inv_hash] - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 1}) + . fsimp_all [slot_t_inv, slot_s_inv, slot_s_inv_hash] + . fsimp_all + . fsimp_all + . fsimp_all -@[pspec] +@[progress] theorem get_mut_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : ∃ opt_v back, get_mut hm key = ok (opt_v, back) ∧ hm.lookup key = opt_v ∧ @@ -1049,21 +1031,17 @@ theorem get_mut_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : ∀ key', key' ≠ key → hm'.lookup key' = hm.lookup key') := by rw [get_mut] - simp [hash_key, alloc.vec.Vec.len] + fsimp [hash_key, alloc.vec.Vec.len] progress as ⟨ hash_mod ⟩ - simp at * - have : 0 ≤ hash_mod.val ∧ hash_mod.val < hm.slots.val.length ∧ hash_mod.toNat < hm.slots.val.length := by scalar_tac + fsimp at * + have : hash_mod.val < hm.slots.val.length ∧ hash_mod.val < hm.slots.val.length := by scalar_tac +nonLin progress as ⟨ slot, index_back ⟩ have : slot_t_inv hm.slots.val.length hash_mod slot := by - simp_all (config := {maxDischargeDepth := 1}) [inv, slots_t_inv] - have := hInv.right.left (key % (hm.slots.val.length : Int)).toNat - simp_all (config := {maxDischargeDepth := 1}) - /-have : slot.lookup key ≠ none := by - simp_all (config := {maxDischargeDepth := 1}) [lookup]-/ + fsimp_all [inv, slots_t_inv] progress as ⟨ opt_v, back, _, hBackNone, hBackSome ⟩ - simp [lookup, *] + fsimp [lookup, *] constructor - . simp_all (config := {maxDischargeDepth := 1}) [lookup] + . fsimp_all [lookup] . -- Backward function split_conjs . -- case: none @@ -1071,24 +1049,24 @@ theorem get_mut_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : simp_all -- TODO: tactic to automate this have hSlotsEq : - hm.slots.update hash_mod ((hm.slots.val).index (key.val % (hm.slots.val).length).toNat) = hm.slots := by - simp_all [alloc.vec.Vec.update] - simp [hSlotsEq] + hm.slots.set hash_mod ((hm.slots.val)[(key.val % (hm.slots.val).length)]!) = hm.slots := by + simp_all [alloc.vec.Vec.set] + fsimp [hSlotsEq] . -- case: some intro v v' hVeq - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all -- Last postcondition replace hBackSome := hBackSome v v' (by simp) have ⟨ _, _, _, _ ⟩ := hBackSome clear hBackSome intro key' hNotEq -- TODO: simplify - have : 0 ≤ key'.val % hm.slots.val.length ∧ key'.val % hm.slots.val.length < hm.slots.val.length := by scalar_tac + have : key'.val % hm.slots.val.length < hm.slots.val.length := by scalar_tac +nonLin -- We need to do a case disjunction cases h: (key.val % hm.slots.val.length == key'.val % hm.slots.val.length) <;> - simp_all (config := {maxDischargeDepth := 2}) + simp_all -@[pspec] +@[progress] theorem remove_from_list_spec {α} (key : Usize) (slot : AList α) {l i} (hInv : slot_t_inv l i slot) : ∃ v slot', remove_from_list key slot = ok (v, slot') ∧ slot.lookup key = v ∧ @@ -1104,29 +1082,29 @@ theorem remove_from_list_spec {α} (key : Usize) (slot : AList α) {l i} (hInv : | .Cons k v0 tl => simp if hKey : k = key then - simp [hKey] - simp_all (config := {maxDischargeDepth := 1}) [slot_t_inv, slot_s_inv] + fsimp [hKey] + fsimp_all [slot_t_inv, slot_s_inv] apply slot_allP_not_key_lookup - simp [*] + fsimp [*] else - simp [hKey] - have hInv' : slot_t_inv l i tl := by simp_all (config := {maxDischargeDepth := 1}) [slot_t_inv] + fsimp [hKey] + have hInv' : slot_t_inv l i tl := by fsimp_all [slot_t_inv] progress as ⟨ v1, tl1, _, _, hLookupTl1, _ ⟩ simp [*] intro key' hNotEq1 - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all private theorem lookup_not_none_imp_len_s_pos (hm : HashMap α) (key : Usize) (hLookup : hm.lookup key ≠ none) (hNotEmpty : 0 < hm.slots.val.length) : 0 < hm.len_s := by -- TODO: simplify - have : 0 ≤ key.val % hm.slots.val.length ∧ key.val % hm.slots.val.length < hm.slots.val.length := by scalar_tac - have := List.length_index_le_length_flatten hm.v (key.val % hm.slots.val.length).toNat - have := List.lookup_not_none_imp_length_pos (hm.slots.val.index (key.val % hm.slots.val.length).toNat).v key - simp_all (config := {maxDischargeDepth := 2}) [lookup, len_s, al_v, v] + have : key.val % hm.slots.val.length < hm.slots.val.length := by scalar_tac +nonLin + have := List.length_getElem!_le_length_flatten hm.v (key.val % hm.slots.val.length) + have := List.lookup_not_none_imp_length_pos (hm.slots.val[key.val % hm.slots.val.length]!).v key + simp_all [lookup, len_s, al_v, v] scalar_tac -@[pspec] +@[progress] theorem remove_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : ∃ v hm', remove hm key = ok (v, hm') ∧ hm.lookup key = v ∧ @@ -1136,46 +1114,43 @@ theorem remove_spec {α} (hm : HashMap α) (key : Usize) (hInv : hm.inv) : | none => hm'.len_s = hm.len_s | some _ => hm'.len_s = hm.len_s - 1 := by rw [remove] - simp [hash_key, alloc.vec.Vec.len] + fsimp [hash_key, alloc.vec.Vec.len] progress as ⟨ hash_mod ⟩ -- TODO: decompose post by default - simp at * + fsimp at * -- TODO: simplify - have : 0 ≤ hash_mod.val ∧ hash_mod.val < hm.slots.val.length := by - scalar_tac + have : hash_mod.val < hm.slots.val.length := by + scalar_tac +nonLin progress as ⟨ slot, index_back ⟩ have : slot_t_inv hm.slots.val.length hash_mod slot := by - simp_all (config := {maxDischargeDepth := 1}) [inv, slots_t_inv] - have := hInv.right.left (key % (hm.slots.val.length : Int)).toNat - simp_all (config := {maxDischargeDepth := 1}) + fsimp_all [inv, slots_t_inv] progress as ⟨ vOpt, slot' ⟩ - cases hOpt : vOpt with - | none => - simp [*] - simp [lookup, *] - simp_all (config := {maxDischargeDepth := 2}) [al_v, v] + cases hOpt : vOpt + . fsimp [*] + fsimp [lookup, *] + simp_all [al_v, v] split_conjs . intro key' hNotEq -- We need to make a case disjunction - have : (key' % (hm.slots.val.length : Int)).toNat < hm.slots.val.length := by scalar_tac - cases h: (key.val % hm.slots.val.length).toNat == (key'.val % hm.slots.val.length).toNat <;> - simp_all (config := {maxDischargeDepth := 1}) + have : key' % hm.slots.val.length < hm.slots.val.length := by scalar_tac +nonLin + cases h: key.val % hm.slots.val.length == key'.val % hm.slots.val.length <;> + fsimp_all . -- TODO scalar_tac_preprocess - simp_all (config := {maxDischargeDepth := 2}) + fsimp_all omega - | some v => - simp [*] + . rename_i v + fsimp [*] have : 0 < hm.num_entries.val := by - have := lookup_not_none_imp_len_s_pos hm key (by simp_all (config := {maxDischargeDepth := 1}) [lookup]) (by simp_all (config := {maxDischargeDepth := 1}) [inv]) - simp_all (config := {maxDischargeDepth := 1}) [inv] + have := lookup_not_none_imp_len_s_pos hm key (by fsimp_all [lookup]) (by fsimp_all [inv]) + fsimp_all [inv] progress as ⟨ newSize ⟩ - simp_all (config := {maxDischargeDepth := 2}) [lookup, al_v, HashMap.v] + simp_all [lookup, al_v, HashMap.v] constructor . intro key' hNotEq - have : (key' % (hm.slots.val.length : Int)).toNat < hm.slots.val.length := by scalar_tac - cases h: (key.val % hm.slots.val.length).toNat == (key'.val % hm.slots.val.length).toNat <;> - simp_all (config := {maxDischargeDepth := 1}) - . simp_all (config := {maxDischargeDepth := 2}) [List.length_flatten_update_as_int_eq] + have : key' % hm.slots.val.length < hm.slots.val.length := by scalar_tac +nonLin + cases h: key.val % hm.slots.val.length == key'.val % hm.slots.val.length <;> + fsimp_all + . simp_all [List.length_flatten_set_as_int_eq] scalar_tac end HashMap diff --git a/tests/lean/Hashmap/Types.lean b/tests/lean/Hashmap/Types.lean index 6e6e8ab0..9500a714 100644 --- a/tests/lean/Hashmap/Types.lean +++ b/tests/lean/Hashmap/Types.lean @@ -2,7 +2,7 @@ -- [hashmap]: type definitions import Aeneas import Hashmap.TypesExternal -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/Hashmap/TypesExternal_Template.lean b/tests/lean/Hashmap/TypesExternal_Template.lean index 32cddcef..4772015a 100644 --- a/tests/lean/Hashmap/TypesExternal_Template.lean +++ b/tests/lean/Hashmap/TypesExternal_Template.lean @@ -2,7 +2,7 @@ -- [hashmap]: external types. -- This is a template file: rename it to "TypesExternal.lean" and fill the holes. import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/InfiniteLoop.lean b/tests/lean/InfiniteLoop.lean index 05e09ecb..0bf8ab2a 100644 --- a/tests/lean/InfiniteLoop.lean +++ b/tests/lean/InfiniteLoop.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [infinite_loop] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -11,7 +11,7 @@ namespace infinite_loop /- [infinite_loop::bar]: Source: 'tests/src/infinite-loop.rs', lines 4:0-4:11 -/ def bar : Result Unit := - Result.ok () + ok () /- [infinite_loop::foo]: loop 0: Source: 'tests/src/infinite-loop.rs', lines 8:8-8:13 -/ diff --git a/tests/lean/Issue194RecursiveStructProjector.lean b/tests/lean/Issue194RecursiveStructProjector.lean index 0d9038b5..8aafb583 100644 --- a/tests/lean/Issue194RecursiveStructProjector.lean +++ b/tests/lean/Issue194RecursiveStructProjector.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [issue_194_recursive_struct_projector] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -40,11 +40,11 @@ theorem AVLNode.right._simpLemma_ {T : Type} (value : T) (left : Option /- [issue_194_recursive_struct_projector::get_val]: Source: 'tests/src/issue-194-recursive-struct-projector.rs', lines 10:0-12:1 -/ def get_val {T : Type} (x : AVLNode T) : Result T := - Result.ok x.value + ok x.value /- [issue_194_recursive_struct_projector::get_left]: Source: 'tests/src/issue-194-recursive-struct-projector.rs', lines 14:0-16:1 -/ def get_left {T : Type} (x : AVLNode T) : Result (Option (AVLNode T)) := - Result.ok x.left + ok x.left end issue_194_recursive_struct_projector diff --git a/tests/lean/Issue270LoopList.lean b/tests/lean/Issue270LoopList.lean index f3657eed..c3d77508 100644 --- a/tests/lean/Issue270LoopList.lean +++ b/tests/lean/Issue270LoopList.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [issue_270_loop_list] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -19,13 +19,13 @@ inductive List (T : Type) where divergent def foo_loop (t : List (List U8)) : Result Unit := match t with | List.Cons _ tt => foo_loop tt - | List.Nil => Result.ok () + | List.Nil => ok () /- [issue_270_loop_list::foo]: Source: 'tests/src/issue-270-loop-list.rs', lines 7:0-14:1 -/ def foo (v : List (List U8)) : Result Unit := match v with | List.Cons l t => foo_loop t - | List.Nil => Result.ok () + | List.Nil => ok () end issue_270_loop_list diff --git a/tests/lean/Issue440TypeError.lean b/tests/lean/Issue440TypeError.lean index 42bac700..a0c67e8d 100644 --- a/tests/lean/Issue440TypeError.lean +++ b/tests/lean/Issue440TypeError.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [issue_440_type_error] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -20,7 +20,7 @@ def f (x : PeanoNum) (value : Isize) : Result PeanoNum := match x with | PeanoNum.Zero => let (_, x1) := core.mem.replace PeanoNum.Zero (PeanoNum.Succ PeanoNum.Zero) - Result.ok x1 - | PeanoNum.Succ _ => Result.ok x + ok x1 + | PeanoNum.Succ _ => ok x end issue_440_type_error diff --git a/tests/lean/Loops.lean b/tests/lean/Loops.lean index 9452b33d..5955eb67 100644 --- a/tests/lean/Loops.lean +++ b/tests/lean/Loops.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [loops] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -20,8 +20,8 @@ divergent def sum_loop (max : U32) (i : U32) (s : U32) : Result U32 := /- [loops::sum]: Source: 'tests/src/loops.rs', lines 8:0-18:1 -/ -def sum (max : U32) : Result U32 := - sum_loop max 0#u32 0#u32 +@[reducible] def sum (max : U32) : Result U32 := + sum_loop max 0#u32 0#u32 /- [loops::sum_with_mut_borrows]: loop 0: Source: 'tests/src/loops.rs', lines 26:4-31:5 -/ @@ -37,6 +37,7 @@ divergent def sum_with_mut_borrows_loop /- [loops::sum_with_mut_borrows]: Source: 'tests/src/loops.rs', lines 23:0-35:1 -/ +@[reducible] def sum_with_mut_borrows (max : U32) : Result U32 := sum_with_mut_borrows_loop max 0#u32 0#u32 @@ -54,6 +55,7 @@ divergent def sum_with_shared_borrows_loop /- [loops::sum_with_shared_borrows]: Source: 'tests/src/loops.rs', lines 38:0-52:1 -/ +@[reducible] def sum_with_shared_borrows (max : U32) : Result U32 := sum_with_shared_borrows_loop max 0#u32 0#u32 @@ -68,10 +70,11 @@ divergent def sum_array_loop let s1 ← s + i1 let i2 ← i + 1#usize sum_array_loop a i2 s1 - else Result.ok s + else ok s /- [loops::sum_array]: Source: 'tests/src/loops.rs', lines 54:0-62:1 -/ +@[reducible] def sum_array {N : Usize} (a : Array U32 N) : Result U32 := sum_array_loop a 0#usize 0#u32 @@ -89,10 +92,11 @@ divergent def clear_loop let i2 ← i + 1#usize let v1 := index_mut_back 0#u32 clear_loop v1 i2 - else Result.ok v + else ok v /- [loops::clear]: Source: 'tests/src/loops.rs', lines 66:0-72:1 -/ +@[reducible] def clear (v : alloc.vec.Vec U32) : Result (alloc.vec.Vec U32) := clear_loop v 0#usize @@ -107,9 +111,9 @@ inductive List (T : Type) where divergent def list_mem_loop (x : U32) (ls : List U32) : Result Bool := match ls with | List.Cons y tl => if y = x - then Result.ok true + then ok true else list_mem_loop x tl - | List.Nil => Result.ok false + | List.Nil => ok false /- [loops::list_mem]: Source: 'tests/src/loops.rs', lines 80:0-89:1 -/ @@ -125,15 +129,15 @@ divergent def list_nth_mut_loop_loop | List.Cons x tl => if i = 0#u32 then let back := fun ret => List.Cons ret tl - Result.ok (x, back) + ok (x, back) else do let i1 ← i - 1#u32 let (t, back) ← list_nth_mut_loop_loop tl i1 let back1 := fun ret => let tl1 := back ret List.Cons x tl1 - Result.ok (t, back1) - | List.Nil => Result.fail .panic + ok (t, back1) + | List.Nil => fail panic /- [loops::list_nth_mut_loop]: Source: 'tests/src/loops.rs', lines 92:0-102:1 -/ @@ -149,11 +153,11 @@ divergent def list_nth_shared_loop_loop match ls with | List.Cons x tl => if i = 0#u32 - then Result.ok x + then ok x else do let i1 ← i - 1#u32 list_nth_shared_loop_loop tl i1 - | List.Nil => Result.fail .panic + | List.Nil => fail panic /- [loops::list_nth_shared_loop]: Source: 'tests/src/loops.rs', lines 105:0-115:1 -/ @@ -169,14 +173,14 @@ divergent def get_elem_mut_loop | List.Cons y tl => if y = x then let back := fun ret => List.Cons ret tl - Result.ok (y, back) + ok (y, back) else do let (i, back) ← get_elem_mut_loop x tl let back1 := fun ret => let tl1 := back ret List.Cons y tl1 - Result.ok (i, back1) - | List.Nil => Result.fail .panic + ok (i, back1) + | List.Nil => fail panic /- [loops::get_elem_mut]: Source: 'tests/src/loops.rs', lines 117:0-131:1 -/ @@ -191,7 +195,7 @@ def get_elem_mut let (i, back) ← get_elem_mut_loop x ls let back1 := fun ret => let l := back ret index_mut_back l - Result.ok (i, back1) + ok (i, back1) /- [loops::get_elem_shared]: loop 0: Source: 'tests/src/loops.rs', lines 135:4-147:1 -/ @@ -199,9 +203,9 @@ divergent def get_elem_shared_loop (x : Usize) (ls : List Usize) : Result Usize := match ls with | List.Cons y tl => if y = x - then Result.ok y + then ok y else get_elem_shared_loop x tl - | List.Nil => Result.fail .panic + | List.Nil => fail panic /- [loops::get_elem_shared]: Source: 'tests/src/loops.rs', lines 133:0-147:1 -/ @@ -217,12 +221,12 @@ def get_elem_shared Source: 'tests/src/loops.rs', lines 149:0-151:1 -/ def id_mut {T : Type} (ls : List T) : Result ((List T) × (List T → List T)) := - Result.ok (ls, fun ret => ret) + ok (ls, fun ret => ret) /- [loops::id_shared]: Source: 'tests/src/loops.rs', lines 153:0-155:1 -/ def id_shared {T : Type} (ls : List T) : Result (List T) := - Result.ok ls + ok ls /- [loops::list_nth_mut_loop_with_id]: loop 0: Source: 'tests/src/loops.rs', lines 160:4-169:1 -/ @@ -232,15 +236,15 @@ divergent def list_nth_mut_loop_with_id_loop | List.Cons x tl => if i = 0#u32 then let back := fun ret => List.Cons ret tl - Result.ok (x, back) + ok (x, back) else do let i1 ← i - 1#u32 let (t, back) ← list_nth_mut_loop_with_id_loop i1 tl let back1 := fun ret => let tl1 := back ret List.Cons x tl1 - Result.ok (t, back1) - | List.Nil => Result.fail .panic + ok (t, back1) + | List.Nil => fail panic /- [loops::list_nth_mut_loop_with_id]: Source: 'tests/src/loops.rs', lines 158:0-169:1 -/ @@ -251,7 +255,7 @@ def list_nth_mut_loop_with_id let (t, back) ← list_nth_mut_loop_with_id_loop i ls1 let back1 := fun ret => let l := back ret id_mut_back l - Result.ok (t, back1) + ok (t, back1) /- [loops::list_nth_shared_loop_with_id]: loop 0: Source: 'tests/src/loops.rs', lines 174:4-183:1 -/ @@ -260,11 +264,11 @@ divergent def list_nth_shared_loop_with_id_loop match ls with | List.Cons x tl => if i = 0#u32 - then Result.ok x + then ok x else do let i1 ← i - 1#u32 list_nth_shared_loop_with_id_loop i1 tl - | List.Nil => Result.fail .panic + | List.Nil => fail panic /- [loops::list_nth_shared_loop_with_id]: Source: 'tests/src/loops.rs', lines 172:0-183:1 -/ @@ -288,7 +292,7 @@ divergent def list_nth_mut_loop_pair_loop then let back'a := fun ret => List.Cons ret tl0 let back'b := fun ret => List.Cons ret tl1 - Result.ok ((x0, x1), back'a, back'b) + ok ((x0, x1), back'a, back'b) else do let i1 ← i - 1#u32 @@ -297,9 +301,9 @@ divergent def list_nth_mut_loop_pair_loop List.Cons x0 tl01 let back'b1 := fun ret => let tl11 := back'b ret List.Cons x1 tl11 - Result.ok (p, back'a1, back'b1) - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + ok (p, back'a1, back'b1) + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_mut_loop_pair]: Source: 'tests/src/loops.rs', lines 188:0-209:1 -/ @@ -319,12 +323,12 @@ divergent def list_nth_shared_loop_pair_loop match ls1 with | List.Cons x1 tl1 => if i = 0#u32 - then Result.ok (x0, x1) + then ok (x0, x1) else do let i1 ← i - 1#u32 list_nth_shared_loop_pair_loop tl0 tl1 i1 - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_shared_loop_pair]: Source: 'tests/src/loops.rs', lines 212:0-233:1 -/ @@ -348,7 +352,7 @@ divergent def list_nth_mut_loop_pair_merge_loop let back := fun ret => let (t, t1) := ret (List.Cons t tl0, List.Cons t1 tl1) - Result.ok ((x0, x1), back) + ok ((x0, x1), back) else do let i1 ← i - 1#u32 @@ -357,9 +361,9 @@ divergent def list_nth_mut_loop_pair_merge_loop fun ret => let (tl01, tl11) := back ret (List.Cons x0 tl01, List.Cons x1 tl11) - Result.ok (p, back1) - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + ok (p, back1) + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_mut_loop_pair_merge]: Source: 'tests/src/loops.rs', lines 237:0-252:1 -/ @@ -379,13 +383,13 @@ divergent def list_nth_shared_loop_pair_merge_loop match ls1 with | List.Cons x1 tl1 => if i = 0#u32 - then Result.ok (x0, x1) + then ok (x0, x1) else do let i1 ← i - 1#u32 list_nth_shared_loop_pair_merge_loop tl0 tl1 i1 - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_shared_loop_pair_merge]: Source: 'tests/src/loops.rs', lines 255:0-270:1 -/ @@ -406,16 +410,16 @@ divergent def list_nth_mut_shared_loop_pair_loop | List.Cons x1 tl1 => if i = 0#u32 then let back := fun ret => List.Cons ret tl0 - Result.ok ((x0, x1), back) + ok ((x0, x1), back) else do let i1 ← i - 1#u32 let (p, back) ← list_nth_mut_shared_loop_pair_loop tl0 tl1 i1 let back1 := fun ret => let tl01 := back ret List.Cons x0 tl01 - Result.ok (p, back1) - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + ok (p, back1) + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_mut_shared_loop_pair]: Source: 'tests/src/loops.rs', lines 273:0-288:1 -/ @@ -438,16 +442,16 @@ divergent def list_nth_mut_shared_loop_pair_merge_loop | List.Cons x1 tl1 => if i = 0#u32 then let back := fun ret => List.Cons ret tl0 - Result.ok ((x0, x1), back) + ok ((x0, x1), back) else do let i1 ← i - 1#u32 let (p, back) ← list_nth_mut_shared_loop_pair_merge_loop tl0 tl1 i1 let back1 := fun ret => let tl01 := back ret List.Cons x0 tl01 - Result.ok (p, back1) - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + ok (p, back1) + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_mut_shared_loop_pair_merge]: Source: 'tests/src/loops.rs', lines 292:0-307:1 -/ @@ -470,16 +474,16 @@ divergent def list_nth_shared_mut_loop_pair_loop | List.Cons x1 tl1 => if i = 0#u32 then let back := fun ret => List.Cons ret tl1 - Result.ok ((x0, x1), back) + ok ((x0, x1), back) else do let i1 ← i - 1#u32 let (p, back) ← list_nth_shared_mut_loop_pair_loop tl0 tl1 i1 let back1 := fun ret => let tl11 := back ret List.Cons x1 tl11 - Result.ok (p, back1) - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + ok (p, back1) + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_shared_mut_loop_pair]: Source: 'tests/src/loops.rs', lines 311:0-326:1 -/ @@ -502,16 +506,16 @@ divergent def list_nth_shared_mut_loop_pair_merge_loop | List.Cons x1 tl1 => if i = 0#u32 then let back := fun ret => List.Cons ret tl1 - Result.ok ((x0, x1), back) + ok ((x0, x1), back) else do let i1 ← i - 1#u32 let (p, back) ← list_nth_shared_mut_loop_pair_merge_loop tl0 tl1 i1 let back1 := fun ret => let tl11 := back ret List.Cons x1 tl11 - Result.ok (p, back1) - | List.Nil => Result.fail .panic - | List.Nil => Result.fail .panic + ok (p, back1) + | List.Nil => fail panic + | List.Nil => fail panic /- [loops::list_nth_shared_mut_loop_pair_merge]: Source: 'tests/src/loops.rs', lines 330:0-345:1 -/ @@ -529,14 +533,14 @@ divergent def ignore_input_mut_borrow_loop (i : U32) : Result Unit := then do let i1 ← i - 1#u32 ignore_input_mut_borrow_loop i1 - else Result.ok () + else ok () /- [loops::ignore_input_mut_borrow]: Source: 'tests/src/loops.rs', lines 349:0-353:1 -/ def ignore_input_mut_borrow (_a : U32) (i : U32) : Result U32 := do ignore_input_mut_borrow_loop i - Result.ok _a + ok _a /- [loops::incr_ignore_input_mut_borrow]: loop 0: Source: 'tests/src/loops.rs', lines 359:4-361:5 -/ @@ -545,7 +549,7 @@ divergent def incr_ignore_input_mut_borrow_loop (i : U32) : Result Unit := then do let i1 ← i - 1#u32 incr_ignore_input_mut_borrow_loop i1 - else Result.ok () + else ok () /- [loops::incr_ignore_input_mut_borrow]: Source: 'tests/src/loops.rs', lines 357:0-362:1 -/ @@ -553,7 +557,7 @@ def incr_ignore_input_mut_borrow (a : U32) (i : U32) : Result U32 := do let a1 ← a + 1#u32 incr_ignore_input_mut_borrow_loop i - Result.ok a1 + ok a1 /- [loops::ignore_input_shared_borrow]: loop 0: Source: 'tests/src/loops.rs', lines 367:4-369:5 -/ @@ -562,13 +566,13 @@ divergent def ignore_input_shared_borrow_loop (i : U32) : Result Unit := then do let i1 ← i - 1#u32 ignore_input_shared_borrow_loop i1 - else Result.ok () + else ok () /- [loops::ignore_input_shared_borrow]: Source: 'tests/src/loops.rs', lines 366:0-370:1 -/ def ignore_input_shared_borrow (_a : U32) (i : U32) : Result U32 := do ignore_input_shared_borrow_loop i - Result.ok _a + ok _a end loops diff --git a/tests/lean/MiniTree.lean b/tests/lean/MiniTree.lean index c7fe4f65..afc7e495 100644 --- a/tests/lean/MiniTree.lean +++ b/tests/lean/MiniTree.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [mini_tree] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -28,11 +28,12 @@ structure Tree where Source: 'tests/src/mini_tree.rs', lines 17:8-19:9 -/ divergent def Tree.explore_loop (current_tree : Option Node) : Result Unit := match current_tree with - | none => Result.ok () + | none => ok () | some current_node => Tree.explore_loop current_node.child /- [mini_tree::{mini_tree::Tree}::explore]: Source: 'tests/src/mini_tree.rs', lines 14:4-20:5 -/ +@[reducible] def Tree.explore (self : Tree) : Result Unit := Tree.explore_loop self.root diff --git a/tests/lean/MutuallyRecursiveTraits.lean b/tests/lean/MutuallyRecursiveTraits.lean index 14b292f3..50652d22 100644 --- a/tests/lean/MutuallyRecursiveTraits.lean +++ b/tests/lean/MutuallyRecursiveTraits.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [mutually_recursive_traits] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/NoNestedBorrows.lean b/tests/lean/NoNestedBorrows.lean index 9815ec4f..69b93f06 100644 --- a/tests/lean/NoNestedBorrows.lean +++ b/tests/lean/NoNestedBorrows.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [no_nested_borrows] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -49,34 +49,34 @@ inductive Sum (T1 : Type) (T2 : Type) where /- [no_nested_borrows::cast_u32_to_i32]: Source: 'tests/src/no_nested_borrows.rs', lines 49:0-51:1 -/ def cast_u32_to_i32 (x : U32) : Result I32 := - Scalar.cast .I32 x + ok (UScalar.hcast .I32 x) /- [no_nested_borrows::cast_bool_to_i32]: Source: 'tests/src/no_nested_borrows.rs', lines 53:0-55:1 -/ def cast_bool_to_i32 (x : Bool) : Result I32 := - Scalar.cast_bool .I32 x + ok (IScalar.cast_fromBool .I32 x) /- [no_nested_borrows::cast_bool_to_bool]: Source: 'tests/src/no_nested_borrows.rs', lines 58:0-60:1 -/ def cast_bool_to_bool (x : Bool) : Result Bool := - Result.ok x + ok x /- [no_nested_borrows::test2]: Source: 'tests/src/no_nested_borrows.rs', lines 63:0-73:1 -/ def test2 : Result Unit := do let _ ← 23#u32 + 44#u32 - Result.ok () + ok () /- Unit test for [no_nested_borrows::test2] -/ -#assert (test2 == Result.ok ()) +#assert (test2 == ok ()) /- [no_nested_borrows::get_max]: Source: 'tests/src/no_nested_borrows.rs', lines 75:0-81:1 -/ def get_max (x : U32) (y : U32) : Result U32 := if x >= y - then Result.ok x - else Result.ok y + then ok x + else ok y /- [no_nested_borrows::test3]: Source: 'tests/src/no_nested_borrows.rs', lines 83:0-88:1 -/ @@ -88,7 +88,7 @@ def test3 : Result Unit := massert (z = 15#u32) /- Unit test for [no_nested_borrows::test3] -/ -#assert (test3 == Result.ok ()) +#assert (test3 == ok ()) /- [no_nested_borrows::test_neg1]: Source: 'tests/src/no_nested_borrows.rs', lines 90:0-94:1 -/ @@ -98,7 +98,7 @@ def test_neg1 : Result Unit := massert (y = (-3)#i32) /- Unit test for [no_nested_borrows::test_neg1] -/ -#assert (test_neg1 == Result.ok ()) +#assert (test_neg1 == ok ()) /- [no_nested_borrows::refs_test1]: Source: 'tests/src/no_nested_borrows.rs', lines 97:0-106:1 -/ @@ -106,7 +106,7 @@ def refs_test1 : Result Unit := massert (1#i32 = 1#i32) /- Unit test for [no_nested_borrows::refs_test1] -/ -#assert (refs_test1 == Result.ok ()) +#assert (refs_test1 == ok ()) /- [no_nested_borrows::refs_test2]: Source: 'tests/src/no_nested_borrows.rs', lines 108:0-120:1 -/ @@ -118,15 +118,15 @@ def refs_test2 : Result Unit := massert (2#i32 = 2#i32) /- Unit test for [no_nested_borrows::refs_test2] -/ -#assert (refs_test2 == Result.ok ()) +#assert (refs_test2 == ok ()) /- [no_nested_borrows::test_list1]: Source: 'tests/src/no_nested_borrows.rs', lines 124:0-126:1 -/ def test_list1 : Result Unit := - Result.ok () + ok () /- Unit test for [no_nested_borrows::test_list1] -/ -#assert (test_list1 == Result.ok ()) +#assert (test_list1 == ok ()) /- [no_nested_borrows::test_box1]: Source: 'tests/src/no_nested_borrows.rs', lines 129:0-137:1 -/ @@ -138,12 +138,12 @@ def test_box1 : Result Unit := massert (x = 1#i32) /- Unit test for [no_nested_borrows::test_box1] -/ -#assert (test_box1 == Result.ok ()) +#assert (test_box1 == ok ()) /- [no_nested_borrows::copy_int]: Source: 'tests/src/no_nested_borrows.rs', lines 139:0-141:1 -/ def copy_int (x : I32) : Result I32 := - Result.ok x + ok x /- [no_nested_borrows::test_unreachable]: Source: 'tests/src/no_nested_borrows.rs', lines 145:0-149:1 -/ @@ -168,14 +168,14 @@ def test_copy_int : Result Unit := massert (0#i32 = y) /- Unit test for [no_nested_borrows::test_copy_int] -/ -#assert (test_copy_int == Result.ok ()) +#assert (test_copy_int == ok ()) /- [no_nested_borrows::is_cons]: Source: 'tests/src/no_nested_borrows.rs', lines 174:0-179:1 -/ def is_cons {T : Type} (l : List T) : Result Bool := match l with - | List.Cons _ _ => Result.ok true - | List.Nil => Result.ok false + | List.Cons _ _ => ok true + | List.Nil => ok false /- [no_nested_borrows::test_is_cons]: Source: 'tests/src/no_nested_borrows.rs', lines 181:0-185:1 -/ @@ -185,25 +185,24 @@ def test_is_cons : Result Unit := massert b /- Unit test for [no_nested_borrows::test_is_cons] -/ -#assert (test_is_cons == Result.ok ()) +#assert (test_is_cons == ok ()) /- [no_nested_borrows::split_list]: Source: 'tests/src/no_nested_borrows.rs', lines 187:0-192:1 -/ def split_list {T : Type} (l : List T) : Result (T × (List T)) := match l with - | List.Cons hd tl => Result.ok (hd, tl) - | List.Nil => Result.fail .panic + | List.Cons hd tl => ok (hd, tl) + | List.Nil => fail panic /- [no_nested_borrows::test_split_list]: Source: 'tests/src/no_nested_borrows.rs', lines 195:0-200:1 -/ def test_split_list : Result Unit := do - let p ← split_list (List.Cons 0#i32 List.Nil) - let (hd, _) := p + let (hd, _) ← split_list (List.Cons 0#i32 List.Nil) massert (hd = 0#i32) /- Unit test for [no_nested_borrows::test_split_list] -/ -#assert (test_split_list == Result.ok ()) +#assert (test_split_list == ok ()) /- [no_nested_borrows::choose]: Source: 'tests/src/no_nested_borrows.rs', lines 202:0-208:1 -/ @@ -211,9 +210,9 @@ def choose {T : Type} (b : Bool) (x : T) (y : T) : Result (T × (T → (T × T))) := if b then let back := fun ret => (ret, y) - Result.ok (x, back) + ok (x, back) else let back := fun ret => (x, ret) - Result.ok (y, back) + ok (y, back) /- [no_nested_borrows::choose_test]: Source: 'tests/src/no_nested_borrows.rs', lines 210:0-219:1 -/ @@ -227,17 +226,17 @@ def choose_test : Result Unit := massert (y = 0#i32) /- Unit test for [no_nested_borrows::choose_test] -/ -#assert (choose_test == Result.ok ()) +#assert (choose_test == ok ()) /- [no_nested_borrows::test_char]: Source: 'tests/src/no_nested_borrows.rs', lines 222:0-224:1 -/ def test_char : Result Char := - Result.ok 'a' + ok 'a' /- [no_nested_borrows::panic_mut_borrow]: Source: 'tests/src/no_nested_borrows.rs', lines 227:0-229:1 -/ def panic_mut_borrow (i : U32) : Result U32 := - Result.fail .panic + fail panic mutual @@ -262,7 +261,7 @@ divergent def list_length {T : Type} (l : List T) : Result U32 := | List.Cons _ l1 => do let i ← list_length l1 1#u32 + i - | List.Nil => Result.ok 0#u32 + | List.Nil => ok 0#u32 /- [no_nested_borrows::list_nth_shared]: Source: 'tests/src/no_nested_borrows.rs', lines 280:0-293:1 -/ @@ -270,11 +269,11 @@ divergent def list_nth_shared {T : Type} (l : List T) (i : U32) : Result T := match l with | List.Cons x tl => if i = 0#u32 - then Result.ok x + then ok x else do let i1 ← i - 1#u32 list_nth_shared tl i1 - | List.Nil => Result.fail .panic + | List.Nil => fail panic /- [no_nested_borrows::list_nth_mut]: Source: 'tests/src/no_nested_borrows.rs', lines 296:0-309:1 -/ @@ -284,15 +283,15 @@ divergent def list_nth_mut | List.Cons x tl => if i = 0#u32 then let back := fun ret => List.Cons ret tl - Result.ok (x, back) + ok (x, back) else do let i1 ← i - 1#u32 let (t, list_nth_mut_back) ← list_nth_mut tl i1 let back := fun ret => let tl1 := list_nth_mut_back ret List.Cons x tl1 - Result.ok (t, back) - | List.Nil => Result.fail .panic + ok (t, back) + | List.Nil => fail panic /- [no_nested_borrows::list_rev_aux]: Source: 'tests/src/no_nested_borrows.rs', lines 312:0-322:1 -/ @@ -300,7 +299,7 @@ divergent def list_rev_aux {T : Type} (li : List T) (lo : List T) : Result (List T) := match li with | List.Cons hd tl => list_rev_aux tl (List.Cons hd lo) - | List.Nil => Result.ok lo + | List.Nil => ok lo /- [no_nested_borrows::list_rev]: Source: 'tests/src/no_nested_borrows.rs', lines 326:0-329:1 -/ @@ -332,7 +331,7 @@ def test_list_functions : Result Unit := massert (i6 = 2#i32) /- Unit test for [no_nested_borrows::test_list_functions] -/ -#assert (test_list_functions == Result.ok ()) +#assert (test_list_functions == ok ()) /- [no_nested_borrows::id_mut_pair1]: Source: 'tests/src/no_nested_borrows.rs', lines 347:0-349:1 -/ @@ -340,7 +339,7 @@ def id_mut_pair1 {T1 : Type} {T2 : Type} (x : T1) (y : T2) : Result ((T1 × T2) × ((T1 × T2) → (T1 × T2))) := - Result.ok ((x, y), fun ret => ret) + ok ((x, y), fun ret => ret) /- [no_nested_borrows::id_mut_pair2]: Source: 'tests/src/no_nested_borrows.rs', lines 351:0-353:1 -/ @@ -348,7 +347,7 @@ def id_mut_pair2 {T1 : Type} {T2 : Type} (p : (T1 × T2)) : Result ((T1 × T2) × ((T1 × T2) → (T1 × T2))) := - Result.ok (p, fun ret => ret) + ok (p, fun ret => ret) /- [no_nested_borrows::id_mut_pair3]: Source: 'tests/src/no_nested_borrows.rs', lines 355:0-357:1 -/ @@ -356,7 +355,7 @@ def id_mut_pair3 {T1 : Type} {T2 : Type} (x : T1) (y : T2) : Result ((T1 × T2) × (T1 → T1) × (T2 → T2)) := - Result.ok ((x, y), fun ret => ret, fun ret => ret) + ok ((x, y), fun ret => ret, fun ret => ret) /- [no_nested_borrows::id_mut_pair4]: Source: 'tests/src/no_nested_borrows.rs', lines 359:0-361:1 -/ @@ -364,7 +363,7 @@ def id_mut_pair4 {T1 : Type} {T2 : Type} (p : (T1 × T2)) : Result ((T1 × T2) × (T1 → T1) × (T2 → T2)) := - Result.ok (p, fun ret => ret, fun ret => ret) + ok (p, fun ret => ret, fun ret => ret) /- [no_nested_borrows::StructWithTuple] Source: 'tests/src/no_nested_borrows.rs', lines 366:0-368:1 -/ @@ -374,17 +373,17 @@ structure StructWithTuple (T1 : Type) (T2 : Type) where /- [no_nested_borrows::new_tuple1]: Source: 'tests/src/no_nested_borrows.rs', lines 370:0-372:1 -/ def new_tuple1 : Result (StructWithTuple U32 U32) := - Result.ok { p := (1#u32, 2#u32) } + ok { p := (1#u32, 2#u32) } /- [no_nested_borrows::new_tuple2]: Source: 'tests/src/no_nested_borrows.rs', lines 374:0-376:1 -/ def new_tuple2 : Result (StructWithTuple I16 I16) := - Result.ok { p := (1#i16, 2#i16) } + ok { p := (1#i16, 2#i16) } /- [no_nested_borrows::new_tuple3]: Source: 'tests/src/no_nested_borrows.rs', lines 378:0-380:1 -/ def new_tuple3 : Result (StructWithTuple U64 I64) := - Result.ok { p := (1#u64, 2#i64) } + ok { p := (1#u64, 2#i64) } /- [no_nested_borrows::StructWithPair] Source: 'tests/src/no_nested_borrows.rs', lines 383:0-385:1 -/ @@ -394,7 +393,7 @@ structure StructWithPair (T1 : Type) (T2 : Type) where /- [no_nested_borrows::new_pair1]: Source: 'tests/src/no_nested_borrows.rs', lines 387:0-393:1 -/ def new_pair1 : Result (StructWithPair U32 U32) := - Result.ok { p := { x := 1#u32, y := 2#u32 } } + ok { p := { x := 1#u32, y := 2#u32 } } /- [no_nested_borrows::test_constants]: Source: 'tests/src/no_nested_borrows.rs', lines 395:0-400:1 -/ @@ -413,15 +412,15 @@ def test_constants : Result Unit := massert (swp.p.x = 1#u32) /- Unit test for [no_nested_borrows::test_constants] -/ -#assert (test_constants == Result.ok ()) +#assert (test_constants == ok ()) /- [no_nested_borrows::test_weird_borrows1]: Source: 'tests/src/no_nested_borrows.rs', lines 404:0-412:1 -/ def test_weird_borrows1 : Result Unit := - Result.ok () + ok () /- Unit test for [no_nested_borrows::test_weird_borrows1] -/ -#assert (test_weird_borrows1 == Result.ok ()) +#assert (test_weird_borrows1 == ok ()) /- [no_nested_borrows::test_mem_replace]: Source: 'tests/src/no_nested_borrows.rs', lines 414:0-418:1 -/ @@ -429,31 +428,31 @@ def test_mem_replace (px : U32) : Result U32 := do let (y, _) := core.mem.replace px 1#u32 massert (y = 0#u32) - Result.ok 2#u32 + ok 2#u32 /- [no_nested_borrows::test_shared_borrow_bool1]: Source: 'tests/src/no_nested_borrows.rs', lines 421:0-430:1 -/ def test_shared_borrow_bool1 (b : Bool) : Result U32 := if b - then Result.ok 0#u32 - else Result.ok 1#u32 + then ok 0#u32 + else ok 1#u32 /- [no_nested_borrows::test_shared_borrow_bool2]: Source: 'tests/src/no_nested_borrows.rs', lines 434:0-444:1 -/ def test_shared_borrow_bool2 : Result U32 := - Result.ok 0#u32 + ok 0#u32 /- [no_nested_borrows::test_shared_borrow_enum1]: Source: 'tests/src/no_nested_borrows.rs', lines 449:0-457:1 -/ def test_shared_borrow_enum1 (l : List U32) : Result U32 := match l with - | List.Cons _ _ => Result.ok 1#u32 - | List.Nil => Result.ok 0#u32 + | List.Cons _ _ => ok 1#u32 + | List.Nil => ok 0#u32 /- [no_nested_borrows::test_shared_borrow_enum2]: Source: 'tests/src/no_nested_borrows.rs', lines 461:0-470:1 -/ def test_shared_borrow_enum2 : Result U32 := - Result.ok 0#u32 + ok 0#u32 /- [no_nested_borrows::incr]: Source: 'tests/src/no_nested_borrows.rs', lines 472:0-474:1 -/ @@ -470,7 +469,7 @@ def call_incr (x : U32) : Result U32 := def read_then_incr (x : U32) : Result (U32 × U32) := do let x1 ← x + 1#u32 - Result.ok (x, x1) + ok (x, x1) /- [no_nested_borrows::Tuple] Source: 'tests/src/no_nested_borrows.rs', lines 487:0-487:33 -/ @@ -480,30 +479,30 @@ def Tuple (T1 : Type) (T2 : Type) := T1 × T2 Source: 'tests/src/no_nested_borrows.rs', lines 489:0-491:1 -/ def read_tuple (x : (U32 × U32)) : Result U32 := let (i, _) := x - Result.ok i + ok i /- [no_nested_borrows::update_tuple]: Source: 'tests/src/no_nested_borrows.rs', lines 493:0-495:1 -/ def update_tuple (x : (U32 × U32)) : Result (U32 × U32) := let (_, i) := x - Result.ok (1#u32, i) + ok (1#u32, i) /- [no_nested_borrows::read_tuple_struct]: Source: 'tests/src/no_nested_borrows.rs', lines 497:0-499:1 -/ def read_tuple_struct (x : Tuple U32 U32) : Result U32 := let (i, _) := x - Result.ok i + ok i /- [no_nested_borrows::update_tuple_struct]: Source: 'tests/src/no_nested_borrows.rs', lines 501:0-503:1 -/ def update_tuple_struct (x : Tuple U32 U32) : Result (Tuple U32 U32) := let (_, i) := x - Result.ok (1#u32, i) + ok (1#u32, i) /- [no_nested_borrows::create_tuple_struct]: Source: 'tests/src/no_nested_borrows.rs', lines 505:0-507:1 -/ def create_tuple_struct (x : U32) (y : U64) : Result (Tuple U32 U64) := - Result.ok (x, y) + ok (x, y) /- [no_nested_borrows::IdType] Source: 'tests/src/no_nested_borrows.rs', lines 510:0-510:24 -/ @@ -512,27 +511,27 @@ def create_tuple_struct (x : U32) (y : U64) : Result (Tuple U32 U64) := /- [no_nested_borrows::use_id_type]: Source: 'tests/src/no_nested_borrows.rs', lines 512:0-514:1 -/ def use_id_type {T : Type} (x : IdType T) : Result T := - Result.ok x + ok x /- [no_nested_borrows::create_id_type]: Source: 'tests/src/no_nested_borrows.rs', lines 516:0-518:1 -/ def create_id_type {T : Type} (x : T) : Result (IdType T) := - Result.ok x + ok x /- [no_nested_borrows::not_bool]: Source: 'tests/src/no_nested_borrows.rs', lines 520:0-522:1 -/ def not_bool (x : Bool) : Result Bool := - Result.ok (¬ x) + ok (¬ x) /- [no_nested_borrows::not_u32]: Source: 'tests/src/no_nested_borrows.rs', lines 524:0-526:1 -/ def not_u32 (x : U32) : Result U32 := - Result.ok (¬ x) + ok (~~~ x) /- [no_nested_borrows::not_i32]: Source: 'tests/src/no_nested_borrows.rs', lines 528:0-530:1 -/ def not_i32 (x : I32) : Result I32 := - Result.ok (¬ x) + ok (~~~ x) /- [no_nested_borrows::borrow_mut_tuple]: Source: 'tests/src/no_nested_borrows.rs', lines 532:0-534:1 -/ @@ -540,7 +539,7 @@ def borrow_mut_tuple {T : Type} {U : Type} (x : (T × U)) : Result ((T × U) × ((T × U) → (T × U))) := - Result.ok (x, fun ret => ret) + ok (x, fun ret => ret) /- [no_nested_borrows::ExpandSimpliy::Wrapper] Source: 'tests/src/no_nested_borrows.rs', lines 538:4-538:32 -/ @@ -552,8 +551,8 @@ def ExpandSimpliy.check_expand_simplify_symb1 (x : ExpandSimpliy.Wrapper Bool) : Result (ExpandSimpliy.Wrapper Bool) := let (b, _) := x if b - then Result.ok x - else Result.ok x + then ok x + else ok x /- [no_nested_borrows::ExpandSimpliy::Wrapper2] Source: 'tests/src/no_nested_borrows.rs', lines 548:4-551:5 -/ @@ -566,7 +565,7 @@ structure ExpandSimpliy.Wrapper2 where def ExpandSimpliy.check_expand_simplify_symb2 (x : ExpandSimpliy.Wrapper2) : Result ExpandSimpliy.Wrapper2 := if x.b - then Result.ok x - else Result.ok x + then ok x + else ok x end no_nested_borrows diff --git a/tests/lean/Paper.lean b/tests/lean/Paper.lean index b4d04a0b..9d9ba2eb 100644 --- a/tests/lean/Paper.lean +++ b/tests/lean/Paper.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [paper] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -21,7 +21,7 @@ def test_incr : Result Unit := massert (x = 1#i32) /- Unit test for [paper::test_incr] -/ -#assert (test_incr == Result.ok ()) +#assert (test_incr == ok ()) /- [paper::choose]: Source: 'tests/src/paper.rs', lines 18:0-24:1 -/ @@ -29,9 +29,9 @@ def choose {T : Type} (b : Bool) (x : T) (y : T) : Result (T × (T → (T × T))) := if b then let back := fun ret => (ret, y) - Result.ok (x, back) + ok (x, back) else let back := fun ret => (x, ret) - Result.ok (y, back) + ok (y, back) /- [paper::test_choose]: Source: 'tests/src/paper.rs', lines 26:0-34:1 -/ @@ -45,7 +45,7 @@ def test_choose : Result Unit := massert (y = 0#i32) /- Unit test for [paper::test_choose] -/ -#assert (test_choose == Result.ok ()) +#assert (test_choose == ok ()) /- [paper::List] Source: 'tests/src/paper.rs', lines 38:0-41:1 -/ @@ -61,15 +61,15 @@ divergent def list_nth_mut | List.Cons x tl => if i = 0#u32 then let back := fun ret => List.Cons ret tl - Result.ok (x, back) + ok (x, back) else do let i1 ← i - 1#u32 let (t, list_nth_mut_back) ← list_nth_mut tl i1 let back := fun ret => let tl1 := list_nth_mut_back ret List.Cons x tl1 - Result.ok (t, back) - | List.Nil => Result.fail .panic + ok (t, back) + | List.Nil => fail panic /- [paper::sum]: Source: 'tests/src/paper.rs', lines 60:0-69:1 -/ @@ -78,7 +78,7 @@ divergent def sum (l : List I32) : Result I32 := | List.Cons x tl => do let i ← sum tl x + i - | List.Nil => Result.ok 0#i32 + | List.Nil => ok 0#i32 /- [paper::test_nth]: Source: 'tests/src/paper.rs', lines 71:0-76:1 -/ @@ -93,7 +93,7 @@ def test_nth : Result Unit := massert (i = 7#i32) /- Unit test for [paper::test_nth] -/ -#assert (test_nth == Result.ok ()) +#assert (test_nth == ok ()) /- [paper::call_choose]: Source: 'tests/src/paper.rs', lines 79:0-85:1 -/ @@ -103,6 +103,6 @@ def call_choose (p : (U32 × U32)) : Result U32 := let (pz, choose_back) ← choose true px py let pz1 ← pz + 1#u32 let (px1, _) := choose_back pz1 - Result.ok px1 + ok px1 end paper diff --git a/tests/lean/PoloniusList.lean b/tests/lean/PoloniusList.lean index e70a8924..03b93f6d 100644 --- a/tests/lean/PoloniusList.lean +++ b/tests/lean/PoloniusList.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [polonius_list] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -21,13 +21,13 @@ divergent def get_list_at_x match ls with | List.Cons hd tl => if hd = x - then Result.ok (ls, fun ret => ret) + then ok (ls, fun ret => ret) else do let (l, get_list_at_x_back) ← get_list_at_x tl x let back := fun ret => let tl1 := get_list_at_x_back ret List.Cons hd tl1 - Result.ok (l, back) - | List.Nil => Result.ok (List.Nil, fun ret => ret) + ok (l, back) + | List.Nil => ok (List.Nil, fun ret => ret) end polonius_list diff --git a/tests/lean/RenameAttribute.lean b/tests/lean/RenameAttribute.lean index ae9432e4..e60369f6 100644 --- a/tests/lean/RenameAttribute.lean +++ b/tests/lean/RenameAttribute.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [rename_attribute] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -17,12 +17,12 @@ structure BoolTest (Self : Type) where /- [rename_attribute::{rename_attribute::BoolTrait for bool}::get_bool]: Source: 'tests/src/rename_attribute.rs', lines 22:4-24:5 -/ def BoolTraitBool.getTest (self : Bool) : Result Bool := - Result.ok self + ok self /- [rename_attribute::{rename_attribute::BoolTrait for bool}::ret_true]: Source: 'tests/src/rename_attribute.rs', lines 15:4-17:5 -/ def BoolTraitBool.retTest (self : Bool) : Result Bool := - Result.ok true + ok true /- Trait implementation: [rename_attribute::{rename_attribute::BoolTrait for bool}] Source: 'tests/src/rename_attribute.rs', lines 21:0-25:1 -/ @@ -39,7 +39,7 @@ def BoolFn (T : Type) (x : Bool) : Result Bool := let b ← BoolTraitBool.getTest x if b then BoolTraitBool.retTest x - else Result.ok false + else ok false /- [rename_attribute::SimpleEnum] Source: 'tests/src/rename_attribute.rs', lines 36:0-41:1 -/ @@ -69,7 +69,7 @@ def Const_Aeneas11 : U32 := eval_global Const_Aeneas11_body Source: 'tests/src/rename_attribute.rs', lines 56:0-62:1 -/ divergent def Factfn (n : U64) : Result U64 := if n <= 1#u64 - then Result.ok 1#u64 + then ok 1#u64 else do let i ← n - 1#u64 let i1 ← Factfn i @@ -88,6 +88,7 @@ divergent def No_borrows_sum_loop /- [rename_attribute::sum]: Source: 'tests/src/rename_attribute.rs', lines 65:0-75:1 -/ +@[reducible] def No_borrows_sum (max : U32) : Result U32 := No_borrows_sum_loop max 0#u32 0#u32 @@ -95,6 +96,6 @@ def No_borrows_sum (max : U32) : Result U32 := Source: 'tests/src/rename_attribute.rs', lines 15:4-17:5 -/ def BoolTrait.retTest.default {Self : Type} (self_clause : BoolTest Self) (self : Self) : Result Bool := - Result.ok true + ok true end rename_attribute diff --git a/tests/lean/Scalars.lean b/tests/lean/Scalars.lean index 57f426f5..e4071cbc 100644 --- a/tests/lean/Scalars.lean +++ b/tests/lean/Scalars.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [scalars] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -11,22 +11,22 @@ namespace scalars /- [scalars::u32_use_wrapping_add]: Source: 'tests/src/scalars.rs', lines 3:0-5:1 -/ def u32_use_wrapping_add (x : U32) (y : U32) : Result U32 := - Result.ok (core.num.U32.wrapping_add x y) + ok (core.num.U32.wrapping_add x y) /- [scalars::i32_use_wrapping_add]: Source: 'tests/src/scalars.rs', lines 7:0-9:1 -/ def i32_use_wrapping_add (x : I32) (y : I32) : Result I32 := - Result.ok (core.num.I32.wrapping_add x y) + ok (core.num.I32.wrapping_add x y) /- [scalars::u32_use_wrapping_sub]: Source: 'tests/src/scalars.rs', lines 11:0-13:1 -/ def u32_use_wrapping_sub (x : U32) (y : U32) : Result U32 := - Result.ok (core.num.U32.wrapping_sub x y) + ok (core.num.U32.wrapping_sub x y) /- [scalars::i32_use_wrapping_sub]: Source: 'tests/src/scalars.rs', lines 15:0-17:1 -/ def i32_use_wrapping_sub (x : I32) (y : I32) : Result I32 := - Result.ok (core.num.I32.wrapping_sub x y) + ok (core.num.I32.wrapping_sub x y) /- [scalars::u32_use_shift_right]: Source: 'tests/src/scalars.rs', lines 19:0-21:1 -/ @@ -51,26 +51,29 @@ def i32_use_shift_left (x : I32) : Result I32 := /- [scalars::add_and]: Source: 'tests/src/scalars.rs', lines 35:0-37:1 -/ def add_and (a : U32) (b : U32) : Result U32 := - (b &&& a) + (b &&& a) + do + let (i : U32) ← ↑(b &&& a) + let (i1 : U32) ← ↑(b &&& a) + i + i1 /- [scalars::u32_use_rotate_right]: Source: 'tests/src/scalars.rs', lines 39:0-41:1 -/ def u32_use_rotate_right (x : U32) : Result U32 := - Result.ok (core.num.U32.rotate_right x 2#u32) + ok (core.num.U32.rotate_right x 2#u32) /- [scalars::i32_use_rotate_right]: Source: 'tests/src/scalars.rs', lines 43:0-45:1 -/ def i32_use_rotate_right (x : I32) : Result I32 := - Result.ok (core.num.I32.rotate_right x 2#u32) + ok (core.num.I32.rotate_right x 2#u32) /- [scalars::u32_use_rotate_left]: Source: 'tests/src/scalars.rs', lines 47:0-49:1 -/ def u32_use_rotate_left (x : U32) : Result U32 := - Result.ok (core.num.U32.rotate_left x 2#u32) + ok (core.num.U32.rotate_left x 2#u32) /- [scalars::i32_use_rotate_left]: Source: 'tests/src/scalars.rs', lines 51:0-53:1 -/ def i32_use_rotate_left (x : I32) : Result I32 := - Result.ok (core.num.I32.rotate_left x 2#u32) + ok (core.num.I32.rotate_left x 2#u32) end scalars diff --git a/tests/lean/Slices.lean b/tests/lean/Slices.lean index 7e622052..44c26276 100644 --- a/tests/lean/Slices.lean +++ b/tests/lean/Slices.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [slices] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false diff --git a/tests/lean/SwitchTest.lean b/tests/lean/SwitchTest.lean index a215d2c8..65150ed3 100644 --- a/tests/lean/SwitchTest.lean +++ b/tests/lean/SwitchTest.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [switch_test] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -12,8 +12,8 @@ namespace switch_test Source: 'tests/src/switch_test.rs', lines 4:0-10:1 -/ def match_u32 (x : U32) : Result U32 := match x with - | 0#scalar => Result.ok 0#u32 - | 1#scalar => Result.ok 1#u32 - | _ => Result.ok 2#u32 + | 0#uscalar => ok 0#u32 + | 1#uscalar => ok 1#u32 + | _ => ok 2#u32 end switch_test diff --git a/tests/lean/Traits.lean b/tests/lean/Traits.lean index ac612e3e..b9acf44d 100644 --- a/tests/lean/Traits.lean +++ b/tests/lean/Traits.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [traits] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -17,12 +17,12 @@ structure BoolTrait (Self : Type) where /- [traits::{traits::BoolTrait for bool}::get_bool]: Source: 'tests/src/traits.rs', lines 14:4-16:5 -/ def BoolTraitBool.get_bool (self : Bool) : Result Bool := - Result.ok self + ok self /- [traits::{traits::BoolTrait for bool}::ret_true]: Source: 'tests/src/traits.rs', lines 8:4-10:5 -/ def BoolTraitBool.ret_true (self : Bool) : Result Bool := - Result.ok true + ok true /- Trait implementation: [traits::{traits::BoolTrait for bool}] Source: 'tests/src/traits.rs', lines 13:0-17:1 -/ @@ -39,19 +39,19 @@ def test_bool_trait_bool (x : Bool) : Result Bool := let b ← BoolTraitBool.get_bool x if b then BoolTraitBool.ret_true x - else Result.ok false + else ok false /- [traits::{traits::BoolTrait for core::option::Option}#1::get_bool]: Source: 'tests/src/traits.rs', lines 25:4-30:5 -/ def BoolTraitOption.get_bool {T : Type} (self : Option T) : Result Bool := match self with - | none => Result.ok false - | some _ => Result.ok true + | none => ok false + | some _ => ok true /- [traits::{traits::BoolTrait for core::option::Option}#1::ret_true]: Source: 'tests/src/traits.rs', lines 8:4-10:5 -/ def BoolTraitOption.ret_true {T : Type} (self : Option T) : Result Bool := - Result.ok true + ok true /- Trait implementation: [traits::{traits::BoolTrait for core::option::Option}#1] Source: 'tests/src/traits.rs', lines 24:0-31:1 -/ @@ -68,7 +68,7 @@ def test_bool_trait_option {T : Type} (x : Option T) : Result Bool := let b ← BoolTraitOption.get_bool x if b then BoolTraitOption.ret_true x - else Result.ok false + else ok false /- [traits::test_bool_trait]: Source: 'tests/src/traits.rs', lines 37:0-39:1 -/ @@ -84,7 +84,7 @@ structure ToU64 (Self : Type) where /- [traits::{traits::ToU64 for u64}#2::to_u64]: Source: 'tests/src/traits.rs', lines 46:4-48:5 -/ def ToU64U64.to_u64 (self : U64) : Result U64 := - Result.ok self + ok self /- Trait implementation: [traits::{traits::ToU64 for u64}#2] Source: 'tests/src/traits.rs', lines 45:0-49:1 -/ @@ -163,7 +163,7 @@ structure ToType (Self : Type) (T : Type) where /- [traits::{traits::ToType for u64}#5::to_type]: Source: 'tests/src/traits.rs', lines 95:4-97:5 -/ def ToTypeU64Bool.to_type (self : U64) : Result Bool := - Result.ok (self > 0#u64) + ok (self > 0#u64) /- Trait implementation: [traits::{traits::ToType for u64}#5] Source: 'tests/src/traits.rs', lines 94:0-98:1 -/ @@ -213,7 +213,7 @@ def h4 Source: 'tests/src/traits.rs', lines 141:12-143:13 -/ def TestType.test.TestTraittraitsTestTypetestTestType1.test (self : TestType.test.TestType1) : Result Bool := - Result.ok (self > 1#u64) + ok (self > 1#u64) /- [traits::{traits::TestType}#6::test]: Source: 'tests/src/traits.rs', lines 128:4-149:5 -/ @@ -223,7 +223,7 @@ def TestType.test let x1 ← ToU64Inst.to_u64 x if x1 > 0#u64 then TestType.test.TestTraittraitsTestTypetestTestType1.test 0#u64 - else Result.ok false + else ok false /- [traits::BoolWrapper] Source: 'tests/src/traits.rs', lines 152:0-152:33 -/ @@ -248,7 +248,7 @@ def ToTypetraitsBoolWrapperT {T : Type} (ToTypeBoolTInst : ToType Bool T) : /- [traits::WithConstTy::LEN2] Source: 'tests/src/traits.rs', lines 166:4-166:27 -/ def WithConstTy.LEN2_default_body (Self : Type) (LEN : Usize) : Result Usize := - Result.ok 32#usize + ok 32#usize def WithConstTy.LEN2_default (Self : Type) (LEN : Usize) : Usize := eval_global (WithConstTy.LEN2_default_body Self LEN) @@ -264,13 +264,13 @@ structure WithConstTy (Self : Type) (LEN : Usize) where /- [traits::{traits::WithConstTy<32: usize> for bool}#8::LEN1] Source: 'tests/src/traits.rs', lines 177:4-177:27 -/ -def WithConstTyBool32.LEN1_body : Result Usize := Result.ok 12#usize +def WithConstTyBool32.LEN1_body : Result Usize := ok 12#usize def WithConstTyBool32.LEN1 : Usize := eval_global WithConstTyBool32.LEN1_body /- [traits::{traits::WithConstTy<32: usize> for bool}#8::f]: Source: 'tests/src/traits.rs', lines 182:4-182:42 -/ def WithConstTyBool32.f (i : U64) (a : Array U8 32#usize) : Result U64 := - Result.ok i + ok i /- Trait implementation: [traits::{traits::WithConstTy<32: usize> for bool}#8] Source: 'tests/src/traits.rs', lines 176:0-183:1 -/ @@ -290,7 +290,7 @@ def use_with_const_ty1 {H : Type} {LEN : Usize} (WithConstTyInst : WithConstTy H LEN) : Result Usize := - Result.ok WithConstTyInst.LEN1 + ok WithConstTyInst.LEN1 /- [traits::use_with_const_ty2]: Source: 'tests/src/traits.rs', lines 189:0-189:76 -/ @@ -299,7 +299,7 @@ def use_with_const_ty2 (w : WithConstTyInst.W) : Result Unit := - Result.ok () + ok () /- [traits::use_with_const_ty3]: Source: 'tests/src/traits.rs', lines 191:0-193:1 -/ @@ -313,7 +313,7 @@ def use_with_const_ty3 /- [traits::test_where1]: Source: 'tests/src/traits.rs', lines 195:0-195:43 -/ def test_where1 {T : Type} (_x : T) : Result Unit := - Result.ok () + ok () /- [traits::test_where2]: Source: 'tests/src/traits.rs', lines 196:0-196:60 -/ @@ -321,7 +321,7 @@ def test_where2 {T : Type} (WithConstTyT32Inst : WithConstTy T 32#usize) (_x : U32) : Result Unit := - Result.ok () + ok () /- Trait declaration: [traits::ParentTrait0] Source: 'tests/src/traits.rs', lines 202:0-206:1 -/ @@ -361,7 +361,7 @@ def order1 : ParentTrait0 U) : Result Unit := - Result.ok () + ok () /- Trait declaration: [traits::ChildTrait1] Source: 'tests/src/traits.rs', lines 224:0-224:38 -/ @@ -440,7 +440,7 @@ def ParentTrait2U32 : ParentTrait2 U32 := { /- [traits::{traits::ChildTrait2 for u32}#13::convert]: Source: 'tests/src/traits.rs', lines 275:4-277:5 -/ def ChildTrait2U32.convert (x : U32) : Result U32 := - Result.ok x + ok x /- Trait implementation: [traits::{traits::ChildTrait2 for u32}#13] Source: 'tests/src/traits.rs', lines 274:0-278:1 -/ @@ -487,7 +487,7 @@ structure Trait (Self : Type) where /- [traits::{traits::Trait for @Array}#14::LEN] Source: 'tests/src/traits.rs', lines 317:4-317:25 -/ -def TraitArray.LEN_body (T : Type) (N : Usize) : Result Usize := Result.ok N +def TraitArray.LEN_body (T : Type) (N : Usize) : Result Usize := ok N def TraitArray.LEN (T : Type) (N : Usize) : Usize := eval_global (TraitArray.LEN_body T N) @@ -502,7 +502,7 @@ def TraitArray (T : Type) (N : Usize) : Trait (Array T N) := { Source: 'tests/src/traits.rs', lines 321:4-321:25 -/ def TraittraitsWrapper.LEN_body {T : Type} (TraitInst : Trait T) : Result Usize := - Result.ok 0#usize + ok 0#usize def TraittraitsWrapper.LEN {T : Type} (TraitInst : Trait T) : Usize := eval_global (TraittraitsWrapper.LEN_body TraitInst) @@ -517,7 +517,7 @@ def TraittraitsWrapper {T : Type} (TraitInst : Trait T) : Trait (Wrapper T) /- [traits::use_wrapper_len]: Source: 'tests/src/traits.rs', lines 324:0-326:1 -/ def use_wrapper_len {T : Type} (TraitInst : Trait T) : Result Usize := - Result.ok (TraittraitsWrapper TraitInst).LEN + ok (TraittraitsWrapper TraitInst).LEN /- [traits::Foo] Source: 'tests/src/traits.rs', lines 328:0-331:1 -/ @@ -529,7 +529,7 @@ structure Foo (T : Type) (U : Type) where Source: 'tests/src/traits.rs', lines 334:4-334:43 -/ def Foo.FOO_body {T : Type} (U : Type) (TraitInst : Trait T) : Result (core.result.Result T I32) := - Result.ok (core.result.Result.Err 0#i32) + ok (core.result.Result.Err 0#i32) def Foo.FOO {T : Type} (U : Type) (TraitInst : Trait T) : core.result.Result T I32 := eval_global (Foo.FOO_body U TraitInst) @@ -540,7 +540,7 @@ def use_foo1 {T : Type} (U : Type) (TraitInst : Trait T) : Result (core.result.Result T I32) := - Result.ok (Foo.FOO U TraitInst) + ok (Foo.FOO U TraitInst) /- [traits::use_foo2]: Source: 'tests/src/traits.rs', lines 341:0-343:1 -/ @@ -548,13 +548,13 @@ def use_foo2 (T : Type) {U : Type} (TraitInst : Trait U) : Result (core.result.Result U I32) := - Result.ok (Foo.FOO T TraitInst) + ok (Foo.FOO T TraitInst) /- [traits::BoolTrait::ret_true]: Source: 'tests/src/traits.rs', lines 8:4-10:5 -/ def BoolTrait.ret_true.default {Self : Type} (self_clause : BoolTrait Self) (self : Self) : Result Bool := - Result.ok true + ok true /- Trait declaration: [traits::{traits::TestType}#6::test::TestTrait] Source: 'tests/src/traits.rs', lines 130:8-132:9 -/ diff --git a/tests/lean/Tutorial/Exercises.lean b/tests/lean/Tutorial/Exercises.lean index 8bf88a00..50e243a8 100644 --- a/tests/lean/Tutorial/Exercises.lean +++ b/tests/lean/Tutorial/Exercises.lean @@ -19,7 +19,7 @@ def mul2_add1 (x : U32) : Result U32 := #check U32.add_spec /-- Theorem about `mul2_add1`: without the `progress` tactic -/ --- @[pspec] +-- @[progress] theorem mul2_add1_spec (x : U32) (h : 2 * x.val + 1 ≤ U32.max) : ∃ y, mul2_add1 x = ok y ∧ ↑y = 2 * ↑x + (1 : Int) @@ -32,7 +32,7 @@ theorem mul2_add1_spec (x : U32) (h : 2 * x.val + 1 ≤ U32.max) scalar_tac /-- Theorem about `mul2_add1`: with the `progress` tactic -/ --- @[pspec] +-- @[progress] theorem mul2_add1_spec' (x : U32) (h : 2 * x.val + 1 ≤ U32.max) : ∃ y, mul2_add1 x = ok y ∧ ↑y = 2 * ↑x + (1 : Int) @@ -60,7 +60,7 @@ theorem mul2_add1_add_spec (x : U32) (y : U32) (h : 2 * x.val + 1 + y.val ≤ U3 /- [tutorial::CList] Source: 'src/lib.rs', lines 32:0-32:17 -/ -inductive CList (T : Type) := +inductive CList (T : Type) where | CCons : T → CList T → CList T | CNil : CList T @@ -92,7 +92,7 @@ divergent def list_nth {T : Type} (l : CList T) (i : U32) : Result T := theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) (h : i.val < l.toList.length) : ∃ x, list_nth l i = ok x ∧ - x = l.toList.index i.toNat + x = l.toList[i.val]! := by rw [list_nth] split @@ -103,7 +103,6 @@ theorem list_nth_spec {T : Type} [Inhabited T] (l : CList T) (i : U32) progress as ⟨ x ⟩ simp_all . simp_all - scalar_tac /- [tutorial::i32_id]: Source: 'src/lib.rs', lines 78:0-78:29 -/ @@ -165,8 +164,8 @@ theorem even_spec (n : U32) : . progress as ⟨ n' ⟩ progress as ⟨ b ⟩ simp [*] - simp [Int.odd_sub] -termination_by n.toNat + simp [Nat.even_add_one] +termination_by n.val decreasing_by scalar_decr_tac /-- The proof about `odd` -/ @@ -178,8 +177,8 @@ theorem odd_spec (n : U32) : . progress as ⟨ n' ⟩ progress as ⟨ b ⟩ simp [*] - simp [Int.even_sub] -termination_by n.toNat + simp [Nat.odd_add_one] +termination_by n.val decreasing_by scalar_decr_tac end @@ -376,32 +375,38 @@ set_option pp.coercions true Small preparation for theorem `list_nth_mut1`. -/ +/- The notation `l[i]!` stands for `getElem! l`, and is the `i`th element of list `l`. + + We deactivate the simp lemma below as it replaces terms of the shape `l[i]!` with more + complicated terms: in the present case it is more annoying than anything. -/ +attribute [-simp] List.getElem!_eq_getElem?_getD + /- Reasoning about `List.index`. You can use the following two lemmas. -/ -#check List.index_zero_cons -#check List.index_nzero_cons +#check List.getElem!_cons_zero +#check List.getElem!_cons_nzero /- Example 1: indexing the first element of the list -/ example [Inhabited α] (i : U32) (hd : α) (tl : CList α) (hEq : i = 0#u32) : - (hd :: tl.toList).index i.toNat = hd := by - have hi : i.toNat = 0 := by scalar_tac + (hd :: tl.toList)[i.val]! = hd := by + have hi : i.val = 0 := by scalar_tac simp only [hi] -- - have hIndex := List.index_zero_cons hd tl.toList + have hIndex := @List.getElem!_cons_zero _ hd _ tl.toList simp only [hIndex] /- Example 2: indexing in the tail -/ example [Inhabited α] (i : U32) (hd : α) (tl : CList α) (hEq : i ≠ 0#u32) : - (hd :: tl.toList).index i.toNat = tl.toList.index (i.toNat - 1) := by + (hd :: tl.toList)[i.val]! = tl.toList[i.val - 1]! := by -- Note that `scalar_tac` is aware of `Arith.Nat.not_eq` - have hIndex := List.index_nzero_cons hd tl.toList i.toNat (by scalar_tac) + have hIndex := List.getElem!_cons_nzero hd tl.toList i.val (by scalar_tac) simp only [hIndex] -/- Note that `List.index_zero_cons` and `List.index_nzero_cons` have been +/- Note that `List.index_zero_cons` and `List.index_cons_nzero` have been marked as `@[simp]` and are thus automatically applied. Also note that `simp` can automatically prove the premises of rewriting lemmas, if it has enough information. @@ -410,14 +415,14 @@ example [Inhabited α] (i : U32) (hd : α) (tl : CList α) you expect. -/ example [Inhabited α] (i : U32) (hd : α) (tl : CList α) (hEq : i = 0#u32) : - (hd :: tl.toList).index i.toNat = hd := by + (hd :: tl.toList)[i.val]! = hd := by simp [hEq] -/- Note that `simp_all` manages to automatically apply `List.index_nzero_cons` below, +/- Note that `simp_all` manages to automatically apply `List.index_cons_nzero` below, by using the fact that `i ≠ 0#u32`. -/ example [Inhabited α] (i : U32) (hd : α) (tl : CList α) (hEq : i ≠ 0#u32) : - (hd :: tl.toList).index i.toNat = tl.toList.index (i.toNat - 1) := by + (hd :: tl.toList)[i.val]! = tl.toList[i.val - 1]! := by simp_all /- Below, you will need to reason about `List.update`. @@ -426,8 +431,8 @@ example [Inhabited α] (i : U32) (hd : α) (tl : CList α) Those lemmas have been marked as `@[simp]`, meaning that if `simp` is properly used, it will manage to apply them automatically. -/ -#check List.update_zero_cons -#check List.update_nzero_cons +#check List.set_cons_zero +#check List.set_cons_nzero /- # Some proofs of programs -/ @@ -452,9 +457,9 @@ example [Inhabited α] (i : U32) (hd : α) (tl : CList α) theorem list_nth_mut1_spec {T: Type} [Inhabited T] (l : CList T) (i : U32) (h : i.val < l.toList.length) : ∃ x back, list_nth_mut1 l i = ok (x, back) ∧ - x = l.toList.index i.toNat ∧ + x = l.toList[i.val]! ∧ -- Specification of the backward function - ∀ x', (back x').toList = l.toList.update i.toNat x' := by + ∀ x', (back x').toList = l.toList.set i.val x' := by rw [list_nth_mut1, list_nth_mut1_loop] sorry @@ -492,7 +497,7 @@ def append_in_place /-- Theorem about `list_tail`: exercise -/ -@[pspec] +@[progress] theorem list_tail_spec {T : Type} (l : CList T) : ∃ back, list_tail l = ok (CList.CNil, back) ∧ ∀ tl', (back tl').toList = l.toList ++ tl'.toList := by @@ -500,7 +505,7 @@ theorem list_tail_spec {T : Type} (l : CList T) : sorry /-- Theorem about `append_in_place`: exercise -/ -@[pspec] +@[progress] theorem append_in_place_spec {T : Type} (l0 l1 : CList T) : ∃ l2, append_in_place l0 l1 = ok l2 ∧ l2.toList = l0.toList ++ l1.toList := by @@ -515,7 +520,7 @@ divergent def reverse_loop | CList.CCons hd tl => reverse_loop tl (CList.CCons hd out) | CList.CNil => Result.ok out -@[pspec] +@[progress] theorem reverse_loop_spec {T : Type} (l : CList T) (out : CList T) : ∃ l', reverse_loop l out = ok l' ∧ True -- Leaving the post-condition as an exercise @@ -567,14 +572,11 @@ divergent def zero_loop /-- Auxiliary definitions to interpret a vector of u32 as a mathematical integer -/ @[simp] -def toInt_aux (l : List U32) : ℤ := +def toInt (l : List U32) : ℤ := match l with | [] => 0 | x :: l => - x + 2 ^ 32 * toInt_aux l - -@[reducible] -def toInt (x : alloc.vec.Vec U32) : ℤ := toInt_aux x.val + x + 2 ^ 32 * toInt l /-- The theorem about `zero_loop`: exercise. @@ -587,14 +589,14 @@ def toInt (x : alloc.vec.Vec U32) : ℤ := toInt_aux x.val Ex.: `dcases x = y` will introduce two goals, one with the assumption `x = y` and the other with the assumption `x ≠ y`. You can name this assumption by writing: `dcases h : x = y` -/ -@[pspec] +@[progress] theorem zero_loop_spec (x : alloc.vec.Vec U32) (i : Usize) (h : i.val ≤ x.length) : ∃ x', zero_loop x i = ok x' ∧ x'.length = x.length ∧ - (∀ j, j < i.toNat → x'.val.index j = x.val.index j) ∧ - (∀ j, i.toNat ≤ j → j < x.length → x'.val.index j = 0#u32) := by + (∀ j, j < i.val → x'[j]! = x[j]!) ∧ + (∀ j, i.val ≤ j → j < x.length → x'[j]! = 0#u32) := by rw [zero_loop] simp sorry @@ -609,8 +611,8 @@ def zero (x : alloc.vec.Vec U32) : Result (alloc.vec.Vec U32) := Advice: do the proof of `zero_spec` first, then come back to prove this lemma. -/ theorem all_nil_impl_toInt_eq_zero - (l : List U32) (h : ∀ (j : ℕ), j < l.length → l.index j = 0#u32) : - toInt_aux l = 0 := by + (l : List U32) (h : ∀ (j : ℕ), j < l.length → l[j]! = 0#u32) : + toInt l = 0 := by /- There are two ways of proving this theorem. Either you use the induction tactic applied to `l` (*advised*): @@ -674,8 +676,8 @@ divergent def add_no_overflow_loop You can try both tactics and see their effect. -/ @[simp] -theorem toInt_aux_drop (l : List U32) (i : Nat) (h0 : i < l.length) : - toInt_aux (l.drop i) = l.index i + 2 ^ 32 * toInt_aux (l.drop (i + 1)) := by +theorem toInt_drop (l : List U32) (i : Nat) (h0 : i < l.length) : + toInt (l.drop i) = l[i]! + 2 ^ 32 * toInt (l.drop (i + 1)) := by sorry /-- You will need this lemma for the proof of `add_no_overflow_loop_spec`. @@ -692,8 +694,8 @@ theorem toInt_aux_drop (l : List U32) (i : Nat) (h0 : i < l.length) : - or go see the solution -/ @[simp] -theorem toInt_aux_update (l : List U32) (i : Nat) (x : U32) (h0 : i < l.length) : - toInt_aux (l.update i x) = toInt_aux l + 2 ^ (32 * i) * (x - l.index i) := by +theorem toInt_update (l : List U32) (i : Nat) (x : U32) (h0 : i < l.length) : + toInt (l.set i x) = toInt l + 2 ^ (32 * i) * (x - l[i]!) := by sorry /-- The proof about `add_no_overflow_loop`. @@ -701,16 +703,16 @@ theorem toInt_aux_update (l : List U32) (i : Nat) (x : U32) (h0 : i < l.length) Hint: you will need to reason about non-linear arithmetic with `scalar_nf` and `scalar_eq_nf` (see above). -/ -@[pspec] +@[progress] theorem add_no_overflow_loop_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (i : Usize) (hLength : x.length = y.length) -- No overflow occurs when we add the individual thunks - (hNoOverflow : ∀ (j : Nat), i.toNat ≤ j → j < x.length → (x.val.index j).val + (y.val.index j).val ≤ U32.max) + (hNoOverflow : ∀ (j : Nat), i.val ≤ j → j < x.length → x[j]!.val + y[j]!.val ≤ U32.max) (hi : i.val ≤ x.length) : ∃ x', add_no_overflow_loop x y i = ok x' ∧ x'.length = x.length ∧ - toInt x' = toInt x + 2 ^ (32 * i.toNat) * toInt_aux (y.val.drop i.toNat) := by + toInt x' = toInt x + 2 ^ (32 * i.val) * toInt (y.val.drop i.val) := by rw [add_no_overflow_loop] simp sorry @@ -726,7 +728,7 @@ def add_no_overflow /-- The proof about `add_no_overflow` -/ theorem add_no_overflow_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (hLength : x.length = y.length) - (hNoOverflow : ∀ (j : Nat), j < x.length → (x.val.index j).val + (y.val.index j).val ≤ U32.max) : + (hNoOverflow : ∀ (j : Nat), j < x.length → x[j]!.val + y[j]!.val ≤ U32.max) : ∃ x', add_no_overflow x y = ok x' ∧ x'.length = y.length ∧ toInt x' = toInt x + toInt y := by @@ -746,15 +748,15 @@ divergent def add_with_carry_loop let i2 ← alloc.vec.Vec.index (core.slice.index.SliceIndexUsizeSliceTInst U32) x i - let i3 ← Scalar.cast .U32 c0 - let p ← core.num.U32.overflowing_add i2 i3 + let i3 := UScalar.cast .U32 c0 + let p := core.num.U32.overflowing_add i2 i3 let (sum, c1) := p let i4 ← alloc.vec.Vec.index (core.slice.index.SliceIndexUsizeSliceTInst U32) y i - let p1 ← core.num.U32.overflowing_add sum i4 + let p1 := core.num.U32.overflowing_add sum i4 let (sum1, c2) := p1 - let i5 ← Scalar.cast_bool .U8 c1 - let i6 ← Scalar.cast_bool .U8 c2 + let i5 := UScalar.cast_fromBool .U8 c1 + let i6 := UScalar.cast_fromBool .U8 c2 let c01 ← i5 + i6 let (_, index_mut_back) ← alloc.vec.Vec.index_mut @@ -765,7 +767,7 @@ divergent def add_with_carry_loop else Result.ok (c0, x) /-- The proof about `add_with_carry_loop` -/ -@[pspec] +@[progress] theorem add_with_carry_loop_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (c0 : U8) (i : Usize) (hLength : x.length = y.length) @@ -775,7 +777,7 @@ theorem add_with_carry_loop_spec x'.length = x.length ∧ c1.val ≤ 1 ∧ toInt x' + c1.val * 2 ^ (32 * x'.length) = - toInt x + 2 ^ (32 * i.toNat) * toInt_aux (y.val.drop i.toNat) + c0.val * 2 ^ (32 * i.toNat) := by + toInt x + 2 ^ (32 * i.val) * toInt (y.val.drop i.val) + c0.val * 2 ^ (32 * i.val) := by rw [add_with_carry_loop] simp sorry @@ -789,7 +791,7 @@ def add_with_carry add_with_carry_loop x y 0#u8 0#usize /-- The proof about `add_with_carry` -/ -@[pspec] +@[progress] theorem add_with_carry_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (hLength : x.length = y.length) : @@ -833,13 +835,13 @@ divergent def add_loop let yi ← get_or_zero y i let i1 ← alloc.vec.Vec.index (core.slice.index.SliceIndexUsizeSliceTInst U32) x i - let i2 ← Scalar.cast .U32 c0 - let p ← core.num.U32.overflowing_add i1 i2 + let i2 := UScalar.cast .U32 c0 + let p := core.num.U32.overflowing_add i1 i2 let (sum, c1) := p - let p1 ← core.num.U32.overflowing_add sum yi + let p1 := core.num.U32.overflowing_add sum yi let (sum1, c2) := p1 - let i3 ← Scalar.cast_bool .U8 c1 - let i4 ← Scalar.cast_bool .U8 c2 + let i3 := UScalar.cast_fromBool .U8 c1 + let i4 := UScalar.cast_fromBool .U8 c2 let c01 ← i3 + i4 let (_, index_mut_back) ← alloc.vec.Vec.index_mut @@ -850,7 +852,7 @@ divergent def add_loop else if c0 != 0#u8 then do - let i1 ← Scalar.cast .U32 c0 + let i1 := UScalar.cast .U32 c0 alloc.vec.Vec.push x i1 else Result.ok x diff --git a/tests/lean/Tutorial/Solutions.lean b/tests/lean/Tutorial/Solutions.lean index 0cb1451c..7531c007 100644 --- a/tests/lean/Tutorial/Solutions.lean +++ b/tests/lean/Tutorial/Solutions.lean @@ -6,6 +6,9 @@ set_option maxHeartbeats 1000000 namespace tutorial +/- This simp lemma replaces terms of the shape `l[i]!`: in the present case it is more annoying than anything -/ +attribute [-simp] List.getElem!_eq_getElem?_getD + /- # Basic tactics -/ /- Exercise 1: Version 1: -/ @@ -48,8 +51,8 @@ open CList theorem list_nth_mut1_spec {T: Type} [Inhabited T] (l : CList T) (i : U32) (h : i.val < l.toList.length) : ∃ x back, list_nth_mut1 l i = ok (x, back) ∧ - x = l.toList.index i.toNat ∧ - ∀ x', (back x').toList = l.toList.update i.toNat x' := by + x = l.toList[i.val]! ∧ + ∀ x', (back x').toList = l.toList.set i.val x' := by rw [list_nth_mut1, list_nth_mut1_loop] split . rename_i hd tl @@ -58,35 +61,34 @@ theorem list_nth_mut1_spec {T: Type} [Inhabited T] (l : CList T) (i : U32) simp split_conjs . -- Reasoning about `List.index`: - have hi : i.toNat = 0 := by scalar_tac + have hi : i.val = 0 := by scalar_tac simp only [hi] -- Without the `only`, this actually finished the goal - have hIndex := List.index_zero_cons hd tl.toList + have hIndex := @List.getElem!_cons_zero _ hd _ tl.toList simp only [hIndex] . intro x -- Reasoning about `List.update`: - have hi : i.toNat = 0 := by scalar_tac + have hi : i.val = 0 := by scalar_tac simp only [hi] -- Without the `only`, this actually finished the goal - have hUpdate := List.update_zero_cons hd tl.toList x + have hUpdate := List.set_cons_zero hd tl.toList x simp only [hUpdate] . simp at * progress as ⟨ i1, hi1 ⟩ progress as ⟨ tl1, back, htl1, hback ⟩ simp split_conjs - . have hIndex := List.index_nzero_cons hd tl.toList i.toNat (by scalar_tac) + . have hIndex := List.getElem!_cons_nzero hd tl.toList i.val (by scalar_tac) simp only [hIndex] simp only [htl1] - have hiEq : i1.toNat = i.toNat - 1 := by scalar_tac + have hiEq : i1.val = i.val - 1 := by scalar_tac simp only [hiEq] . -- Backward function intro x' simp [hback] - have hUpdate := List.update_nzero_cons hd tl.toList i.toNat x' (by scalar_tac) + have hUpdate := List.set_cons_nzero hd tl.toList i.val (by scalar_tac) x' simp only [hUpdate] - have hiEq : i1.toNat = i.toNat - 1 := by scalar_tac + have hiEq : i1.val = i.val - 1 := by scalar_tac simp only [hiEq] . simp_all - scalar_tac /-- Theorem about `list_nth_mut1`: simple version. @@ -98,8 +100,8 @@ theorem list_nth_mut1_spec {T: Type} [Inhabited T] (l : CList T) (i : U32) theorem list_nth_mut1_spec' {T: Type} [Inhabited T] (l : CList T) (i : U32) (h : i.val < l.toList.length) : ∃ x back, list_nth_mut1 l i = ok (x, back) ∧ - x = l.toList.index i.toNat ∧ - ∀ x', (back x').toList = l.toList.update i.toNat x' := by + x = l.toList[i.val]! ∧ + ∀ x', (back x').toList = l.toList.set i.val x' := by rw [list_nth_mut1, list_nth_mut1_loop] split . split @@ -118,10 +120,9 @@ theorem list_nth_mut1_spec' {T: Type} [Inhabited T] (l : CList T) (i : U32) intro x' simp [*] . simp_all - scalar_tac /-- Theorem about `list_tail`: verbose version -/ -@[pspec] +@[progress] theorem list_tail_spec {T : Type} (l : CList T) : ∃ back, list_tail l = ok (CList.CNil, back) ∧ ∀ tl', (back tl').toList = l.toList ++ tl'.toList := by @@ -140,7 +141,7 @@ theorem list_tail_spec {T : Type} (l : CList T) : simp /-- Theorem about `list_tail: simple version -/ -@[pspec] +@[progress] theorem list_tail_spec' {T : Type} (l : CList T) : ∃ back, list_tail l = ok (CList.CNil, back) ∧ ∀ tl', (back tl').toList = l.toList ++ tl'.toList := by @@ -154,7 +155,7 @@ theorem list_tail_spec' {T : Type} (l : CList T) : . simp /-- Theorem about `append_in_place` -/ -@[pspec] +@[progress] theorem append_in_place_spec {T : Type} (l0 l1 : CList T) : ∃ l2, append_in_place l0 l1 = ok l2 ∧ l2.toList = l0.toList ++ l1.toList := by @@ -162,7 +163,7 @@ theorem append_in_place_spec {T : Type} (l0 l1 : CList T) : progress as ⟨ tl, back ⟩ progress as ⟨ l2 ⟩ -@[pspec] +@[progress] theorem reverse_loop_spec {T : Type} (l : CList T) (out : CList T) : ∃ l', reverse_loop l out = ok l' ∧ l'.toList = l.toList.reverse ++ out.toList := by @@ -189,24 +190,21 @@ attribute [-simp] Int.reducePow Nat.reducePow -- Auxiliary definitions to interpret a vector of u32 as a mathematical integer @[simp] -def toInt_aux (l : List U32) : ℤ := +def toInt (l : List U32) : Int := match l with | [] => 0 | x :: l => - x + 2 ^ 32 * toInt_aux l - -@[reducible] -def toInt (x : alloc.vec.Vec U32) : ℤ := toInt_aux x.val + x + 2 ^ 32 * toInt l /-- The theorem about `zero_loop` -/ -@[pspec] +@[progress] theorem zero_loop_spec (x : alloc.vec.Vec U32) (i : Usize) (h : i.val ≤ x.length) : ∃ x', zero_loop x i = ok x' ∧ x'.length = x.length ∧ - (∀ j, j < i.toNat → x'.val.index j = x.val.index j) ∧ - (∀ j, i.toNat ≤ j → j < x.length → x'.val.index j = 0#u32) := by + (∀ j, j < i.val → x'[j]! = x[j]!) ∧ + (∀ j, i.val ≤ j → j < x.length → x'[j]! = 0#u32) := by rw [zero_loop] simp split @@ -219,16 +217,16 @@ theorem zero_loop_spec replace hSame := hSame j (by scalar_tac) simp_all . intro j h0 h1 - dcases j = i.toNat <;> simp_all + dcases j = i.val <;> try simp [*] have := hZero j (by scalar_tac) simp_all . simp; scalar_tac -termination_by (x.length - i.val).toNat +termination_by x.length - i.val decreasing_by scalar_decr_tac theorem all_nil_impl_toInt_eq_zero - (l : List U32) (h : ∀ (j : ℕ), j < l.length → l.index j = 0#u32) : - toInt_aux l = 0 := by + (l : List U32) (h : ∀ (j : ℕ), j < l.length → l[j]! = 0#u32) : + toInt l = 0 := by match l with | [] => simp | hd :: tl => @@ -258,29 +256,29 @@ theorem zero_spec (x : alloc.vec.Vec U32) : Advice: do the proof of `add_no_overflow_loop_spec` first, then come back to prove this lemma. -/ @[simp] -theorem toInt_aux_drop (l : List U32) (i : Nat) (h0 : i < l.length) : - toInt_aux (l.drop i) = l.index i + 2 ^ 32 * toInt_aux (l.drop (i + 1)) := by +theorem toInt_drop (l : List U32) (i : Nat) (h0 : i < l.length) : + toInt (l.drop i) = l[i]! + 2 ^ 32 * toInt (l.drop (i + 1)) := by cases l with | nil => simp at * | cons hd tl => simp_all dcases i = 0 <;> simp_all - have := toInt_aux_drop tl (i - 1) (by scalar_tac) + have := toInt_drop tl (i - 1) (by scalar_tac) simp_all scalar_nf at * have : 1 + (i - 1) = i := by scalar_tac simp [*] @[simp] -theorem toInt_aux_update (l : List U32) (i : Nat) (x : U32) (h0 : i < l.length) : - toInt_aux (l.update i x) = toInt_aux l + 2 ^ (32 * i) * (x - l.index i) := by +theorem toInt_update (l : List U32) (i : Nat) (x : U32) (h0 : i < l.length) : + toInt (l.set i x) = toInt l + 2 ^ (32 * i) * (x - l[i]!) := by cases l with | nil => simp at * | cons hd tl => simp_all dcases i = 0 <;> simp_all . scalar_eq_nf - . have := toInt_aux_update tl (i - 1) x (by scalar_tac) + . have := toInt_update tl (i - 1) x (by scalar_tac) simp_all scalar_nf at * scalar_eq_nf @@ -308,16 +306,16 @@ theorem toInt_aux_update (l : List U32) (i : Nat) (x : U32) (h0 : i < l.length) scalar_eq_nf /-- The proof about `add_no_overflow_loop` -/ -@[pspec] +@[progress] theorem add_no_overflow_loop_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (i : Usize) (hLength : x.length = y.length) -- No overflow occurs when we add the individual thunks - (hNoOverflow : ∀ (j : Nat), i.toNat ≤ j → j < x.length → (x.val.index j).val + (y.val.index j).val ≤ U32.max) + (hNoOverflow : ∀ (j : Nat), i.val ≤ j → j < x.length → x[j]!.val + y[j]!.val ≤ U32.max) (hi : i.val ≤ x.length) : ∃ x', add_no_overflow_loop x y i = ok x' ∧ x'.length = x.length ∧ - toInt x' = toInt x + 2 ^ (32 * i.toNat) * toInt_aux (y.val.drop i.toNat) := by + toInt x' = toInt x + 2 ^ (32 * i.val) * toInt (y.val.drop i.val) := by rw [add_no_overflow_loop] simp split @@ -325,7 +323,7 @@ theorem add_no_overflow_loop_spec progress as ⟨ xv ⟩ progress as ⟨ sum ⟩ . -- This precondition is not proven automatically - have := hNoOverflow i.toNat (by scalar_tac) (by scalar_tac) + have := hNoOverflow i.val (by scalar_tac) (by scalar_tac) scalar_tac progress as ⟨ i' ⟩ progress as ⟨ x1 ⟩ @@ -333,43 +331,37 @@ theorem add_no_overflow_loop_spec intro j h0 h1 simp_all -- Simplifying (x.update ...).index: - have := List.index_update_neq x.val i.toNat j sum (by scalar_tac) + have := List.getElem!_set_neq x.val i.val j sum (by scalar_tac) simp [*] apply hNoOverflow j (by scalar_tac) (by scalar_tac) -- Postcondition - /- Note that you don't have to manually call the lemmas `toInt_aux_update` - and `toInt_aux_drop` below if you first do: + /- Note that you don't have to manually call the lemmas `toInt_update` + and `toInt_drop` below if you first do: ``` - have : i.toNat < x.length := by scalar_tac + have : i.val < x.length := by scalar_tac ``` (simp_all will automatically apply the lemmas and prove the the precondition sby using the context) -/ - simp_all [toInt] + simp_all scalar_eq_nf - -- Simplifying: toInt_aux ((↑x).update (↑i).toNat sum) - have := toInt_aux_update x.val i.toNat sum (by scalar_tac) - simp [*]; scalar_eq_nf - -- Simplifying: toInt_aux (List.drop (1 + (↑i).toNat) ↑y - have := toInt_aux_drop y.val i.toNat (by scalar_tac) - simp [*]; scalar_eq_nf . simp_all -termination_by (x.length - i.val).toNat +termination_by x.length - i.val decreasing_by scalar_decr_tac /-- The proof about `add_no_overflow` -/ theorem add_no_overflow_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (hLength : x.length = y.length) - (hNoOverflow : ∀ (j : Nat), j < x.length → (x.val.index j).val + (y.val.index j).val ≤ U32.max) : + (hNoOverflow : ∀ (j : Nat), j < x.length → x[j]!.val + y[j]!.val ≤ U32.max) : ∃ x', add_no_overflow x y = ok x' ∧ x'.length = y.length ∧ toInt x' = toInt x + toInt y := by rw [add_no_overflow] progress as ⟨ x' ⟩ <;> - simp_all [toInt] + simp_all /-- The proof about `add_with_carry_loop` -/ -@[pspec] +@[progress] theorem add_with_carry_loop_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (c0 : U8) (i : Usize) (hLength : x.length = y.length) @@ -379,38 +371,67 @@ theorem add_with_carry_loop_spec x'.length = x.length ∧ c1.val ≤ 1 ∧ toInt x' + c1.val * 2 ^ (32 * x'.length) = - toInt x + 2 ^ (32 * i.toNat) * toInt_aux (y.val.drop i.toNat) + c0.val * 2 ^ (32 * i.toNat) := by + toInt x + 2 ^ (32 * i.val) * toInt (y.val.drop i.val) + c0.val * 2 ^ (32 * i.val) := by rw [add_with_carry_loop] simp split . progress as ⟨ xi ⟩ progress as ⟨ c0u ⟩ - . progress as ⟨ s1, c1, hConv1 ⟩ - progress as ⟨ yi ⟩ - progress as ⟨ s2, c2, hConv2 ⟩ - progress as ⟨ c1u ⟩ - progress as ⟨ c2u ⟩ - progress as ⟨ c3 ⟩ - progress as ⟨ _ ⟩ - progress as ⟨ i1 ⟩ - progress as ⟨ c4, x1 ⟩ - -- Proving the post-condition - simp_all [toInt] - have hxUpdate := toInt_aux_update x.val i.toNat s2 (by scalar_tac) + have : c0u.val = c0.val := by scalar_tac + progress as ⟨ s1, c1, hConv1 ⟩ + progress as ⟨ yi ⟩ + progress as ⟨ s2, c2, hConv2 ⟩ + progress as ⟨ c1u, hc1u ⟩ + progress as ⟨ c2u, hc2u ⟩ + progress as ⟨ c3, hc3 ⟩ + progress as ⟨ _ ⟩ + progress as ⟨ i1 ⟩ + have : c3.val ≤ 1 := by + /- We need to make a case disjunction on hConv1 and hConv2. + This can be done with `split at hConv1 <;> ...`, but + `scalar_tac` can actually do it for us with the `+split`` + option, which allows it to make a case disjunction over + the `if then else` appearing in the context. + -/ + scalar_tac +split + progress as ⟨ c4, x1, _, _, hc4 ⟩ + -- Proving the post-condition + split_conjs + . simp [*] + . simp [*] + . simp [hc4] + have hxUpdate := toInt_update x.val i.val s2 (by scalar_tac) simp [hxUpdate]; clear hxUpdate - have hyDrop := toInt_aux_drop y.val i.toNat (by scalar_tac) + have hyDrop := toInt_drop y.val i.val (by scalar_tac) simp [hyDrop]; clear hyDrop scalar_eq_nf + -- The best way is to do a case disjunction and treat each sub-case separately split at hConv1 <;> - split at hConv2 <;> - simp_all <;> - scalar_eq_nf <;> simp [U32.max] <;> scalar_eq_nf + split at hConv2 + . have hConv1' : (s1.val : Int) = xi.val + c0u.val - U32.size := by scalar_tac + have hConv2' : (s2.val : Int) = s1.val + yi.val - U32.size := by scalar_tac + simp [hConv2', hConv1'] + /- `U32.size_eq` is a lemma which allows to simplify `U32.size`. + But you can also simply do `simp [U32.size]`, which simplifies + `U32.size` to `2^U32.numBits`, then simplify `U32.numBits`. -/ + simp [*, U32.size_eq] + scalar_eq_nf + . have hConv1' : (s1.val : Int) = xi.val + c0u.val - U32.size := by scalar_tac + simp [hConv2, hConv1'] + simp [*, U32.size_eq] + scalar_eq_nf + . have hConv2' : (s2.val : Int) = s1.val + yi.val - U32.size := by scalar_tac + simp [hConv2', hConv1] + simp [*, U32.size_eq] + scalar_eq_nf + . simp [*, U32.size_eq] + scalar_eq_nf . simp_all -termination_by (x.length - i.val).toNat +termination_by x.length - i.val decreasing_by scalar_decr_tac /-- The proof about `add_with_carry` -/ -@[pspec] +@[progress] theorem add_with_carry_spec (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) (hLength : x.length = y.length) : diff --git a/tests/lean/Tutorial/Tutorial.lean b/tests/lean/Tutorial/Tutorial.lean index 2fd8e653..fcb70745 100644 --- a/tests/lean/Tutorial/Tutorial.lean +++ b/tests/lean/Tutorial/Tutorial.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [tutorial] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -14,9 +14,9 @@ def choose {T : Type} (b : Bool) (x : T) (y : T) : Result (T × (T → (T × T))) := if b then let back := fun ret => (ret, y) - Result.ok (x, back) + ok (x, back) else let back := fun ret => (x, ret) - Result.ok (y, back) + ok (y, back) /- [tutorial::mul2_add1]: Source: 'src/lib.rs', lines 9:0-11:1 -/ @@ -44,7 +44,7 @@ def use_incr : Result Unit := let x ← incr 0#u32 let x1 ← incr x let _ ← incr x1 - Result.ok () + ok () /- [tutorial::CList] Source: 'src/lib.rs', lines 30:0-33:1 -/ @@ -58,11 +58,11 @@ divergent def list_nth {T : Type} (l : CList T) (i : U32) : Result T := match l with | CList.CCons x tl => if i = 0#u32 - then Result.ok x + then ok x else do let i1 ← i - 1#u32 list_nth tl i1 - | CList.CNil => Result.fail .panic + | CList.CNil => fail panic /- [tutorial::list_nth_mut]: Source: 'src/lib.rs', lines 50:0-63:1 -/ @@ -72,15 +72,15 @@ divergent def list_nth_mut | CList.CCons x tl => if i = 0#u32 then let back := fun ret => CList.CCons ret tl - Result.ok (x, back) + ok (x, back) else do let i1 ← i - 1#u32 let (t, list_nth_mut_back) ← list_nth_mut tl i1 let back := fun ret => let tl1 := list_nth_mut_back ret CList.CCons x tl1 - Result.ok (t, back) - | CList.CNil => Result.fail .panic + ok (t, back) + | CList.CNil => fail panic /- [tutorial::list_nth1]: loop 0: Source: 'src/lib.rs', lines 66:4-74:1 -/ @@ -88,11 +88,11 @@ divergent def list_nth1_loop {T : Type} (l : CList T) (i : U32) : Result T := match l with | CList.CCons x tl => if i = 0#u32 - then Result.ok x + then ok x else do let i1 ← i - 1#u32 list_nth1_loop tl i1 - | CList.CNil => Result.fail .panic + | CList.CNil => fail panic /- [tutorial::list_nth1]: Source: 'src/lib.rs', lines 65:0-74:1 -/ @@ -104,7 +104,7 @@ def list_nth1 {T : Type} (l : CList T) (i : U32) : Result T := Source: 'src/lib.rs', lines 76:0-83:1 -/ divergent def i32_id (i : I32) : Result I32 := if i = 0#i32 - then Result.ok 0#i32 + then ok 0#i32 else do let i1 ← i - 1#i32 let i2 ← i32_id i1 @@ -114,7 +114,7 @@ divergent def i32_id (i : I32) : Result I32 := Source: 'src/lib.rs', lines 85:0-92:1 -/ mutual divergent def even (i : U32) : Result Bool := if i = 0#u32 - then Result.ok true + then ok true else do let i1 ← i - 1#u32 odd i1 @@ -123,7 +123,7 @@ mutual divergent def even (i : U32) : Result Bool := Source: 'src/lib.rs', lines 94:0-101:1 -/ divergent def odd (i : U32) : Result Bool := if i = 0#u32 - then Result.ok false + then ok false else do let i1 ← i - 1#u32 even i1 @@ -140,7 +140,7 @@ structure Counter (Self : Type) where def CounterUsize.incr (self : Usize) : Result (Usize × Usize) := do let self1 ← self + 1#usize - Result.ok (self, self1) + ok (self, self1) /- Trait implementation: [tutorial::{tutorial::Counter for usize}] Source: 'src/lib.rs', lines 109:0-115:1 -/ @@ -163,15 +163,15 @@ divergent def list_nth_mut1_loop | CList.CCons x tl => if i = 0#u32 then let back := fun ret => CList.CCons ret tl - Result.ok (x, back) + ok (x, back) else do let i1 ← i - 1#u32 let (t, back) ← list_nth_mut1_loop tl i1 let back1 := fun ret => let tl1 := back ret CList.CCons x tl1 - Result.ok (t, back1) - | CList.CNil => Result.fail .panic + ok (t, back1) + | CList.CNil => fail panic /- [tutorial::list_nth_mut1]: Source: 'src/lib.rs', lines 123:0-132:1 -/ @@ -190,8 +190,8 @@ divergent def list_tail_loop let (c, back) ← list_tail_loop tl let back1 := fun ret => let tl1 := back ret CList.CCons t tl1 - Result.ok (c, back1) - | CList.CNil => Result.ok (CList.CNil, fun ret => ret) + ok (c, back1) + | CList.CNil => ok (CList.CNil, fun ret => ret) /- [tutorial::list_tail]: Source: 'src/lib.rs', lines 134:0-139:1 -/ @@ -206,7 +206,7 @@ def append_in_place {T : Type} (l0 : CList T) (l1 : CList T) : Result (CList T) := do let (_, list_tail_back) ← list_tail l0 - Result.ok (list_tail_back l1) + ok (list_tail_back l1) /- [tutorial::reverse]: loop 0: Source: 'src/lib.rs', lines 148:4-152:5 -/ @@ -214,10 +214,11 @@ divergent def reverse_loop {T : Type} (l : CList T) (out : CList T) : Result (CList T) := match l with | CList.CCons hd tl => reverse_loop tl (CList.CCons hd out) - | CList.CNil => Result.ok out + | CList.CNil => ok out /- [tutorial::reverse]: Source: 'src/lib.rs', lines 146:0-154:1 -/ +@[reducible] def reverse {T : Type} (l : CList T) : Result (CList T) := reverse_loop l CList.CNil @@ -235,10 +236,11 @@ divergent def zero_loop let i2 ← i + 1#usize let x1 := index_mut_back 0#u32 zero_loop x1 i2 - else Result.ok x + else ok x /- [tutorial::zero]: Source: 'src/lib.rs', lines 162:0-168:1 -/ +@[reducible] def zero (x : alloc.vec.Vec U32) : Result (alloc.vec.Vec U32) := zero_loop x 0#usize @@ -261,10 +263,11 @@ divergent def add_no_overflow_loop let i5 ← i + 1#usize let x1 := index_mut_back i4 add_no_overflow_loop x1 y i5 - else Result.ok x + else ok x /- [tutorial::add_no_overflow]: Source: 'src/lib.rs', lines 175:0-181:1 -/ +@[reducible] def add_no_overflow (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) : Result (alloc.vec.Vec U32) @@ -283,15 +286,14 @@ divergent def add_with_carry_loop do let i2 ← alloc.vec.Vec.index (core.slice.index.SliceIndexUsizeSliceTInst U32) x i - let i3 ← Scalar.cast .U32 c0 - let p ← core.num.U32.overflowing_add i2 i3 - let (sum, c1) := p + let (i3 : U32) ← ↑(UScalar.cast .U32 c0) + let ((sum, c1) : (U32 × Bool)) ← ↑(core.num.U32.overflowing_add i2 i3) let i4 ← alloc.vec.Vec.index (core.slice.index.SliceIndexUsizeSliceTInst U32) y i - let p1 ← core.num.U32.overflowing_add sum i4 - let (sum1, c2) := p1 - let i5 ← Scalar.cast_bool .U8 c1 - let i6 ← Scalar.cast_bool .U8 c2 + let ((sum1, c2) : (U32 × Bool)) ← + ↑(core.num.U32.overflowing_add sum i4) + let (i5 : U8) ← ↑(UScalar.cast_fromBool .U8 c1) + let (i6 : U8) ← ↑(UScalar.cast_fromBool .U8 c2) let c01 ← i5 + i6 let (_, index_mut_back) ← alloc.vec.Vec.index_mut (core.slice.index.SliceIndexUsizeSliceTInst U32) @@ -299,10 +301,11 @@ divergent def add_with_carry_loop let i7 ← i + 1#usize let x1 := index_mut_back sum1 add_with_carry_loop x1 y c01 i7 - else Result.ok (c0, x) + else ok (c0, x) /- [tutorial::add_with_carry]: Source: 'src/lib.rs', lines 186:0-199:1 -/ +@[reducible] def add_with_carry (x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) : Result (U8 × (alloc.vec.Vec U32)) @@ -313,8 +316,8 @@ def add_with_carry Source: 'src/lib.rs', lines 201:0-203:1 -/ def max (x : Usize) (y : Usize) : Result Usize := if x > y - then Result.ok x - else Result.ok y + then ok x + else ok y /- [tutorial::get_or_zero]: Source: 'src/lib.rs', lines 205:0-207:1 -/ @@ -322,7 +325,7 @@ def get_or_zero (y : alloc.vec.Vec U32) (i : Usize) : Result U32 := let i1 := alloc.vec.Vec.len y if i < i1 then alloc.vec.Vec.index (core.slice.index.SliceIndexUsizeSliceTInst U32) y i - else Result.ok 0#u32 + else ok 0#u32 /- [tutorial::add]: loop 0: Source: 'src/lib.rs', lines 221:4-229:5 -/ @@ -337,13 +340,12 @@ divergent def add_loop let yi ← get_or_zero y i let i1 ← alloc.vec.Vec.index (core.slice.index.SliceIndexUsizeSliceTInst U32) x i - let i2 ← Scalar.cast .U32 c0 - let p ← core.num.U32.overflowing_add i1 i2 - let (sum, c1) := p - let p1 ← core.num.U32.overflowing_add sum yi - let (sum1, c2) := p1 - let i3 ← Scalar.cast_bool .U8 c1 - let i4 ← Scalar.cast_bool .U8 c2 + let (i2 : U32) ← ↑(UScalar.cast .U32 c0) + let ((sum, c1) : (U32 × Bool)) ← ↑(core.num.U32.overflowing_add i1 i2) + let ((sum1, c2) : (U32 × Bool)) ← + ↑(core.num.U32.overflowing_add sum yi) + let (i3 : U8) ← ↑(UScalar.cast_fromBool .U8 c1) + let (i4 : U8) ← ↑(UScalar.cast_fromBool .U8 c2) let c01 ← i3 + i4 let (_, index_mut_back) ← alloc.vec.Vec.index_mut (core.slice.index.SliceIndexUsizeSliceTInst U32) @@ -353,10 +355,11 @@ divergent def add_loop add_loop x1 y max1 c01 i5 else if c0 != 0#u8 - then do - let i1 ← Scalar.cast .U32 c0 - alloc.vec.Vec.push x i1 - else Result.ok x + then + do + let (i1 : U32) ← ↑(UScalar.cast .U32 c0) + alloc.vec.Vec.push x i1 + else ok x /- [tutorial::add]: Source: 'src/lib.rs', lines 214:0-235:1 -/ diff --git a/tests/lean/Vec.lean b/tests/lean/Vec.lean index d96e715b..702d25cf 100644 --- a/tests/lean/Vec.lean +++ b/tests/lean/Vec.lean @@ -1,7 +1,7 @@ -- THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS -- [vec] import Aeneas -open Aeneas.Std +open Aeneas.Std Result Error set_option linter.dupNamespace false set_option linter.hashCommand false set_option linter.unusedVariables false @@ -21,6 +21,6 @@ def use_extend_from_slice Source: 'tests/src/vec.rs', lines 9:0-11:1 -/ def use_alloc_with_capacity (T : Type) (n : Usize) : Result (alloc.vec.Vec T) := - Result.ok (alloc.vec.Vec.with_capacity T n) + ok (alloc.vec.Vec.with_capacity T n) end vec diff --git a/tests/src/as_mut.rs b/tests/src/as_mut.rs new file mode 100644 index 00000000..69ffad2d --- /dev/null +++ b/tests/src/as_mut.rs @@ -0,0 +1,8 @@ +//@ [!lean] skip +fn use_box_as_mut(mut x : &mut Box) -> &mut T{ + x.as_mut() +} + +fn use_as_mut>(mut x : &mut T) -> &mut S { + x.as_mut() +} diff --git a/tests/src/mutually-recursive-traits.lean.out b/tests/src/mutually-recursive-traits.lean.out index 7938481c..c97ef5e3 100644 --- a/tests/src/mutually-recursive-traits.lean.out +++ b/tests/src/mutually-recursive-traits.lean.out @@ -13,5 +13,5 @@ Called from Aeneas__Translate.extract_definitions.export_decl_group in file "Tra Called from Stdlib__List.iter in file "list.ml", line 110, characters 12-15 Called from Aeneas__Translate.extract_definitions in file "Translate.ml", line 882, characters 2-177 Called from Aeneas__Translate.extract_file in file "Translate.ml", line 1014, characters 2-36 -Called from Aeneas__Translate.translate_crate in file "Translate.ml", line 1648, characters 5-42 -Called from Dune__exe__Main in file "Main.ml", line 577, characters 11-63 +Called from Aeneas__Translate.translate_crate in file "Translate.ml", line 1647, characters 5-42 +Called from Dune__exe__Main in file "Main.ml", line 579, characters 11-63