Skip to content

Commit

Permalink
Reformat.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Nov 2, 2024
1 parent ced1580 commit 54c3ba7
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 354 deletions.
8 changes: 4 additions & 4 deletions prelude/ad.fut
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,18 @@

-- | Jacobian-Vector Product ("forward mode"), producing also the
-- primal result as the first element of the result tuple.
let jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) =
def jvp2 'a 'b (f: a -> b) (x: a) (x': a): (b, b) =
intrinsics.jvp2 f x x'

-- | Vector-Jacobian Product ("reverse mode"), producing also the
-- primal result as the first element of the result tuple.
let vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) =
def vjp2 'a 'b (f: a -> b) (x: a) (y': b): (b, a) =
intrinsics.vjp2 f x y'

-- | Jacobian-Vector Product ("forward mode").
let jvp 'a 'b (f: a -> b) (x: a) (x': a) : b =
def jvp 'a 'b (f: a -> b) (x: a) (x': a): b =
(jvp2 f x x').1

-- | Vector-Jacobian Product ("reverse mode").
let vjp 'a 'b (f: a -> b) (x: a) (y': b) : a =
def vjp 'a 'b (f: a -> b) (x: a) (y': b): a =
(vjp2 f x y').1
56 changes: 28 additions & 28 deletions prelude/array.fut
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,31 @@ def head [n] 't (x: [n]t) = x[0]
--
-- **Complexity:** O(1).
#[inline]
def last [n] 't (x: [n]t) = x[n-1]
def last [n] 't (x: [n]t) = x[n - 1]

-- | Everything but the first element of the array.
--
-- **Complexity:** O(1).
#[inline]
def tail [n] 't (x: [n]t): [n-1]t = x[1:]
def tail [n] 't (x: [n]t) : [n - 1]t = x[1:]

-- | Everything but the last element of the array.
--
-- **Complexity:** O(1).
#[inline]
def init [n] 't (x: [n]t): [n-1]t = x[0:n-1]
def init [n] 't (x: [n]t) : [n - 1]t = x[0:n - 1]

-- | Take some number of elements from the head of the array.
--
-- **Complexity:** O(1).
#[inline]
def take [n] 't (i: i64) (x: [n]t): [i]t = x[0:i]
def take [n] 't (i: i64) (x: [n]t) : [i]t = x[0:i]

-- | Remove some number of elements from the head of the array.
--
-- **Complexity:** O(1).
#[inline]
def drop [n] 't (i: i64) (x: [n]t): [n-i]t = x[i:]
def drop [n] 't (i: i64) (x: [n]t) : [n - i]t = x[i:]

-- | Statically change the size of an array. Fail at runtime if the
-- imposed size does not match the actual size. Essentially syntactic
Expand All @@ -61,14 +61,14 @@ def sized [m] 't (n: i64) (xs: [m]t) : [n]t = xs :> [n]t
--
-- **Complexity:** O(1).
#[inline]
def split [n][m] 't (xs: [n+m]t): ([n]t, [m]t) =
(xs[0:n], xs[n:n+m] :> [m]t)
def split [n] [m] 't (xs: [n + m]t) : ([n]t, [m]t) =
(xs[0:n], xs[n:n + m] :> [m]t)

-- | Return the elements of the array in reverse order.
--
-- **Complexity:** O(1).
#[inline]
def reverse [n] 't (x: [n]t): [n]t = x[::-1]
def reverse [n] 't (x: [n]t) : [n]t = x[::-1]

-- | Concatenate two arrays. Warning: never try to perform a reduction
-- with this operator; it will not work.
Expand All @@ -77,11 +77,11 @@ def reverse [n] 't (x: [n]t): [n]t = x[::-1]
--
-- **Span:** O(1).
#[inline]
def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = intrinsics.concat xs ys
def (++) [n] [m] 't (xs: [n]t) (ys: [m]t) : *[n + m]t = intrinsics.concat xs ys

-- | An old-fashioned way of saying `++`.
#[inline]
def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys
def concat [n] [m] 't (xs: [n]t) (ys: [m]t) : *[n + m]t = xs ++ ys

-- | Construct an array of consecutive integers of the given length,
-- starting at 0.
Expand All @@ -90,7 +90,7 @@ def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys
--
-- **Span:** O(1).
#[inline]
def iota (n: i64): *[n]i64 =
def iota (n: i64) : *[n]i64 =
0..1..<n

-- | Construct an array comprising valid indexes into some other
Expand All @@ -116,7 +116,7 @@ def indices [n] 't (_: [n]t) : *[n]i64 =
-- operations such as `map`, in which case it is free.
#[inline]
def rotate [n] 't (r: i64) (a: [n]t) =
map (\i -> #[unsafe] a[(i+r)%n]) (iota n)
map (\i -> #[unsafe] a[(i + r) % n]) (iota n)

-- | Construct an array of the given length containing the given
-- value.
Expand All @@ -125,7 +125,7 @@ def rotate [n] 't (r: i64) (a: [n]t) =
--
-- **Span:** O(1).
#[inline]
def replicate 't (n: i64) (x: t): *[n]t =
def replicate 't (n: i64) (x: t) : *[n]t =
map (const x) (iota n)

-- | Construct an array of an inferred length containing the given
Expand All @@ -135,7 +135,7 @@ def replicate 't (n: i64) (x: t): *[n]t =
--
-- **Span:** O(1).
#[inline]
def rep 't [n] (x: t): *[n]t =
def rep 't [n] (x: t) : *[n]t =
replicate n x

-- | Copy a value. The result will not alias anything.
Expand All @@ -144,7 +144,7 @@ def rep 't [n] (x: t): *[n]t =
--
-- **Span:** O(1).
#[inline]
def copy 't (a: t): *t =
def copy 't (a: t) : *t =
([a])[0]

-- | Copy a value. The result will not alias anything. Additionally,
Expand All @@ -156,48 +156,48 @@ def copy 't (a: t): *t =
--
-- **Span:** O(1).
#[inline]
def manifest 't (a: t): *t =
def manifest 't (a: t) : *t =
intrinsics.manifest a

-- | Combines the outer two dimensions of an array.
--
-- **Complexity:** O(1).
#[inline]
def flatten [n][m] 't (xs: [n][m]t): [n*m]t =
def flatten [n] [m] 't (xs: [n][m]t) : [n * m]t =
intrinsics.flatten xs

-- | Like `flatten`, but on the outer three dimensions of an array.
#[inline]
def flatten_3d [n][m][l] 't (xs: [n][m][l]t): [n*m*l]t =
def flatten_3d [n] [m] [l] 't (xs: [n][m][l]t) : [n * m * l]t =
flatten (flatten xs)

-- | Like `flatten`, but on the outer four dimensions of an array.
#[inline]
def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): [n*m*l*k]t =
def flatten_4d [n] [m] [l] [k] 't (xs: [n][m][l][k]t) : [n * m * l * k]t =
flatten (flatten_3d xs)

-- | Splits the outer dimension of an array in two.
--
-- **Complexity:** O(1).
#[inline]
def unflatten 't [n][m] (xs: [n*m]t): [n][m]t =
def unflatten 't [n] [m] (xs: [n * m]t) : [n][m]t =
intrinsics.unflatten n m xs

-- | Like `unflatten`, but produces three dimensions.
#[inline]
def unflatten_3d 't [n][m][l] (xs: [n*m*l]t): [n][m][l]t =
def unflatten_3d 't [n] [m] [l] (xs: [n * m * l]t) : [n][m][l]t =
unflatten (unflatten xs)

-- | Like `unflatten`, but produces four dimensions.
#[inline]
def unflatten_4d 't [n][m][l][k] (xs: [n*m*l*k]t): [n][m][l][k]t =
def unflatten_4d 't [n] [m] [l] [k] (xs: [n * m * l * k]t) : [n][m][l][k]t =
unflatten (unflatten_3d xs)

-- | Transpose an array.
--
-- **Complexity:** O(1).
#[inline]
def transpose [n] [m] 't (a: [n][m]t): [m][n]t =
def transpose [n] [m] 't (a: [n][m]t) : [m][n]t =
intrinsics.transpose a

-- | True if all of the input elements are true. Produces true on an
Expand All @@ -221,37 +221,37 @@ def or [n] (xs: [n]bool) = any id xs
-- **Work:** O(n ✕ W(f))).
--
-- **Span:** O(n ✕ S(f)).
def foldl [n] 'a 'b (f: a -> b -> a) (acc: a) (bs: [n]b): a =
def foldl [n] 'a 'b (f: a -> b -> a) (acc: a) (bs: [n]b) : a =
loop acc for b in bs do f acc b

-- | Perform a *sequential* right-fold of an array.
--
-- **Work:** O(n ✕ W(f))).
--
-- **Span:** O(n ✕ S(f)).
def foldr [n] 'a 'b (f: b -> a -> a) (acc: a) (bs: [n]b): a =
def foldr [n] 'a 'b (f: b -> a -> a) (acc: a) (bs: [n]b) : a =
foldl (flip f) acc (reverse bs)

-- | Create a value for each point in a one-dimensional index space.
--
-- **Work:** *O(n ✕ W(f))*
--
-- **Span:** *O(S(f))*
def tabulate 'a (n: i64) (f: i64 -> a): *[n]a =
def tabulate 'a (n: i64) (f: i64 -> a) : *[n]a =
map1 f (iota n)

-- | Create a value for each point in a two-dimensional index space.
--
-- **Work:** *O(n ✕ m ✕ W(f))*
--
-- **Span:** *O(S(f))*
def tabulate_2d 'a (n: i64) (m: i64) (f: i64 -> i64 -> a): *[n][m]a =
def tabulate_2d 'a (n: i64) (m: i64) (f: i64 -> i64 -> a) : *[n][m]a =
map1 (f >-> tabulate m) (iota n)

-- | Create a value for each point in a three-dimensional index space.
--
-- **Work:** *O(n ✕ m ✕ o ✕ W(f))*
--
-- **Span:** *O(S(f))*
def tabulate_3d 'a (n: i64) (m: i64) (o: i64) (f: i64 -> i64 -> i64 -> a): *[n][m][o]a =
def tabulate_3d 'a (n: i64) (m: i64) (o: i64) (f: i64 -> i64 -> i64 -> a) : *[n][m][o]a =
map1 (f >-> tabulate_2d m o) (iota n)
3 changes: 1 addition & 2 deletions prelude/functional.fut
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def (|>) '^a '^b (x: a) (f: a -> b): b = f x
-- ```
-- filter (>0) [-1,0,1] |> length
-- ```

def (<|) '^a '^b (f: a -> b) (x: a) = f x

-- | Function composition, with values flowing from left to right.
Expand Down Expand Up @@ -76,7 +75,7 @@ def iterate 'a (n: i32) (f: a -> a) (x: a) =
-- | Keep applying `f` until `p` returns true for the input value.
-- May apply zero times. *Note*: may not terminate.
def iterate_until 'a (p: a -> bool) (f: a -> a) (x: a) =
loop x while ! (p x) do f x
loop x while !(p x) do f x

-- | Keep applying `f` while `p` returns true for the input value.
-- May apply zero times. *Note*: may not terminate.
Expand Down
Loading

0 comments on commit 54c3ba7

Please sign in to comment.