You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
You've discussed before how custom PyTrees should not call __init__ in their unflatten function, pointing to this section of the docs. This caused, for example, issues with flax.struct.dataclass and jaxtyping, since jaxtyping hooks into __init__ and thus throws type-checking errors when JAX uses unexpected objects for tracing.
However, JAX now has a built-in jax.tree_util.register_dataclass function. Unfortunately, the implementation does use __init__ for unflattening (see here). This causes the same old issues as flax.struct.dataclass (which now uses jax.tree_util.register_dataclass under the hood). For example:
importdataclassesfromfunctoolsimportpartialimportjaximportjax.numpyasjnpfromjaxtypingimportArrayfromjaxtypingimportjaxtypedimporttypeguard@jaxtyped(typechecker=typeguard.typechecked)@partial(jax.tree_util.register_dataclass, data_fields=["a"], meta_fields=["b"])@dataclasses.dataclassclassData:
a: Arrayb: strdeff(x: Data) ->int:
return1data=Data(a=jnp.ones(10, dtype=int), b="hello")
f(data) # worksjax.vmap(f)(data) # Error: Called with parameters: {'self': Data(...), 'a': <object object at 0x7030cc516760>, 'b': 'hello'}
Technically, the docs suggest using __new__ in tree_unflatten only as one possible solution. Their other recommendation is:
For this reason, the init and new methods of custom PyTree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases.
This would look like adding some sort of special case to jaxtyping to try to detect when __init__ is being called during unflattening, and skip type-checking. However, I'm not sure how you would accomplish this besides simply ignoring leaves of certain types, like object, which would at least fix the issue for JAX transforms. It doesn't seem like a great solution through. In theory, a third-party library can unflatten PyTrees with any leaf type (although then you might also argue that jaxtypingshould throw a type-checking error? not sure).
Since you're much more familiar with all of this stuff, what are your thoughts?
The text was updated successfully, but these errors were encountered:
Hi @patrick-kidger,
You've discussed before how custom PyTrees should not call
__init__
in their unflatten function, pointing to this section of the docs. This caused, for example, issues withflax.struct.dataclass
andjaxtyping
, sincejaxtyping
hooks into__init__
and thus throws type-checking errors when JAX uses unexpected objects for tracing.However, JAX now has a built-in
jax.tree_util.register_dataclass
function. Unfortunately, the implementation does use__init__
for unflattening (see here). This causes the same old issues asflax.struct.dataclass
(which now usesjax.tree_util.register_dataclass
under the hood). For example:Technically, the docs suggest using
__new__
intree_unflatten
only as one possible solution. Their other recommendation is:This would look like adding some sort of special case to
jaxtyping
to try to detect when__init__
is being called during unflattening, and skip type-checking. However, I'm not sure how you would accomplish this besides simply ignoring leaves of certain types, likeobject
, which would at least fix the issue for JAX transforms. It doesn't seem like a great solution through. In theory, a third-party library can unflatten PyTrees with any leaf type (although then you might also argue thatjaxtyping
should throw a type-checking error? not sure).Since you're much more familiar with all of this stuff, what are your thoughts?
The text was updated successfully, but these errors were encountered: