From 2d4065785794c11efef2b0920e055ba9bc00caa9 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 23:44:39 +0100 Subject: [PATCH] decouple aux function --- numpyro/infer/hmc_util.py | 6 +++--- numpyro/optim.py | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index a921a0b58..f3b2d81ee 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -239,11 +239,11 @@ def final_fn(state, regularize=False): return init_fn, update_fn, final_fn -def _value_and_grad(f, x, argnums=0, has_aux=False, holomorphic=False, forward_mode_differentiation=False): +def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: - return f(x), jacfwd(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x) + return f(x), jacfwd(f)(x) else: - return value_and_grad(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x) + return value_and_grad(f)(x) def _kinetic_grad(kinetic_fn, inverse_mass_matrix, r): diff --git a/numpyro/optim.py b/numpyro/optim.py index d68097328..613140c09 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -11,15 +11,13 @@ from collections.abc import Callable from typing import Any, TypeVar -from jax import lax +from jax import jacfwd, lax, value_and_grad from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.optimize import minimize from jax.tree_util import register_pytree_node, tree_map -from numpyro.infer.hmc_util import _value_and_grad - __all__ = [ "Adam", "Adagrad", @@ -36,6 +34,11 @@ _OptState = TypeVar("_OptState") _IterOptState = tuple[int, _OptState] +def _value_and_grad(f, x, forward_mode_differentiation=False): + if forward_mode_differentiation: + return f(x), jacfwd(f, has_aux=True)(x) + else: + return value_and_grad(f, has_aux=True)(x) class _NumPyroOptim(object): def __init__(self, optim_fn: Callable, *args, **kwargs) -> None: @@ -83,7 +86,7 @@ def eval_and_update( """ params = self.get_params(state) (out, aux), grads = _value_and_grad( - fn, x=params, has_aux=True, forward_mode_differentiation=forward_mode_differentiation + fn, x=params, forward_mode_differentiation=forward_mode_differentiation ) return (out, aux), self.update(grads, state)