From 0433592987f8565594a31a8231738ce8d645624d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 17 Oct 2024 23:20:53 +0200 Subject: [PATCH] Add grad function. --- README.md | 7 ++++++- lib/github.com/diku-dk/autodiff/onehot.fut | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fd754ef..3ebbb34 100644 --- a/README.md +++ b/README.md @@ -20,5 +20,10 @@ $ futhark pkg sync [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0)], [(0.0, 0.0), (0.0, 0.0), (1.0, 0.0)], [(0.0, 0.0), (0.0, 0.0), (0.0, 1.0)]] - +> grad (onehot.(f64)) f64.sin 3.14 +[-0.9999987317275395] +> grad (onehot.(arr f64)) (map f64.sin) [0, 1, 2] +[[1.0, 0.0, -0.0], + [0.0, 0.5403023058681398, -0.0], + [0.0, 0.0, -0.4161468365471424]] ``` diff --git a/lib/github.com/diku-dk/autodiff/onehot.fut b/lib/github.com/diku-dk/autodiff/onehot.fut index 224c54f..9539568 100644 --- a/lib/github.com/diku-dk/autodiff/onehot.fut +++ b/lib/github.com/diku-dk/autodiff/onehot.fut @@ -107,5 +107,9 @@ module onehot : onehot = { } } +-- | Generate all one-hot values possible for a given generator. def onehots [n] 'a (gen: onehot.gen [n] a) : [n]a = tabulate n (onehot.onehot gen) + +-- | Compute the gradient of a function given a one-hot generator for its result. +def grad gen f x = map (vjp f x) (onehots gen)