Skip to content

Commit

Permalink
feat: make <num>#<term> bitvector literal notation global (#4260)
Browse files Browse the repository at this point in the history
The bitvector literal notation (e.g., `0x1#4`) is currently scoped, but
it is used when pretty-printing. This has led to confusion among users
who find that the notation does not work in their code unless they
include `open BitVec`.

Users are often puzzled when they try to use the bitvector literal
notation directly in their code, only to find it doesn't work without
explicitly opening `BitVec`. This behavior seems counterintuitive and
has been a source of frustration. Additionally, this notation is popular
among users for its compact and expressive representation.

- Alternative Notations:
- Using polymorphic numerals (e.g., `1`) loses valuable bitvector size
information.
- Using the verbose form (e.g., `(1 : BitVec 4)`) is cumbersome and less
readable.

@kmill and @semorrison suggested the updated syntax definition for
bitvector literals:

```lean
scoped syntax:max num noWs "#" noWs term:max : term
macro_rules | `($i:num#$n) => `(BitVec.ofNat $n $i)
```

This change ensures that:
- There is no lexical conflict with existing Mathlib notations,
particularly with cardinality and Finset notations.

- The notation remains intuitive and easy to use.

- Users will no longer need to remember to open `BitVec` to use the
notation, making it more intuitive and less error-prone.

- Pretty-printing BitVec-heavy goals, such as those in the SSFT24
tutorial, will be more readable and less verbose.

- We can still write `0#n` to denote `BitVec.ofNat n 0`

---------

Co-authored-by: Kim Morrison <[email protected]>
  • Loading branch information
leodemoura and kim-em authored Jun 6, 2024
1 parent faea7f9 commit ff69c28
Show file tree
Hide file tree
Showing 51 changed files with 68,040 additions and 56,190 deletions.
8 changes: 4 additions & 4 deletions src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ end Int
section Syntax

/-- Notation for bit vector literals. `i#n` is a shorthand for `BitVec.ofNat n i`. -/
scoped syntax:max term:max noWs "#" noWs term:max : term
macro_rules | `($i#$n) => `(BitVec.ofNat $n $i)
syntax:max num noWs "#" noWs term:max : term
macro_rules | `($i:num#$n) => `(BitVec.ofNat $n $i)

/-- Unexpander for bit vector literals. -/
@[app_unexpander BitVec.ofNat] def unexpandBitVecOfNat : Lean.PrettyPrinter.Unexpander
| `($(_) $n $i) => `($i#$n)
| `($(_) $n $i:num) => `($i:num#$n)
| _ => throw ()

/-- Notation for bit vector literals without truncation. `i#'lt` is a shorthand for `BitVec.ofNatLt i lt`. -/
Expand Down Expand Up @@ -504,7 +504,7 @@ equivalent to `a * 2^s`, modulo `2^n`.
SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value.
-/
protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := (a.toNat <<< s)#n
protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (a.toNat <<< s)
instance : HShiftLeft (BitVec w) Nat (BitVec w) := ⟨.shiftLeft⟩

/--
Expand Down
32 changes: 16 additions & 16 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ theorem ofBool_eq_iff_eq : ∀(b b' : Bool), BitVec.ofBool b = BitVec.ofBool b'
getLsb (x#'lt) i = x.testBit i := by
simp [getLsb, BitVec.ofNatLt]

@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (x#w).toNat = x % 2^w := by
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']

@[simp] theorem toFin_ofNat (x : Nat) : toFin x#w = Fin.ofNat' x (Nat.two_pow_pos w) := rfl
@[simp] theorem toFin_ofNat (x : Nat) : toFin (BitVec.ofNat w x) = Fin.ofNat' x (Nat.two_pow_pos w) := rfl

-- Remark: we don't use `[simp]` here because simproc` subsumes it for literals.
-- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea.
theorem getLsb_ofNat (n : Nat) (x : Nat) (i : Nat) :
getLsb (x#n) i = (i < n && x.testBit i) := by
getLsb (BitVec.ofNat n x) i = (i < n && x.testBit i) := by
simp [getLsb, BitVec.ofNat, Fin.val_ofNat']

@[simp, deprecated toNat_ofNat (since := "2024-02-22")]
Expand Down Expand Up @@ -316,19 +316,19 @@ theorem zeroExtend'_eq {x : BitVec w} (h : w ≤ v) : x.zeroExtend' h = x.zeroEx
let ⟨x, lt_n⟩ := x
simp [truncate, zeroExtend]

@[simp] theorem zeroExtend_zero (m n : Nat) : zeroExtend m (0#n) = 0#m := by
@[simp] theorem zeroExtend_zero (m n : Nat) : zeroExtend m 0#n = 0#m := by
apply eq_of_toNat_eq
simp [toNat_zeroExtend]

@[simp] theorem truncate_eq (x : BitVec n) : truncate n x = x := zeroExtend_eq x

@[simp] theorem ofNat_toNat (m : Nat) (x : BitVec n) : x.toNat#m = truncate m x := by
@[simp] theorem ofNat_toNat (m : Nat) (x : BitVec n) : BitVec.ofNat m x.toNat = truncate m x := by
apply eq_of_toNat_eq
simp

/-- Moves one-sided left toNat equality to BitVec equality. -/
theorem toNat_eq_nat (x : BitVec w) (y : Nat)
: (x.toNat = y) ↔ (y < 2^w ∧ (x = y#w)) := by
: (x.toNat = y) ↔ (y < 2^w ∧ (x = BitVec.ofNat w y)) := by
apply Iff.intro
· intro eq
simp at eq
Expand All @@ -340,7 +340,7 @@ theorem toNat_eq_nat (x : BitVec w) (y : Nat)

/-- Moves one-sided right toNat equality to BitVec equality. -/
theorem nat_eq_toNat (x : BitVec w) (y : Nat)
: (y = x.toNat) ↔ (y < 2^w ∧ (x = y#w)) := by
: (y = x.toNat) ↔ (y < 2^w ∧ (x = BitVec.ofNat w y)) := by
rw [@eq_comm _ _ x.toNat]
apply toNat_eq_nat

Expand Down Expand Up @@ -416,7 +416,7 @@ protected theorem extractLsb_ofFin {n} (x : Fin (2^n)) (hi lo : Nat) :

@[simp]
protected theorem extractLsb_ofNat (x n : Nat) (hi lo : Nat) :
extractLsb hi lo x#n = .ofNat (hi - lo + 1) ((x % 2^n) >>> lo) := by
extractLsb hi lo (BitVec.ofNat n x) = .ofNat (hi - lo + 1) ((x % 2^n) >>> lo) := by
apply eq_of_getLsb_eq
intro ⟨i, _lt⟩
simp [BitVec.ofNat]
Expand Down Expand Up @@ -1008,10 +1008,10 @@ Definition of bitvector addition as a nat.
@[simp] theorem add_ofFin (x : BitVec n) (y : Fin (2^n)) :
x + .ofFin y = .ofFin (x.toFin + y) := rfl

theorem ofNat_add {n} (x y : Nat) : (x + y)#n = x#n + y#n := by
theorem ofNat_add {n} (x y : Nat) : BitVec.ofNat n (x + y) = BitVec.ofNat n x + BitVec.ofNat n y := by
apply eq_of_toNat_eq ; simp [BitVec.ofNat]

theorem ofNat_add_ofNat {n} (x y : Nat) : x#n + y#n = (x + y)#n :=
theorem ofNat_add_ofNat {n} (x y : Nat) : BitVec.ofNat n x + BitVec.ofNat n y = BitVec.ofNat n (x + y) :=
(ofNat_add x y).symm

protected theorem add_assoc (x y z : BitVec n) : x + y + z = x + (y + z) := by
Expand Down Expand Up @@ -1057,10 +1057,10 @@ theorem sub_def {n} (x y : BitVec n) : x - y = .ofNat n (x.toNat + (2^n - y.toNa
rfl
-- Remark: we don't use `[simp]` here because simproc` subsumes it for literals.
-- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea.
theorem ofNat_sub_ofNat {n} (x y : Nat) : x#n - y#n = .ofNat n (x + (2^n - y % 2^n)) := by
theorem ofNat_sub_ofNat {n} (x y : Nat) : BitVec.ofNat n x - BitVec.ofNat n y = .ofNat n (x + (2^n - y % 2^n)) := by
apply eq_of_toNat_eq ; simp [BitVec.ofNat]

@[simp] protected theorem sub_zero (x : BitVec n) : x - (0#n) = x := by apply eq_of_toNat_eq ; simp
@[simp] protected theorem sub_zero (x : BitVec n) : x - 0#n = x := by apply eq_of_toNat_eq ; simp

@[simp] protected theorem sub_self (x : BitVec n) : x - x = 0#n := by
apply eq_of_toNat_eq
Expand All @@ -1080,7 +1080,7 @@ theorem sub_toAdd {n} (x y : BitVec n) : x - y = x + - y := by
apply eq_of_toNat_eq
simp

@[simp] theorem neg_zero (n:Nat) : -0#n = 0#n := by apply eq_of_toNat_eq ; simp
@[simp] theorem neg_zero (n:Nat) : -BitVec.ofNat n 0 = BitVec.ofNat n 0 := by apply eq_of_toNat_eq ; simp

theorem add_sub_cancel (x y : BitVec w) : x + y - y = x := by
apply eq_of_toNat_eq
Expand Down Expand Up @@ -1157,7 +1157,7 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) =
x ≤ BitVec.ofFin y ↔ x.toFin ≤ y := Iff.rfl
@[simp] theorem ofFin_le (x : Fin (2^n)) (y : BitVec n) :
BitVec.ofFin x ≤ y ↔ x ≤ y.toFin := Iff.rfl
@[simp] theorem ofNat_le_ofNat {n} (x y : Nat) : (x#n) ≤ (y#n) ↔ x % 2^n ≤ y % 2^n := by
@[simp] theorem ofNat_le_ofNat {n} (x y : Nat) : (BitVec.ofNat n x) ≤ (BitVec.ofNat n y) ↔ x % 2^n ≤ y % 2^n := by
simp [le_def]

@[bv_toNat] theorem lt_def (x y : BitVec n) :
Expand All @@ -1167,7 +1167,7 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) =
x < BitVec.ofFin y ↔ x.toFin < y := Iff.rfl
@[simp] theorem ofFin_lt (x : Fin (2^n)) (y : BitVec n) :
BitVec.ofFin x < y ↔ x < y.toFin := Iff.rfl
@[simp] theorem ofNat_lt_ofNat {n} (x y : Nat) : (x#n) < (y#n) ↔ x % 2^n < y % 2^n := by
@[simp] theorem ofNat_lt_ofNat {n} (x y : Nat) : BitVec.ofNat n x < BitVec.ofNat n y ↔ x % 2^n < y % 2^n := by
simp [lt_def]

protected theorem lt_of_le_ne (x y : BitVec n) (h1 : x <= y) (h2 : ¬ x = y) : x < y := by
Expand All @@ -1180,7 +1180,7 @@ protected theorem lt_of_le_ne (x y : BitVec n) (h1 : x <= y) (h2 : ¬ x = y) : x
/-! ### intMax -/

/-- The bitvector of width `w` that has the largest value when interpreted as an integer. -/
def intMax (w : Nat) : BitVec w := (2^w - 1)#w
def intMax (w : Nat) : BitVec w := BitVec.ofNat w (2^w - 1)

theorem getLsb_intMax_eq (w : Nat) : (intMax w).getLsb i = decide (i < w) := by
simp [intMax, getLsb]
Expand Down
Loading

0 comments on commit ff69c28

Please sign in to comment.