Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with torch.compile #196

Open
botev opened this issue Mar 26, 2024 · 5 comments
Open

Issues with torch.compile #196

botev opened this issue Mar 26, 2024 · 5 comments

Comments

@botev
Copy link

botev commented Mar 26, 2024

We are very happy with the fact that jaxtyping supports Pytorch as well, but we are currently hitting some kind of weird error/edge case and was hoping if you can give some suggestions.
When compiling a module and trying to run it we get this stacktrace:

  File "/build/work/cfc8a89b76634373e85beb2a59a94e9e781a/google3/runfiles/google3/third_party/py/torch/_dynamo/bytecode_transformation.py", [line 646](https://cs.corp.google.com/piper///depot/google3/third_party/py/torch/_dynamo/bytecode_transformation.py?l=646&ws=botev/13260&snapshot=14397), in compute_exception_table
    keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1]))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of 'NoneType' and 'int'

from user code:
   File "/build/work/cfc8a89b76634373e85beb2a59a94e9e781a/google3/runfiles/google3/third_party/py/jaxtyping/_decorator.py", [line 411](https://cs.corp.google.com/piper///depot/google3/third_party/py/jaxtyping/_decorator.py?l=411&ws=botev/13260&snapshot=14397), in wrapped_fn
    bound = param_signature.bind(*args, **kwargs)
@patrick-kidger
Copy link
Owner

I'd suggest raising this with the PyTorch folks, including a MWE. This is likely this is an instance of hitting something torch.compile doesn't support yet. Another example came up recently at pytorch/pytorch#122093

@botev
Copy link
Author

botev commented Mar 27, 2024

Hmm, I managed to fix a few things and rearrange, but now I get:

The problem arose whilst typechecking parameter 'self'.
Actual value: MyModuel(....)
Expected type: <class 'inspect._empty'>.

which I would guess is because the compile rewrite the forward pass as a pure function?

@patrick-kidger
Copy link
Owner

I'm not sure of the details of torch.compile, so hard for me to speculate I'm afraid.

@Chrixtar
Copy link

Hi @botev , I face exactly the same (original) problem.
Did you find any solution?

@botev
Copy link
Author

botev commented Apr 30, 2024

Unfortunately no, I just disabled the guard for PyTorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants