Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support jax arrays (and optionally, cvxpy expressions) everywhere #574

Open
Jacob-Stevens-Haas opened this issue Oct 18, 2024 · 0 comments

Comments

@Jacob-Stevens-Haas
Copy link
Member

Jacob-Stevens-Haas commented Oct 18, 2024

See #562

This was thought to be easy, because in many cases jax arrays were an
almost drop-in replacement for numpy arrays. However, they are far less
amenable to subclassing. Why does this matter?

The codebase gained a lot of readability with AxesArray allowing arrays
to dynamically know what their axes meant, even after indexing changed
their shape. However, extending AxesArray to dynamically subclass either
numpy.ndarray or jax.Array is impossible - even a static subclass of the
latter is impossible.

Long term, we will need our own metadata type that carries around an array,
it's type package (numpy or jax.numpy or cvxpy.numpy), its bidirectional
mapping between axis index and axis meaning, and maybe even something from
sympy. The hard part of this is done, since after all, AxesArray functionality
only deals with the axes

Short term, we should expose our general expectations for axis definitions
as global constants. This is still error prone, as the constants are
incorrect for arrays that have changed shape due to indexing, but will
be far more readable than magic numbers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant