diff --git a/lib/github.com/diku-dk/autodiff/onehot.fut b/lib/github.com/diku-dk/autodiff/onehot.fut index 9539568..04d7b22 100644 --- a/lib/github.com/diku-dk/autodiff/onehot.fut +++ b/lib/github.com/diku-dk/autodiff/onehot.fut @@ -45,6 +45,11 @@ module type onehot = { val f32 : gen [1] f32 val f64 : gen [1] f64 + -- | A generator for a fixed value that (as far as generation is + -- concerned) contains no other values. This is useful for input or + -- output that is known not to contribute to the derivative. + val fixed 'a : a -> gen [0] a + -- | Produce a generator for pairs based on generators for the -- components. val pair [n][m] 'a 'b : gen [n] a -> gen [m] b -> gen [n+m] (a,b) @@ -86,6 +91,8 @@ module onehot : onehot = { def f32 = point 1f32 0f32 def f64 = point 1f64 0f64 + def fixed a = { size = witness 0, gen = const a } + def pair [n][m] 'a 'b (x: gen[n]a) (y: gen[m]b) = { size = witness (n+m), gen = \i -> (onehot x i,