Skip to content

Commit

Permalink
refactor(jax.extend): update from deprecated imports
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Jan 28, 2025
1 parent a23ae43 commit 23084e7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/examples/default_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"\n",
" @staticmethod\n",
" def default(\n",
" primitive: jax.core.Primitive,\n",
" primitive: jax.extend.core.Primitive,\n",
" values: Sequence[Union[ArrayLike, quax.Value]],\n",
" params: dict,\n",
" ):\n",
Expand Down
15 changes: 8 additions & 7 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import jax.tree_util as jtu
import plum
from jax.custom_derivatives import SymbolicZero as SZ
from jax.extend.core import jaxpr_as_fun
from jaxtyping import ArrayLike, PyTree


Expand Down Expand Up @@ -47,8 +48,8 @@ def _(x: SomeValue, y: SomeValue):
**Arguments:**
- `primitive`: The `jax.core.Primitive` to provide a multiple dispatch
implementation for.
- `primitive`: The `jax.extend.core.Primitive` to provide a multiple
dispatch implementation for.
- `precedence`: The precedence of this rule.
See `plum.Dispatcher.dispatch` for details.
Expand Down Expand Up @@ -394,7 +395,7 @@ def default(
**Arguments:**
- `primitive`: the `jax.core.Primitive` being considered.
- `primitive`: the `jax.extend.core.Primitive` being considered.
- `values`: a sequence of what values this primitive is being called with. Each
value can either be [`quax.Value`][]s, or a normal JAX arraylike (i.e.
`bool`/`int`/`float`/`complex`/NumPy scalar/NumPy array/JAX array).
Expand Down Expand Up @@ -519,7 +520,7 @@ def aval(self) -> core.ShapedArray:
@register(jax._src.pjit.pjit_p) # pyright: ignore
def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):
del kwargs
fun = quaxify(core.jaxpr_as_fun(jaxpr))
fun = quaxify(jaxpr_as_fun(jaxpr))
if inline:
return fun(*args)
else:
Expand All @@ -541,9 +542,9 @@ def _(
init_vals = args[cond_nconsts + body_nconsts :]

# compute jaxpr of quaxified body and condition function
quax_cond_fn = quaxify(core.jaxpr_as_fun(cond_jaxpr))
quax_cond_fn = quaxify(jaxpr_as_fun(cond_jaxpr))
quax_cond_jaxpr = jax.make_jaxpr(quax_cond_fn)(*cond_consts, *init_vals)
quax_body_fn = quaxify(core.jaxpr_as_fun(body_jaxpr))
quax_body_fn = quaxify(jaxpr_as_fun(body_jaxpr))
quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals)

cond_leaves, _ = jtu.tree_flatten(cond_consts)
Expand Down Expand Up @@ -581,7 +582,7 @@ def _(

def flat_quax_call(flat_args):
args = jtu.tree_unflatten(in_tree, flat_args)
out = quaxify(core.jaxpr_as_fun(jaxpr))(*args)
out = quaxify(jaxpr_as_fun(jaxpr))(*args)
flat_out, out_tree = jtu.tree_flatten(out)
out_trees.append(out_tree)
return flat_out
Expand Down
3 changes: 2 additions & 1 deletion quax/examples/named/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import equinox as eqx
import jax.core
import jax.extend
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import ArrayLike
Expand Down Expand Up @@ -98,7 +99,7 @@ def _broadcast_axes(axes1, axes2):


def _register_elementwise_binop(
op: Callable[[Any, Any], Any], prim: jax.core.Primitive
op: Callable[[Any, Any], Any], prim: jax.extend.core.Primitive
):
quax_op = quax.quaxify(op)

Expand Down

0 comments on commit 23084e7

Please sign in to comment.