From 08860fdbeb7f03c4fceb22e0bb64161b37418306 Mon Sep 17 00:00:00 2001 From: Sean McLaughlin Date: Mon, 6 Jan 2025 13:56:44 -0800 Subject: [PATCH] Advanced indexing We implement a common version of advanced indexing, where all arguments are (broadcastable to) int arrays. --- TensorLib/Broadcast.lean | 21 ++++++ TensorLib/Common.lean | 67 +++++++++++++++++-- TensorLib/Dtype.lean | 14 ++++ TensorLib/Index.lean | 141 ++++++++++++++++++++++++++++++++------- TensorLib/Tensor.lean | 50 +++++++++++++- 5 files changed, 262 insertions(+), 31 deletions(-) diff --git a/TensorLib/Broadcast.lean b/TensorLib/Broadcast.lean index c6583b2..33c3d92 100644 --- a/TensorLib/Broadcast.lean +++ b/TensorLib/Broadcast.lean @@ -39,6 +39,8 @@ Rule 2 A: (3, 2, 7) B: (3, 2, 7) + +Theorem to prove: If we can broadcast s1 to s2, then given an array with shape s1, then s1.reshape s2 succeeds -/ namespace TensorLib @@ -103,4 +105,23 @@ def canBroadcast (b : Broadcast) : Bool := (broadcast b).isSome broadcast b2 == broadcast b1 && broadcast b2 == .some (Shape.mk [1, 2, 3]) +def broadcastList (shapes : List Shape) : Option Shape := Id.run do + match shapes with + | [] => none + | shape :: shapes => + let mut shape := shape + for s in shapes do + let b := Broadcast.mk shape s + match b.broadcast with + | .none => return .none + | .some s => + shape := s + return shape + +#guard + let x1 := Shape.mk [1, 2, 3] + let x2 := Shape.mk [2, 3] + let x3 := Shape.mk [] + broadcastList [x1, x2, x3] == .some x1 + end Broadcast diff --git a/TensorLib/Common.lean b/TensorLib/Common.lean index dd6f57e..a454870 100644 --- a/TensorLib/Common.lean +++ b/TensorLib/Common.lean @@ -5,6 +5,7 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin -/ import Std.Tactic.BVDecide +import Batteries.Data.List namespace TensorLib @@ -27,6 +28,7 @@ def natDivCeil (num denom : Nat) : Nat := (num + denom - 1) / denom def natProd (shape : List Nat) : Nat := shape.foldl (fun x y => x * y) 1 + -- We generally have large tensors, so don't show them by default instance ByteArrayRepr : Repr ByteArray where reprPrec x _ := @@ -45,11 +47,41 @@ inductive ByteOrder where | bigEndian deriving BEq, Repr, Inhabited +namespace ByteOrder + @[simp] -def ByteOrder.isMultiByte (x : ByteOrder) : Bool := match x with +def isMultiByte (x : ByteOrder) : Bool := match x with | .oneByte => false | .littleEndian | .bigEndian => true +def bytesToInt (order : ByteOrder) (bytes : ByteArray) : Int := Id.run do + let mut n : Nat := 0 + let nbytes := bytes.size + let signByte := match order with + | .littleEndian => bytes.get! (nbytes - 1) + | .bigEndian | oneByte => bytes.get! 0 + let negative := 128 <= signByte + for i in [0:nbytes] do + let v : UInt8 := bytes.get! i + let v := if negative then UInt8.complement v else v + let p := match order with + | .oneByte => 0 -- nbytes = 1 + | .littleEndian => i + | .bigEndian => nbytes - 1 - i + n := n + Pow.pow 2 (8 * p) * v.toNat + return if 128 <= signByte then -(n + 1) else n + +#guard bytesToInt .littleEndian (ByteArray.mk #[1, 1]) == 257 +#guard bytesToInt .bigEndian (ByteArray.mk #[1, 1]) == 257 +#guard bytesToInt .littleEndian (ByteArray.mk #[0, 1]) == 256 +#guard bytesToInt .bigEndian (ByteArray.mk #[0, 1]) == 1 +#guard bytesToInt .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == -1 +#guard bytesToInt .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == -1 +#guard bytesToInt .bigEndian (ByteArray.mk #[0x80, 0]) == -32768 +#guard bytesToInt .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80 + +end ByteOrder + /-! The strides are how many bytes you need to skip to get to the next element in that "row". For example, in an array of 8-byte data with shape 2, 3, the strides are (24, 8). @@ -144,6 +176,9 @@ deriving BEq, Repr, Inhabited namespace Shape +instance : ToString Shape where + toString := reprStr + def empty : Shape := Shape.mk [] --! The number of elements in a tensor. All that's needed is the shape for this calculation. @@ -155,6 +190,10 @@ def ndim (shape : Shape) : Nat := shape.val.length def map (shape : Shape) (f : List Nat -> List Nat) : Shape := Shape.mk (f shape.val) +def dimIndexInRange (shape : Shape) (dimIndex : DimIndex) := + shape.ndim == dimIndex.length && + (shape.val.zip dimIndex).all fun (n, i) => i < n + /-! Strides can be computed from the shape by figuring out how many elements you need to jump over to get to the next spot and mulitplying by the bytes in each @@ -205,6 +244,7 @@ def positionToDimIndex (strides : Strides) (n : Position) : DimIndex := let (_, idx) := strides.foldl foldFn (n, []) idx.reverse +-- TODO: Return `Err Offset` for when the strides and index have different lengths? def dimIndexToOffset (strides : Strides) (index : DimIndex) : Offset := dot strides (index.map Int.ofNat) #guard positionToDimIndex [3, 1] 4 == [1, 1] @@ -223,6 +263,21 @@ def allDimIndices (shape : Shape) : List DimIndex := Id.run do #guard allDimIndices (Shape.mk [5]) == [[0], [1], [2], [3], [4]] #guard allDimIndices (Shape.mk [3, 2]) == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]] +-- NumPy supports negative indices, which simply wrap around. E.g. `x[.., -1, ..] = x[.., n-1, ..]` where `n` is the +-- dimension in question. It only supports `-n` to `n`. +def intIndexToDimIndex (shape : Shape) (index : List Int) : Err DimIndex := do + if shape.ndim != index.length then .error "intsToDimIndex length mismatch" else + let conv (dim : Nat) (ind : Int) : Err Nat := + if 0 <= ind then + if ind < dim then .ok ind.toNat + else .error "index out of bounds" + else if ind < -dim then .error "index out of bounds" + else .ok (dim + ind).toNat + (shape.val.zip index).traverse (fun (dim, ind) => conv dim ind) + +#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, -1, -1] == (.ok [0, 1, 2]) +#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, 1, -2] == (.ok [0, 1, 1]) + end Shape /- @@ -278,7 +333,7 @@ Note: I tried writing this as a `do/for` loop and in this case the recursive one seems nicer. We are walking over two lists simultaneously, which is easy here but with a for loop is either quadratic or awkward. -/ -def next (iter : DimsIter) : List Nat × DimsIter := +def next (iter : DimsIter) : DimIndex × DimsIter := -- Invariant: `acc` is a list of 0s, so doesn't need to be reversed let rec loop (acc ms ns : List Nat) : List Nat := match ms, ns with @@ -293,7 +348,7 @@ def next (iter : DimsIter) : List Nat × DimsIter := let curr' := loop [] iter.dims iter.curr (iter.curr.reverse, { iter with curr := curr' }) -instance [Monad m] : ForIn m DimsIter (List Nat) where +instance [Monad m] : ForIn m DimsIter DimIndex where forIn {α} [Monad m] (iter : DimsIter) (x : α) (f : List Nat -> α -> m (ForInStep α)) : m α := do let mut iter := iter let mut res := x @@ -305,7 +360,7 @@ instance [Monad m] : ForIn m DimsIter (List Nat) where | .done k => return k return res -private def toList (iter : DimsIter) : List (List Nat) := Id.run do +private def toList (iter : DimsIter) : List DimIndex := Id.run do let mut res := [] for xs in iter do res := xs :: res @@ -321,7 +376,7 @@ private def toList (iter : DimsIter) : List (List Nat) := Id.run do #guard (DimsIter.make $ Shape.mk [1, 1, 2]).toList == [[0, 0, 0], [0, 0, 1]] #guard (DimsIter.make $ Shape.mk [3, 2]).toList == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]] -private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do +private def testBreak (iter : DimsIter) : List DimIndex := Id.run do let mut res := [] for xs in iter do res := xs :: res @@ -330,7 +385,7 @@ private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do #guard (DimsIter.make $ Shape.mk [3, 2]).testBreak == [[0, 0]] -private def testReturn (iter : DimsIter) : List (List Nat) := Id.run do +private def testReturn (iter : DimsIter) : List DimIndex := Id.run do let mut res := [] let mut i := 0 for xs in iter do diff --git a/TensorLib/Dtype.lean b/TensorLib/Dtype.lean index 784c0b2..f53406c 100644 --- a/TensorLib/Dtype.lean +++ b/TensorLib/Dtype.lean @@ -34,6 +34,16 @@ def isMultiByte (x : Name) : Bool := match x with | bool | int8 | uint8 => false | _ => true +def isInt (x : Name) : Bool := match x with +| int8 | int16 | int32 | int64 => true +| _ => false + +def isUint (x : Name) : Bool := match x with +| uint8 | uint16 | uint32 | uint64 => true +| _ => false + +def isIntLike (x : Name) : Bool := x.isInt || x.isUint + --! Number of bytes used by each element of the given dtype def itemsize (x : Name) : Nat := match x with | float64 | int64 | uint64 => 8 @@ -57,6 +67,10 @@ def itemsize (t : Dtype) := t.name.itemsize def sizedStrides (dtype : Dtype) (s : Shape) : Strides := List.map (fun x => x * dtype.itemsize) s.unitStrides +def isInt (dtype : Dtype) : Bool := dtype.name.isInt +def isUint (dtype : Dtype) : Bool := dtype.name.isUint +def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint + def int8 : Dtype := Dtype.mk Dtype.Name.int8 ByteOrder.littleEndian def uint8 : Dtype := Dtype.mk Dtype.Name.uint8 ByteOrder.littleEndian def uint64 : Dtype := Dtype.mk Dtype.Name.uint64 ByteOrder.littleEndian diff --git a/TensorLib/Index.lean b/TensorLib/Index.lean index 51a6708..9f8b459 100644 --- a/TensorLib/Index.lean +++ b/TensorLib/Index.lean @@ -4,12 +4,19 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin -/ +import TensorLib.Broadcast import TensorLib.Common -import TensorLib.Tensor -import TensorLib.Slice import TensorLib.Npy +import TensorLib.Slice +import TensorLib.Tensor /- +There are several types of indexing in NumPy. + + https://numpy.org/doc/stable/user/basics.indexing.html + +We handle basic indexing and some types of advanced indexing. + Theorems to prove (taken from NumPy docs): 1. Basic slicing with more than one non-: entry in the slicing tuple, @@ -18,7 +25,7 @@ Theorems to prove (taken from NumPy docs): Thus, x[ind1, ..., ind2,:] acts like x[ind1][..., ind2, :] under basic slicing. 2. Advanced indices always are broadcast and iterated as one: -result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], + result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], ..., ind_N[i_1, ..., i_M]] 3. ...TODO... -/ @@ -251,6 +258,8 @@ private def testReturn (iter : BasicIter) : List (List Nat) := Id.run do let slice := Slice.Iter.make Slice.all 5 testReturn (get! $ BasicIter.make shape [.slice slice, .slice slice]) == [[0, 0], [0, 1], [0, 2]] +end BasicIter + def applyWithCopy (index : NumpyBasic) (arr : Tensor) : Err Tensor := do let itemsize := arr.itemsize let oldShape := arr.shape @@ -301,6 +310,94 @@ def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do } return (res, false) +/- +For advanced indexing, the all-multidimensional-array case is relatively easy; +broadcast all arguments to the same shape, then select the elements of the original +array one by one. For example + +# x = np.arange(6).reshape(2, 3) +# x +array([[0, 1, 2], + [3, 4, 5]]) +# i0 = np.array([1, 0])[:, None] +# ii = np.array([1, 2, 0])[None, :] +# x[i0, i1] +array([[4, 5, 3], + [1, 2, 0]]) + +To obtain the later result, we simply walk through the [2, 3]-shaped indices +[[x[1, 1], x[1, 2], x[1, 0]], + [x[0, 1], x[0, 2], x[0, 0]], + +This also works when the dims of the index is smaller than the dims +of the array. Each x[i, j] is just an array instead of a scalar. We do not currently +implement that, but if we need it it will be clear what to do; we just copy the (contiguous) +bytes of the sub-array. + +Mixing basic and advanced indexing is complex: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing +While I can follow the individual examples, the general case is some work. +As a simple example of mixing, they give `x[..., ind, :]` where +`x.shape` is `(10, 20, 30)` and `ind` is a `(2, 5, 2)`-shaped indexing array. +The result has shape `(10, 2, 5, 2, 30)` and `result[..., i, j, k, :] = x[..., ind[i, j, k], :]`. +While this example is understandable, things get more complex, and I've not yet seen +examples we currently want to support that uses them. Therefore, we do not currently implement +mixed basic/advanced indexing. +-/ + +namespace Advanced + +def apply (indexTensors : List Tensor) (arr : Tensor) : Err Tensor := do + if indexTensors.any fun arr => !arr.isIntLike then .error "Index arrays must have an int-like type" else + if indexTensors.length != arr.ndim then .error "advanced indexing length mismatch" + -- Reshape all the input tesnsors + match Broadcast.broadcastList (indexTensors.map fun arr => arr.shape) with + | none => .error "input shapes must be broadcastable" + | some outShape => + let mut reshapedIndexTensors := [] + for indexTensor in indexTensors do + let indexTensor <- indexTensor.reshape outShape -- should never fail since broadcastList succeeds + reshapedIndexTensors := indexTensor :: reshapedIndexTensors + reshapedIndexTensors := reshapedIndexTensors.reverse + let mut res := Tensor.zeros arr.dtype outShape + -- Now we will iterate over the output shape, computing the values one-by-one from the input array + for outDimIndex in DimsIter.make outShape do + -- Get the index for each dimension in the original array from the corresponding value of the index tensors + let mut inIntIndex : List Int := [] + for indexTensor in reshapedIndexTensors do + let v <- indexTensor.intAtDimIndex outDimIndex + inIntIndex := v :: inIntIndex + let inDimIndex <- arr.shape.intIndexToDimIndex inIntIndex.reverse + let bytes <- arr.byteArrayAtDimIndex inDimIndex + res <- res.setByteArrayAtDimIndex outDimIndex bytes + return res + +/- +0 1 2 +3 4 5 +6 7 8 + +9 10 11 +12 12 14 +15 16 17 + +18 19 20 +21 22 23 +24 25 26 +-/ +#guard + let ind0 := (Tensor.Element.ofList Tensor.Element.Int8Native [1, 2, 0, 0]).reshape! (Shape.mk [2, 2]) + let ind1 := (Tensor.Element.ofList Tensor.Element.Int8Native [2, -2, 0, 1]).reshape! (Shape.mk [2, 2]) + let ind2 := (Tensor.Element.ofList Tensor.Element.Int8Native [1, 1, -1, -1]).reshape! (Shape.mk [2, 2]) + let typ := BV16 + let arr := (Tensor.Element.arange typ 27).reshape! (Shape.mk [3, 3, 3]) + let res := get! $ apply [ind0, ind1, ind2] arr + let tree := get! $ res.toTree typ + tree == Tensor.Format.Tree.node [.root [16, 22], .root [2, 5]] + +end Advanced + +section Test + #guard let tp := BV8 let tensor := Tensor.Element.arange tp 10 @@ -335,35 +432,33 @@ def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do let tree' := Tensor.Format.Tree.root [19] !copied && tree == tree' -section Test - -- Testing private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (List (List Nat)) := do let shape := Shape.mk dims let (basic, _) <- (toBasic basic shape).toOption - let iter <- (make shape basic).toOption + let iter <- (BasicIter.make shape basic).toOption iter.toList #guard numpyBasicToList [] [] == some [[]] #guard numpyBasicToList [1] [.int 0] == some [[0]] #guard numpyBasicToList [2] [.int 0] == some [[0]] #guard numpyBasicToList [2] [.int 1] == some [[1]] -#guard (numpyBasicToList [2] [.int 2]) == none -#guard (numpyBasicToList [2] [.int (-1)]) == some [[1]] -#guard (numpyBasicToList [2] [.int (-3)]) == none -#guard (numpyBasicToList [4] [.slice Slice.all]) == some [[0], [1], [2], [3]] -#guard (numpyBasicToList [4] [.slice $ Slice.build! .none .none (.some 2)]) == some [[0], [2]] -#guard (numpyBasicToList [4] [.slice $ Slice.build! (.some (-1)) .none (.some (-2))]) == some [[3], [1]] -#guard (numpyBasicToList [2, 2] [.int 5]) == none -#guard (numpyBasicToList [2, 2] [.int 0]) == some [[0, 0], [0, 1]] -#guard (numpyBasicToList [2, 2] [.int 0, .int 0]) == some [[0, 0]] -#guard (numpyBasicToList [2, 2] [.int 0, .int 1]) == some [[0, 1]] -#guard (numpyBasicToList [2, 2] [.int 0, .int 2]) == none -#guard (numpyBasicToList [3, 3] [.slice Slice.all, .int 2]) == some [[0, 2], [1, 2], [2, 2]] -#guard (numpyBasicToList [3, 3] [.int 2, .slice Slice.all]) == some [[2, 0], [2, 1], [2, 2]] -#guard (numpyBasicToList [2, 2] [.slice Slice.all, .slice Slice.all]) == some [[0, 0], [0, 1], [1, 0], [1, 1]] -#guard (numpyBasicToList [2, 2] [.slice (Slice.build! .none .none (.some (-1))), .slice Slice.all]) == some [[1, 0], [1, 1], [0, 0], [0, 1]] -#guard (numpyBasicToList [4, 2] [.slice (Slice.build! .none .none (.some (-2))), .slice Slice.all]) == some [[3, 0], [3, 1], [1, 0], [1, 1]] +#guard numpyBasicToList [2] [.int 2] == none +#guard numpyBasicToList [2] [.int (-1)] == some [[1]] +#guard numpyBasicToList [2] [.int (-3)] == none +#guard numpyBasicToList [4] [.slice Slice.all] == some [[0], [1], [2], [3]] +#guard numpyBasicToList [4] [.slice $ Slice.build! .none .none (.some 2)] == some [[0], [2]] +#guard numpyBasicToList [4] [.slice $ Slice.build! (.some (-1)) .none (.some (-2))] == some [[3], [1]] +#guard numpyBasicToList [2, 2] [.int 5] == none +#guard numpyBasicToList [2, 2] [.int 0] == some [[0, 0], [0, 1]] +#guard numpyBasicToList [2, 2] [.int 0, .int 0] == some [[0, 0]] +#guard numpyBasicToList [2, 2] [.int 0, .int 1] == some [[0, 1]] +#guard numpyBasicToList [2, 2] [.int 0, .int 2] == none +#guard numpyBasicToList [3, 3] [.slice Slice.all, .int 2] == some [[0, 2], [1, 2], [2, 2]] +#guard numpyBasicToList [3, 3] [.int 2, .slice Slice.all] == some [[2, 0], [2, 1], [2, 2]] +#guard numpyBasicToList [2, 2] [.slice Slice.all, .slice Slice.all] == some [[0, 0], [0, 1], [1, 0], [1, 1]] +#guard numpyBasicToList [2, 2] [.slice (Slice.build! .none .none (.some (-1))), .slice Slice.all] == some [[1, 0], [1, 1], [0, 0], [0, 1]] +#guard numpyBasicToList [4, 2] [.slice (Slice.build! .none .none (.some (-2))), .slice Slice.all] == some [[3, 0], [3, 1], [1, 0], [1, 1]] -- Commented for easier debugging. Remove some day -- #eval do @@ -380,9 +475,7 @@ private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (Li -- -- let (ns8, iter8) <- iter7.next -- -- let (ns9, iter9) <- iter8.next -- return (basic, iter0, ns0, iter1, ns1, iter2, ns2, iter3) -- , ns4, iter4) -- , ns5, iter5, ns6, iter6, ns7, iter7, ns8, iter8, ns9, iter9) - end Test -end BasicIter end Index end TensorLib diff --git a/TensorLib/Tensor.lean b/TensorLib/Tensor.lean index 48b18ba..93f615a 100644 --- a/TensorLib/Tensor.lean +++ b/TensorLib/Tensor.lean @@ -108,6 +108,8 @@ def ones (dtype : Dtype) (shape : Shape) : Tensor := Id.run do data := data.push byte { dtype := dtype, shape := shape, data := data } +def byteOrder (arr : Tensor) : ByteOrder := arr.dtype.order + --! number of dimensions def ndim (x : Tensor) : Nat := x.shape.ndim @@ -126,11 +128,35 @@ def dimIndexToOffset (x : Tensor) (i : DimIndex) : Int := --! Get the starting byte corresponding to a DimIndex def dimIndexToPosition (x : Tensor) (i : DimIndex) : Nat := - (x.dimIndexToOffset i).toNat + (x.startIndex + (x.dimIndexToOffset i)).toNat --! number of bytes representing the entire tensor def nbytes (x : Tensor) : Nat := x.itemsize * x.size +def isIntLike (x : Tensor) : Bool := x.dtype.isIntLike + +def dimIndexInRange (arr : Tensor) (dimIndex : DimIndex) : Bool := arr.shape.dimIndexInRange dimIndex + +def byteArrayAtDimIndex (arr : Tensor) (dimIndex : DimIndex) : Err ByteArray := do + if !arr.dimIndexInRange dimIndex then .error "index is incompatible with tensor shape" else + let posn := arr.dimIndexToPosition dimIndex + .ok $ arr.data.extract posn (posn + arr.itemsize) + +def setByteArrayAtDimIndex (arr : Tensor) (dimIndex : DimIndex) (bytes : ByteArray) : Err Tensor := do + if !arr.dimIndexInRange dimIndex then .error "index is incompatible with tensor shape" else + if arr.itemsize != bytes.size then .error "byte size mismatch" else + let posn := arr.dimIndexToPosition dimIndex + .ok $ { arr with data := bytes.copySlice 0 arr.data posn bytes.size } + +/-! +Return the integer at the dimIndex. This is useful, for example, in advanced indexing +where we must have an int/uint Tensor as an argument. +-/ +def intAtDimIndex (arr : Tensor) (dimIndex : DimIndex) : Err Int := do + if !arr.isIntLike then .error "natAt expects an int tensor" else + let bytes <- byteArrayAtDimIndex arr dimIndex + .ok $ arr.byteOrder.bytesToInt bytes + /-! Copy a Tensor's data to new, contiguous storage. @@ -267,6 +293,16 @@ def setPosition [typ : Element a] (x : Tensor) (n : Nat) (v : a): Err Tensor := let posn := n * itemsize .ok { x with data := bytes.copySlice 0 x.data posn itemsize true } +def ofList (typ : Element a) (xs : List a) : Tensor := Id.run do + let arr := Tensor.zeros typ.dtype (Shape.mk [xs.length]) + let mut data := arr.data + let mut posn := 0 + for x in xs do + let v := typ.toByteArray x + data := v.copySlice 0 data posn typ.itemsize + posn := posn + arr.itemsize + { arr with data := data } + -- Since the DimIndex is independent of the dtype size, we need to recompute the strides -- TODO: Would be better to not recompute this over and over. We should find a place to store -- the 1-based default strides @@ -298,6 +334,15 @@ instance BV8Native : Element BV8 where toByteArray (x : BV8) : ByteArray := x.toByteArray fromByteArray arr startIndex := ByteArray.toBV8 arr startIndex +instance Int8Native : Element Int8 where + dtype := Dtype.mk .int8 .oneByte + itemsize := 1 + ofNat n := n.toInt8 + toByteArray (x : Int8) : ByteArray := [x.toUInt8].toByteArray + fromByteArray arr startIndex := (ByteArray.toBV8 arr startIndex).map fun b => Int8.mk b.toUInt8 + +#guard Int8Native.fromByteArray (Int8Native.toByteArray (-5)) 0 == .ok (-5) + instance BV16Little : Element BV16 where dtype := Dtype.mk .uint16 .littleEndian itemsize := 2 @@ -457,6 +502,9 @@ private def arr1 := Element.arange BV8 12 #guard (ones (Dtype.float64) $ Shape.mk [2, 2]).nbytes == 2 * 2 * 8 #guard (ones (Dtype.float64) $ Shape.mk [2, 2]).data.toList.count 1 == 2 * 2 +#guard get! ((Element.ofList Element.BV8Native [1, 2, 3]).toTree BV8) == Format.Tree.root [1, 2, 3] +#guard get! (((Element.ofList Element.BV8Native [0, 1, 2, 3, 4, 5]).reshape! (Shape.mk [2, 3])).toTree BV8) == .node [.root [0, 1, 2], .root [3, 4, 5]] + end Test end Tensor