Skip to content

Commit

Permalink
Fixed broken filter_closure_convert (#232)
Browse files Browse the repository at this point in the history
* Fixed broken filter_closure_convert

* Also fixed stateful operations breaking with new JAX
  • Loading branch information
patrick-kidger authored Nov 17, 2022
1 parent 2e00721 commit 253522c
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 49 deletions.
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
from .vmap_pmap import filter_pmap, filter_vmap


__version__ = "0.9.1"
__version__ = "0.9.2"
40 changes: 0 additions & 40 deletions equinox/experimental/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from ..filters import is_array
from ..module import Module, static_field
from ..tree import tree_at


# So the use of a weak dictionary is a bit of wishful thinking here, really.
Expand Down Expand Up @@ -235,47 +234,10 @@ def _monkey_patch():
if not _have_monkey_patched:
_have_monkey_patched = True

_old_outside_call_impl = hcb.outside_call_p.impl
_old_outside_call_translation_rule = xla._translations[hcb.outside_call_p]
_old_outside_call_batching_rule = batching.primitive_batchers[
hcb.outside_call_p
]

#
# Overwrite impl and abstract_eval:
# Make `get_state` not actually pass `like` into the
# callback. This means we don't need to wait for `like` to be computed at
# runtime.
#

def _outside_call_impl(*arg_flat, arg_treedef, **params):
leaves = [None] * arg_treedef.num_leaves
call_type = type(jtu.tree_unflatten(arg_treedef, leaves))
# Not using isinstance for speed. (Questionable choice?)
if call_type is _GetStateArg:
arg = jtu.tree_unflatten(arg_treedef, arg_flat)
assert arg.index._state is None
token_like = jtu.tree_map(lambda _: jax.core.token, arg.like)
arg = tree_at(lambda a: a.like, arg, token_like)
arg_flat = jtu.tree_leaves(arg)
return _old_outside_call_impl(*arg_flat, arg_treedef=arg_treedef, **params)

def _outside_call_translation_rule(ctx, avals_in, *args, arg_treedef, **kwargs):
leaves = [None] * arg_treedef.num_leaves
call_type = type(jtu.tree_unflatten(arg_treedef, leaves))
if call_type is _GetStateArg:
arg_flat = avals_in[:-2]
extra_tokens = avals_in[-2:]
arg = jtu.tree_unflatten(arg_treedef, arg_flat)
assert arg.index._state is None
token_like = jtu.tree_map(lambda _: jax.core.abstract_token, arg.like)
arg = tree_at(lambda a: a.like, arg, token_like)
arg_flat = jtu.tree_leaves(arg)
avals_in = arg_flat + extra_tokens
return _old_outside_call_translation_rule(
ctx, avals_in, *args, arg_treedef=arg_treedef, **kwargs
)

#
# Overwrite batching:
# Allows us to use get_state and set_state inside vmap.
Expand Down Expand Up @@ -324,9 +286,7 @@ def _outside_call_batching_rule(
**params,
)

hcb.outside_call_p.def_impl(_outside_call_impl)
batching.primitive_batchers[hcb.outside_call_p] = _outside_call_batching_rule
xla.register_translation(hcb.outside_call_p, _outside_call_translation_rule)


def _batchify_impl(*flat, treedef, like_batch_axes, current_batch_axes):
Expand Down
9 changes: 5 additions & 4 deletions equinox/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from .doc_utils import doc_strip_annotations
from .filters import (
combine,
filter,
is_array,
is_inexact_array,
is_inexact_array_like,
partition,
)
from .make_jaxpr import filter_make_jaxpr
from .module import Module, module_update_wrapper, Static
from .module import Module, module_update_wrapper, Static, static_field


class _ValueAndGradWrapper(Module):
Expand Down Expand Up @@ -259,10 +260,10 @@ def diff_fun(*_diff):


class _ClosureConvert(Module):
jaxpr: jax.core.Jaxpr
jaxpr: jax.core.Jaxpr = static_field()
consts: PyTree[Array] # Captured in the PyTree structure of _ClosureConvert
out_dynamic_struct: PyTree[jax.ShapeDtypeStruct]
out_static: PyTree[Any]
out_dynamic_struct: PyTree[jax.ShapeDtypeStruct] = static_field()
out_static: PyTree[Any] = static_field()

def __call__(self, *args, **kwargs):
dynamic = filter((args, kwargs), is_array)
Expand Down
2 changes: 1 addition & 1 deletion equinox/make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def filter_make_jaxpr(fun):
A wrapped version of `fun`, that when applied to example arguments
`*args, **kwargs`, will return a 3-tuple of:
- A `ClosedJaxpr` representing the evaluation of that function on those arguments.
- A `PyTree[jax.ShapeDtypeStruct]` representing the output shape and dtype of the
result.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,32 @@ def f(x, y):
f(1.0, 1.0)


def test_closure_convert_custom_jvp():
# Deliberately not using filter_custom_jvp to check the static fields on the
# closure converted function.
@jax.custom_jvp
def call(f, x):
return f(x)

@call.defjvp
def call_jvp(primals, tangents):
f, x = primals
tf, tx = tangents
out = call(f, x)
tsum = sum(jnp.sum(x) for x in jtu.tree_leaves((tf, tx)))
tout = jtu.tree_map(lambda x: jnp.full(x.shape, tsum, x.dtype), out)
return out, tout

@jax.grad
def run(x):
x1, x2 = x
f = lambda y: x1 * y + x2
f = eqx.filter_closure_convert(f, 3.0)
return call(f, 3.0)

assert shaped_allclose(run((2.0, 4.0)), (jnp.array(1.0), jnp.array(1.0)))


def test_filter_custom_jvp():
@eqx.filter_custom_jvp
def call(fn, x):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_vmap(getkey):
o2 = mlp_jit_vmap(x)
o3 = mlp_vmap_noinline(x)
o4 = mlp_jit_vmap_noinline(x)
assert shaped_allclose(o1, o2)
assert shaped_allclose(o1, o3)
assert shaped_allclose(o1, o4)
assert shaped_allclose(o1, o2, atol=1e-5)
assert shaped_allclose(o1, o3, atol=1e-5)
assert shaped_allclose(o1, o4, atol=1e-5)


def test_jvp(getkey):
Expand Down

0 comments on commit 253522c

Please sign in to comment.