Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility with JAX-native register_dataclass #277

Open
kvablack opened this issue Dec 11, 2024 · 1 comment
Open

Incompatibility with JAX-native register_dataclass #277

kvablack opened this issue Dec 11, 2024 · 1 comment
Labels
question User queries

Comments

@kvablack
Copy link

Hi @patrick-kidger,

You've discussed before how custom PyTrees should not call __init__ in their unflatten function, pointing to this section of the docs. This caused, for example, issues with flax.struct.dataclass and jaxtyping, since jaxtyping hooks into __init__ and thus throws type-checking errors when JAX uses unexpected objects for tracing.

However, JAX now has a built-in jax.tree_util.register_dataclass function. Unfortunately, the implementation does use __init__ for unflattening (see here). This causes the same old issues as flax.struct.dataclass (which now uses jax.tree_util.register_dataclass under the hood). For example:

import dataclasses
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array
from jaxtyping import jaxtyped
import typeguard


@jaxtyped(typechecker=typeguard.typechecked)
@partial(jax.tree_util.register_dataclass, data_fields=["a"], meta_fields=["b"])
@dataclasses.dataclass
class Data:
    a: Array
    b: str


def f(x: Data) -> int:
    return 1


data = Data(a=jnp.ones(10, dtype=int), b="hello")

f(data)  # works
jax.vmap(f)(data)  # Error: Called with parameters: {'self': Data(...), 'a': <object object at 0x7030cc516760>, 'b': 'hello'}

Technically, the docs suggest using __new__ in tree_unflatten only as one possible solution. Their other recommendation is:

For this reason, the init and new methods of custom PyTree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases.

This would look like adding some sort of special case to jaxtyping to try to detect when __init__ is being called during unflattening, and skip type-checking. However, I'm not sure how you would accomplish this besides simply ignoring leaves of certain types, like object, which would at least fix the issue for JAX transforms. It doesn't seem like a great solution through. In theory, a third-party library can unflatten PyTrees with any leaf type (although then you might also argue that jaxtyping should throw a type-checking error? not sure).

Since you're much more familiar with all of this stuff, what are your thoughts?

@patrick-kidger
Copy link
Owner

Looks like a bug in JAX. I've opened jax-ml/jax#25486 .

@patrick-kidger patrick-kidger added the question User queries label Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants