Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can dtypes of later arguments be annotated to follow the dtypes of earlier arguments? #276

Open
NiklasKappel opened this issue Nov 29, 2024 · 1 comment
Labels
question User queries

Comments

@NiklasKappel
Copy link

Consider the function

def fill_array(array: Shaped[Array, "N"], fill_value: Any) -> Shaped[Array, "N"]:
    return jnp.full_like(array, fill_value)

In principle, the dtypes of fill_value and the return value should be constrained to be same as that of the input array. E.g., if I pass in an array of Integers, I should not specify a Float as the fill_value. Is it possible to annotate this constraint?

@patrick-kidger
Copy link
Owner

Yeah, I've wanted this a few times as well. I don't know of a good way to do this unfortunately. You could use @typing.overload if you really need, but that is fairly heavy-handed.

Hypothetically jaxtyping could have had an alternate design that looks likeJaxtyped[Array, "N", "dtype"] -- so that dtype is dynamically bound in the same as the shapes -- but that seemed too verbose in the common case of a fixed dtype.

I'm sorry that I don't have a better answer for you!

@patrick-kidger patrick-kidger added the question User queries label Nov 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants