Skip to content

Commit

Permalink
Merge pull request #184 from patrick-kidger/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
patrick-kidger authored Aug 30, 2022
2 parents bc3a8d9 + 48c6928 commit 18d260d
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 61 deletions.
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
from .vmap_pmap import filter_pmap, filter_vmap


__version__ = "0.6.0"
__version__ = "0.7.0"
6 changes: 4 additions & 2 deletions equinox/experimental/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def _stats(y):
if inference is None:
inference = self.inference
if inference:
running_mean, running_var = get_state(self.state_index, like=batch_state)
running_mean, running_var = get_state(
self.state_index, like=lax.stop_gradient(batch_state)
)
else:
first_time = get_state(self.first_time_index, like=jnp.array(False))
running_state = lax.cond(
Expand All @@ -170,7 +172,7 @@ def _stats(y):
1 - self.momentum
) * batch_mean + self.momentum * running_mean
running_var = (1 - self.momentum) * batch_var + self.momentum * running_var
set_state(self.state_index, (running_mean, running_var))
set_state(self.state_index, lax.stop_gradient((running_mean, running_var)))

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
Expand Down
6 changes: 3 additions & 3 deletions equinox/experimental/spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def __jax_array__(self):
v_like = self.weight[-1]
u, v = get_state(self.uv_index, (u_like, v_like))
if not self.inference:
weight = lax.stop_gradient(self.weight)
eps = lax.stop_gradient(self.eps)
for _ in range(self.num_power_iterations):
u, v = _power_iteration(self.weight, u, v, self.eps)
u, v = _power_iteration(weight, u, v, eps)
set_state(self.uv_index, (u, v))
u = lax.stop_gradient(u)
v = lax.stop_gradient(v)
σ = jnp.einsum("i,ij,j->", u, self.weight, v)
return jnp.reshape(self.weight / σ, self.weight_shape)

Expand Down
56 changes: 34 additions & 22 deletions equinox/experimental/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.interpreters.xla as xla
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu

Expand Down Expand Up @@ -205,6 +204,22 @@ def tree_flatten(self):
__hash__ = None


def _delete_smuggled_state(x: StateIndex) -> StateIndex:
# `x._state` may have a gradient on it, which would mean we hit the JVP rule for
# `host_callback.call`, which doesn't exist. Simplest thing to do is just to
# delete it, provided we don't need it.
#
# We don't use `tree_at` because `tree_at(where, pytree, ...)` checks that
# `where(pytree)` doesn't depend on the values of the leaves of `pytree`. This
# involves a flatten. Meanwhile `StateIndex` sneakily modifies its structure
# under flatten, and this trips a false positive.

leaves, treedef = jax.tree_flatten(x)
x = jax.tree_unflatten(treedef, leaves)
object.__setattr__(x, "_state", None)
return x


class _Leaf: # Not a PyTree
def __init__(self, value):
self.value = value
Expand All @@ -227,7 +242,7 @@ def _monkey_patch():

#
# Overwrite impl and abstract_eval:
# Make `get_state` not actually pass `index._state` or `like` into the
# Make `get_state` not actually pass `like` into the
# callback. This means we don't need to wait for `like` to be computed at
# runtime.
#
Expand All @@ -238,13 +253,9 @@ def _outside_call_impl(*arg_flat, arg_treedef, **params):
# Not using isinstance for speed. (Questionable choice?)
if call_type is _GetStateArg:
arg = jtu.tree_unflatten(arg_treedef, arg_flat)
token_index = jtu.tree_map(lambda _: jax.core.token, arg.index)
assert arg.index._state is None
token_like = jtu.tree_map(lambda _: jax.core.token, arg.like)
arg = tree_at(
lambda a: jtu.tree_leaves((a.index, a.like)),
arg,
jtu.tree_leaves((token_index, token_like)),
)
arg = tree_at(lambda a: a.like, arg, token_like)
arg_flat = jtu.tree_leaves(arg)
return _old_outside_call_impl(*arg_flat, arg_treedef=arg_treedef, **params)

Expand All @@ -255,13 +266,9 @@ def _outside_call_translation_rule(ctx, avals_in, *args, arg_treedef, **kwargs):
arg_flat = avals_in[:-2]
extra_tokens = avals_in[-2:]
arg = jtu.tree_unflatten(arg_treedef, arg_flat)
token_index = jtu.tree_map(lambda _: jax.core.abstract_token, arg.index)
assert arg.index._state is None
token_like = jtu.tree_map(lambda _: jax.core.abstract_token, arg.like)
arg = tree_at(
lambda a: jtu.tree_leaves((a.index, a.like)),
arg,
jtu.tree_leaves((token_index, token_like)),
)
arg = tree_at(lambda a: a.like, arg, token_like)
arg_flat = jtu.tree_leaves(arg)
avals_in = arg_flat + extra_tokens
return _old_outside_call_translation_rule(
Expand Down Expand Up @@ -313,7 +320,7 @@ def _outside_call_batching_rule(
batch_axes_flat,
arg_treedef=arg_treedef,
result_treedef=result_treedef,
**params
**params,
)

hcb.outside_call_p.def_impl(_outside_call_impl)
Expand Down Expand Up @@ -349,7 +356,7 @@ def _batchify_batching_rule(
*flat,
treedef=treedef,
like_batch_axes=like_batch_axes + like_batch_axis,
current_batch_axes=current_batch_axes
current_batch_axes=current_batch_axes,
),
like_batch_axis,
)
Expand Down Expand Up @@ -493,11 +500,12 @@ def get_state(index: StateIndex, like: PyTree[Array]) -> PyTree[Array]:
*flat,
treedef=treedef,
like_batch_axes=[],
current_batch_axes=current_batch_axes
current_batch_axes=current_batch_axes,
)
return jtu.tree_unflatten(_treedef, out)
else:
_monkey_patch()
index = _delete_smuggled_state(index)
return _get_state(index, like, [])


Expand Down Expand Up @@ -526,7 +534,8 @@ def _set_state_hcb(arg: _SetStateArg) -> None:
state_shape = jax.eval_shape(lambda: state)
if current_state_shape != state_shape:
raise RuntimeError(
"New state and old state have different shape, dtype, or PyTree structure"
"New state and old state have different shape, dtype, or PyTree "
f"structure. New: {current_state_shape}. Old: {state_shape}."
)
if current_batch_axes != batch_axes:
raise RuntimeError("New state and old state have different batch axes")
Expand All @@ -548,21 +557,24 @@ def set_state(index: StateIndex, state: PyTree[Array]) -> None:
**Raises:**
A `TypeError` at trace time if `state` is not a PyTree of JAX arrays.
A `RuntimeError` at run time if this `index` has previously been used to save a
`state` with a different shape, dtype, PyTree structure, or batch axes.
A `RuntimeError` at trace time if `index.inference` is truthy.
A `TypeError` at trace time if `state` is not a PyTree of JAX arrays.
A `NotImplementedError` at trace time if trying to compute a gradient through
`state`.
!!! info
The same `index` can be used multiple times, to overwrite a previously saved
value. The new and old `state` must both have the same PyTree structure, however.
!!! warning
Note that gradient information in `state` will not preserved.
Note that `state` cannot be differentiated.
!!! warning
Expand All @@ -582,7 +594,7 @@ def set_state(index: StateIndex, state: PyTree[Array]) -> None:
if any(not is_array(x) for x in jtu.tree_leaves(state)):
raise TypeError("`state` must be a PyTree containing only JAX arrays")
_monkey_patch()
state = jtu.tree_map(lax.stop_gradient, state)
index = _delete_smuggled_state(index)
_set_state(index, state, [])


Expand Down
Loading

0 comments on commit 18d260d

Please sign in to comment.