Skip to content

Commit

Permalink
decouple aux function
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Feb 7, 2024
1 parent 8373ead commit 2d40657
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
6 changes: 3 additions & 3 deletions numpyro/infer/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,11 @@ def final_fn(state, regularize=False):
return init_fn, update_fn, final_fn


def _value_and_grad(f, x, argnums=0, has_aux=False, holomorphic=False, forward_mode_differentiation=False):
def _value_and_grad(f, x, forward_mode_differentiation=False):
if forward_mode_differentiation:
return f(x), jacfwd(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x)
return f(x), jacfwd(f)(x)
else:
return value_and_grad(f, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic)(x)
return value_and_grad(f)(x)


def _kinetic_grad(kinetic_fn, inverse_mass_matrix, r):
Expand Down
11 changes: 7 additions & 4 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
from collections.abc import Callable
from typing import Any, TypeVar

from jax import lax
from jax import jacfwd, lax, value_and_grad
from jax.example_libraries import optimizers
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax.tree_util import register_pytree_node, tree_map

from numpyro.infer.hmc_util import _value_and_grad

__all__ = [
"Adam",
"Adagrad",
Expand All @@ -36,6 +34,11 @@
_OptState = TypeVar("_OptState")
_IterOptState = tuple[int, _OptState]

def _value_and_grad(f, x, forward_mode_differentiation=False):
if forward_mode_differentiation:
return f(x), jacfwd(f, has_aux=True)(x)
else:
return value_and_grad(f, has_aux=True)(x)

class _NumPyroOptim(object):
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
Expand Down Expand Up @@ -83,7 +86,7 @@ def eval_and_update(
"""
params = self.get_params(state)
(out, aux), grads = _value_and_grad(
fn, x=params, has_aux=True, forward_mode_differentiation=forward_mode_differentiation
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
return (out, aux), self.update(grads, state)

Expand Down

0 comments on commit 2d40657

Please sign in to comment.