diff --git a/README.md b/README.md index 22d5d39..2824348 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ This is a collection of small utilities for working with AD in Futhark. +[See examples here.](examples/) + ## Installation ``` @@ -13,6 +15,11 @@ $ futhark pkg sync ```Futhark > import "lib/github.com/diku-dk/autodiff/onehot" +> import "lib/github.com/diku-dk/autodiff/autodiff" +> grad64 f64.tanh 2.0 +7.065082485316443e-2 +> grad64 (grad64 f64.tanh) 2.0 +-0.13621868742711296 > onehots onehot.(arr (pair f64 f32)) : [][3](f64,f32) [[(1.0, 0.0), (0.0, 0.0), (0.0, 0.0)], [(0.0, 1.0), (0.0, 0.0), (0.0, 0.0)], @@ -20,8 +27,6 @@ $ futhark pkg sync [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0)], [(0.0, 0.0), (0.0, 0.0), (1.0, 0.0)], [(0.0, 0.0), (0.0, 0.0), (0.0, 1.0)]] -> grad onehot.(f64) f64.sin 3.14 -[-0.9999987317275395] > grad onehot.(arr f64) (map f64.sin) [0, 1, 2] [[1.0, 0.0, -0.0], [0.0, 0.5403023058681398, -0.0], diff --git a/examples/jax_grad.fut b/examples/jax_grad.fut new file mode 100644 index 0000000..5326430 --- /dev/null +++ b/examples/jax_grad.fut @@ -0,0 +1,40 @@ +-- The following example is ported from the Jax cookbook at +-- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html + +import "../lib/github.com/diku-dk/autodiff/autodiff" + +def dot a b = f64.sum (map2 (*) a b) + +def sigmoid x = + 0.5 * (f64.tanh(x/2)+1) + +def predict W b inputs = + sigmoid(dot inputs W + b) + +def inputs : [][]f64 = + [[0.52, 1.12, 0.77], + [0.88, -1.08, 0.15], + [0.52, 0.06, -1.30], + [0.74, -2.49, 1.39]] + +def targets = + [true,true,false,true] + +def vecadd = map2 (f64.+) +def vecmul = map2 (f64.*) + +def loss W b = + let preds = map (predict W b) inputs + let label_probs = (preds `vecmul` map f64.bool targets) + `vecadd` (map (1-) preds `vecmul` + map ((1-) <-< f64.bool) targets) + in -f64.sum(map f64.log label_probs) + +-- Not going to import random number generation just for this. These +-- are made with 'futhark dataset'. +def W = [0.12152684143560777f64, 0.5526745035133085f64, 0.5189896463245001f64] +def b = 0.12152684143560777f64 + +def Wb_grad = grad64 (\(W,b) -> loss W b) (W,b) +def W_grad = Wb_grad.0 +def b_grad = Wb_grad.1 diff --git a/lib/github.com/diku-dk/autodiff/autodiff.fut b/lib/github.com/diku-dk/autodiff/autodiff.fut new file mode 100644 index 0000000..e39ad25 --- /dev/null +++ b/lib/github.com/diku-dk/autodiff/autodiff.fut @@ -0,0 +1,21 @@ +-- | Various utilities for performing AD. + +import "onehot" + +local def singular 'a (x: onehot.gen [1] a) = onehot.onehot x 0 + +-- | Compute the gradient of a scalar-valued function given a one-hot +-- generator for its result. +def grad_unit gen f x = vjp f x (singular gen) + +-- | Convenience function for computing the gradient of an +-- 'f64'-valued differentiable function. +def grad32 = grad_unit onehot.f32 + +-- | Convenience function for computing the gradient of an +-- 'f64'-valued differentiable function. +def grad64 = grad_unit onehot.f64 + +-- | Compute the gradient of an arbitrary differentiable function +-- given a one-hot generator for its result. +def grad_rev gen f x = map (vjp f x) (onehots gen) diff --git a/lib/github.com/diku-dk/autodiff/onehot.fut b/lib/github.com/diku-dk/autodiff/onehot.fut index 22998b2..371d2f7 100644 --- a/lib/github.com/diku-dk/autodiff/onehot.fut +++ b/lib/github.com/diku-dk/autodiff/onehot.fut @@ -120,6 +120,3 @@ module onehot : onehot = { -- | Generate all one-hot values possible for a given generator. def onehots [n] 'a (gen: onehot.gen [n] a) : [n]a = tabulate n (onehot.onehot gen) - --- | Compute the gradient of a function given a one-hot generator for its result. -def grad gen f x = map (vjp f x) (onehots gen)