Equinox v0.10.0
Highlights
-
A dramatically simplified API for
equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}
. This is a backward-incompatible change. -
equinox.internal.while_loop
, which is a reverse-mode autodifferentiable while loop, using recursive checkpointing.
Full change list
New features
Some new relatively minor new features available in this release.
- Added support for donating buffers when using
eqx.{filter_jit, filter_pmap}
. (Thanks @uuirs in #235!) - Added
eqx.nn.PRelu
. (Thanks @enver1323 in #249!) - Added
eqx.tree_pprint
. - Added
eqx.module_update_wrapper
. eqx.filter_custom_jvp
now supports keyword arguments (which are always treated as nondifferentiable).
New internal
features
Introducing a slew of new features for the advanced JAX user.
These are all available in the equinox.internal
namespace. Note that these comes without stability guarantees, as they often depend on functionality that JAX doesn't make fully public.
eqxi.abstractattribute
, for marking abstract instance attributes of abstract Equinox modules.eqxi.tree_pp
, for producing a pretty-print doc of an object. (This is what is then formatted to a particular width in e.g.eqx.tree_pformat
.) In addition classes can now have custom pretty behaviour when used witheqx.{tree_pp, tree_pformat, tree_pprint}
, by setting a__tree_pp__
method.eqxi.if_mapped
, as an alternative to the usualeqx.if_array
passed toeqx.{filter_vmap, filter_pmap}(out_axes=...)
.eqxi.{finalise_jaxpr, finalise_fn}
for tracing through custom primitivesimpl
rules (so that the custom primitive no longer appears in the jaxpr). This is useful for replacing such custom primitives prior to offloading a jaxpr to some other IR, e.g. viajax2tf
.eqxi.{nonbatchable, nondifferentiable, nondifferentiable_backward, nontraceable}
for asserting that an operation is never batched, differentiated, or subject to any transform at all.eqxi.to_onnx
for exporting to ONNX.eqxi.while_loop
for reverse-mode autodifferentiable while loops; in particular making use of recursive checkpointing. (A la treeverse.)
Backward-incompatible changes
- The API for
equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}
has been dramatically simplified. If you were using the extra arguments to these functions (i.e. not just calling@eqx.filter_jit
etc. directly) then this is a backward-incompatible change; see the discussion below for more details. - Removed
equinox.nn.{AvgPool1D, AvgPool2D, AvgPool3D, MaxPool1D, MaxPool2D, MaxPool3D}
. UseAvgPool1d
etc. (lower-case "d") instead. (These were backward-compatiblity stubs that have now been removed.) - Removed
equinox.Module.{tree_flatten, tree_unflatten}
. These were never technically public API; usejax.tree_util.{tree_flatten, tree_unflatten}
instead. equinox.filter_closure_convert
now asserts that you call it with argments compatible with those it was closure-converted with.- Dropped support for Python 3.7.
Other
- The Python overhead when crossing a
filter_jit
orfilter_pmap
boundary should now be much reduced. eqx.tree_inference
now runs faster. (Thanks @uuirs in #233!)- Lots of documentation improvements; in particular a new "Tricks" section forsome advanced notes. (Thanks @carlosgmartin in #239!)
Filtered transformation API changes (AKA: "my code isn't working any more?")
These APIs have been simplified and made much easier to understand. No functionality has been lost, things might just need tweaking.
filter_jit
This previously took default
, args
, kwargs
, out
, fn
arguments, for controlling what should be traced and what should be held static.
In practice all JAX arrays and NumPy arrays always had to be traced, and everything that wasn't a JAXable type (JAX array, NumPy array, bool
, int
, float
, complex
) had to be held static. So these arguments just weren't that useful: pretty much the only thing you could do with them was to specify that you'd like to trace a bool
/int
/float
/complex
.
This minor use-case wasn't worth complicating such an important API for, which is why these arguments have been removed.
If after this change you still want to trace with respect to bool
/int
/float
/complex
, then do so simply by wrapping them into JAX arrays or NumPy arrays first: np.asarray(x)
.
filter_grad
and filter_value_and_grad
These previously took an arg
argument, for controlling what parts of the first argument should be differentiated.
This was useful occasionally -- e.g. when freezing parts of a layer -- but in practice it still wasn't used that often. As such it this argument has been removed for the sake of simplicity.
If after this change you want to replicate the previous behaviour, then it is simple to do so using partition
and combine
:
# Before
@eqx.filter_grad(arg=foo)
def loss(first_arg, ...):
...
loss(bar, ...)
# After
@eqx.filter_grad
def loss(diff_first_arg, static_first_arg, ...):
first_arg = eqx.combine(diff_first_arg, static_first_arg)
...
diff_bar, static_bar = eqx.partition(bar, foo)
loss(diff_bar, static_bar, ...)
See also the updated frozen layer example for a demonstration.
filter_vmap
This previously took default
, args
, kwargs
, out
, fn
arguments, for controlling what axes should be vectorised over.
In practice this API was just a bit more complicated than it really needed to be. The only useful feature relative to jax.vmap
was kwargs
, for easily specifying just a few named arguments that should behave differently.
The new API instead accepts in_axes
and out_axes
arguments, just like jax.vmap
. To replace kwargs
, one extra feature is supported: in_axes
may be a dictionary of named argments, e.g.
@eqx.filter_vmap(in_axes=dict(bar=None))
def fn(foo, bar):
...
All arguments not named in kwargs
will have the default value of eqx.if_array(0) -> 0 if is_array(x) else None
applied to them.
On which note, a new eqx.if_array(i)
now exists, to make it easier to specify values for in_axes
and out_axes
.
If you were using the old fn
argument, then this can be replicated by instead decorating a function that accepts the callable:
# Before
@eqx.filter_vmap(foo, fn=bar)(x, y)
# After
@eqx.filter_vmap(in_axes=dict(fn=bar))
def accepts_foo(fn, x, y):
return fn(x, y)
accepts_foo(foo, x, y)
filter_pmap
.
This previously took default
, args
, kwargs
, out
, fn
arguments, for controlling what axes should be parallelised over, and which arguments should be traced vs static.
This was a fiendishly complicated API merging together both the filter_jit
and filter_vmap
APIs.
The JIT part of it is now handled automatically, as with filter_jit
: all arrays are traced, everything else is static.
The vmap part of it is now handled in the same way as filter_vmap
, using in_axes
and out_axes
arguments.
New Contributors
- @carlosgmartin made their first contribution in #239
- @enver1323 made their first contribution in #249
Full Changelog: v0.9.2...v0.10.0