Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal functions #23

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions TensorLib/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ import TensorLib.Broadcast
import TensorLib.Slice
import TensorLib.Tensor
import TensorLib.Index
import TensorLib.Mgrid
import TensorLib.Ufunc
21 changes: 21 additions & 0 deletions TensorLib/Broadcast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Rule 2

A: (3, 2, 7)
B: (3, 2, 7)

Theorem to prove: If we can broadcast s1 to s2, then given an array with shape s1, then s1.reshape s2 succeeds
-/

namespace TensorLib
Expand Down Expand Up @@ -103,4 +105,23 @@ def canBroadcast (b : Broadcast) : Bool := (broadcast b).isSome
broadcast b2 == broadcast b1 &&
broadcast b2 == .some (Shape.mk [1, 2, 3])

def broadcastList (shapes : List Shape) : Option Shape := Id.run do
match shapes with
| [] => none
| shape :: shapes =>
let mut shape := shape
for s in shapes do
let b := Broadcast.mk shape s
match b.broadcast with
| .none => return .none
| .some s =>
shape := s
return shape

#guard
let x1 := Shape.mk [1, 2, 3]
let x2 := Shape.mk [2, 3]
let x3 := Shape.mk []
broadcastList [x1, x2, x3] == .some x1

end Broadcast
67 changes: 61 additions & 6 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Std.Tactic.BVDecide
import Batteries.Data.List

namespace TensorLib

Expand All @@ -27,6 +28,7 @@ def natDivCeil (num denom : Nat) : Nat := (num + denom - 1) / denom

def natProd (shape : List Nat) : Nat := shape.foldl (fun x y => x * y) 1


-- We generally have large tensors, so don't show them by default
instance ByteArrayRepr : Repr ByteArray where
reprPrec x _ :=
Expand All @@ -45,11 +47,41 @@ inductive ByteOrder where
| bigEndian
deriving BEq, Repr, Inhabited

namespace ByteOrder

@[simp]
def ByteOrder.isMultiByte (x : ByteOrder) : Bool := match x with
def isMultiByte (x : ByteOrder) : Bool := match x with
| .oneByte => false
| .littleEndian | .bigEndian => true

def bytesToInt (order : ByteOrder) (bytes : ByteArray) : Int := Id.run do
let mut n : Nat := 0
let nbytes := bytes.size
let signByte := match order with
| .littleEndian => bytes.get! (nbytes - 1)
| .bigEndian | oneByte => bytes.get! 0
let negative := 128 <= signByte
for i in [0:nbytes] do
let v : UInt8 := bytes.get! i
let v := if negative then UInt8.complement v else v
let p := match order with
| .oneByte => 0 -- nbytes = 1
| .littleEndian => i
| .bigEndian => nbytes - 1 - i
n := n + Pow.pow 2 (8 * p) * v.toNat
return if 128 <= signByte then -(n + 1) else n

#guard bytesToInt .littleEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .bigEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .littleEndian (ByteArray.mk #[0, 1]) == 256
#guard bytesToInt .bigEndian (ByteArray.mk #[0, 1]) == 1
#guard bytesToInt .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0x80, 0]) == -32768
#guard bytesToInt .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80

end ByteOrder

/-!
The strides are how many bytes you need to skip to get to the next element in that
"row". For example, in an array of 8-byte data with shape 2, 3, the strides are (24, 8).
Expand Down Expand Up @@ -144,6 +176,9 @@ deriving BEq, Repr, Inhabited

namespace Shape

instance : ToString Shape where
toString := reprStr

def empty : Shape := Shape.mk []

--! The number of elements in a tensor. All that's needed is the shape for this calculation.
Expand All @@ -155,6 +190,10 @@ def ndim (shape : Shape) : Nat := shape.val.length

def map (shape : Shape) (f : List Nat -> List Nat) : Shape := Shape.mk (f shape.val)

def dimIndexInRange (shape : Shape) (dimIndex : DimIndex) :=
shape.ndim == dimIndex.length &&
(shape.val.zip dimIndex).all fun (n, i) => i < n

/-!
Strides can be computed from the shape by figuring out how many elements you
need to jump over to get to the next spot and mulitplying by the bytes in each
Expand Down Expand Up @@ -205,6 +244,7 @@ def positionToDimIndex (strides : Strides) (n : Position) : DimIndex :=
let (_, idx) := strides.foldl foldFn (n, [])
idx.reverse

-- TODO: Return `Err Offset` for when the strides and index have different lengths?
def dimIndexToOffset (strides : Strides) (index : DimIndex) : Offset := dot strides (index.map Int.ofNat)

