Skip to content

Commit

Permalink
Add autodiff module and some examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Oct 19, 2024
1 parent 1ee9301 commit ed450f1
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 5 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
This is a collection of small utilities for working with AD in
Futhark.

[See examples here.](examples/)

## Installation

```
Expand All @@ -13,15 +15,18 @@ $ futhark pkg sync

```Futhark
> import "lib/github.com/diku-dk/autodiff/onehot"
> import "lib/github.com/diku-dk/autodiff/autodiff"
> grad64 f64.tanh 2.0
7.065082485316443e-2
> grad64 (grad64 f64.tanh) 2.0
-0.13621868742711296
> onehots onehot.(arr (pair f64 f32)) : [][3](f64,f32)
[[(1.0, 0.0), (0.0, 0.0), (0.0, 0.0)],
[(0.0, 1.0), (0.0, 0.0), (0.0, 0.0)],
[(0.0, 0.0), (1.0, 0.0), (0.0, 0.0)],
[(0.0, 0.0), (0.0, 1.0), (0.0, 0.0)],
[(0.0, 0.0), (0.0, 0.0), (1.0, 0.0)],
[(0.0, 0.0), (0.0, 0.0), (0.0, 1.0)]]
> grad onehot.(f64) f64.sin 3.14
[-0.9999987317275395]
> grad onehot.(arr f64) (map f64.sin) [0, 1, 2]
[[1.0, 0.0, -0.0],
[0.0, 0.5403023058681398, -0.0],
Expand Down
40 changes: 40 additions & 0 deletions examples/jax_grad.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
-- The following example is ported from the Jax cookbook at
-- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

import "../lib/github.com/diku-dk/autodiff/autodiff"

def dot a b = f64.sum (map2 (*) a b)

def sigmoid x =
0.5 * (f64.tanh(x/2)+1)

def predict W b inputs =
sigmoid(dot inputs W + b)

def inputs : [][]f64 =
[[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]]

def targets =
[true,true,false,true]

def vecadd = map2 (f64.+)
def vecmul = map2 (f64.*)

def loss W b =
let preds = map (predict W b) inputs
let label_probs = (preds `vecmul` map f64.bool targets)
`vecadd` (map (1-) preds `vecmul`
map ((1-) <-< f64.bool) targets)
in -f64.sum(map f64.log label_probs)

-- Not going to import random number generation just for this. These
-- are made with 'futhark dataset'.
def W = [0.12152684143560777f64, 0.5526745035133085f64, 0.5189896463245001f64]
def b = 0.12152684143560777f64

def Wb_grad = grad64 (\(W,b) -> loss W b) (W,b)
def W_grad = Wb_grad.0
def b_grad = Wb_grad.1
21 changes: 21 additions & 0 deletions lib/github.com/diku-dk/autodiff/autodiff.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- | Various utilities for performing AD.

import "onehot"

local def singular 'a (x: onehot.gen [1] a) = onehot.onehot x 0

-- | Compute the gradient of a scalar-valued function given a one-hot
-- generator for its result.
def grad_unit gen f x = vjp f x (singular gen)

-- | Convenience function for computing the gradient of an
-- 'f64'-valued differentiable function.
def grad32 = grad_unit onehot.f32

-- | Convenience function for computing the gradient of an
-- 'f64'-valued differentiable function.
def grad64 = grad_unit onehot.f64

-- | Compute the gradient of an arbitrary differentiable function
-- given a one-hot generator for its result.
def grad_rev gen f x = map (vjp f x) (onehots gen)
3 changes: 0 additions & 3 deletions lib/github.com/diku-dk/autodiff/onehot.fut
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,3 @@ module onehot : onehot = {
-- | Generate all one-hot values possible for a given generator.
def onehots [n] 'a (gen: onehot.gen [n] a) : [n]a =
tabulate n (onehot.onehot gen)

-- | Compute the gradient of a function given a one-hot generator for its result.
def grad gen f x = map (vjp f x) (onehots gen)

0 comments on commit ed450f1

Please sign in to comment.