-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ValueError is raised when using two models #84
Comments
Additional info:
PyTreeDef(CustomNode(PhysicsModelState[()], [*, *, *, *, *, *])) == PyTreeDef(CustomNode(PhysicsModelState[()], [*, *, *, *, *, *]))
>>> True PyTreeDef(CustomNode(GroundContact[(array([20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 33, 33, 33,
33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34]),)], [*])) == PyTreeDef(CustomNode(GroundContact[(array([20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 33, 33, 33,
33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34]),)], [*]))
>>> ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() I verified that each children had the same |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When calling the same method on two different
high_level.Model
instance, aValueError
is raised:Minimal reproducible example:
Throws:
I initially thought it was related to this jax-ml/jax#4717, but apparently it has been solved.
Nevertheless, another possibility is that the error is raised here, in
jax_dataclasses
as:Throws:
For some reason, it looks like the two models are not treated as two separate entities, and at some point the comparison between the models is performed inside
jax_dataclasses
The text was updated successfully, but these errors were encountered: