Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add type annotations #23

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.5.2
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.365
rev: v1.1.372
hooks:
- id: pyright
additional_dependencies: ["equinox", "pytest", "jax", "jaxtyping", "plum-dispatch"]
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeCon
[tool.ruff.lint]
select = ["E", "F", "I001"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
ignore-init-module-imports = true
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newer ruff doesn't require it. It surfaces a warning saying that it's being deprecated.

fixable = ["I001", "F401"]

[tool.ruff.lint.isort]
Expand Down
30 changes: 19 additions & 11 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools as ft
import itertools as it
from collections.abc import Callable, Sequence
from typing import Any, cast, Union
from typing import Any, cast, Generic, TypeVar, Union
from typing_extensions import TypeGuard

import equinox as eqx
Expand All @@ -17,6 +17,8 @@
from jaxtyping import ArrayLike, PyTree


CT = TypeVar("CT", bound=Callable)

#
# Rules
#
Expand All @@ -25,7 +27,7 @@
_rules: dict[core.Primitive, plum.Function] = {}


def register(primitive: core.Primitive):
def register(primitive: core.Primitive) -> Callable[[CT], CT]:
"""Registers a multiple dispatch implementation for this JAX primitive.

!!! Example
Expand Down Expand Up @@ -53,7 +55,7 @@ def _(x: SomeValue, y: SomeValue):
A decorator for registering a multiple dispatch rule with the specified primitive.
"""

def _register(rule: Callable):
def _register(rule: CT) -> CT:
try:
existing_rule = _rules[primitive] # pyright: ignore
except KeyError:
Expand All @@ -80,7 +82,7 @@ def existing_rule():
class _QuaxTracer(core.Tracer):
__slots__ = ("value",)

def __init__(self, trace: "_QuaxTrace", value: "Value"):
def __init__(self, trace: "_QuaxTrace", value: "Value") -> None:
assert _is_value(value)
self._trace = trace
self.value = value
Expand Down Expand Up @@ -292,13 +294,13 @@ def _unwrap_tracer(trace, x):
return x


class _Quaxify(eqx.Module):
fn: Callable
class _Quaxify(eqx.Module, Generic[CT]):
fn: CT
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]]
dynamic: bool = eqx.field(static=True)

@property
def __wrapped__(self):
def __wrapped__(self) -> CT:
return self.fn

def __call__(self, *args, **kwargs):
Expand All @@ -320,13 +322,16 @@ def __call__(self, *args, **kwargs):
out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
return out

def __get__(self, instance, owner):
def __get__(self, instance: Union[object, None], owner: Any):
if instance is None:
return self
return eqx.Partial(self, instance)


def quaxify(fn, filter_spec=True):
def quaxify(
fn: CT,
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = True,
) -> _Quaxify[CT]:
"""'Quaxifies' a function, so that it understands custom array-ish objects like
[`quax.examples.lora.LoraArray`][]. When this function is called, multiple dispatch
will be performed against the types it is called with.
Expand All @@ -349,7 +354,10 @@ def quaxify(fn, filter_spec=True):
nested `quax.quaxify`. See the
[advanced tutorial](../examples/redispatch.ipynb).
"""
return eqx.module_update_wrapper(_Quaxify(fn, filter_spec, dynamic=False))
return cast(
_Quaxify[CT],
eqx.module_update_wrapper(_Quaxify(fn, filter_spec, dynamic=False)),
)


#
Expand Down Expand Up @@ -381,7 +389,7 @@ def aval(self) -> core.AbstractValue:

@staticmethod
def default(
primitive, values: Sequence[Union[ArrayLike, "Value"]], params
primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
) -> Union[ArrayLike, "Value", Sequence[Union[ArrayLike, "Value"]]]:
"""This is the default rule for when no rule has been [`quax.register`][]'d for
a primitive.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def default(primitive, values, params):
if primitive.multiple_results:
return [Foo(x) for x in out]
else:
return Foo(out)
return Foo(cast(Array, out))

return Foo

Expand Down
Loading