Skip to content

Commit

Permalink
Fixes (#1)
Browse files Browse the repository at this point in the history
* Fixes

* Black

* Test fix

* Test fix

* workflow fixes
  • Loading branch information
patrick-kidger authored Jul 11, 2022
1 parent 207faaa commit b5c2582
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
with:
python-version: "3.8"
test-script: |
python -m pip install pytest jax jaxlib typeguard
python -m pip install pytest beartype equinox jaxlib
cp -r ${{ github.workspace }}/test ./test
pytest
pypi-token: ${{ secrets.pypi_token }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest wheel jaxlib
python -m pip install pytest wheel beartype equinox jaxlib
- name: Checks with pre-commit
uses: pre-commit/[email protected]
Expand Down
1 change: 1 addition & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ force_alphabetical_sort_within_sections=true
lines_after_imports=2
profile=black
treat_comments_as_code=true
extra_standard_library=typing_extensions
2 changes: 1 addition & 1 deletion jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@
from .pytree_type import PyTree


__version__ = "0.0.1"
__version__ = "0.0.2"
3 changes: 2 additions & 1 deletion jaxtyping/array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import functools as ft
from typing import Any, Dict, List, Literal, NoReturn, Optional, Tuple, Union
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
from typing_extensions import Literal

import jax.numpy as jnp

Expand Down
2 changes: 1 addition & 1 deletion jaxtyping/import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
0,
ast.Attribute(
ast.Name(id="jaxtyping", ctx=ast.Load()), "jaxtyped", ast.Load()
)
),
)
if self._typechecker is not None:
# Place at the end of the decorator list, as decorators
Expand Down
22 changes: 21 additions & 1 deletion jaxtyping/pytree_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,27 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import functools as ft
from typing import Generic, TypeVar

import jax
import typeguard


_T = TypeVar("_T")


class _FakePyTree(Generic[_T]):
pass


_FakePyTree.__name__ = "PyTree"
_FakePyTree.__qualname__ = "PyTree"
# Can't do type("PyTree", (Generic[_T],), {}) because dynamic subclassing of typeforms
# isn't allowed.
# Can't do types.new_class("PyTree", (Generic[_T],), {}) because that has __module__
# "types", e.g. we get types.PyTree[int].


class _MetaPyTree(type):
def __call__(self, *args, **kwargs):
raise RuntimeError("PyTree cannot be instantiated")
Expand All @@ -32,7 +48,8 @@ def __instancecheck__(cls, obj):

@ft.lru_cache(maxsize=None)
def __getitem__(cls, item):
return _MetaSubscriptPyTree(f"PyTree[{item.__name__}]", (), {"leaftype": item})
name = str(_FakePyTree[item])
return _MetaSubscriptPyTree(name, (), {"leaftype": item})


class _MetaSubscriptPyTree(type):
Expand Down Expand Up @@ -63,3 +80,6 @@ def is_leaftype(x):


PyTree = _MetaPyTree("PyTree", (), {})
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
# instancecheck for PyTree[foo], but we subclassing
# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

# We use typeguard internally (in a fairly minimal way), but it's not required that
# end users make the same choice.
install_requires = ["jax>=0.3.4", "typeguard>=2.13.3"]
install_requires = ["jax>=0.3.4", "typeguard>=2.13.3", "typing_extensions>=4.2.0"]

entry_points = dict(pytest11=["jaxtyping = jaxtyping.pytest_plugin"])

Expand Down

0 comments on commit b5c2582

Please sign in to comment.