Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: leanprover/TensorLib
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 621d94a70dc99d6815ec7ffb2c8f85ef1db8e465
Choose a base ref
..
head repository: leanprover/TensorLib
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: a3757f22528c9ec370e1922456ddd8a11e63b84f
Choose a head ref
Showing with 5 additions and 9 deletions.
  1. +1 −2 TensorLib/Broadcast.lean
  2. +1 −3 TensorLib/Common.lean
  3. +3 −4 TensorLib/Tensor.lean
3 changes: 1 addition & 2 deletions TensorLib/Broadcast.lean
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Aesop
import Batteries.Data.List
import TensorLib.Common

/-!
@@ -83,7 +82,7 @@ private def matchPairs (b : Broadcast) : Option Shape :=
else if x == 1 then some y
else if y == 1 then some x
else none
let dims := (b.left.val.zip b.right.val).traverse f
let dims := (b.left.val.zip b.right.val).mapM f
dims.map Shape.mk

--! Returns the shape resulting from broadcast the arguments
4 changes: 1 addition & 3 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
@@ -5,8 +5,6 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Std.Tactic.BVDecide
import Batteries.Data.List

namespace TensorLib

--! The error monad for TensorLib
@@ -273,7 +271,7 @@ def intIndexToDimIndex (shape : Shape) (index : List Int) : Err DimIndex := do
else .error "index out of bounds"
else if ind < -dim then .error "index out of bounds"
else .ok (dim + ind).toNat
(shape.val.zip index).traverse (fun (dim, ind) => conv dim ind)
(shape.val.zip index).mapM (fun (dim, ind) => conv dim ind)

#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, -1, -1] == (.ok [0, 1, 2])
#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, 1, -2] == (.ok [0, 1, 1])
7 changes: 3 additions & 4 deletions TensorLib/Tensor.lean
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Batteries.Data.List
import Batteries.Data.List -- for `toChunks`
import TensorLib.Common
import TensorLib.Dtype
import TensorLib.Npy
@@ -320,8 +320,7 @@ def setDimIndex [Element a] (x : Tensor) (index : DimIndex) (v : a): Err Tensor

-- TODO: remove `Err` by proving all indices are within range
def toList (a : Type) [Tensor.Element a] (x : Tensor) : Err (List a) :=
let traverseFn ind : Err a := getDimIndex x ind
x.shape.allDimIndices.traverse traverseFn
x.shape.allDimIndices.mapM (getDimIndex x)

def toList! (a : Type) [Tensor.Element a] (x : Tensor) : List a := match toList a x with
| .error _ => []
@@ -389,7 +388,7 @@ private def toTree {a : Type} (x : List a) (strides : Strides) : Err (Tree a) :=
| [_] => .error "not a unit stride"
| stride :: strides => do
let chunks := x.toChunks stride.toNat
let res <- chunks.traverse (fun x => toTree x strides)
let res <- chunks.mapM (fun x => toTree x strides)
return .node res

private def toTree! {a : Type} (x : List a) (strides : Strides) : Tree a := match toTree x strides with