diff --git a/prelude/ad.fut b/prelude/ad.fut index 4daac5a2cd..6a512d40b9 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -95,18 +95,18 @@ -- | Jacobian-Vector Product ("forward mode"), producing also the -- primal result as the first element of the result tuple. -let jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) = +def jvp2 'a 'b (f: a -> b) (x: a) (x': a): (b, b) = intrinsics.jvp2 f x x' -- | Vector-Jacobian Product ("reverse mode"), producing also the -- primal result as the first element of the result tuple. -let vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = +def vjp2 'a 'b (f: a -> b) (x: a) (y': b): (b, a) = intrinsics.vjp2 f x y' -- | Jacobian-Vector Product ("forward mode"). -let jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = +def jvp 'a 'b (f: a -> b) (x: a) (x': a): b = (jvp2 f x x').1 -- | Vector-Jacobian Product ("reverse mode"). -let vjp 'a 'b (f: a -> b) (x: a) (y': b) : a = +def vjp 'a 'b (f: a -> b) (x: a) (y': b): a = (vjp2 f x y').1 diff --git a/prelude/array.fut b/prelude/array.fut index dfb8252196..c1c3977c7a 100644 --- a/prelude/array.fut +++ b/prelude/array.fut @@ -25,31 +25,31 @@ def head [n] 't (x: [n]t) = x[0] -- -- **Complexity:** O(1). #[inline] -def last [n] 't (x: [n]t) = x[n-1] +def last [n] 't (x: [n]t) = x[n - 1] -- | Everything but the first element of the array. -- -- **Complexity:** O(1). #[inline] -def tail [n] 't (x: [n]t): [n-1]t = x[1:] +def tail [n] 't (x: [n]t) : [n - 1]t = x[1:] -- | Everything but the last element of the array. -- -- **Complexity:** O(1). #[inline] -def init [n] 't (x: [n]t): [n-1]t = x[0:n-1] +def init [n] 't (x: [n]t) : [n - 1]t = x[0:n - 1] -- | Take some number of elements from the head of the array. -- -- **Complexity:** O(1). #[inline] -def take [n] 't (i: i64) (x: [n]t): [i]t = x[0:i] +def take [n] 't (i: i64) (x: [n]t) : [i]t = x[0:i] -- | Remove some number of elements from the head of the array. -- -- **Complexity:** O(1). #[inline] -def drop [n] 't (i: i64) (x: [n]t): [n-i]t = x[i:] +def drop [n] 't (i: i64) (x: [n]t) : [n - i]t = x[i:] -- | Statically change the size of an array. Fail at runtime if the -- imposed size does not match the actual size. Essentially syntactic @@ -61,14 +61,14 @@ def sized [m] 't (n: i64) (xs: [m]t) : [n]t = xs :> [n]t -- -- **Complexity:** O(1). #[inline] -def split [n][m] 't (xs: [n+m]t): ([n]t, [m]t) = - (xs[0:n], xs[n:n+m] :> [m]t) +def split [n] [m] 't (xs: [n + m]t) : ([n]t, [m]t) = + (xs[0:n], xs[n:n + m] :> [m]t) -- | Return the elements of the array in reverse order. -- -- **Complexity:** O(1). #[inline] -def reverse [n] 't (x: [n]t): [n]t = x[::-1] +def reverse [n] 't (x: [n]t) : [n]t = x[::-1] -- | Concatenate two arrays. Warning: never try to perform a reduction -- with this operator; it will not work. @@ -77,11 +77,11 @@ def reverse [n] 't (x: [n]t): [n]t = x[::-1] -- -- **Span:** O(1). #[inline] -def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = intrinsics.concat xs ys +def (++) [n] [m] 't (xs: [n]t) (ys: [m]t) : *[n + m]t = intrinsics.concat xs ys -- | An old-fashioned way of saying `++`. #[inline] -def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys +def concat [n] [m] 't (xs: [n]t) (ys: [m]t) : *[n + m]t = xs ++ ys -- | Construct an array of consecutive integers of the given length, -- starting at 0. @@ -90,7 +90,7 @@ def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys -- -- **Span:** O(1). #[inline] -def iota (n: i64): *[n]i64 = +def iota (n: i64) : *[n]i64 = 0..1.. #[unsafe] a[(i+r)%n]) (iota n) + map (\i -> #[unsafe] a[(i + r) % n]) (iota n) -- | Construct an array of the given length containing the given -- value. @@ -125,7 +125,7 @@ def rotate [n] 't (r: i64) (a: [n]t) = -- -- **Span:** O(1). #[inline] -def replicate 't (n: i64) (x: t): *[n]t = +def replicate 't (n: i64) (x: t) : *[n]t = map (const x) (iota n) -- | Construct an array of an inferred length containing the given @@ -135,7 +135,7 @@ def replicate 't (n: i64) (x: t): *[n]t = -- -- **Span:** O(1). #[inline] -def rep 't [n] (x: t): *[n]t = +def rep 't [n] (x: t) : *[n]t = replicate n x -- | Copy a value. The result will not alias anything. @@ -144,7 +144,7 @@ def rep 't [n] (x: t): *[n]t = -- -- **Span:** O(1). #[inline] -def copy 't (a: t): *t = +def copy 't (a: t) : *t = ([a])[0] -- | Copy a value. The result will not alias anything. Additionally, @@ -156,48 +156,48 @@ def copy 't (a: t): *t = -- -- **Span:** O(1). #[inline] -def manifest 't (a: t): *t = +def manifest 't (a: t) : *t = intrinsics.manifest a -- | Combines the outer two dimensions of an array. -- -- **Complexity:** O(1). #[inline] -def flatten [n][m] 't (xs: [n][m]t): [n*m]t = +def flatten [n] [m] 't (xs: [n][m]t) : [n * m]t = intrinsics.flatten xs -- | Like `flatten`, but on the outer three dimensions of an array. #[inline] -def flatten_3d [n][m][l] 't (xs: [n][m][l]t): [n*m*l]t = +def flatten_3d [n] [m] [l] 't (xs: [n][m][l]t) : [n * m * l]t = flatten (flatten xs) -- | Like `flatten`, but on the outer four dimensions of an array. #[inline] -def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): [n*m*l*k]t = +def flatten_4d [n] [m] [l] [k] 't (xs: [n][m][l][k]t) : [n * m * l * k]t = flatten (flatten_3d xs) -- | Splits the outer dimension of an array in two. -- -- **Complexity:** O(1). #[inline] -def unflatten 't [n][m] (xs: [n*m]t): [n][m]t = +def unflatten 't [n] [m] (xs: [n * m]t) : [n][m]t = intrinsics.unflatten n m xs -- | Like `unflatten`, but produces three dimensions. #[inline] -def unflatten_3d 't [n][m][l] (xs: [n*m*l]t): [n][m][l]t = +def unflatten_3d 't [n] [m] [l] (xs: [n * m * l]t) : [n][m][l]t = unflatten (unflatten xs) -- | Like `unflatten`, but produces four dimensions. #[inline] -def unflatten_4d 't [n][m][l][k] (xs: [n*m*l*k]t): [n][m][l][k]t = +def unflatten_4d 't [n] [m] [l] [k] (xs: [n * m * l * k]t) : [n][m][l][k]t = unflatten (unflatten_3d xs) -- | Transpose an array. -- -- **Complexity:** O(1). #[inline] -def transpose [n] [m] 't (a: [n][m]t): [m][n]t = +def transpose [n] [m] 't (a: [n][m]t) : [m][n]t = intrinsics.transpose a -- | True if all of the input elements are true. Produces true on an @@ -221,7 +221,7 @@ def or [n] (xs: [n]bool) = any id xs -- **Work:** O(n ✕ W(f))). -- -- **Span:** O(n ✕ S(f)). -def foldl [n] 'a 'b (f: a -> b -> a) (acc: a) (bs: [n]b): a = +def foldl [n] 'a 'b (f: a -> b -> a) (acc: a) (bs: [n]b) : a = loop acc for b in bs do f acc b -- | Perform a *sequential* right-fold of an array. @@ -229,7 +229,7 @@ def foldl [n] 'a 'b (f: a -> b -> a) (acc: a) (bs: [n]b): a = -- **Work:** O(n ✕ W(f))). -- -- **Span:** O(n ✕ S(f)). -def foldr [n] 'a 'b (f: b -> a -> a) (acc: a) (bs: [n]b): a = +def foldr [n] 'a 'b (f: b -> a -> a) (acc: a) (bs: [n]b) : a = foldl (flip f) acc (reverse bs) -- | Create a value for each point in a one-dimensional index space. @@ -237,7 +237,7 @@ def foldr [n] 'a 'b (f: b -> a -> a) (acc: a) (bs: [n]b): a = -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* -def tabulate 'a (n: i64) (f: i64 -> a): *[n]a = +def tabulate 'a (n: i64) (f: i64 -> a) : *[n]a = map1 f (iota n) -- | Create a value for each point in a two-dimensional index space. @@ -245,7 +245,7 @@ def tabulate 'a (n: i64) (f: i64 -> a): *[n]a = -- **Work:** *O(n ✕ m ✕ W(f))* -- -- **Span:** *O(S(f))* -def tabulate_2d 'a (n: i64) (m: i64) (f: i64 -> i64 -> a): *[n][m]a = +def tabulate_2d 'a (n: i64) (m: i64) (f: i64 -> i64 -> a) : *[n][m]a = map1 (f >-> tabulate m) (iota n) -- | Create a value for each point in a three-dimensional index space. @@ -253,5 +253,5 @@ def tabulate_2d 'a (n: i64) (m: i64) (f: i64 -> i64 -> a): *[n][m]a = -- **Work:** *O(n ✕ m ✕ o ✕ W(f))* -- -- **Span:** *O(S(f))* -def tabulate_3d 'a (n: i64) (m: i64) (o: i64) (f: i64 -> i64 -> i64 -> a): *[n][m][o]a = +def tabulate_3d 'a (n: i64) (m: i64) (o: i64) (f: i64 -> i64 -> i64 -> a) : *[n][m][o]a = map1 (f >-> tabulate_2d m o) (iota n) diff --git a/prelude/functional.fut b/prelude/functional.fut index d34018412a..a44e9dac3f 100644 --- a/prelude/functional.fut +++ b/prelude/functional.fut @@ -23,7 +23,6 @@ def (|>) '^a '^b (x: a) (f: a -> b): b = f x -- ``` -- filter (>0) [-1,0,1] |> length -- ``` - def (<|) '^a '^b (f: a -> b) (x: a) = f x -- | Function composition, with values flowing from left to right. @@ -76,7 +75,7 @@ def iterate 'a (n: i32) (f: a -> a) (x: a) = -- | Keep applying `f` until `p` returns true for the input value. -- May apply zero times. *Note*: may not terminate. def iterate_until 'a (p: a -> bool) (f: a -> a) (x: a) = - loop x while ! (p x) do f x + loop x while !(p x) do f x -- | Keep applying `f` while `p` returns true for the input value. -- May apply zero times. *Note*: may not terminate. diff --git a/prelude/math.fut b/prelude/math.fut index c960e57d9c..cc20f491d1 100644 --- a/prelude/math.fut +++ b/prelude/math.fut @@ -71,6 +71,7 @@ module type numeric = { -- | Returns `lowest` on empty input. val maximum [n]: [n]t -> t + -- | Returns `highest` on empty input. val minimum [n]: [n]t -> t } @@ -83,14 +84,17 @@ module type integral = { -- | Like `/`@term, but rounds towards zero. This only matters when -- one of the operands is negative. May be more efficient. val //: t -> t -> t + -- | Like `%`@term, but rounds towards zero. This only matters when -- one of the operands is negative. May be more efficient. val %%: t -> t -> t -- | Bitwise and. val &: t -> t -> t + -- | Bitwise or. val |: t -> t -> t + -- | Bitwise xor. val ^: t -> t -> t @@ -99,8 +103,10 @@ module type integral = { -- | Left shift; inserting zeroes. val <<: t -> t -> t + -- | Arithmetic right shift, using sign extension for the leftmost bits. val >>: t -> t -> t + -- | Logical right shift, inserting zeroes for the leftmost bits. val >>>: t -> t -> t @@ -143,6 +149,7 @@ module type real = { -- | Square root. val sqrt: t -> t + -- | Cube root. val cbrt: t -> t val exp: t -> t @@ -175,13 +182,15 @@ module type real = { -- | The true Gamma function. val gamma: t -> t + -- | The natural logarithm of the absolute value of `gamma`@term. val lgamma: t -> t -- | The error function. - val erf : t -> t + val erf: t -> t + -- | The complementary error function. - val erfc : t -> t + val erfc: t -> t -- | Linear interpolation. The third argument must be in the range -- `[0,1]` or the results are unspecified. @@ -189,33 +198,39 @@ module type real = { -- | Natural logarithm. val log: t -> t + -- | Base-2 logarithm. val log2: t -> t + -- | Base-10 logarithm. val log10: t -> t + -- | Compute `log (1 + x)` accurately even when `x` is very small. val log1p: t -> t -- | Round towards infinity. - val ceil : t -> t + val ceil: t -> t + -- | Round towards negative infinity. - val floor : t -> t + val floor: t -> t + -- | Round towards zero. - val trunc : t -> t + val trunc: t -> t + -- | Round to the nearest integer, with halfway cases rounded to the -- nearest even integer. Note that this differs from `round()` in -- C, but matches more modern languages. - val round : t -> t + val round: t -> t -- | Computes `a*b+c`. Depending on the compiler backend, this may -- be fused into a single operation that is faster but less -- accurate. Do not confuse it with `fma`@term. - val mad : (a: t) -> (b: t) -> (c: t) -> t + val mad: (a: t) -> (b: t) -> (c: t) -> t -- | Computes `a*b+c`, with `a*b` being rounded with infinite -- precision. Rounding of intermediate products shall not -- occur. Edge case behavior is per the IEEE 754-2008 standard. - val fma : (a: t) -> (b: t) -> (c: t) -> t + val fma: (a: t) -> (b: t) -> (c: t) -> t val isinf: t -> bool val isnan: t -> bool @@ -253,13 +268,13 @@ module type float = { -- | Produces the next representable number from `x` in the -- direction of `y`. - val nextafter : (x: t) -> (y: t) -> t + val nextafter: (x: t) -> (y: t) -> t -- | Multiplies floating-point value by 2 raised to an integer power. - val ldexp : t -> i32 -> t + val ldexp: t -> i32 -> t -- | Compose a floating-point value with the magnitude of `x` and the sign of `y`. - val copysign : (x: t) -> (y: t) -> t + val copysign: (x: t) -> (y: t) -> t } -- | Boolean numbers. When converting from a number to `bool`, 0 is @@ -267,12 +282,12 @@ module type float = { module bool: from_prim with t = bool = { type t = bool - def i8 = intrinsics.itob_i8_bool + def i8 = intrinsics.itob_i8_bool def i16 = intrinsics.itob_i16_bool def i32 = intrinsics.itob_i32_bool def i64 = intrinsics.itob_i64_bool - def u8 (x: u8) = intrinsics.itob_i8_bool (intrinsics.sign_i8 x) + def u8 (x: u8) = intrinsics.itob_i8_bool (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.itob_i16_bool (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.itob_i32_bool (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.itob_i64_bool (intrinsics.sign_i64 x) @@ -287,30 +302,30 @@ module bool: from_prim with t = bool = { module i8: (integral with t = i8) = { type t = i8 - def (x: i8) + (y: i8) = intrinsics.add8 (x, y) - def (x: i8) - (y: i8) = intrinsics.sub8 (x, y) - def (x: i8) * (y: i8) = intrinsics.mul8 (x, y) - def (x: i8) / (y: i8) = intrinsics.sdiv8 (x, y) - def (x: i8) ** (y: i8) = intrinsics.pow8 (x, y) - def (x: i8) % (y: i8) = intrinsics.smod8 (x, y) - def (x: i8) // (y: i8) = intrinsics.squot8 (x, y) - def (x: i8) %% (y: i8) = intrinsics.srem8 (x, y) - - def (x: i8) & (y: i8) = intrinsics.and8 (x, y) - def (x: i8) | (y: i8) = intrinsics.or8 (x, y) - def (x: i8) ^ (y: i8) = intrinsics.xor8 (x, y) + def (+) (x: i8) (y: i8) = intrinsics.add8 (x, y) + def (-) (x: i8) (y: i8) = intrinsics.sub8 (x, y) + def (*) (x: i8) (y: i8) = intrinsics.mul8 (x, y) + def (/) (x: i8) (y: i8) = intrinsics.sdiv8 (x, y) + def (**) (x: i8) (y: i8) = intrinsics.pow8 (x, y) + def (%) (x: i8) (y: i8) = intrinsics.smod8 (x, y) + def (//) (x: i8) (y: i8) = intrinsics.squot8 (x, y) + def (%%) (x: i8) (y: i8) = intrinsics.srem8 (x, y) + + def (&) (x: i8) (y: i8) = intrinsics.and8 (x, y) + def (|) (x: i8) (y: i8) = intrinsics.or8 (x, y) + def (^) (x: i8) (y: i8) = intrinsics.xor8 (x, y) def not (x: i8) = intrinsics.complement8 x - def (x: i8) << (y: i8) = intrinsics.shl8 (x, y) - def (x: i8) >> (y: i8) = intrinsics.ashr8 (x, y) - def (x: i8) >>> (y: i8) = intrinsics.lshr8 (x, y) + def (<<) (x: i8) (y: i8) = intrinsics.shl8 (x, y) + def (>>) (x: i8) (y: i8) = intrinsics.ashr8 (x, y) + def (>>>) (x: i8) (y: i8) = intrinsics.lshr8 (x, y) - def i8 (x: i8) = intrinsics.sext_i8_i8 x + def i8 (x: i8) = intrinsics.sext_i8_i8 x def i16 (x: i16) = intrinsics.sext_i16_i8 x def i32 (x: i32) = intrinsics.sext_i32_i8 x def i64 (x: i64) = intrinsics.sext_i64_i8 x - def u8 (x: u8) = intrinsics.zext_i8_i8 (intrinsics.sign_i8 x) + def u8 (x: u8) = intrinsics.zext_i8_i8 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i8 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i8 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i8 (intrinsics.sign_i64 x) @@ -321,15 +336,15 @@ module i8: (integral with t = i8) = { def bool = intrinsics.btoi_bool_i8 - def to_i32(x: i8) = intrinsics.sext_i8_i32 x - def to_i64(x: i8) = intrinsics.sext_i8_i64 x + def to_i32 (x: i8) = intrinsics.sext_i8_i32 x + def to_i64 (x: i8) = intrinsics.sext_i8_i64 x - def (x: i8) == (y: i8) = intrinsics.eq_i8 (x, y) - def (x: i8) < (y: i8) = intrinsics.slt8 (x, y) - def (x: i8) > (y: i8) = intrinsics.slt8 (y, x) - def (x: i8) <= (y: i8) = intrinsics.sle8 (x, y) - def (x: i8) >= (y: i8) = intrinsics.sle8 (y, x) - def (x: i8) != (y: i8) = !(x == y) + def (==) (x: i8) (y: i8) = intrinsics.eq_i8 (x, y) + def (<) (x: i8) (y: i8) = intrinsics.slt8 (x, y) + def (>) (x: i8) (y: i8) = intrinsics.slt8 (y, x) + def (<=) (x: i8) (y: i8) = intrinsics.sle8 (x, y) + def (>=) (x: i8) (y: i8) = intrinsics.sle8 (y, x) + def (!=) (x: i8) (y: i8) = !(x == y) def sgn (x: i8) = intrinsics.ssignum8 x def abs (x: i8) = intrinsics.abs8 x @@ -343,8 +358,10 @@ module i8: (integral with t = i8) = { def num_bits = 8i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) + def popc = intrinsics.popc8 def mul_hi a b = intrinsics.smul_hi8 (i8 a, i8 b) def mad_hi a b c = intrinsics.smad_hi8 (i8 a, i8 b, i8 c) @@ -360,30 +377,30 @@ module i8: (integral with t = i8) = { module i16: (integral with t = i16) = { type t = i16 - def (x: i16) + (y: i16) = intrinsics.add16 (x, y) - def (x: i16) - (y: i16) = intrinsics.sub16 (x, y) - def (x: i16) * (y: i16) = intrinsics.mul16 (x, y) - def (x: i16) / (y: i16) = intrinsics.sdiv16 (x, y) - def (x: i16) ** (y: i16) = intrinsics.pow16 (x, y) - def (x: i16) % (y: i16) = intrinsics.smod16 (x, y) - def (x: i16) // (y: i16) = intrinsics.squot16 (x, y) - def (x: i16) %% (y: i16) = intrinsics.srem16 (x, y) - - def (x: i16) & (y: i16) = intrinsics.and16 (x, y) - def (x: i16) | (y: i16) = intrinsics.or16 (x, y) - def (x: i16) ^ (y: i16) = intrinsics.xor16 (x, y) + def (+) (x: i16) (y: i16) = intrinsics.add16 (x, y) + def (-) (x: i16) (y: i16) = intrinsics.sub16 (x, y) + def (*) (x: i16) (y: i16) = intrinsics.mul16 (x, y) + def (/) (x: i16) (y: i16) = intrinsics.sdiv16 (x, y) + def (**) (x: i16) (y: i16) = intrinsics.pow16 (x, y) + def (%) (x: i16) (y: i16) = intrinsics.smod16 (x, y) + def (//) (x: i16) (y: i16) = intrinsics.squot16 (x, y) + def (%%) (x: i16) (y: i16) = intrinsics.srem16 (x, y) + + def (&) (x: i16) (y: i16) = intrinsics.and16 (x, y) + def (|) (x: i16) (y: i16) = intrinsics.or16 (x, y) + def (^) (x: i16) (y: i16) = intrinsics.xor16 (x, y) def not (x: i16) = intrinsics.complement16 x - def (x: i16) << (y: i16) = intrinsics.shl16 (x, y) - def (x: i16) >> (y: i16) = intrinsics.ashr16 (x, y) - def (x: i16) >>> (y: i16) = intrinsics.lshr16 (x, y) + def (<<) (x: i16) (y: i16) = intrinsics.shl16 (x, y) + def (>>) (x: i16) (y: i16) = intrinsics.ashr16 (x, y) + def (>>>) (x: i16) (y: i16) = intrinsics.lshr16 (x, y) - def i8 (x: i8) = intrinsics.sext_i8_i16 x + def i8 (x: i8) = intrinsics.sext_i8_i16 x def i16 (x: i16) = intrinsics.sext_i16_i16 x def i32 (x: i32) = intrinsics.sext_i32_i16 x def i64 (x: i64) = intrinsics.sext_i64_i16 x - def u8 (x: u8) = intrinsics.zext_i8_i16 (intrinsics.sign_i8 x) + def u8 (x: u8) = intrinsics.zext_i8_i16 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i16 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i16 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i16 (intrinsics.sign_i64 x) @@ -394,15 +411,15 @@ module i16: (integral with t = i16) = { def bool = intrinsics.btoi_bool_i16 - def to_i32(x: i16) = intrinsics.sext_i16_i32 x - def to_i64(x: i16) = intrinsics.sext_i16_i64 x + def to_i32 (x: i16) = intrinsics.sext_i16_i32 x + def to_i64 (x: i16) = intrinsics.sext_i16_i64 x - def (x: i16) == (y: i16) = intrinsics.eq_i16 (x, y) - def (x: i16) < (y: i16) = intrinsics.slt16 (x, y) - def (x: i16) > (y: i16) = intrinsics.slt16 (y, x) - def (x: i16) <= (y: i16) = intrinsics.sle16 (x, y) - def (x: i16) >= (y: i16) = intrinsics.sle16 (y, x) - def (x: i16) != (y: i16) = !(x == y) + def (==) (x: i16) (y: i16) = intrinsics.eq_i16 (x, y) + def (<) (x: i16) (y: i16) = intrinsics.slt16 (x, y) + def (>) (x: i16) (y: i16) = intrinsics.slt16 (y, x) + def (<=) (x: i16) (y: i16) = intrinsics.sle16 (x, y) + def (>=) (x: i16) (y: i16) = intrinsics.sle16 (y, x) + def (!=) (x: i16) (y: i16) = !(x == y) def sgn (x: i16) = intrinsics.ssignum16 x def abs (x: i16) = intrinsics.abs16 x @@ -416,8 +433,10 @@ module i16: (integral with t = i16) = { def num_bits = 16i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) + def popc = intrinsics.popc16 def mul_hi a b = intrinsics.smul_hi16 (i16 a, i16 b) def mad_hi a b c = intrinsics.smad_hi16 (i16 a, i16 b, i16 c) @@ -436,30 +455,30 @@ module i32: (integral with t = i32) = { def sign (x: u32) = intrinsics.sign_i32 x def unsign (x: i32) = intrinsics.unsign_i32 x - def (x: i32) + (y: i32) = intrinsics.add32 (x, y) - def (x: i32) - (y: i32) = intrinsics.sub32 (x, y) - def (x: i32) * (y: i32) = intrinsics.mul32 (x, y) - def (x: i32) / (y: i32) = intrinsics.sdiv32 (x, y) - def (x: i32) ** (y: i32) = intrinsics.pow32 (x, y) - def (x: i32) % (y: i32) = intrinsics.smod32 (x, y) - def (x: i32) // (y: i32) = intrinsics.squot32 (x, y) - def (x: i32) %% (y: i32) = intrinsics.srem32 (x, y) - - def (x: i32) & (y: i32) = intrinsics.and32 (x, y) - def (x: i32) | (y: i32) = intrinsics.or32 (x, y) - def (x: i32) ^ (y: i32) = intrinsics.xor32 (x, y) + def (+) (x: i32) (y: i32) = intrinsics.add32 (x, y) + def (-) (x: i32) (y: i32) = intrinsics.sub32 (x, y) + def (*) (x: i32) (y: i32) = intrinsics.mul32 (x, y) + def (/) (x: i32) (y: i32) = intrinsics.sdiv32 (x, y) + def (**) (x: i32) (y: i32) = intrinsics.pow32 (x, y) + def (%) (x: i32) (y: i32) = intrinsics.smod32 (x, y) + def (//) (x: i32) (y: i32) = intrinsics.squot32 (x, y) + def (%%) (x: i32) (y: i32) = intrinsics.srem32 (x, y) + + def (&) (x: i32) (y: i32) = intrinsics.and32 (x, y) + def (|) (x: i32) (y: i32) = intrinsics.or32 (x, y) + def (^) (x: i32) (y: i32) = intrinsics.xor32 (x, y) def not (x: i32) = intrinsics.complement32 x - def (x: i32) << (y: i32) = intrinsics.shl32 (x, y) - def (x: i32) >> (y: i32) = intrinsics.ashr32 (x, y) - def (x: i32) >>> (y: i32) = intrinsics.lshr32 (x, y) + def (<<) (x: i32) (y: i32) = intrinsics.shl32 (x, y) + def (>>) (x: i32) (y: i32) = intrinsics.ashr32 (x, y) + def (>>>) (x: i32) (y: i32) = intrinsics.lshr32 (x, y) - def i8 (x: i8) = intrinsics.sext_i8_i32 x + def i8 (x: i8) = intrinsics.sext_i8_i32 x def i16 (x: i16) = intrinsics.sext_i16_i32 x def i32 (x: i32) = intrinsics.sext_i32_i32 x def i64 (x: i64) = intrinsics.sext_i64_i32 x - def u8 (x: u8) = intrinsics.zext_i8_i32 (intrinsics.sign_i8 x) + def u8 (x: u8) = intrinsics.zext_i8_i32 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i32 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i32 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i32 (intrinsics.sign_i64 x) @@ -470,15 +489,15 @@ module i32: (integral with t = i32) = { def bool = intrinsics.btoi_bool_i32 - def to_i32(x: i32) = intrinsics.sext_i32_i32 x - def to_i64(x: i32) = intrinsics.sext_i32_i64 x + def to_i32 (x: i32) = intrinsics.sext_i32_i32 x + def to_i64 (x: i32) = intrinsics.sext_i32_i64 x - def (x: i32) == (y: i32) = intrinsics.eq_i32 (x, y) - def (x: i32) < (y: i32) = intrinsics.slt32 (x, y) - def (x: i32) > (y: i32) = intrinsics.slt32 (y, x) - def (x: i32) <= (y: i32) = intrinsics.sle32 (x, y) - def (x: i32) >= (y: i32) = intrinsics.sle32 (y, x) - def (x: i32) != (y: i32) = !(x == y) + def (==) (x: i32) (y: i32) = intrinsics.eq_i32 (x, y) + def (<) (x: i32) (y: i32) = intrinsics.slt32 (x, y) + def (>) (x: i32) (y: i32) = intrinsics.slt32 (y, x) + def (<=) (x: i32) (y: i32) = intrinsics.sle32 (x, y) + def (>=) (x: i32) (y: i32) = intrinsics.sle32 (y, x) + def (!=) (x: i32) (y: i32) = !(x == y) def sgn (x: i32) = intrinsics.ssignum32 x def abs (x: i32) = intrinsics.abs32 x @@ -492,8 +511,10 @@ module i32: (integral with t = i32) = { def num_bits = 32i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) + def popc = intrinsics.popc32 def mul_hi a b = intrinsics.smul_hi32 (i32 a, i32 b) def mad_hi a b c = intrinsics.smad_hi32 (i32 a, i32 b, i32 c) @@ -512,30 +533,30 @@ module i64: (integral with t = i64) = { def sign (x: u64) = intrinsics.sign_i64 x def unsign (x: i64) = intrinsics.unsign_i64 x - def (x: i64) + (y: i64) = intrinsics.add64 (x, y) - def (x: i64) - (y: i64) = intrinsics.sub64 (x, y) - def (x: i64) * (y: i64) = intrinsics.mul64 (x, y) - def (x: i64) / (y: i64) = intrinsics.sdiv64 (x, y) - def (x: i64) ** (y: i64) = intrinsics.pow64 (x, y) - def (x: i64) % (y: i64) = intrinsics.smod64 (x, y) - def (x: i64) // (y: i64) = intrinsics.squot64 (x, y) - def (x: i64) %% (y: i64) = intrinsics.srem64 (x, y) - - def (x: i64) & (y: i64) = intrinsics.and64 (x, y) - def (x: i64) | (y: i64) = intrinsics.or64 (x, y) - def (x: i64) ^ (y: i64) = intrinsics.xor64 (x, y) + def (+) (x: i64) (y: i64) = intrinsics.add64 (x, y) + def (-) (x: i64) (y: i64) = intrinsics.sub64 (x, y) + def (*) (x: i64) (y: i64) = intrinsics.mul64 (x, y) + def (/) (x: i64) (y: i64) = intrinsics.sdiv64 (x, y) + def (**) (x: i64) (y: i64) = intrinsics.pow64 (x, y) + def (%) (x: i64) (y: i64) = intrinsics.smod64 (x, y) + def (//) (x: i64) (y: i64) = intrinsics.squot64 (x, y) + def (%%) (x: i64) (y: i64) = intrinsics.srem64 (x, y) + + def (&) (x: i64) (y: i64) = intrinsics.and64 (x, y) + def (|) (x: i64) (y: i64) = intrinsics.or64 (x, y) + def (^) (x: i64) (y: i64) = intrinsics.xor64 (x, y) def not (x: i64) = intrinsics.complement64 x - def (x: i64) << (y: i64) = intrinsics.shl64 (x, y) - def (x: i64) >> (y: i64) = intrinsics.ashr64 (x, y) - def (x: i64) >>> (y: i64) = intrinsics.lshr64 (x, y) + def (<<) (x: i64) (y: i64) = intrinsics.shl64 (x, y) + def (>>) (x: i64) (y: i64) = intrinsics.ashr64 (x, y) + def (>>>) (x: i64) (y: i64) = intrinsics.lshr64 (x, y) - def i8 (x: i8) = intrinsics.sext_i8_i64 x + def i8 (x: i8) = intrinsics.sext_i8_i64 x def i16 (x: i16) = intrinsics.sext_i16_i64 x def i32 (x: i32) = intrinsics.sext_i32_i64 x def i64 (x: i64) = intrinsics.sext_i64_i64 x - def u8 (x: u8) = intrinsics.zext_i8_i64 (intrinsics.sign_i8 x) + def u8 (x: u8) = intrinsics.zext_i8_i64 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i64 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i64 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i64 (intrinsics.sign_i64 x) @@ -546,15 +567,15 @@ module i64: (integral with t = i64) = { def bool = intrinsics.btoi_bool_i64 - def to_i32(x: i64) = intrinsics.sext_i64_i32 x - def to_i64(x: i64) = intrinsics.sext_i64_i64 x + def to_i32 (x: i64) = intrinsics.sext_i64_i32 x + def to_i64 (x: i64) = intrinsics.sext_i64_i64 x - def (x: i64) == (y: i64) = intrinsics.eq_i64 (x, y) - def (x: i64) < (y: i64) = intrinsics.slt64 (x, y) - def (x: i64) > (y: i64) = intrinsics.slt64 (y, x) - def (x: i64) <= (y: i64) = intrinsics.sle64 (x, y) - def (x: i64) >= (y: i64) = intrinsics.sle64 (y, x) - def (x: i64) != (y: i64) = !(x == y) + def (==) (x: i64) (y: i64) = intrinsics.eq_i64 (x, y) + def (<) (x: i64) (y: i64) = intrinsics.slt64 (x, y) + def (>) (x: i64) (y: i64) = intrinsics.slt64 (y, x) + def (<=) (x: i64) (y: i64) = intrinsics.sle64 (x, y) + def (>=) (x: i64) (y: i64) = intrinsics.sle64 (y, x) + def (!=) (x: i64) (y: i64) = !(x == y) def sgn (x: i64) = intrinsics.ssignum64 x def abs (x: i64) = intrinsics.abs64 x @@ -568,8 +589,10 @@ module i64: (integral with t = i64) = { def num_bits = 64i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | intrinsics.zext_i32_i64 (b intrinsics.<< bit)) + def popc = intrinsics.popc64 def mul_hi a b = intrinsics.smul_hi64 (i64 a, i64 b) def mad_hi a b c = intrinsics.smad_hi64 (i64 a, i64 b, i64 c) @@ -588,30 +611,30 @@ module u8: (integral with t = u8) = { def sign (x: u8) = intrinsics.sign_i8 x def unsign (x: i8) = intrinsics.unsign_i8 x - def (x: u8) + (y: u8) = unsign (intrinsics.add8 (sign x, sign y)) - def (x: u8) - (y: u8) = unsign (intrinsics.sub8 (sign x, sign y)) - def (x: u8) * (y: u8) = unsign (intrinsics.mul8 (sign x, sign y)) - def (x: u8) / (y: u8) = unsign (intrinsics.udiv8 (sign x, sign y)) - def (x: u8) ** (y: u8) = unsign (intrinsics.pow8 (sign x, sign y)) - def (x: u8) % (y: u8) = unsign (intrinsics.umod8 (sign x, sign y)) - def (x: u8) // (y: u8) = unsign (intrinsics.udiv8 (sign x, sign y)) - def (x: u8) %% (y: u8) = unsign (intrinsics.umod8 (sign x, sign y)) - - def (x: u8) & (y: u8) = unsign (intrinsics.and8 (sign x, sign y)) - def (x: u8) | (y: u8) = unsign (intrinsics.or8 (sign x, sign y)) - def (x: u8) ^ (y: u8) = unsign (intrinsics.xor8 (sign x, sign y)) + def (+) (x: u8) (y: u8) = unsign (intrinsics.add8 (sign x, sign y)) + def (-) (x: u8) (y: u8) = unsign (intrinsics.sub8 (sign x, sign y)) + def (*) (x: u8) (y: u8) = unsign (intrinsics.mul8 (sign x, sign y)) + def (/) (x: u8) (y: u8) = unsign (intrinsics.udiv8 (sign x, sign y)) + def (**) (x: u8) (y: u8) = unsign (intrinsics.pow8 (sign x, sign y)) + def (%) (x: u8) (y: u8) = unsign (intrinsics.umod8 (sign x, sign y)) + def (//) (x: u8) (y: u8) = unsign (intrinsics.udiv8 (sign x, sign y)) + def (%%) (x: u8) (y: u8) = unsign (intrinsics.umod8 (sign x, sign y)) + + def (&) (x: u8) (y: u8) = unsign (intrinsics.and8 (sign x, sign y)) + def (|) (x: u8) (y: u8) = unsign (intrinsics.or8 (sign x, sign y)) + def (^) (x: u8) (y: u8) = unsign (intrinsics.xor8 (sign x, sign y)) def not (x: u8) = unsign (intrinsics.complement8 (sign x)) - def (x: u8) << (y: u8) = unsign (intrinsics.shl8 (sign x, sign y)) - def (x: u8) >> (y: u8) = unsign (intrinsics.ashr8 (sign x, sign y)) - def (x: u8) >>> (y: u8) = unsign (intrinsics.lshr8 (sign x, sign y)) + def (<<) (x: u8) (y: u8) = unsign (intrinsics.shl8 (sign x, sign y)) + def (>>) (x: u8) (y: u8) = unsign (intrinsics.ashr8 (sign x, sign y)) + def (>>>) (x: u8) (y: u8) = unsign (intrinsics.lshr8 (sign x, sign y)) - def u8 (x: u8) = unsign (i8.u8 x) + def u8 (x: u8) = unsign (i8.u8 x) def u16 (x: u16) = unsign (i8.u16 x) def u32 (x: u32) = unsign (i8.u32 x) def u64 (x: u64) = unsign (i8.u64 x) - def i8 (x: i8) = unsign (intrinsics.zext_i8_i8 x) + def i8 (x: i8) = unsign (intrinsics.zext_i8_i8 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i8 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i8 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i8 x) @@ -622,15 +645,15 @@ module u8: (integral with t = u8) = { def bool x = unsign (intrinsics.btoi_bool_i8 x) - def to_i32(x: u8) = intrinsics.zext_i8_i32 (sign x) - def to_i64(x: u8) = intrinsics.zext_i8_i64 (sign x) + def to_i32 (x: u8) = intrinsics.zext_i8_i32 (sign x) + def to_i64 (x: u8) = intrinsics.zext_i8_i64 (sign x) - def (x: u8) == (y: u8) = intrinsics.eq_i8 (sign x, sign y) - def (x: u8) < (y: u8) = intrinsics.ult8 (sign x, sign y) - def (x: u8) > (y: u8) = intrinsics.ult8 (sign y, sign x) - def (x: u8) <= (y: u8) = intrinsics.ule8 (sign x, sign y) - def (x: u8) >= (y: u8) = intrinsics.ule8 (sign y, sign x) - def (x: u8) != (y: u8) = !(x == y) + def (==) (x: u8) (y: u8) = intrinsics.eq_i8 (sign x, sign y) + def (<) (x: u8) (y: u8) = intrinsics.ult8 (sign x, sign y) + def (>) (x: u8) (y: u8) = intrinsics.ult8 (sign y, sign x) + def (<=) (x: u8) (y: u8) = intrinsics.ule8 (sign x, sign y) + def (>=) (x: u8) (y: u8) = intrinsics.ule8 (sign y, sign x) + def (!=) (x: u8) (y: u8) = !(x == y) def sgn (x: u8) = unsign (intrinsics.usignum8 (sign x)) def abs (x: u8) = x @@ -644,8 +667,10 @@ module u8: (integral with t = u8) = { def num_bits = 8i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) + def popc x = intrinsics.popc8 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi8 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi8 (sign a, sign b, sign c)) @@ -664,30 +689,30 @@ module u16: (integral with t = u16) = { def sign (x: u16) = intrinsics.sign_i16 x def unsign (x: i16) = intrinsics.unsign_i16 x - def (x: u16) + (y: u16) = unsign (intrinsics.add16 (sign x, sign y)) - def (x: u16) - (y: u16) = unsign (intrinsics.sub16 (sign x, sign y)) - def (x: u16) * (y: u16) = unsign (intrinsics.mul16 (sign x, sign y)) - def (x: u16) / (y: u16) = unsign (intrinsics.udiv16 (sign x, sign y)) - def (x: u16) ** (y: u16) = unsign (intrinsics.pow16 (sign x, sign y)) - def (x: u16) % (y: u16) = unsign (intrinsics.umod16 (sign x, sign y)) - def (x: u16) // (y: u16) = unsign (intrinsics.udiv16 (sign x, sign y)) - def (x: u16) %% (y: u16) = unsign (intrinsics.umod16 (sign x, sign y)) - - def (x: u16) & (y: u16) = unsign (intrinsics.and16 (sign x, sign y)) - def (x: u16) | (y: u16) = unsign (intrinsics.or16 (sign x, sign y)) - def (x: u16) ^ (y: u16) = unsign (intrinsics.xor16 (sign x, sign y)) + def (+) (x: u16) (y: u16) = unsign (intrinsics.add16 (sign x, sign y)) + def (-) (x: u16) (y: u16) = unsign (intrinsics.sub16 (sign x, sign y)) + def (*) (x: u16) (y: u16) = unsign (intrinsics.mul16 (sign x, sign y)) + def (/) (x: u16) (y: u16) = unsign (intrinsics.udiv16 (sign x, sign y)) + def (**) (x: u16) (y: u16) = unsign (intrinsics.pow16 (sign x, sign y)) + def (%) (x: u16) (y: u16) = unsign (intrinsics.umod16 (sign x, sign y)) + def (//) (x: u16) (y: u16) = unsign (intrinsics.udiv16 (sign x, sign y)) + def (%%) (x: u16) (y: u16) = unsign (intrinsics.umod16 (sign x, sign y)) + + def (&) (x: u16) (y: u16) = unsign (intrinsics.and16 (sign x, sign y)) + def (|) (x: u16) (y: u16) = unsign (intrinsics.or16 (sign x, sign y)) + def (^) (x: u16) (y: u16) = unsign (intrinsics.xor16 (sign x, sign y)) def not (x: u16) = unsign (intrinsics.complement16 (sign x)) - def (x: u16) << (y: u16) = unsign (intrinsics.shl16 (sign x, sign y)) - def (x: u16) >> (y: u16) = unsign (intrinsics.ashr16 (sign x, sign y)) - def (x: u16) >>> (y: u16) = unsign (intrinsics.lshr16 (sign x, sign y)) + def (<<) (x: u16) (y: u16) = unsign (intrinsics.shl16 (sign x, sign y)) + def (>>) (x: u16) (y: u16) = unsign (intrinsics.ashr16 (sign x, sign y)) + def (>>>) (x: u16) (y: u16) = unsign (intrinsics.lshr16 (sign x, sign y)) - def u8 (x: u8) = unsign (i16.u8 x) + def u8 (x: u8) = unsign (i16.u8 x) def u16 (x: u16) = unsign (i16.u16 x) def u32 (x: u32) = unsign (i16.u32 x) def u64 (x: u64) = unsign (i16.u64 x) - def i8 (x: i8) = unsign (intrinsics.zext_i8_i16 x) + def i8 (x: i8) = unsign (intrinsics.zext_i8_i16 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i16 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i16 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i16 x) @@ -698,15 +723,15 @@ module u16: (integral with t = u16) = { def bool x = unsign (intrinsics.btoi_bool_i16 x) - def to_i32(x: u16) = intrinsics.zext_i16_i32 (sign x) - def to_i64(x: u16) = intrinsics.zext_i16_i64 (sign x) + def to_i32 (x: u16) = intrinsics.zext_i16_i32 (sign x) + def to_i64 (x: u16) = intrinsics.zext_i16_i64 (sign x) - def (x: u16) == (y: u16) = intrinsics.eq_i16 (sign x, sign y) - def (x: u16) < (y: u16) = intrinsics.ult16 (sign x, sign y) - def (x: u16) > (y: u16) = intrinsics.ult16 (sign y, sign x) - def (x: u16) <= (y: u16) = intrinsics.ule16 (sign x, sign y) - def (x: u16) >= (y: u16) = intrinsics.ule16 (sign y, sign x) - def (x: u16) != (y: u16) = !(x == y) + def (==) (x: u16) (y: u16) = intrinsics.eq_i16 (sign x, sign y) + def (<) (x: u16) (y: u16) = intrinsics.ult16 (sign x, sign y) + def (>) (x: u16) (y: u16) = intrinsics.ult16 (sign y, sign x) + def (<=) (x: u16) (y: u16) = intrinsics.ule16 (sign x, sign y) + def (>=) (x: u16) (y: u16) = intrinsics.ule16 (sign y, sign x) + def (!=) (x: u16) (y: u16) = !(x == y) def sgn (x: u16) = unsign (intrinsics.usignum16 (sign x)) def abs (x: u16) = x @@ -720,8 +745,10 @@ module u16: (integral with t = u16) = { def num_bits = 16i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) + def popc x = intrinsics.popc16 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi16 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi16 (sign a, sign b, sign c)) @@ -740,30 +767,30 @@ module u32: (integral with t = u32) = { def sign (x: u32) = intrinsics.sign_i32 x def unsign (x: i32) = intrinsics.unsign_i32 x - def (x: u32) + (y: u32) = unsign (intrinsics.add32 (sign x, sign y)) - def (x: u32) - (y: u32) = unsign (intrinsics.sub32 (sign x, sign y)) - def (x: u32) * (y: u32) = unsign (intrinsics.mul32 (sign x, sign y)) - def (x: u32) / (y: u32) = unsign (intrinsics.udiv32 (sign x, sign y)) - def (x: u32) ** (y: u32) = unsign (intrinsics.pow32 (sign x, sign y)) - def (x: u32) % (y: u32) = unsign (intrinsics.umod32 (sign x, sign y)) - def (x: u32) // (y: u32) = unsign (intrinsics.udiv32 (sign x, sign y)) - def (x: u32) %% (y: u32) = unsign (intrinsics.umod32 (sign x, sign y)) - - def (x: u32) & (y: u32) = unsign (intrinsics.and32 (sign x, sign y)) - def (x: u32) | (y: u32) = unsign (intrinsics.or32 (sign x, sign y)) - def (x: u32) ^ (y: u32) = unsign (intrinsics.xor32 (sign x, sign y)) + def (+) (x: u32) (y: u32) = unsign (intrinsics.add32 (sign x, sign y)) + def (-) (x: u32) (y: u32) = unsign (intrinsics.sub32 (sign x, sign y)) + def (*) (x: u32) (y: u32) = unsign (intrinsics.mul32 (sign x, sign y)) + def (/) (x: u32) (y: u32) = unsign (intrinsics.udiv32 (sign x, sign y)) + def (**) (x: u32) (y: u32) = unsign (intrinsics.pow32 (sign x, sign y)) + def (%) (x: u32) (y: u32) = unsign (intrinsics.umod32 (sign x, sign y)) + def (//) (x: u32) (y: u32) = unsign (intrinsics.udiv32 (sign x, sign y)) + def (%%) (x: u32) (y: u32) = unsign (intrinsics.umod32 (sign x, sign y)) + + def (&) (x: u32) (y: u32) = unsign (intrinsics.and32 (sign x, sign y)) + def (|) (x: u32) (y: u32) = unsign (intrinsics.or32 (sign x, sign y)) + def (^) (x: u32) (y: u32) = unsign (intrinsics.xor32 (sign x, sign y)) def not (x: u32) = unsign (intrinsics.complement32 (sign x)) - def (x: u32) << (y: u32) = unsign (intrinsics.shl32 (sign x, sign y)) - def (x: u32) >> (y: u32) = unsign (intrinsics.ashr32 (sign x, sign y)) - def (x: u32) >>> (y: u32) = unsign (intrinsics.lshr32 (sign x, sign y)) + def (<<) (x: u32) (y: u32) = unsign (intrinsics.shl32 (sign x, sign y)) + def (>>) (x: u32) (y: u32) = unsign (intrinsics.ashr32 (sign x, sign y)) + def (>>>) (x: u32) (y: u32) = unsign (intrinsics.lshr32 (sign x, sign y)) - def u8 (x: u8) = unsign (i32.u8 x) + def u8 (x: u8) = unsign (i32.u8 x) def u16 (x: u16) = unsign (i32.u16 x) def u32 (x: u32) = unsign (i32.u32 x) def u64 (x: u64) = unsign (i32.u64 x) - def i8 (x: i8) = unsign (intrinsics.zext_i8_i32 x) + def i8 (x: i8) = unsign (intrinsics.zext_i8_i32 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i32 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i32 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i32 x) @@ -774,15 +801,15 @@ module u32: (integral with t = u32) = { def bool x = unsign (intrinsics.btoi_bool_i32 x) - def to_i32(x: u32) = intrinsics.zext_i32_i32 (sign x) - def to_i64(x: u32) = intrinsics.zext_i32_i64 (sign x) + def to_i32 (x: u32) = intrinsics.zext_i32_i32 (sign x) + def to_i64 (x: u32) = intrinsics.zext_i32_i64 (sign x) - def (x: u32) == (y: u32) = intrinsics.eq_i32 (sign x, sign y) - def (x: u32) < (y: u32) = intrinsics.ult32 (sign x, sign y) - def (x: u32) > (y: u32) = intrinsics.ult32 (sign y, sign x) - def (x: u32) <= (y: u32) = intrinsics.ule32 (sign x, sign y) - def (x: u32) >= (y: u32) = intrinsics.ule32 (sign y, sign x) - def (x: u32) != (y: u32) = !(x == y) + def (==) (x: u32) (y: u32) = intrinsics.eq_i32 (sign x, sign y) + def (<) (x: u32) (y: u32) = intrinsics.ult32 (sign x, sign y) + def (>) (x: u32) (y: u32) = intrinsics.ult32 (sign y, sign x) + def (<=) (x: u32) (y: u32) = intrinsics.ule32 (sign x, sign y) + def (>=) (x: u32) (y: u32) = intrinsics.ule32 (sign y, sign x) + def (!=) (x: u32) (y: u32) = !(x == y) def sgn (x: u32) = unsign (intrinsics.usignum32 (sign x)) def abs (x: u32) = x @@ -796,8 +823,10 @@ module u32: (integral with t = u32) = { def num_bits = 32i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) + def popc x = intrinsics.popc32 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi32 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi32 (sign a, sign b, sign c)) @@ -816,30 +845,30 @@ module u64: (integral with t = u64) = { def sign (x: u64) = intrinsics.sign_i64 x def unsign (x: i64) = intrinsics.unsign_i64 x - def (x: u64) + (y: u64) = unsign (intrinsics.add64 (sign x, sign y)) - def (x: u64) - (y: u64) = unsign (intrinsics.sub64 (sign x, sign y)) - def (x: u64) * (y: u64) = unsign (intrinsics.mul64 (sign x, sign y)) - def (x: u64) / (y: u64) = unsign (intrinsics.udiv64 (sign x, sign y)) - def (x: u64) ** (y: u64) = unsign (intrinsics.pow64 (sign x, sign y)) - def (x: u64) % (y: u64) = unsign (intrinsics.umod64 (sign x, sign y)) - def (x: u64) // (y: u64) = unsign (intrinsics.udiv64 (sign x, sign y)) - def (x: u64) %% (y: u64) = unsign (intrinsics.umod64 (sign x, sign y)) - - def (x: u64) & (y: u64) = unsign (intrinsics.and64 (sign x, sign y)) - def (x: u64) | (y: u64) = unsign (intrinsics.or64 (sign x, sign y)) - def (x: u64) ^ (y: u64) = unsign (intrinsics.xor64 (sign x, sign y)) + def (+) (x: u64) (y: u64) = unsign (intrinsics.add64 (sign x, sign y)) + def (-) (x: u64) (y: u64) = unsign (intrinsics.sub64 (sign x, sign y)) + def (*) (x: u64) (y: u64) = unsign (intrinsics.mul64 (sign x, sign y)) + def (/) (x: u64) (y: u64) = unsign (intrinsics.udiv64 (sign x, sign y)) + def (**) (x: u64) (y: u64) = unsign (intrinsics.pow64 (sign x, sign y)) + def (%) (x: u64) (y: u64) = unsign (intrinsics.umod64 (sign x, sign y)) + def (//) (x: u64) (y: u64) = unsign (intrinsics.udiv64 (sign x, sign y)) + def (%%) (x: u64) (y: u64) = unsign (intrinsics.umod64 (sign x, sign y)) + + def (&) (x: u64) (y: u64) = unsign (intrinsics.and64 (sign x, sign y)) + def (|) (x: u64) (y: u64) = unsign (intrinsics.or64 (sign x, sign y)) + def (^) (x: u64) (y: u64) = unsign (intrinsics.xor64 (sign x, sign y)) def not (x: u64) = unsign (intrinsics.complement64 (sign x)) - def (x: u64) << (y: u64) = unsign (intrinsics.shl64 (sign x, sign y)) - def (x: u64) >> (y: u64) = unsign (intrinsics.ashr64 (sign x, sign y)) - def (x: u64) >>> (y: u64) = unsign (intrinsics.lshr64 (sign x, sign y)) + def (<<) (x: u64) (y: u64) = unsign (intrinsics.shl64 (sign x, sign y)) + def (>>) (x: u64) (y: u64) = unsign (intrinsics.ashr64 (sign x, sign y)) + def (>>>) (x: u64) (y: u64) = unsign (intrinsics.lshr64 (sign x, sign y)) - def u8 (x: u8) = unsign (i64.u8 x) + def u8 (x: u8) = unsign (i64.u8 x) def u16 (x: u16) = unsign (i64.u16 x) def u32 (x: u32) = unsign (i64.u32 x) def u64 (x: u64) = unsign (i64.u64 x) - def i8 (x: i8) = unsign (intrinsics.zext_i8_i64 x) + def i8 (x: i8) = unsign (intrinsics.zext_i8_i64 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i64 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i64 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i64 x) @@ -850,15 +879,15 @@ module u64: (integral with t = u64) = { def bool x = unsign (intrinsics.btoi_bool_i64 x) - def to_i32(x: u64) = intrinsics.zext_i64_i32 (sign x) - def to_i64(x: u64) = intrinsics.zext_i64_i64 (sign x) + def to_i32 (x: u64) = intrinsics.zext_i64_i32 (sign x) + def to_i64 (x: u64) = intrinsics.zext_i64_i64 (sign x) - def (x: u64) == (y: u64) = intrinsics.eq_i64 (sign x, sign y) - def (x: u64) < (y: u64) = intrinsics.ult64 (sign x, sign y) - def (x: u64) > (y: u64) = intrinsics.ult64 (sign y, sign x) - def (x: u64) <= (y: u64) = intrinsics.ule64 (sign x, sign y) - def (x: u64) >= (y: u64) = intrinsics.ule64 (sign y, sign x) - def (x: u64) != (y: u64) = !(x == y) + def (==) (x: u64) (y: u64) = intrinsics.eq_i64 (sign x, sign y) + def (<) (x: u64) (y: u64) = intrinsics.ult64 (sign x, sign y) + def (>) (x: u64) (y: u64) = intrinsics.ult64 (sign y, sign x) + def (<=) (x: u64) (y: u64) = intrinsics.ule64 (sign x, sign y) + def (>=) (x: u64) (y: u64) = intrinsics.ule64 (sign y, sign x) + def (!=) (x: u64) (y: u64) = !(x == y) def sgn (x: u64) = unsign (intrinsics.usignum64 (sign x)) def abs (x: u64) = x @@ -872,8 +901,10 @@ module u64: (integral with t = u64) = { def num_bits = 64i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) + def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) + def popc x = intrinsics.popc64 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi64 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi64 (sign a, sign b, sign c)) @@ -893,14 +924,14 @@ module f64: (float with t = f64 with int_t = u64) = { module i64m = i64 module u64m = u64 - def (x: f64) + (y: f64) = intrinsics.fadd64 (x, y) - def (x: f64) - (y: f64) = intrinsics.fsub64 (x, y) - def (x: f64) * (y: f64) = intrinsics.fmul64 (x, y) - def (x: f64) / (y: f64) = intrinsics.fdiv64 (x, y) - def (x: f64) % (y: f64) = intrinsics.fmod64 (x, y) - def (x: f64) ** (y: f64) = intrinsics.fpow64 (x, y) + def (+) (x: f64) (y: f64) = intrinsics.fadd64 (x, y) + def (-) (x: f64) (y: f64) = intrinsics.fsub64 (x, y) + def (*) (x: f64) (y: f64) = intrinsics.fmul64 (x, y) + def (/) (x: f64) (y: f64) = intrinsics.fdiv64 (x, y) + def (%) (x: f64) (y: f64) = intrinsics.fmod64 (x, y) + def (**) (x: f64) (y: f64) = intrinsics.fpow64 (x, y) - def u8 (x: u8) = intrinsics.uitofp_i8_f64 (i8.u8 x) + def u8 (x: u8) = intrinsics.uitofp_i8_f64 (i8.u8 x) def u16 (x: u16) = intrinsics.uitofp_i16_f64 (i16.u16 x) def u32 (x: u32) = intrinsics.uitofp_i32_f64 (i32.u32 x) def u64 (x: u64) = intrinsics.uitofp_i64_f64 (i64.u64 x) @@ -920,15 +951,15 @@ module f64: (float with t = f64 with int_t = u64) = { def to_i64 (x: f64) = intrinsics.fptosi_f64_i64 x def to_f64 (x: f64) = x - def (x: f64) == (y: f64) = intrinsics.eq_f64 (x, y) - def (x: f64) < (y: f64) = intrinsics.lt64 (x, y) - def (x: f64) > (y: f64) = intrinsics.lt64 (y, x) - def (x: f64) <= (y: f64) = intrinsics.le64 (x, y) - def (x: f64) >= (y: f64) = intrinsics.le64 (y, x) - def (x: f64) != (y: f64) = !(x == y) + def (==) (x: f64) (y: f64) = intrinsics.eq_f64 (x, y) + def (<) (x: f64) (y: f64) = intrinsics.lt64 (x, y) + def (>) (x: f64) (y: f64) = intrinsics.lt64 (y, x) + def (<=) (x: f64) (y: f64) = intrinsics.le64 (x, y) + def (>=) (x: f64) (y: f64) = intrinsics.le64 (y, x) + def (!=) (x: f64) (y: f64) = !(x == y) def neg (x: t) = -x - def recip (x: t) = 1/x + def recip (x: t) = 1 / x def max (x: t) (y: t) = intrinsics.fmax64 (x, y) def min (x: t) (y: t) = intrinsics.fmin64 (x, y) @@ -962,19 +993,19 @@ module f64: (float with t = f64 with int_t = u64) = { def erf = intrinsics.erf64 def erfc = intrinsics.erfc64 - def lerp v0 v1 t = intrinsics.lerp64 (v0,v1,t) - def fma a b c = intrinsics.fma64 (a,b,c) - def mad a b c = intrinsics.mad64 (a,b,c) + def lerp v0 v1 t = intrinsics.lerp64 (v0, v1, t) + def fma a b c = intrinsics.fma64 (a, b, c) + def mad a b c = intrinsics.mad64 (a, b, c) def ceil = intrinsics.ceil64 def floor = intrinsics.floor64 - def trunc (x: f64) : f64 = i64 (i64m.f64 x) + def trunc (x: f64): f64 = i64 (i64m.f64 x) def round = intrinsics.round64 - def nextafter x y = intrinsics.nextafter64 (x,y) - def ldexp x y = intrinsics.ldexp64 (x,y) - def copysign x y = intrinsics.copysign64 (x,y) + def nextafter x y = intrinsics.nextafter64 (x, y) + def ldexp x y = intrinsics.ldexp64 (x, y) + def copysign x y = intrinsics.copysign64 (x, y) def to_bits (x: f64): u64 = u64m.i64 (intrinsics.to_bits64 x) def from_bits (x: u64): f64 = intrinsics.from_bits64 (intrinsics.sign_i64 x) @@ -1010,14 +1041,14 @@ module f32: (float with t = f32 with int_t = u32) = { module u32m = u32 module f64m = f64 - def (x: f32) + (y: f32) = intrinsics.fadd32 (x, y) - def (x: f32) - (y: f32) = intrinsics.fsub32 (x, y) - def (x: f32) * (y: f32) = intrinsics.fmul32 (x, y) - def (x: f32) / (y: f32) = intrinsics.fdiv32 (x, y) - def (x: f32) % (y: f32) = intrinsics.fmod32 (x, y) - def (x: f32) ** (y: f32) = intrinsics.fpow32 (x, y) + def (+) (x: f32) (y: f32) = intrinsics.fadd32 (x, y) + def (-) (x: f32) (y: f32) = intrinsics.fsub32 (x, y) + def (*) (x: f32) (y: f32) = intrinsics.fmul32 (x, y) + def (/) (x: f32) (y: f32) = intrinsics.fdiv32 (x, y) + def (%) (x: f32) (y: f32) = intrinsics.fmod32 (x, y) + def (**) (x: f32) (y: f32) = intrinsics.fpow32 (x, y) - def u8 (x: u8) = intrinsics.uitofp_i8_f32 (i8.u8 x) + def u8 (x: u8) = intrinsics.uitofp_i8_f32 (i8.u8 x) def u16 (x: u16) = intrinsics.uitofp_i16_f32 (i16.u16 x) def u32 (x: u32) = intrinsics.uitofp_i32_f32 (i32.u32 x) def u64 (x: u64) = intrinsics.uitofp_i64_f32 (i64.u64 x) @@ -1037,15 +1068,15 @@ module f32: (float with t = f32 with int_t = u32) = { def to_i64 (x: f32) = intrinsics.fptosi_f32_i64 x def to_f64 (x: f32) = intrinsics.fpconv_f32_f64 x - def (x: f32) == (y: f32) = intrinsics.eq_f32 (x, y) - def (x: f32) < (y: f32) = intrinsics.lt32 (x, y) - def (x: f32) > (y: f32) = intrinsics.lt32 (y, x) - def (x: f32) <= (y: f32) = intrinsics.le32 (x, y) - def (x: f32) >= (y: f32) = intrinsics.le32 (y, x) - def (x: f32) != (y: f32) = !(x == y) + def (==) (x: f32) (y: f32) = intrinsics.eq_f32 (x, y) + def (<) (x: f32) (y: f32) = intrinsics.lt32 (x, y) + def (>) (x: f32) (y: f32) = intrinsics.lt32 (y, x) + def (<=) (x: f32) (y: f32) = intrinsics.le32 (x, y) + def (>=) (x: f32) (y: f32) = intrinsics.le32 (y, x) + def (!=) (x: f32) (y: f32) = !(x == y) def neg (x: t) = -x - def recip (x: t) = 1/x + def recip (x: t) = 1 / x def max (x: t) (y: t) = intrinsics.fmax32 (x, y) def min (x: t) (y: t) = intrinsics.fmin32 (x, y) @@ -1079,19 +1110,19 @@ module f32: (float with t = f32 with int_t = u32) = { def erf = intrinsics.erf32 def erfc = intrinsics.erfc32 - def lerp v0 v1 t = intrinsics.lerp32 (v0,v1,t) - def fma a b c = intrinsics.fma32 (a,b,c) - def mad a b c = intrinsics.mad32 (a,b,c) + def lerp v0 v1 t = intrinsics.lerp32 (v0, v1, t) + def fma a b c = intrinsics.fma32 (a, b, c) + def mad a b c = intrinsics.mad32 (a, b, c) def ceil = intrinsics.ceil32 def floor = intrinsics.floor32 - def trunc (x: f32) : f32 = i32 (i32m.f32 x) + def trunc (x: f32): f32 = i32 (i32m.f32 x) def round = intrinsics.round32 - def nextafter x y = intrinsics.nextafter32 (x,y) - def ldexp x y = intrinsics.ldexp32 (x,y) - def copysign x y = intrinsics.copysign32 (x,y) + def nextafter x y = intrinsics.nextafter32 (x, y) + def ldexp x y = intrinsics.ldexp32 (x, y) + def copysign x y = intrinsics.copysign32 (x, y) def to_bits (x: f32): u32 = u32m.i32 (intrinsics.to_bits32 x) def from_bits (x: u32): f32 = intrinsics.from_bits32 (intrinsics.sign_i32 x) @@ -1131,14 +1162,14 @@ module f16: (float with t = f16 with int_t = u16) = { module u16m = u16 module f64m = f64 - def (x: f16) + (y: f16) = intrinsics.fadd16 (x, y) - def (x: f16) - (y: f16) = intrinsics.fsub16 (x, y) - def (x: f16) * (y: f16) = intrinsics.fmul16 (x, y) - def (x: f16) / (y: f16) = intrinsics.fdiv16 (x, y) - def (x: f16) % (y: f16) = intrinsics.fmod16 (x, y) - def (x: f16) ** (y: f16) = intrinsics.fpow16 (x, y) + def (+) (x: f16) (y: f16) = intrinsics.fadd16 (x, y) + def (-) (x: f16) (y: f16) = intrinsics.fsub16 (x, y) + def (*) (x: f16) (y: f16) = intrinsics.fmul16 (x, y) + def (/) (x: f16) (y: f16) = intrinsics.fdiv16 (x, y) + def (%) (x: f16) (y: f16) = intrinsics.fmod16 (x, y) + def (**) (x: f16) (y: f16) = intrinsics.fpow16 (x, y) - def u8 (x: u8) = intrinsics.uitofp_i8_f16 (i8.u8 x) + def u8 (x: u8) = intrinsics.uitofp_i8_f16 (i8.u8 x) def u16 (x: u16) = intrinsics.uitofp_i16_f16 (i16.u16 x) def u32 (x: u32) = intrinsics.uitofp_i32_f16 (i32.u32 x) def u64 (x: u64) = intrinsics.uitofp_i64_f16 (i64.u64 x) @@ -1158,15 +1189,15 @@ module f16: (float with t = f16 with int_t = u16) = { def to_i64 (x: f16) = intrinsics.fptosi_f16_i64 x def to_f64 (x: f16) = intrinsics.fpconv_f16_f64 x - def (x: f16) == (y: f16) = intrinsics.eq_f16 (x, y) - def (x: f16) < (y: f16) = intrinsics.lt16 (x, y) - def (x: f16) > (y: f16) = intrinsics.lt16 (y, x) - def (x: f16) <= (y: f16) = intrinsics.le16 (x, y) - def (x: f16) >= (y: f16) = intrinsics.le16 (y, x) - def (x: f16) != (y: f16) = !(x == y) + def (==) (x: f16) (y: f16) = intrinsics.eq_f16 (x, y) + def (<) (x: f16) (y: f16) = intrinsics.lt16 (x, y) + def (>) (x: f16) (y: f16) = intrinsics.lt16 (y, x) + def (<=) (x: f16) (y: f16) = intrinsics.le16 (x, y) + def (>=) (x: f16) (y: f16) = intrinsics.le16 (y, x) + def (!=) (x: f16) (y: f16) = !(x == y) def neg (x: t) = -x - def recip (x: t) = 1/x + def recip (x: t) = 1 / x def max (x: t) (y: t) = intrinsics.fmax16 (x, y) def min (x: t) (y: t) = intrinsics.fmin16 (x, y) @@ -1200,19 +1231,19 @@ module f16: (float with t = f16 with int_t = u16) = { def erf = intrinsics.erf16 def erfc = intrinsics.erfc16 - def lerp v0 v1 t = intrinsics.lerp16 (v0,v1,t) - def fma a b c = intrinsics.fma16 (a,b,c) - def mad a b c = intrinsics.mad16 (a,b,c) + def lerp v0 v1 t = intrinsics.lerp16 (v0, v1, t) + def fma a b c = intrinsics.fma16 (a, b, c) + def mad a b c = intrinsics.mad16 (a, b, c) def ceil = intrinsics.ceil16 def floor = intrinsics.floor16 - def trunc (x: f16) : f16 = i16 (i16m.f16 x) + def trunc (x: f16): f16 = i16 (i16m.f16 x) def round = intrinsics.round16 - def nextafter x y = intrinsics.nextafter16 (x,y) - def ldexp x y = intrinsics.ldexp16 (x,y) - def copysign x y = intrinsics.copysign16 (x,y) + def nextafter x y = intrinsics.nextafter16 (x, y) + def ldexp x y = intrinsics.ldexp16 (x, y) + def copysign x y = intrinsics.copysign16 (x, y) def to_bits (x: f16): u16 = u16m.i16 (intrinsics.to_bits16 x) def from_bits (x: u16): f16 = intrinsics.from_bits16 (intrinsics.sign_i16 x) diff --git a/prelude/soacs.fut b/prelude/soacs.fut index 310fad5421..02576f09a9 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -47,7 +47,7 @@ import "zip" -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* -def map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = +def map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = intrinsics.map f as -- | Apply the given function to each element of a single array. @@ -55,7 +55,7 @@ def map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* -def map1 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = +def map1 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = map f as -- | As `map1`@term, but with one more array. @@ -63,7 +63,7 @@ def map1 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* -def map2 'a 'b [n] 'x (f: a -> b -> x) (as: [n]a) (bs: [n]b): *[n]x = +def map2 'a 'b [n] 'x (f: a -> b -> x) (as: [n]a) (bs: [n]b) : *[n]x = map (\(a, b) -> f a b) (zip2 as bs) -- | As `map2`@term, but with one more array. @@ -71,7 +71,7 @@ def map2 'a 'b [n] 'x (f: a -> b -> x) (as: [n]a) (bs: [n]b): *[n]x = -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* -def map3 'a 'b 'c [n] 'x (f: a -> b -> c -> x) (as: [n]a) (bs: [n]b) (cs: [n]c): *[n]x = +def map3 'a 'b 'c [n] 'x (f: a -> b -> c -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) : *[n]x = map (\(a, b, c) -> f a b c) (zip3 as bs cs) -- | As `map3`@term, but with one more array. @@ -79,7 +79,7 @@ def map3 'a 'b 'c [n] 'x (f: a -> b -> c -> x) (as: [n]a) (bs: [n]b) (cs: [n]c): -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* -def map4 'a 'b 'c 'd [n] 'x (f: a -> b -> c -> d -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d): *[n]x = +def map4 'a 'b 'c 'd [n] 'x (f: a -> b -> c -> d -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) : *[n]x = map (\(a, b, c, d) -> f a b c d) (zip4 as bs cs ds) -- | As `map3`@term, but with one more array. @@ -87,7 +87,7 @@ def map4 'a 'b 'c 'd [n] 'x (f: a -> b -> c -> d -> x) (as: [n]a) (bs: [n]b) (cs -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* -def map5 'a 'b 'c 'd 'e [n] 'x (f: a -> b -> c -> d -> e -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e): *[n]x = +def map5 'a 'b 'c 'd 'e [n] 'x (f: a -> b -> c -> d -> e -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e) : *[n]x = map (\(a, b, c, d, e) -> f a b c d e) (zip5 as bs cs ds es) -- | Reduce the array `as` with `op`, with `ne` as the neutral @@ -103,7 +103,7 @@ def map5 'a 'b 'c 'd 'e [n] 'x (f: a -> b -> c -> d -> e -> x) (as: [n]a) (bs: [ -- -- Note that the complexity implies that parallelism in the combining -- operator will *not* be exploited. -def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = +def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a) : a = intrinsics.reduce op ne as -- | As `reduce`, but the operator must also be commutative. This is @@ -114,7 +114,7 @@ def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = -- **Work:** *O(n ✕ W(op))* -- -- **Span:** *O(log(n) ✕ W(op))* -def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = +def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a) : a = intrinsics.reduce_comm op ne as -- | `h = hist op ne k is as` computes a generalised `k`-bin histogram @@ -142,15 +142,15 @@ def hist 'a [n] (op: a -> a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k -- position), but *O(W(op))* in the best case. -- -- In practice, linear span only occurs if *k* is also very large. -def reduce_by_index 'a [k] [n] (dest : *[k]a) (f : a -> a -> a) (ne : a) (is : [n]i64) (as : [n]a) : *[k]a = +def reduce_by_index 'a [k] [n] (dest: *[k]a) (f: a -> a -> a) (ne: a) (is: [n]i64) (as: [n]a) : *[k]a = intrinsics.hist_1d 1 dest f ne is as -- | As `reduce_by_index`, but with two-dimensional indexes. -def reduce_by_index_2d 'a [k] [n] [m] (dest : *[k][m]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64)) (as : [n]a) : *[k][m]a = +def reduce_by_index_2d 'a [k] [n] [m] (dest: *[k][m]a) (f: a -> a -> a) (ne: a) (is: [n](i64, i64)) (as: [n]a) : *[k][m]a = intrinsics.hist_2d 1 dest f ne is as -- | As `reduce_by_index`, but with three-dimensional indexes. -def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64,i64)) (as : [n]a) : *[k][m][l]a = +def reduce_by_index_3d 'a [k] [n] [m] [l] (dest: *[k][m][l]a) (f: a -> a -> a) (ne: a) (is: [n](i64, i64, i64)) (as: [n]a) : *[k][m][l]a = intrinsics.hist_3d 1 dest f ne is as -- | Inclusive prefix scan. Has the same caveats with respect to @@ -159,7 +159,7 @@ def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a) -- **Work:** *O(n ✕ W(op))* -- -- **Span:** *O(log(n) ✕ W(op))* -def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a = +def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a) : *[n]a = intrinsics.scan op ne as -- | Split an array into those elements that satisfy the given @@ -168,7 +168,7 @@ def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a = -- **Work:** *O(n ✕ W(p))* -- -- **Span:** *O(log(n) ✕ W(p))* -def partition [n] 'a (p: a -> bool) (as: [n]a): ?[k].([k]a, [n-k]a) = +def partition [n] 'a (p: a -> bool) (as: [n]a) : ?[k].([k]a, [n - k]a) = let p' x = if p x then 0 else 1 let (as', is) = intrinsics.partition 2 p' as in (as'[0:is[0]], as'[is[0]:n]) @@ -178,12 +178,13 @@ def partition [n] 'a (p: a -> bool) (as: [n]a): ?[k].([k]a, [n-k]a) = -- **Work:** *O(n ✕ (W(p1) + W(p2)))* -- -- **Span:** *O(log(n) ✕ (W(p1) + W(p2)))* -def partition2 [n] 'a (p1: a -> bool) (p2: a -> bool) (as: [n]a): ?[k][l].([k]a, [l]a, [n-k-l]a) = +def partition2 [n] 'a (p1: a -> bool) (p2: a -> bool) (as: [n]a) : ?[k][l].([k]a, [l]a, [n - k - l]a) = let p' x = if p1 x then 0 else if p2 x then 1 else 2 let (as', is) = intrinsics.partition 3 p' as - in (as'[0:is[0]], - as'[is[0]:is[0]+is[1]] :> [is[1]]a, - as'[is[0]+is[1]:n] :> [n-is[0]-is[1]]a) + in ( as'[0:is[0]] + , as'[is[0]:is[0] + is[1]] :> [is[1]]a + , as'[is[0] + is[1]:n] :> [n - is[0] - is[1]]a + ) -- | Return `true` if the given function returns `true` for all -- elements in the array. @@ -191,7 +192,7 @@ def partition2 [n] 'a (p1: a -> bool) (p2: a -> bool) (as: [n]a): ?[k][l].([k]a, -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(log(n) + S(f))* -def all [n] 'a (f: a -> bool) (as: [n]a): bool = +def all [n] 'a (f: a -> bool) (as: [n]a) : bool = reduce (&&) true (map f as) -- | Return `true` if the given function returns `true` for any @@ -200,7 +201,7 @@ def all [n] 'a (f: a -> bool) (as: [n]a): bool = -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(log(n) + S(f))* -def any [n] 'a (f: a -> bool) (as: [n]a): bool = +def any [n] 'a (f: a -> bool) (as: [n]a) : bool = reduce (||) false (map f as) -- | `r = spread k x is vs` produces an array `r` such that `r[i] = @@ -214,7 +215,7 @@ def any [n] 'a (f: a -> bool) (as: [n]a): bool = -- **Work:** *O(k + n)* -- -- **Span:** *O(1)* -def spread 't [n] (k: i64) (x: t) (is: [n]i64) (vs: [n]t): *[k]t = +def spread 't [n] (k: i64) (x: t) (is: [n]i64) (vs: [n]t) : *[k]t = intrinsics.scatter (map (\_ -> x) (0..1.. bool) (as: [n]a): *[]a = +def filter [n] 'a (p: a -> bool) (as: [n]a) : *[]a = let flags = map (\x -> if p x then 1 else 0) as let offsets = scan (+) 0 flags - let m = if n == 0 then 0 else offsets[n-1] + let m = if n == 0 then 0 else offsets[n - 1] in scatter (map (\x -> x) as[:m]) - (map2 (\f o -> if f==1 then o-1 else -1) flags offsets) + (map2 (\f o -> if f == 1 then o - 1 else -1) flags offsets) as diff --git a/prelude/zip.fut b/prelude/zip.fut index 1171820307..fdd0abbfe5 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -6,55 +6,54 @@ -- The main reason this module exists is that we need it to define -- SOACs like `map2`. - -- We need a map to define some of the zip variants, but this file is -- depended upon by soacs.fut. So we just define a quick-and-dirty -- internal one here that uses the intrinsic version. -local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = +local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = intrinsics.map f as -- | Construct an array of pairs from two arrays. -def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = +def zip [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = intrinsics.zip as bs -- | Construct an array of pairs from two arrays. -def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = +def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = zip as bs -- | As `zip2`@term, but with one more array. -def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c): *[n](a,b,c) = - internal_map (\(a,(b,c)) -> (a,b,c)) (zip as (zip2 bs cs)) +def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c) : *[n](a, b, c) = + internal_map (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) -- | As `zip3`@term, but with one more array. -def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d): *[n](a,b,c,d) = - internal_map (\(a,(b,c,d)) -> (a,b,c,d)) (zip as (zip3 bs cs ds)) +def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) : *[n](a, b, c, d) = + internal_map (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) -- | As `zip4`@term, but with one more array. -def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e): *[n](a,b,c,d,e) = - internal_map (\(a,(b,c,d,e)) -> (a,b,c,d,e)) (zip as (zip4 bs cs ds es)) +def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e) : *[n](a, b, c, d, e) = + internal_map (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) -- | Turn an array of pairs into two arrays. -def unzip [n] 'a 'b (xs: [n](a,b)): ([n]a, [n]b) = +def unzip [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = intrinsics.unzip xs -- | Turn an array of pairs into two arrays. -def unzip2 [n] 'a 'b (xs: [n](a,b)): ([n]a, [n]b) = +def unzip2 [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = unzip xs -- | As `unzip2`@term, but with one more array. -def unzip3 [n] 'a 'b 'c (xs: [n](a,b,c)): ([n]a, [n]b, [n]c) = - let (as, bcs) = unzip (internal_map (\(a,b,c) -> (a,(b,c))) xs) +def unzip3 [n] 'a 'b 'c (xs: [n](a, b, c)) : ([n]a, [n]b, [n]c) = + let (as, bcs) = unzip (internal_map (\(a, b, c) -> (a, (b, c))) xs) let (bs, cs) = unzip bcs in (as, bs, cs) -- | As `unzip3`@term, but with one more array. -def unzip4 [n] 'a 'b 'c 'd (xs: [n](a,b,c,d)): ([n]a, [n]b, [n]c, [n]d) = - let (as, bs, cds) = unzip3 (internal_map (\(a,b,c,d) -> (a,b,(c,d))) xs) +def unzip4 [n] 'a 'b 'c 'd (xs: [n](a, b, c, d)) : ([n]a, [n]b, [n]c, [n]d) = + let (as, bs, cds) = unzip3 (internal_map (\(a, b, c, d) -> (a, b, (c, d))) xs) let (cs, ds) = unzip cds in (as, bs, cs, ds) -- | As `unzip4`@term, but with one more array. -def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a,b,c,d,e)): ([n]a, [n]b, [n]c, [n]d, [n]e) = - let (as, bs, cs, des) = unzip4 (internal_map (\(a,b,c,d,e) -> (a,b,c,(d,e))) xs) +def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a, b, c, d, e)) : ([n]a, [n]b, [n]c, [n]d, [n]e) = + let (as, bs, cs, des) = unzip4 (internal_map (\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) let (ds, es) = unzip des in (as, bs, cs, ds, es)