From 08860fdbeb7f03c4fceb22e0bb64161b37418306 Mon Sep 17 00:00:00 2001 From: Sean McLaughlin Date: Mon, 6 Jan 2025 13:56:44 -0800 Subject: [PATCH 1/4] Advanced indexing We implement a common version of advanced indexing, where all arguments are (broadcastable to) int arrays. --- TensorLib/Broadcast.lean | 21 ++++++ TensorLib/Common.lean | 67 +++++++++++++++++-- TensorLib/Dtype.lean | 14 ++++ TensorLib/Index.lean | 141 ++++++++++++++++++++++++++++++++------- TensorLib/Tensor.lean | 50 +++++++++++++- 5 files changed, 262 insertions(+), 31 deletions(-) diff --git a/TensorLib/Broadcast.lean b/TensorLib/Broadcast.lean index c6583b2..33c3d92 100644 --- a/TensorLib/Broadcast.lean +++ b/TensorLib/Broadcast.lean @@ -39,6 +39,8 @@ Rule 2 A: (3, 2, 7) B: (3, 2, 7) + +Theorem to prove: If we can broadcast s1 to s2, then given an array with shape s1, then s1.reshape s2 succeeds -/ namespace TensorLib @@ -103,4 +105,23 @@ def canBroadcast (b : Broadcast) : Bool := (broadcast b).isSome broadcast b2 == broadcast b1 && broadcast b2 == .some (Shape.mk [1, 2, 3]) +def broadcastList (shapes : List Shape) : Option Shape := Id.run do + match shapes with + | [] => none + | shape :: shapes => + let mut shape := shape + for s in shapes do + let b := Broadcast.mk shape s + match b.broadcast with + | .none => return .none + | .some s => + shape := s + return shape + +#guard + let x1 := Shape.mk [1, 2, 3] + let x2 := Shape.mk [2, 3] + let x3 := Shape.mk [] + broadcastList [x1, x2, x3] == .some x1 + end Broadcast diff --git a/TensorLib/Common.lean b/TensorLib/Common.lean index dd6f57e..a454870 100644 --- a/TensorLib/Common.lean +++ b/TensorLib/Common.lean @@ -5,6 +5,7 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin -/ import Std.Tactic.BVDecide +import Batteries.Data.List namespace TensorLib @@ -27,6 +28,7 @@ def natDivCeil (num denom : Nat) : Nat := (num + denom - 1) / denom def natProd (shape : List Nat) : Nat := shape.foldl (fun x y => x * y) 1 + -- We generally have large tensors, so don't show them by default instance ByteArrayRepr : Repr ByteArray where reprPrec x _ := @@ -45,11 +47,41 @@ inductive ByteOrder where | bigEndian deriving BEq, Repr, Inhabited +namespace ByteOrder + @[simp] -def ByteOrder.isMultiByte (x : ByteOrder) : Bool := match x with +def isMultiByte (x : ByteOrder) : Bool := match x with | .oneByte => false | .littleEndian | .bigEndian => true +def bytesToInt (order : ByteOrder) (bytes : ByteArray) : Int := Id.run do + let mut n : Nat := 0 + let nbytes := bytes.size + let signByte := match order with + | .littleEndian => bytes.get! (nbytes - 1) + | .bigEndian | oneByte => bytes.get! 0 + let negative := 128 <= signByte + for i in [0:nbytes] do + let v : UInt8 := bytes.get! i + let v := if negative then UInt8.complement v else v + let p := match order with + | .oneByte => 0 -- nbytes = 1 + | .littleEndian => i + | .bigEndian => nbytes - 1 - i + n := n + Pow.pow 2 (8 * p) * v.toNat + return if 128 <= signByte then -(n + 1) else n + +#guard bytesToInt .littleEndian (ByteArray.mk #[1, 1]) == 257 +#guard bytesToInt .bigEndian (ByteArray.mk #[1, 1]) == 257 +#guard bytesToInt .littleEndian (ByteArray.mk #[0, 1]) == 256 +#guard bytesToInt .bigEndian (ByteArray.mk #[0, 1]) == 1 +#guard bytesToInt .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == -1 +#guard bytesToInt .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == -1 +#guard bytesToInt .bigEndian (ByteArray.mk #[0x80, 0]) == -32768 +#guard bytesToInt .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80 + +end ByteOrder + /-! The strides are how many bytes you need to skip to get to the next element in that "row". For example, in an array of 8-byte data with shape 2, 3, the strides are (24, 8). @@ -144,6 +176,9 @@ deriving BEq, Repr, Inhabited namespace Shape +instance : ToString Shape where + toString := reprStr + def empty : Shape := Shape.mk [] --! The number of elements in a tensor. All that's needed is the shape for this calculation. @@ -155,6 +190,10 @@ def ndim (shape : Shape) : Nat := shape.val.length def map (shape : Shape) (f : List Nat -> List Nat) : Shape := Shape.mk (f shape.val) +def dimIndexInRange (shape : Shape) (dimIndex : DimIndex) := + shape.ndim == dimIndex.length && + (shape.val.zip dimIndex).all fun (n, i) => i < n + /-! Strides can be computed from the shape by figuring out how many elements you need to jump over to get to the next spot and mulitplying by the bytes in each @@ -205,6 +244,7 @@ def positionToDimIndex (strides : Strides) (n : Position) : DimIndex := let (_, idx) := strides.foldl foldFn (n, []) idx.reverse +-- TODO: Return `Err Offset` for when the strides and index have different lengths? def dimIndexToOffset (strides : Strides) (index : DimIndex) : Offset := dot strides (index.map Int.ofNat) #guard positionToDimIndex [3, 1] 4 == [1, 1] @@ -223,6 +263,21 @@ def allDimIndices (shape : Shape) : List DimIndex := Id.run do #guard allDimIndices (Shape.mk [5]) == [[0], [1], [2], [3], [4]] #guard allDimIndices (Shape.mk [3, 2]) == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]] +-- NumPy supports negative indices, which simply wrap around. E.g. `x[.., -1, ..] = x[.., n-1, ..]` where `n` is the +-- dimension in question. It only supports `-n` to `n`. +def intIndexToDimIndex (shape : Shape) (index : List Int) : Err DimIndex := do + if shape.ndim != index.length then .error "intsToDimIndex length mismatch" else + let conv (dim : Nat) (ind : Int) : Err Nat := + if 0 <= ind then + if ind < dim then .ok ind.toNat + else .error "index out of bounds" + else if ind < -dim then .error "index out of bounds" + else .ok (dim + ind).toNat + (shape.val.zip index).traverse (fun (dim, ind) => conv dim ind) + +#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, -1, -1] == (.ok [0, 1, 2]) +#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, 1, -2] == (.ok [0, 1, 1]) + end Shape /- @@ -278,7 +333,7 @@ Note: I tried writing this as a `do/for` loop and in this case the recursive one seems nicer. We are walking over two lists simultaneously, which is easy here but with a for loop is either quadratic or awkward. -/ -def next (iter : DimsIter) : List Nat × DimsIter := +def next (iter : DimsIter) : DimIndex × DimsIter := -- Invariant: `acc` is a list of 0s, so doesn't need to be reversed let rec loop (acc ms ns : List Nat) : List Nat := match ms, ns with @@ -293,7 +348,7 @@ def next (iter : DimsIter) : List Nat × DimsIter := let curr' := loop [] iter.dims iter.curr (iter.curr.reverse, { iter with curr := curr' }) -instance [Monad m] : ForIn m DimsIter (List Nat) where +instance [Monad m] : ForIn m DimsIter DimIndex where forIn {α} [Monad m] (iter : DimsIter) (x : α) (f : List Nat -> α -> m (ForInStep α)) : m α := do let mut iter := iter let mut res := x @@ -305,7 +360,7 @@ instance [Monad m] : ForIn m DimsIter (List Nat) where | .done k => return k return res -private def toList (iter : DimsIter) : List (List Nat) := Id.run do +private def toList (iter : DimsIter) : List DimIndex := Id.run do let mut res := [] for xs in iter do res := xs :: res @@ -321,7 +376,7 @@ private def toList (iter : DimsIter) : List (List Nat) := Id.run do #guard (DimsIter.make $ Shape.mk [1, 1, 2]).toList == [[0, 0, 0], [0, 0, 1]] #guard (DimsIter.make $ Shape.mk [3, 2]).toList == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]] -private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do +private def testBreak (iter : DimsIter) : List DimIndex := Id.run do let mut res := [] for xs in iter do res := xs :: res @@ -330,7 +385,7 @@ private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do #guard (DimsIter.make $ Shape.mk [3, 2]).testBreak == [[0, 0]] -private def testReturn (iter : DimsIter) : List (List Nat) := Id.run do +private def testReturn (iter : DimsIter) : List DimIndex := Id.run do let mut res := [] let mut i := 0 for xs in iter do diff --git a/TensorLib/Dtype.lean b/TensorLib/Dtype.lean index 784c0b2..f53406c 100644 --- a/TensorLib/Dtype.lean +++ b/TensorLib/Dtype.lean @@ -34,6 +34,16 @@ def isMultiByte (x : Name) : Bool := match x with | bool | int8 | uint8 => false | _ => true +def isInt (x : Name) : Bool := match x with +| int8 | int16 | int32 | int64 => true +| _ => false + +def isUint (x : Name) : Bool := match x with +| uint8 | uint16 | uint32 | uint64 => true +| _ => false + +def isIntLike (x : Name) : Bool := x.isInt || x.isUint + --! Number of bytes used by each element of the given dtype def itemsize (x : Name) : Nat := match x with | float64 | int64 | uint64 => 8 @@ -57,6 +67,10 @@ def itemsize (t : Dtype) := t.name.itemsize def sizedStrides (dtype : Dtype) (s : Shape) : Strides := List.map (fun x => x * dtype.itemsize) s.unitStrides +def isInt (dtype : Dtype) : Bool := dtype.name.isInt +def isUint (dtype : Dtype) : Bool := dtype.name.isUint +def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint + def int8 : Dtype := Dtype.mk Dtype.Name.int8 ByteOrder.littleEndian def uint8 : Dtype := Dtype.mk Dtype.Name.uint8 ByteOrder.littleEndian def uint64 : Dtype := Dtype.mk Dtype.Name.uint64 ByteOrder.littleEndian diff --git a/TensorLib/Index.lean b/TensorLib/Index.lean index 51a6708..9f8b459 100644 --- a/TensorLib/Index.lean +++ b/TensorLib/Index.lean @@ -4,12 +4,19 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin -/ +import TensorLib.Broadcast import TensorLib.Common -import TensorLib.Tensor -import TensorLib.Slice import TensorLib.Npy +import TensorLib.Slice +import TensorLib.Tensor /- +There are several types of indexing in NumPy. + + https://numpy.org/doc/stable/user/basics.indexing.html + +We handle basic indexing and some types of advanced indexing. + Theorems to prove (taken from NumPy docs): 1. Basic slicing with more than one non-: entry in the slicing tuple, @@ -18,7 +25,7 @@ Theorems to prove (taken from NumPy docs): Thus, x[ind1, ..., ind2,:] acts like x[ind1][..., ind2, :] under basic slicing. 2. Advanced indices always are broadcast and iterated as one: -result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], + result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], ..., ind_N[i_1, ..., i_M]] 3. ...TODO... -/ @@ -251,6 +258,8 @@ private def testReturn (iter : BasicIter) : List (List Nat) := Id.run do let slice := Slice.Iter.make Slice.all 5 testReturn (get! $ BasicIter.make shape [.slice slice, .slice slice]) == [[0, 0], [0, 1], [0, 2]] +end BasicIter + def applyWithCopy (index : NumpyBasic) (arr : Tensor) : Err Tensor := do let itemsize := arr.itemsize let oldShape := arr.shape @@ -301,6 +310,94 @@ def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do } return (res, false) +/- +For advanced indexing, the all-multidimensional-array case is relatively easy; +broadcast all arguments to the same shape, then select the elements of the original +array one by one. For example + +# x = np.arange(6).reshape(2, 3) +# x +array([[0, 1, 2], + [3, 4, 5]]) +# i0 = np.array([1, 0])[:, None] +# ii = np.array([1, 2, 0])[None, :] +# x[i0, i1] +array([[4, 5, 3], + [1, 2, 0]]) + +To obtain the later result, we simply walk through the [2, 3]-shaped indices +[[x[1, 1], x[1, 2], x[1, 0]], + [x[0, 1], x[0, 2], x[0, 0]], + +This also works when the dims of the index is smaller than the dims +of the array. Each x[i, j] is just an array instead of a scalar. We do not currently +implement that, but if we need it it will be clear what to do; we just copy the (contiguous) +bytes of the sub-array. + +Mixing basic and advanced indexing is complex: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing +While I can follow the individual examples, the general case is some work. +As a simple example of mixing, they give `x[..., ind, :]` where +`x.shape` is `(10, 20, 30)` and `ind` is a `(2, 5, 2)`-shaped indexing array. +The result has shape `(10, 2, 5, 2, 30)` and `result[..., i, j, k, :] = x[..., ind[i, j, k], :]`. +While this example is understandable, things get more complex, and I've not yet seen +examples we currently want to support that uses them. Therefore, we do not currently implement +mixed basic/advanced indexing. +-/ + +namespace Advanced + +def apply (indexTensors : List Tensor) (arr : Tensor) : Err Tensor := do + if indexTensors.any fun arr => !arr.isIntLike then .error "Index arrays must have an int-like type" else + if indexTensors.length != arr.ndim then .error "advanced indexing length mismatch" + -- Reshape all the input tesnsors + match Broadcast.broadcastList (indexTensors.map fun arr => arr.shape) with + | none => .error "input shapes must be broadcastable" + | some outShape => + let mut reshapedIndexTensors := [] + for indexTensor in indexTensors do + let indexTensor <- indexTensor.reshape outShape -- should never fail since broadcastList succeeds + reshapedIndexTensors := indexTensor :: reshapedIndexTensors + reshapedIndexTensors := reshapedIndexTensors.reverse + let mut res := Tensor.zeros arr.dtype outShape + -- Now we will iterate over the output shape, computing the values one-by-one from the input array + for outDimIndex in DimsIter.make outShape do + -- Get the index for each dimension in the original array from the corresponding value of the index tensors + let mut inIntIndex : List Int := [] + for indexTensor in reshapedIndexTensors do + let v <- indexTensor.intAtDimIndex outDimIndex + inIntIndex := v :: inIntIndex + let inDimIndex <- arr.shape.intIndexToDimIndex inIntIndex.reverse + let bytes <- arr.byteArrayAtDimIndex inDimIndex + res <- res.setByteArrayAtDimIndex outDimIndex bytes + return res + +/- +0 1 2 +3 4 5 +6 7 8 + +9 10 11 +12 12 14 +15 16 17 + +18 19 20 +21 22 23 +24 25 26 +-/ +#guard + let ind0 := (Tensor.Element.ofList Tensor.Element.Int8Native [1, 2, 0, 0]).reshape! (Shape.mk [2, 2]) + let ind1 := (Tensor.Element.ofList Tensor.Element.Int8Native [2, -2, 0, 1]).reshape! (Shape.mk [2, 2]) + let ind2 := (Tensor.Element.ofList Tensor.Element.Int8Native [1, 1, -1, -1]).reshape! (Shape.mk [2, 2]) + let typ := BV16 + let arr := (Tensor.Element.arange typ 27).reshape! (Shape.mk [3, 3, 3]) + let res := get! $ apply [ind0, ind1, ind2] arr + let tree := get! $ res.toTree typ + tree == Tensor.Format.Tree.node [.root [16, 22], .root [2, 5]] + +end Advanced + +section Test + #guard let tp := BV8 let tensor := Tensor.Element.arange tp 10 @@ -335,35 +432,33 @@ def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do let tree' := Tensor.Format.Tree.root [19] !copied && tree == tree' -section Test - -- Testing private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (List (List Nat)) := do let shape := Shape.mk dims let (basic, _) <- (toBasic basic shape).toOption - let iter <- (make shape basic).toOption + let iter <- (BasicIter.make shape basic).toOption iter.toList #guard numpyBasicToList [] [] == some [[]] #guard numpyBasicToList [1] [.int 0] == some [[0]] #guard numpyBasicToList [2] [.int 0] == some [[0]] #guard numpyBasicToList [2] [.int 1] == some [[1]] -#guard (numpyBasicToList [2] [.int 2]) == none -#guard (numpyBasicToList [2] [.int (-1)]) == some [[1]] -#guard (numpyBasicToList [2] [.int (-3)]) == none -#guard (numpyBasicToList [4] [.slice Slice.all]) == some [[0], [1], [2], [3]] -#guard (numpyBasicToList [4] [.slice $ Slice.build! .none .none (.some 2)]) == some [[0], [2]] -#guard (numpyBasicToList [4] [.slice $ Slice.build! (.some (-1)) .none (.some (-2))]) == some [[3], [1]] -#guard (numpyBasicToList [2, 2] [.int 5]) == none -#guard (numpyBasicToList [2, 2] [.int 0]) == some [[0, 0], [0, 1]] -#guard (numpyBasicToList [2, 2] [.int 0, .int 0]) == some [[0, 0]] -#guard (numpyBasicToList [2, 2] [.int 0, .int 1]) == some [[0, 1]] -#guard (numpyBasicToList [2, 2] [.int 0, .int 2]) == none -#guard (numpyBasicToList [3, 3] [.slice Slice.all, .int 2]) == some [[0, 2], [1, 2], [2, 2]] -#guard (numpyBasicToList [3, 3] [.int 2, .slice Slice.all]) == some [[2, 0], [2, 1], [2, 2]] -#guard (numpyBasicToList [2, 2] [.slice Slice.all, .slice Slice.all]) == some [[0, 0], [0, 1], [1, 0], [1, 1]] -#guard (numpyBasicToList [2, 2] [.slice (Slice.build! .none .none (.some (-1))), .slice Slice.all]) == some [[1, 0], [1, 1], [0, 0], [0, 1]] -#guard (numpyBasicToList [4, 2] [.slice (Slice.build! .none .none (.some (-2))), .slice Slice.all]) == some [[3, 0], [3, 1], [1, 0], [1, 1]] +#guard numpyBasicToList [2] [.int 2] == none +#guard numpyBasicToList [2] [.int (-1)] == some [[1]] +#guard numpyBasicToList [2] [.int (-3)] == none +#guard numpyBasicToList [4] [.slice Slice.all] == some [[0], [1], [2], [3]] +#guard numpyBasicToList [4] [.slice $ Slice.build! .none .none (.some 2)] == some [[0], [2]] +#guard numpyBasicToList [4] [.slice $ Slice.build! (.some (-1)) .none (.some (-2))] == some [[3], [1]] +#guard numpyBasicToList [2, 2] [.int 5] == none +#guard numpyBasicToList [2, 2] [.int 0] == some [[0, 0], [0, 1]] +#guard numpyBasicToList [2, 2] [.int 0, .int 0] == some [[0, 0]] +#guard numpyBasicToList [2, 2] [.int 0, .int 1] == some [[0, 1]] +#guard numpyBasicToList [2, 2] [.int 0, .int 2] == none +#guard numpyBasicToList [3, 3] [.slice Slice.all, .int 2] == some [[0, 2], [1, 2], [2, 2]] +#guard numpyBasicToList [3, 3] [.int 2, .slice Slice.all] == some [[2, 0], [2, 1], [2, 2]] +#guard numpyBasicToList [2, 2] [.slice Slice.all, .slice Slice.all] == some [[0, 0], [0, 1], [1, 0], [1, 1]] +#guard numpyBasicToList [2, 2] [.slice (Slice.build! .none .none (.some (-1))), .slice Slice.all] == some [[1, 0], [1, 1], [0, 0], [0, 1]] +#guard numpyBasicToList [4, 2] [.slice (Slice.build! .none .none (.some (-2))), .slice Slice.all] == some [[3, 0], [3, 1], [1, 0], [1, 1]] -- Commented for easier debugging. Remove some day -- #eval do @@ -380,9 +475,7 @@ private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (Li -- -- let (ns8, iter8) <- iter7.next -- -- let (ns9, iter9) <- iter8.next -- return (basic, iter0, ns0, iter1, ns1, iter2, ns2, iter3) -- , ns4, iter4) -- , ns5, iter5, ns6, iter6, ns7, iter7, ns8, iter8, ns9, iter9) - end Test -end BasicIter end Index end TensorLib diff --git a/TensorLib/Tensor.lean b/TensorLib/Tensor.lean index 48b18ba..93f615a 100644 --- a/TensorLib/Tensor.lean +++ b/TensorLib/Tensor.lean @@ -108,6 +108,8 @@ def ones (dtype : Dtype) (shape : Shape) : Tensor := Id.run do data := data.push byte { dtype := dtype, shape := shape, data := data } +def byteOrder (arr : Tensor) : ByteOrder := arr.dtype.order + --! number of dimensions def ndim (x : Tensor) : Nat := x.shape.ndim @@ -126,11 +128,35 @@ def dimIndexToOffset (x : Tensor) (i : DimIndex) : Int := --! Get the starting byte corresponding to a DimIndex def dimIndexToPosition (x : Tensor) (i : DimIndex) : Nat := - (x.dimIndexToOffset i).toNat + (x.startIndex + (x.dimIndexToOffset i)).toNat --! number of bytes representing the entire tensor def nbytes (x : Tensor) : Nat := x.itemsize * x.size +def isIntLike (x : Tensor) : Bool := x.dtype.isIntLike + +def dimIndexInRange (arr : Tensor) (dimIndex : DimIndex) : Bool := arr.shape.dimIndexInRange dimIndex + +def byteArrayAtDimIndex (arr : Tensor) (dimIndex : DimIndex) : Err ByteArray := do + if !arr.dimIndexInRange dimIndex then .error "index is incompatible with tensor shape" else + let posn := arr.dimIndexToPosition dimIndex + .ok $ arr.data.extract posn (posn + arr.itemsize) + +def setByteArrayAtDimIndex (arr : Tensor) (dimIndex : DimIndex) (bytes : ByteArray) : Err Tensor := do + if !arr.dimIndexInRange dimIndex then .error "index is incompatible with tensor shape" else + if arr.itemsize != bytes.size then .error "byte size mismatch" else + let posn := arr.dimIndexToPosition dimIndex + .ok $ { arr with data := bytes.copySlice 0 arr.data posn bytes.size } + +/-! +Return the integer at the dimIndex. This is useful, for example, in advanced indexing +where we must have an int/uint Tensor as an argument. +-/ +def intAtDimIndex (arr : Tensor) (dimIndex : DimIndex) : Err Int := do + if !arr.isIntLike then .error "natAt expects an int tensor" else + let bytes <- byteArrayAtDimIndex arr dimIndex + .ok $ arr.byteOrder.bytesToInt bytes + /-! Copy a Tensor's data to new, contiguous storage. @@ -267,6 +293,16 @@ def setPosition [typ : Element a] (x : Tensor) (n : Nat) (v : a): Err Tensor := let posn := n * itemsize .ok { x with data := bytes.copySlice 0 x.data posn itemsize true } +def ofList (typ : Element a) (xs : List a) : Tensor := Id.run do + let arr := Tensor.zeros typ.dtype (Shape.mk [xs.length]) + let mut data := arr.data + let mut posn := 0 + for x in xs do + let v := typ.toByteArray x + data := v.copySlice 0 data posn typ.itemsize + posn := posn + arr.itemsize + { arr with data := data } + -- Since the DimIndex is independent of the dtype size, we need to recompute the strides -- TODO: Would be better to not recompute this over and over. We should find a place to store -- the 1-based default strides @@ -298,6 +334,15 @@ instance BV8Native : Element BV8 where toByteArray (x : BV8) : ByteArray := x.toByteArray fromByteArray arr startIndex := ByteArray.toBV8 arr startIndex +instance Int8Native : Element Int8 where + dtype := Dtype.mk .int8 .oneByte + itemsize := 1 + ofNat n := n.toInt8 + toByteArray (x : Int8) : ByteArray := [x.toUInt8].toByteArray + fromByteArray arr startIndex := (ByteArray.toBV8 arr startIndex).map fun b => Int8.mk b.toUInt8 + +#guard Int8Native.fromByteArray (Int8Native.toByteArray (-5)) 0 == .ok (-5) + instance BV16Little : Element BV16 where dtype := Dtype.mk .uint16 .littleEndian itemsize := 2 @@ -457,6 +502,9 @@ private def arr1 := Element.arange BV8 12 #guard (ones (Dtype.float64) $ Shape.mk [2, 2]).nbytes == 2 * 2 * 8 #guard (ones (Dtype.float64) $ Shape.mk [2, 2]).data.toList.count 1 == 2 * 2 +#guard get! ((Element.ofList Element.BV8Native [1, 2, 3]).toTree BV8) == Format.Tree.root [1, 2, 3] +#guard get! (((Element.ofList Element.BV8Native [0, 1, 2, 3, 4, 5]).reshape! (Shape.mk [2, 3])).toTree BV8) == .node [.root [0, 1, 2], .root [3, 4, 5]] + end Test end Tensor From 621d94a70dc99d6815ec7ffb2c8f85ef1db8e465 Mon Sep 17 00:00:00 2001 From: Sean McLaughlin Date: Tue, 7 Jan 2025 10:03:04 -0800 Subject: [PATCH 2/4] Update to Lean 4.15.0 This will help sync with NKL --- lake-manifest.json | 8 ++++---- lakefile.lean | 8 +------- lean-toolchain | 2 +- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/lake-manifest.json b/lake-manifest.json index 836242b..6106804 100644 --- a/lake-manifest.json +++ b/lake-manifest.json @@ -15,20 +15,20 @@ "type": "git", "subDir": null, "scope": "", - "rev": "5a0ec8588855265ade536f35bcdcf0fb24fd6030", + "rev": "79402ad9ab4be9a2286701a9880697e2351e4955", "name": "aesop", "manifestFile": "lake-manifest.json", - "inputRev": "v4.14.0", + "inputRev": "master", "inherited": false, "configFile": "lakefile.toml"}, {"url": "https://github.com/leanprover-community/batteries", "type": "git", "subDir": null, "scope": "", - "rev": "8d6c853f11a5172efa0e96b9f2be1a83d861cdd9", + "rev": "8ce422eb59adf557fac184f8b1678c75fa03075c", "name": "batteries", "manifestFile": "lake-manifest.json", - "inputRev": "v4.14.0", + "inputRev": "v4.16.0-rc1", "inherited": true, "configFile": "lakefile.toml"}], "name": "TensorLib", diff --git a/lakefile.lean b/lakefile.lean index babbe24..913da0f 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -11,14 +11,8 @@ lean_lib «TensorLib» where lean_exe "tensorlib" where root := `Main --- require aesop from git --- "https://github.com/leanprover-community/aesop" @ "v4.15.0-rc1" - --- require Cli from git --- "https://github.com/seanmcl/lean4-cli.git" @ "v4.15.0-rc1" - require aesop from git - "https://github.com/leanprover-community/aesop" @ "v4.14.0" + "https://github.com/leanprover-community/aesop" @ "master" -- 'master' rather than a tag is a workaround for segfault bug https://github.com/leanprover/lean4/issues/6518#issuecomment-2574607960 require Cli from git "https://github.com/leanprover/lean4-cli.git" @ "v2.2.0-lv4.14.0-rc1" diff --git a/lean-toolchain b/lean-toolchain index 1e70935..d0eb99f 100644 --- a/lean-toolchain +++ b/lean-toolchain @@ -1 +1 @@ -leanprover/lean4:v4.14.0 +leanprover/lean4:v4.15.0 From 5970b37e2641f1650b6bb1443ede3a8e93ae2084 Mon Sep 17 00:00:00 2001 From: Sean McLaughlin Date: Thu, 2 Jan 2025 14:16:10 -0800 Subject: [PATCH 3/4] Mgrid https://numpy.org/doc/2.1/reference/generated/numpy.mgrid.html --- TensorLib/Basic.lean | 1 + TensorLib/Mgrid.lean | 135 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 TensorLib/Mgrid.lean diff --git a/TensorLib/Basic.lean b/TensorLib/Basic.lean index e084e7c..8dcf324 100644 --- a/TensorLib/Basic.lean +++ b/TensorLib/Basic.lean @@ -11,3 +11,4 @@ import TensorLib.Broadcast import TensorLib.Slice import TensorLib.Tensor import TensorLib.Index +import TensorLib.Mgrid diff --git a/TensorLib/Mgrid.lean b/TensorLib/Mgrid.lean new file mode 100644 index 0000000..6045520 --- /dev/null +++ b/TensorLib/Mgrid.lean @@ -0,0 +1,135 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin +-/ + +import TensorLib.Common +import TensorLib.Tensor +import TensorLib.Slice +import TensorLib.Index + +namespace TensorLib + +/- +Shape of the resulting tensor is |slices| :: slices.map(fun x => x.shape) + +Slices are treated differently in NumPy when used in mgrid vs indexing. In particular, +since we don't know what the stopping point should be from the context, it can use the `start` +value as a `stop` instead. For example, + + # np.arange(10)[5:] + [5, 6, 7, 8, 9] + + # np.mgrid[5:] + [0, 1, 2, 3, 4] + +However, things work normally when there is a stopping point + + # np.mgrid[5:10] + [5, 6, 7, 8, 9] + +Since this is imo surprising, we just fail if the start or stopping point are absent. +Moreover, we require more than one slice, since `mgrid[slice] == arange(slice)`, and a single +slice behaves differently from all other quantities. + +All mgrid ints are stored as 64-bit, little-endian ints by convention. In Numpy they are stored with +native byte order, as the architecture word size. + +For eample, + +# np.mgrid[2:4:, 4:7:] +array([[[2, 2, 2], + [3, 3, 3]], + + [[4, 5, 6], + [4, 5, 6]]]) + +We have an iterator that will give us the indices [0, 0], [0, 1], [0, 2], [1, 0], ..., [1, 2] +and an interator that will give us the values [2, 4], [2, 5], [2, 6], ... , [3, 6] +These are the same length, so we can combine them to get the mgrid; +- [0, 0] of the first element gits 2, +- [0, 0] of the second element gets 4 +- [0, 1] of the first gets 2 +- [0, 1] of the second gets 5 +- ... +-/ +def mgrid (slices : List Slice) : Err Tensor := do + let sliceCount := slices.length + if sliceCount < 2 then .error "mgrid requires at least two slices" + let arbitrary : Nat := 10 -- Slice.size does not use the second argument if both start and stop are specified + let sliceSize (slice : Slice) : Nat := slice.size arbitrary + let mut slicesDims := [] + for slice in slices.reverse do + match slice.start, slice.stop with + | .none, _ => .error "Slices need an upper bound in mgrid" + | _, .none => .error "Slices need a lower bound in mgrid" + | _, _ => + let sz := sliceSize slice + slicesDims := sz :: slicesDims + let shape := Shape.mk $ sliceCount :: slicesDims + let slicesShape := Shape.mk slicesDims + let dtype := Dtype.uint64 + let mut arr := Tensor.zeros dtype shape + let basic := slices.map fun s => .slice (Slice.Iter.make s arbitrary) + let mut sliceIter <- Index.BasicIter.make slicesShape basic + let indexIter := DimsIter.make slicesShape + if sliceIter.size != indexIter.size then .error "Invariant failure: iterator size mismatch at start" + for index in indexIter do + match sliceIter.next with + | .none => .error "Invariant failure: iterator size mismatch during iteration" + | .some (values, sliceIter') => + sliceIter := sliceIter' + if values.length != sliceCount then .error "Invariant failure: value length mismatch" + for (i, v) in (List.range sliceCount).zip values do + let value := BV64.ofNat v + arr <- Tensor.Element.setDimIndex arr (i :: index) value + return arr + +section Test + +open TensorLib.Tensor.Format +open Tree + +#guard (get! (mgrid [Slice.ofStartStop 2 4, Slice.ofStartStop 4 7])).toTree BV64 == .ok ( + .node [ + .node [ + .root [2, 2, 2], .root [3, 3, 3] + ], + .node [ + .root [4, 5, 6], .root [4, 5, 6] + ] + ] +) + +#guard (get! (mgrid [Slice.ofStartStop 2 4, Slice.ofStartStop 4 7, Slice.ofStop 2])).toTree BV64 == .ok ( + .node [ + .node [ + .node [ + .root [2, 2], .root [2, 2], .root [2, 2] + ], + .node [ + .root [3, 3], .root [3, 3], .root [3, 3] + ], + ], + .node [ + .node [ + .root [4, 4], .root [5, 5], .root [6, 6] + ], + .node [ + .root [4, 4], .root [5, 5], .root [6, 6] + ] + ], + .node [ + .node [ + .root [0, 1], .root [0, 1], .root [0, 1] + ], + .node [ + .root [0, 1], .root [0, 1], .root [0, 1] + ] + ] + ] +) + +end Test +end TensorLib From e3bb7206e9f2bc4a78ba2f1389ae0d9718935c57 Mon Sep 17 00:00:00 2001 From: Sean McLaughlin Date: Thu, 2 Jan 2025 14:45:00 -0800 Subject: [PATCH 4/4] Universal functions --- TensorLib/Basic.lean | 1 + TensorLib/Ufunc.lean | 180 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 TensorLib/Ufunc.lean diff --git a/TensorLib/Basic.lean b/TensorLib/Basic.lean index 8dcf324..d813c0e 100644 --- a/TensorLib/Basic.lean +++ b/TensorLib/Basic.lean @@ -12,3 +12,4 @@ import TensorLib.Slice import TensorLib.Tensor import TensorLib.Index import TensorLib.Mgrid +import TensorLib.Ufunc diff --git a/TensorLib/Ufunc.lean b/TensorLib/Ufunc.lean new file mode 100644 index 0000000..a6df760 --- /dev/null +++ b/TensorLib/Ufunc.lean @@ -0,0 +1,180 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin +-/ + +import TensorLib.Common +import TensorLib.Tensor +import TensorLib.Broadcast + +/-! +Universal functions: https://numpy.org/doc/stable/reference/ufuncs.html +-/ + +namespace TensorLib +namespace Tensor +namespace Ufunc + +private def binop (a : Type) [Element a] (x y : Tensor) (op : a -> a -> Err a) : Err Tensor := + match Broadcast.broadcast { left := x.shape, right := y.shape } with + | .none => .error s!"Can't broadcast shapes ${x.shape} with {y.shape}" + | .some shape => + if x.dtype != y.dtype then .error s!"Casting between dtypes is not implemented yet: {repr x.dtype} <> {repr y.dtype}" else + do + let mut arr := Tensor.empty x.dtype shape + let iter := DimsIter.make shape + for idx in iter do + let v <- Element.getDimIndex x idx + let w <- Element.getDimIndex y idx + let k <- op v w + let arr' <- Element.setDimIndex arr idx k + arr := arr' + .ok arr + +def add (a : Type) [Add a] [Element a] (x y : Tensor) : Err Tensor := + binop a x y (fun x y => .ok (x + y)) + +def sub (a : Type) [Sub a] [Element a] (x y : Tensor) : Err Tensor := + binop a x y (fun x y => .ok (x - y)) + +def mul (a : Type) [Mul a] [Element a] (x y : Tensor) : Err Tensor := + binop a x y (fun x y => .ok (x * y)) + +def div (a : Type) [Div a] [Element a] (x y : Tensor) : Err Tensor := + binop a x y (fun x y => .ok (x / y)) + +/- +TODO: +- np.sum. Prove that np.sum(x, axis=(2, 4, 6)) == np.sum(np.sum(np.sum(x, axis=6), axis=4), axis=2) # and other variations +-/ + +-- Sum with no axis. Adds all the elements. +private def sum0 (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) : Err Tensor := do + let mut acc : a := 0 + let mut iter := DimsIter.make arr.shape + for index in iter do + let n : a <- Element.getDimIndex arr index + acc := Add.add acc n + return Element.arrayScalar acc + +-- Sum with a single axis. +def sum1 (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) (axis : Nat) : Err Tensor := do + if arr.ndim <= axis then .error "axis out of range" else + let oldshape := arr.shape + let (leftShape, rightShape) := oldshape.val.splitAt axis + match rightShape with + | [] => .error "Invariant failure" + | dim :: dims => + let newshape := Shape.mk $ leftShape ++ dims + let mut res := Tensor.zeros arr.dtype newshape + let mut iter := DimsIter.make newshape + for index in iter do + let mut acc : a := 0 + for i in [0:dim] do + let index' := index.insertIdx axis i + let v : a <- Element.getDimIndex arr index' + acc := acc + v + res <- Element.setDimIndex res index acc + return res + +private def uniq [BEq a] (xs : List a) : Bool := match xs with +| [] | [_] => true +| x1 :: x2 :: xs => x1 != x2 && uniq (x2 :: xs) + +def sum (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) (axes : Option (List Nat)) : Err Tensor := + match axes with + | .none => sum0 a arr + | .some axes => + if !(uniq axes) then .error "Duplicate axis elements" else + let axes := (List.mergeSort axes).reverse + match axes with + | [] => sum0 a arr + | axis :: axes => do + let mut res <- sum1 a arr axis + let rec loop (axes : List Nat) (acc : Tensor) : Err Tensor := match axes with + | [] => .ok acc + | axis :: axes => do + let acc <- sum1 a acc axis + let axes := axes.map fun n => n-1 -- When we remove an axis, all later axes point to one dimension less + loop axes acc + termination_by axes.length + loop axes res + +private def hasTree0 (a : Type) [BEq a] [Element a] (arr : Tensor) (n : a) : Bool := + arr.shape.val == [] && match Element.getPosition arr 0 with + | .error _ => false + | .ok (v : a) => v == n + +private def hasTree1 (a : Type) [Repr a] [BEq a] [Element a] (arr : Tensor) (xs : List a) : Bool := + arr.shape.val == [xs.length] && match arr.toTree a with + | .error _ => false + | .ok v => v == .root xs + +-- [[0, 1, 2, 3, 4], +-- [5, 6, 7, 8, 9]] +#guard + let typ := BV8 + let arr := get! $ (Element.arange typ 10).reshape (Shape.mk [2, 5]) + let x0 := get! $ sum typ arr .none + let x1 := get! $ sum typ arr (.some []) + let x2 := get! $ sum typ arr (.some [0]) + let x3 := get! $ sum typ arr (.some [1]) + let x4 := get! $ sum typ arr (.some [1, 0]) + let x5 := get! $ sum typ arr (.some [0, 1]) + let res := + hasTree0 typ x0 45 && + hasTree0 typ x1 45 && + hasTree1 typ x2 [5, 7, 9, 11, 13] && + hasTree1 typ x3 [10, 35] && + hasTree0 typ x4 45 && + hasTree0 typ x5 45 + res + +#guard + let typ := BV8 + let x := Element.arange typ 10 + let arr := get! $ add typ x x + hasTree1 typ arr [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + +#guard + let typ := BV8 + let x := Element.arange typ 10 + let y := Element.arrayScalar (7 : typ) + let arr := get! $ add typ x y + hasTree1 typ arr [7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + +/-! WIP example NKI kernel +""" +NKI kernel to compute element-wise addition of two input tensors + +This kernel assumes strict input/output tile-sizes, of up-to [128,512] + +Args: + a_input: a first input tensor, of shape [128,512] + b_input: a second input tensor, of shape [128,512] + c_output: an output tensor, of shape [128,512] +""" +private def nki_tensor_add_kernel_ (program_id0 program_id1 : Nat) (a_input b_input c_input : NumpyRepr) : Err Unit := do + let tp := BV64 + + -- Calculate tile offsets based on current 'program' + let offset_i_x : tp := program_id0 * 128 + let offset_i_y : tp := program_id1 * 512 + -- Generate tensor indices to index tensors a and b + let rx0 := Element.arange tp 128 + let rx <- rx0.reshape [128, 1] + let ox := Element.arrayScalar offset_i_x + let ix <- Ufunc.add tp ox rx + let ry0 := Element.arange tp 128 + let ry <- ry0.reshape [1, 512] + let oy := Element.arrayScalar offset_i_y + let iy <- Ufunc.add tp oy ry + let a_tile <- sorry -- load from a_input + let b_tile <- sorry -- load from b_input + let c_tile <- Ufunc.add tp a_tile b_tile + let () <- sorry -- store to c_input + .ok () +-/ + +end Ufunc