-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented jax.lax.while primitive #16
Conversation
Huh! This... actually seems really simple. I was not expecting it to be this easy. I think to merge this I would like to see some tests added, if that's okay? The ones that jump out are:
|
I agree that some test cases are necessary. To this end, I extended the unitful example and added it to the examples folder to have some wrapper that allows for easy testing. Then I added some test cases including constant values, jit and vmap. |
Nice, thank you! My comments on this are:
|
This feature looks very useful! |
Sorry for the (very) late reply! This project kind of fell under the table and the comment of @nstarman reminded that the PR is still open. Regarding your comments:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments aside this LGTM! Let's get this in :)
It seems that the pre-commit checks are failing. These can be ran locally, on every commit, before pushing remotely. I've realised that we don't seem to have a contributing guide for Quax (I usually try to put one in every repo, not sure how this one slipped through!) So I've just added one here: https://github.com/patrick-kidger/quax/blob/main/CONTRIBUTING.md
quax/_core.py
Outdated
@@ -537,4 +537,34 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs): | |||
return jax.jit(flat_fun)(leaves) # now we can call without Quax. | |||
|
|||
|
|||
@register(jax.lax.while_p) | |||
def _( | |||
*args: ArrayValue | ArrayLike, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is only Python 3.10+? For now we're on 3.9+.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true, i locally used Python 3.11 and good this mixed up. It is changed now.
quax/_core.py
Outdated
quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals) | ||
|
||
leaves, _ = jtu.tree_flatten(args) | ||
_, val_treedef = jtu.tree_flatten(init_vals) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is implicitly relying on the flattening here occuring in the same way / in the same order, as in the make_jaxpr
calls above? I could see that potentially going wrong.
My guess is that the 'correct' thing to would be to call, in order:
flatten
make_jaxpr
quaxify
unflatten
jaxpr_as_fun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, i split the tree flatten now into three different calls for the cond-constants, body-constants and while variables.
quax/examples/unitful/_core.py
Outdated
import jax | ||
import jax.core as core | ||
import jax.numpy as jnp | ||
from jaxtyping import ArrayLike # https://github.com/patrick-kidger/quax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in the comment ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There is still one error remaining in pyright, which originates from the Zero-Example. I think this may have been introduced in the last update of quax? I made no changes to this part of the code to keep this PR more or less self-contained. |
The checks look to have passed! FWIW @nstarman fixed up some pyright stuff recently, looks like things have been resolved there. So anyway: LGTM, merged -- thank you for contributing, I'm happy to have this in! |
This PR adds an implementation of the jax.lax.while primitive
Closes #15