From 25cfeae0a796fc90954f7ba16114b8beabc84ff1 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 21:30:48 +0100 Subject: [PATCH 01/16] allow forward pass --- numpyro/infer/hmc_util.py | 6 +++--- numpyro/optim.py | 22 +++++++++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index f3b2d81ee..a921a0b58 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -239,11 +239,11 @@ def final_fn(state, regularize=False): return init_fn, update_fn, final_fn -def _value_and_grad(f, x, forward_mode_differentiation=False): +def _value_and_grad(f, x, argnums=0, has_aux=False, holomorphic=False, forward_mode_differentiation=False): if forward_mode_differentiation: - return f(x), jacfwd(f)(x) + return f(x), jacfwd(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x) else: - return value_and_grad(f)(x) + return value_and_grad(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x) def _kinetic_grad(kinetic_fn, inverse_mass_matrix, r): diff --git a/numpyro/optim.py b/numpyro/optim.py index 8a3b78149..a6069880f 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -11,13 +11,15 @@ from collections.abc import Callable from typing import Any, TypeVar -from jax import lax, value_and_grad +from jax import lax from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.optimize import minimize from jax.tree_util import register_pytree_node, tree_map +from numpyro.infer.hmc_util import _value_and_grad + __all__ = [ "Adam", "Adagrad", @@ -61,7 +63,9 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState: opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state - def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): + def eval_and_update( + self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False + ): """ Performs an optimization step for the objective function `fn`. For most optimizers, the update is performed based on the gradient @@ -74,13 +78,18 @@ def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): is a scalar loss function to be differentiated and the second item is an auxiliary output. :param state: current optimizer state. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """ params = self.get_params(state) - (out, aux), grads = value_and_grad(fn, has_aux=True)(params) + (out, aux), grads = _value_and_grad( + fn, has_aux=True, forward_mode_differentiation=forward_mode_differentiation + )(params) return (out, aux), self.update(grads, state) - def eval_and_stable_update(self, fn: Callable[[Any], tuple], state: _IterOptState): + def eval_and_stable_update( + self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False + ): """ Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` @@ -88,10 +97,13 @@ def eval_and_stable_update(self, fn: Callable[[Any], tuple], state: _IterOptStat :param fn: objective function. :param state: current optimizer state. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """ params = self.get_params(state) - (out, aux), grads = value_and_grad(fn, has_aux=True)(params) + (out, aux), grads = _value_and_grad( + fn, has_aux=True, forward_mode_differentiation=forward_mode_differentiation + )(params) out, state = lax.cond( jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(), lambda _: (out, self.update(grads, state)), From b574f1320a119be01aa094bda6bd39efb99371bd Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 21:39:10 +0100 Subject: [PATCH 02/16] fix params --- numpyro/optim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/numpyro/optim.py b/numpyro/optim.py index a6069880f..0f94d46df 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -83,8 +83,8 @@ def eval_and_update( """ params = self.get_params(state) (out, aux), grads = _value_and_grad( - fn, has_aux=True, forward_mode_differentiation=forward_mode_differentiation - )(params) + fn, x=params, has_aux=True, forward_mode_differentiation=forward_mode_differentiation + ) return (out, aux), self.update(grads, state) def eval_and_stable_update( @@ -102,8 +102,8 @@ def eval_and_stable_update( """ params = self.get_params(state) (out, aux), grads = _value_and_grad( - fn, has_aux=True, forward_mode_differentiation=forward_mode_differentiation - )(params) + fn, x=params, has_aux=True, forward_mode_differentiation=forward_mode_differentiation + ) out, state = lax.cond( jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(), lambda _: (out, self.update(grads, state)), From 449df7cabd619c5db4151506f8d527bea03651bc Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 21:57:36 +0100 Subject: [PATCH 03/16] add missing param in docstring --- numpyro/infer/hmc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index f3b31adf0..19b407a66 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -265,6 +265,7 @@ def init_kernel( `d2` is the max tree depth during post warmup phase. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. + ::param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation. :param bool regularize_mass_matrix: whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if ``adapt_mass_matrix == False``. From 346d71bbd64fc6ada8e8fee19cfc710a61044ae2 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 21:58:02 +0100 Subject: [PATCH 04/16] add flag to svi --- numpyro/infer/svi.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 4b99302cc..6636ad118 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -256,12 +256,14 @@ def get_params(self, svi_state): params = self.constrain_fn(self.optim.get_params(svi_state.optim_state)) return params - def update(self, svi_state, *args, **kwargs): + def update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs): """ Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer. :param svi_state: current state of SVI. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary @@ -281,16 +283,18 @@ def update(self, svi_state, *args, **kwargs): mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_update( - loss_fn, svi_state.optim_state + loss_fn, svi_state.optim_state, forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val - def stable_update(self, svi_state, *args, **kwargs): + def stable_update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs): """ Similar to :meth:`update` but returns the current state if the the loss or the new state contains invalid values. :param svi_state: current state of SVI. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary @@ -310,7 +314,7 @@ def stable_update(self, svi_state, *args, **kwargs): mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update( - loss_fn, svi_state.optim_state + loss_fn, svi_state.optim_state, forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val @@ -321,6 +325,7 @@ def run( *args, progress_bar=True, stable_update=False, + forward_mode_differentiation=False, init_state=None, init_params=None, **kwargs, @@ -342,6 +347,8 @@ def run( ``True``. :param bool stable_update: whether to use :meth:`stable_update` to update the state. Defaults to False. + :param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation. + Defaults to False. :param SVIState init_state: if not None, begin SVI from the final state of previous SVI run. Usage:: @@ -365,9 +372,9 @@ def run( def body_fn(svi_state, _): if stable_update: - svi_state, loss = self.stable_update(svi_state, *args, **kwargs) + svi_state, loss = self.stable_update(svi_state, forward_mode_differentiation, *args, **kwargs) else: - svi_state, loss = self.update(svi_state, *args, **kwargs) + svi_state, loss = self.update(svi_state, forward_mode_differentiation, *args, **kwargs) return svi_state, loss if init_state is None: From a833f541bc5c1cf7ac8ddda08cfe82f0dd565dec Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 22:02:10 +0100 Subject: [PATCH 05/16] typo docs --- numpyro/infer/hmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index 19b407a66..e9cdb51de 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -265,7 +265,7 @@ def init_kernel( `d2` is the max tree depth during post warmup phase. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. - ::param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation. + :param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation. :param bool regularize_mass_matrix: whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if ``adapt_mass_matrix == False``. From 779b412378bd5d282dfeb83ba1d130b8f84eafdf Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 22:16:21 +0100 Subject: [PATCH 06/16] reorder arguments --- numpyro/infer/svi.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 6636ad118..6605c50ff 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -256,16 +256,16 @@ def get_params(self, svi_state): params = self.constrain_fn(self.optim.get_params(svi_state.optim_state)) return params - def update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs): + def update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs): """ Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer. :param svi_state: current state of SVI. - :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. - Defaults to False. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple of `(svi_state, loss)`. @@ -287,16 +287,16 @@ def update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs) ) return SVIState(optim_state, mutable_state, rng_key), loss_val - def stable_update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs): + def stable_update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs): """ Similar to :meth:`update` but returns the current state if the the loss or the new state contains invalid values. :param svi_state: current state of SVI. - :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. - Defaults to False. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple of `(svi_state, loss)`. From 2424178930c6c6be7c9deb01c66c6a2086e48353 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 22:29:15 +0100 Subject: [PATCH 07/16] order args --- numpyro/infer/svi.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 6605c50ff..0b9db9474 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -372,9 +372,13 @@ def run( def body_fn(svi_state, _): if stable_update: - svi_state, loss = self.stable_update(svi_state, forward_mode_differentiation, *args, **kwargs) + svi_state, loss = self.stable_update( + svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs + ) else: - svi_state, loss = self.update(svi_state, forward_mode_differentiation, *args, **kwargs) + svi_state, loss = self.update( + svi_state, *args, forward_mode_differentiation=forward_mode_differentiation, **kwargs + ) return svi_state, loss if init_state is None: From 67e4c169c36532396f8e760fc24d2309a8b13211 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 22:34:52 +0100 Subject: [PATCH 08/16] kw argument internal function --- numpyro/infer/svi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 0b9db9474..860873bd7 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -283,7 +283,7 @@ def update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs) mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_update( - loss_fn, svi_state.optim_state, forward_mode_differentiation + loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val @@ -314,7 +314,7 @@ def stable_update(self, svi_state, *args, forward_mode_differentiation=False, ** mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update( - loss_fn, svi_state.optim_state, forward_mode_differentiation + loss_fn, svi_state.optim_state, forward_mode_differentiation=forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val From 8373ead733ff0db8219609872dbda6102018c67a Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 22:44:22 +0100 Subject: [PATCH 09/16] add arg to minimize --- numpyro/optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/numpyro/optim.py b/numpyro/optim.py index 0f94d46df..d68097328 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -278,7 +278,9 @@ def __init__(self, method="BFGS", **kwargs): self._method = method self._kwargs = kwargs - def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): + def eval_and_update( + self, fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation=False + ): i, (flat_params, unravel_fn) = state def loss_fn(x): From 2d4065785794c11efef2b0920e055ba9bc00caa9 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 23:44:39 +0100 Subject: [PATCH 10/16] decouple aux function --- numpyro/infer/hmc_util.py | 6 +++--- numpyro/optim.py | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index a921a0b58..f3b2d81ee 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -239,11 +239,11 @@ def final_fn(state, regularize=False): return init_fn, update_fn, final_fn -def _value_and_grad(f, x, argnums=0, has_aux=False, holomorphic=False, forward_mode_differentiation=False): +def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: - return f(x), jacfwd(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x) + return f(x), jacfwd(f)(x) else: - return value_and_grad(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x) + return value_and_grad(f)(x) def _kinetic_grad(kinetic_fn, inverse_mass_matrix, r): diff --git a/numpyro/optim.py b/numpyro/optim.py index d68097328..613140c09 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -11,15 +11,13 @@ from collections.abc import Callable from typing import Any, TypeVar -from jax import lax +from jax import jacfwd, lax, value_and_grad from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.optimize import minimize from jax.tree_util import register_pytree_node, tree_map -from numpyro.infer.hmc_util import _value_and_grad - __all__ = [ "Adam", "Adagrad", @@ -36,6 +34,11 @@ _OptState = TypeVar("_OptState") _IterOptState = tuple[int, _OptState] +def _value_and_grad(f, x, forward_mode_differentiation=False): + if forward_mode_differentiation: + return f(x), jacfwd(f, has_aux=True)(x) + else: + return value_and_grad(f, has_aux=True)(x) class _NumPyroOptim(object): def __init__(self, optim_fn: Callable, *args, **kwargs) -> None: @@ -83,7 +86,7 @@ def eval_and_update( """ params = self.get_params(state) (out, aux), grads = _value_and_grad( - fn, x=params, has_aux=True, forward_mode_differentiation=forward_mode_differentiation + fn, x=params, forward_mode_differentiation=forward_mode_differentiation ) return (out, aux), self.update(grads, state) From f86a8c695af6d1a83178e5f6deb73823b9350bc8 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 8 Feb 2024 07:27:26 +0100 Subject: [PATCH 11/16] rm kw argument unused --- numpyro/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/optim.py b/numpyro/optim.py index 613140c09..65cb2be8e 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -105,7 +105,7 @@ def eval_and_stable_update( """ params = self.get_params(state) (out, aux), grads = _value_and_grad( - fn, x=params, has_aux=True, forward_mode_differentiation=forward_mode_differentiation + fn, x=params, forward_mode_differentiation=forward_mode_differentiation ) out, state = lax.cond( jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(), From 0ab0034364373bf6dcc40eed50609392148141e3 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 8 Feb 2024 10:40:02 +0100 Subject: [PATCH 12/16] nicer doctrings --- numpyro/infer/hmc.py | 8 +++++++- numpyro/infer/svi.py | 9 +++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index e9cdb51de..aa2fcd802 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -265,7 +265,13 @@ def init_kernel( `d2` is the max tree depth during post warmup phase. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. - :param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation. + :param bool forward_mode_differentiation: whether to use forward-mode differentiation + or reverse-mode differentiation. By default, we use reverse mode but the forward + mode can be useful in some cases to improve the performance. In addition, some + control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` + only supports forward-mode differentiation. See + `JAX's The Autodiff Cookbook `_ + for more information. :param bool regularize_mass_matrix: whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if ``adapt_mass_matrix == False``. diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 860873bd7..77f05255f 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -347,8 +347,13 @@ def run( ``True``. :param bool stable_update: whether to use :meth:`stable_update` to update the state. Defaults to False. - :param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation. - Defaults to False. + :param bool forward_mode_differentiation: whether to use forward-mode differentiation + or reverse-mode differentiation. By default, we use reverse mode but the forward + mode can be useful in some cases to improve the performance. In addition, some + control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` + only supports forward-mode differentiation. See + `JAX's The Autodiff Cookbook `_ + for more information. :param SVIState init_state: if not None, begin SVI from the final state of previous SVI run. Usage:: From 532fb5b186d445fda626c3eb682a3d8b666dfb21 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 8 Feb 2024 10:40:18 +0100 Subject: [PATCH 13/16] simple test --- test/infer/test_svi.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 0e62bd395..65a1751ff 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -8,7 +8,7 @@ import pytest import jax -from jax import jit, random, value_and_grad +from jax import jit, lax, random, value_and_grad from jax.example_libraries import optimizers import jax.numpy as jnp from jax.tree_util import tree_all, tree_map @@ -757,3 +757,20 @@ def guide(): params = svi_results.params assert_allclose(params["loc"], actual_loc, rtol=0.1) assert_allclose(params["scale"], actual_scale, rtol=0.1) + + +def test_forward_mode_differentiation(): + def model(): + x = numpyro.sample("x", dist.Normal(0, 1)) + y = lax.while_loop(lambda x: x < 10, lambda x: x + 1, x) + numpyro.sample("obs", dist.Normal(y, 1), obs=1.0) + + def guide(): + loc = numpyro.param("loc", 0.) + scale = numpyro.param("scale", 1., constraint=dist.constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale)) + + # this fails in reverse mode + optimizer = numpyro.optim.Adam(step_size=0.01) + svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) + svi.run(random.PRNGKey(0), 1000, forward_mode_differentiation=True) From ef3a3f76ba53fe984b50f4adfb410c71913d3b4b Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 8 Feb 2024 12:45:14 +0100 Subject: [PATCH 14/16] add wrapper --- numpyro/optim.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/numpyro/optim.py b/numpyro/optim.py index 65cb2be8e..37d16dc4a 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -36,7 +36,11 @@ def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: - return f(x), jacfwd(f, has_aux=True)(x) + def _wrapper(h, x): + out, aux = h(x) + return out, (out, aux) + grads, (out, aux) = _wrapper(jacfwd(f, has_aux=True), x) + return (out, aux), grads else: return value_and_grad(f, has_aux=True)(x) From a64c092f2e630b99fc4833927e05292775519251 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 8 Feb 2024 15:41:55 +0100 Subject: [PATCH 15/16] fix wrapper order --- numpyro/optim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/optim.py b/numpyro/optim.py index 37d16dc4a..5ad451bab 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -36,10 +36,10 @@ def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: - def _wrapper(h, x): - out, aux = h(x) + def _wrapper(x): + out, aux = f(x) return out, (out, aux) - grads, (out, aux) = _wrapper(jacfwd(f, has_aux=True), x) + grads, (out, aux) = jacfwd(_wrapper, has_aux=True)(x) return (out, aux), grads else: return value_and_grad(f, has_aux=True)(x) From 1259db7b79ac1aecce97fa87a984200906c50e1e Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 8 Feb 2024 16:07:20 +0100 Subject: [PATCH 16/16] add wrapper trick to hmc --- numpyro/infer/hmc_util.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index f3b2d81ee..07f36a81a 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -241,9 +241,13 @@ def final_fn(state, regularize=False): def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: - return f(x), jacfwd(f)(x) + def _wrapper(x): + out = f(x) + return out, out + grads, out = jacfwd(_wrapper, has_aux=True)(x) + return out, grads else: - return value_and_grad(f)(x) + return value_and_grad(f, has_aux=False)(x) def _kinetic_grad(kinetic_fn, inverse_mass_matrix, r):