Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented jax.lax.while primitive #16

Merged
merged 5 commits into from
Jul 21, 2024

Conversation

ymahlau
Copy link
Contributor

@ymahlau ymahlau commented Apr 20, 2024

This PR adds an implementation of the jax.lax.while primitive

Closes #15

@patrick-kidger
Copy link
Owner

Huh! This... actually seems really simple. I was not expecting it to be this easy.

I think to merge this I would like to see some tests added, if that's okay? The ones that jump out are:

  • a basic test with quaxified carry.
  • something where the cond_fun and body_fun close over additional constant values.
  • combining this with jit/grad/vmap.

@ymahlau
Copy link
Contributor Author

ymahlau commented May 4, 2024

I agree that some test cases are necessary. To this end, I extended the unitful example and added it to the examples folder to have some wrapper that allows for easy testing. Then I added some test cases including constant values, jit and vmap.

@patrick-kidger
Copy link
Owner

Nice, thank you! My comments on this are:

  1. it looks like the pre-commit hooks are failing (formatting etc.). These should be easy enough to fix: pre-commit install; pre-commit run --all-files; git add -u; git commit.
  2. I have a very specific request: can you add some a test in which we (a) close over a variable in the body function whilst simultaneously (b) differentiating with respect to it? This is something which I have come to know can easily go wrong... so I'd like to be sure we get it right!

@nstarman
Copy link
Contributor

This feature looks very useful!

@ymahlau
Copy link
Contributor Author

ymahlau commented Jul 14, 2024

Sorry for the (very) late reply! This project kind of fell under the table and the comment of @nstarman reminded that the PR is still open. Regarding your comments:

  1. formatting should be fine now.
  2. I added an example with closure over the body function and grad. In practice this is a bit finicky, because the differentiation needs to be inside quaxify such that the closure is correct. Then the tangent array needs to be passed through the quaxify boundary as well. If you want to change anything about this example feel free to suggest improvements.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments aside this LGTM! Let's get this in :)

It seems that the pre-commit checks are failing. These can be ran locally, on every commit, before pushing remotely. I've realised that we don't seem to have a contributing guide for Quax (I usually try to put one in every repo, not sure how this one slipped through!) So I've just added one here: https://github.com/patrick-kidger/quax/blob/main/CONTRIBUTING.md

quax/_core.py Outdated
@@ -537,4 +537,34 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):
return jax.jit(flat_fun)(leaves) # now we can call without Quax.


@register(jax.lax.while_p)
def _(
*args: ArrayValue | ArrayLike,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only Python 3.10+? For now we're on 3.9+.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true, i locally used Python 3.11 and good this mixed up. It is changed now.

quax/_core.py Outdated
quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals)

leaves, _ = jtu.tree_flatten(args)
_, val_treedef = jtu.tree_flatten(init_vals)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is implicitly relying on the flattening here occuring in the same way / in the same order, as in the make_jaxpr calls above? I could see that potentially going wrong.

My guess is that the 'correct' thing to would be to call, in order:

flatten
make_jaxpr
quaxify
unflatten
jaxpr_as_fun

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, i split the tree flatten now into three different calls for the cond-constants, body-constants and while variables.

import jax
import jax.core as core
import jax.numpy as jnp
from jaxtyping import ArrayLike # https://github.com/patrick-kidger/quax
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in the comment ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@ymahlau
Copy link
Contributor Author

ymahlau commented Jul 20, 2024

There is still one error remaining in pyright, which originates from the Zero-Example. I think this may have been introduced in the last update of quax? I made no changes to this part of the code to keep this PR more or less self-contained.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jul 21, 2024

The checks look to have passed! FWIW @nstarman fixed up some pyright stuff recently, looks like things have been resolved there.

So anyway: LGTM, merged -- thank you for contributing, I'm happy to have this in!

@patrick-kidger patrick-kidger merged commit dcbadb6 into patrick-kidger:main Jul 21, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to implement jax.lax.while with quax
3 participants