Skip to content

Commit

Permalink
Add grad function.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Oct 17, 2024
1 parent 6f3c286 commit 0433592
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,10 @@ $ futhark pkg sync
[(0.0, 0.0), (0.0, 1.0), (0.0, 0.0)],
[(0.0, 0.0), (0.0, 0.0), (1.0, 0.0)],
[(0.0, 0.0), (0.0, 0.0), (0.0, 1.0)]]
> grad (onehot.(f64)) f64.sin 3.14
[-0.9999987317275395]
> grad (onehot.(arr f64)) (map f64.sin) [0, 1, 2]
[[1.0, 0.0, -0.0],
[0.0, 0.5403023058681398, -0.0],
[0.0, 0.0, -0.4161468365471424]]
```
4 changes: 4 additions & 0 deletions lib/github.com/diku-dk/autodiff/onehot.fut
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,9 @@ module onehot : onehot = {
}
}

-- | Generate all one-hot values possible for a given generator.
def onehots [n] 'a (gen: onehot.gen [n] a) : [n]a =
tabulate n (onehot.onehot gen)

-- | Compute the gradient of a function given a one-hot generator for its result.
def grad gen f x = map (vjp f x) (onehots gen)

0 comments on commit 0433592

Please sign in to comment.