#guard positionToDimIndex [3, 1] 4 == [1, 1]
Expand All @@ -223,6 +263,21 @@ def allDimIndices (shape : Shape) : List DimIndex := Id.run do
#guard allDimIndices (Shape.mk [5]) == [[0], [1], [2], [3], [4]]
#guard allDimIndices (Shape.mk [3, 2]) == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]

-- NumPy supports negative indices, which simply wrap around. E.g. `x[.., -1, ..] = x[.., n-1, ..]` where `n` is the
-- dimension in question. It only supports `-n` to `n`.
def intIndexToDimIndex (shape : Shape) (index : List Int) : Err DimIndex := do
if shape.ndim != index.length then .error "intsToDimIndex length mismatch" else
let conv (dim : Nat) (ind : Int) : Err Nat :=
if 0 <= ind then
if ind < dim then .ok ind.toNat
else .error "index out of bounds"
else if ind < -dim then .error "index out of bounds"
else .ok (dim + ind).toNat
(shape.val.zip index).traverse (fun (dim, ind) => conv dim ind)

#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, -1, -1] == (.ok [0, 1, 2])
#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, 1, -2] == (.ok [0, 1, 1])

end Shape

/-
Expand Down Expand Up @@ -278,7 +333,7 @@ Note: I tried writing this as a `do/for` loop and in this case the recursive
one seems nicer. We are walking over two lists simultaneously, which is easy
here but with a for loop is either quadratic or awkward.
-/
def next (iter : DimsIter) : List Nat × DimsIter :=
def next (iter : DimsIter) : DimIndex × DimsIter :=
-- Invariant: `acc` is a list of 0s, so doesn't need to be reversed
let rec loop (acc ms ns : List Nat) : List Nat :=
match ms, ns with
Expand All @@ -293,7 +348,7 @@ def next (iter : DimsIter) : List Nat × DimsIter :=
let curr' := loop [] iter.dims iter.curr
(iter.curr.reverse, { iter with curr := curr' })

instance [Monad m] : ForIn m DimsIter (List Nat) where
instance [Monad m] : ForIn m DimsIter DimIndex where
forIn {α} [Monad m] (iter : DimsIter) (x : α) (f : List Nat -> α -> m (ForInStep α)) : m α := do
let mut iter := iter
let mut res := x
Expand All @@ -305,7 +360,7 @@ instance [Monad m] : ForIn m DimsIter (List Nat) where
| .done k => return k
return res

private def toList (iter : DimsIter) : List (List Nat) := Id.run do
private def toList (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
for xs in iter do
res := xs :: res
Expand All @@ -321,7 +376,7 @@ private def toList (iter : DimsIter) : List (List Nat) := Id.run do
#guard (DimsIter.make $ Shape.mk [1, 1, 2]).toList == [[0, 0, 0], [0, 0, 1]]
#guard (DimsIter.make $ Shape.mk [3, 2]).toList == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]

private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do
private def testBreak (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
for xs in iter do
res := xs :: res
Expand All @@ -330,7 +385,7 @@ private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do

#guard (DimsIter.make $ Shape.mk [3, 2]).testBreak == [[0, 0]]

private def testReturn (iter : DimsIter) : List (List Nat) := Id.run do
private def testReturn (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
let mut i := 0
for xs in iter do
Expand Down
14 changes: 14 additions & 0 deletions TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def isMultiByte (x : Name) : Bool := match x with
| bool | int8 | uint8 => false
| _ => true

def isInt (x : Name) : Bool := match x with
| int8 | int16 | int32 | int64 => true
| _ => false

def isUint (x : Name) : Bool := match x with
| uint8 | uint16 | uint32 | uint64 => true
| _ => false

def isIntLike (x : Name) : Bool := x.isInt || x.isUint

--! Number of bytes used by each element of the given dtype
def itemsize (x : Name) : Nat := match x with
| float64 | int64 | uint64 => 8
Expand All @@ -57,6 +67,10 @@ def itemsize (t : Dtype) := t.name.itemsize

def sizedStrides (dtype : Dtype) (s : Shape) : Strides := List.map (fun x => x * dtype.itemsize) s.unitStrides

def isInt (dtype : Dtype) : Bool := dtype.name.isInt
def isUint (dtype : Dtype) : Bool := dtype.name.isUint
def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint

def int8 : Dtype := Dtype.mk Dtype.Name.int8 ByteOrder.littleEndian
def uint8 : Dtype := Dtype.mk Dtype.Name.uint8 ByteOrder.littleEndian
def uint64 : Dtype := Dtype.mk Dtype.Name.uint64 ByteOrder.littleEndian
Expand Down
Loading