Skip to content

Commit

Permalink
More fixes (#27)
Browse files Browse the repository at this point in the history
* Edge-case doc fixes for parameterising types with PyTrees or AbstractDtypes

* Fixes for threading

* version bump
  • Loading branch information
patrick-kidger authored Sep 20, 2022
1 parent 6202dcc commit 39439c2
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 11 deletions.
2 changes: 1 addition & 1 deletion jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ class Array:
from .pytree_type import PyTree


__version__ = "0.2.3"
__version__ = "0.2.4"
10 changes: 7 additions & 3 deletions jaxtyping/array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,15 @@ def __instancecheck__(cls, obj):
if cls.dtypes is not _any_dtype and dtype not in cls.dtypes:
return False

if len(storage.memo_stack) == 0:
temp_memo = not hasattr(storage, "memo_stack") or len(storage.memo_stack) == 0

if temp_memo:
# `isinstance` happening outside any @jaxtyped decorators, e.g. at the
# global scope. In this case just create a temporary memo, since we're not
# going to be comparing against any stored values anyway.
single_memo = {}
variadic_memo = {}
variadic_broadcast_memo = {}
temp_memo = True
else:
single_memo, variadic_memo, variadic_broadcast_memo = storage.memo_stack[-1]
# Make a copy so we don't mutate the original memo during the shape check.
Expand Down Expand Up @@ -474,7 +475,10 @@ class _Cls(AbstractDtype):

_Cls.__name__ = name
_Cls.__qualname__ = name
_Cls.__module__ = "jaxtyping"
if getattr(typing, "GENERATING_DOCUMENTATION", False):
_Cls.__module__ = "builtins"
else:
_Cls.__module__ = "jaxtyping"
return _Cls

UInt8 = _make_dtype(_uint8, "UInt8")
Expand Down
25 changes: 20 additions & 5 deletions jaxtyping/pytree_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import functools as ft
from typing import Generic, TypeVar
import typing
from typing import Generic, TYPE_CHECKING, TypeVar
from typing_extensions import Protocol

import jax
import typeguard
Expand Down Expand Up @@ -51,7 +53,10 @@ def __instancecheck__(cls, obj):
def __getitem__(cls, item):
name = str(_FakePyTree[item])
out = _MetaSubscriptPyTree(name, (), {"leaftype": item})
out.__module__ = "jaxtyping"
if getattr(typing, "GENERATING_DOCUMENTATION", False):
out.__module__ = "builtins"
else:
out.__module__ = "jaxtyping"
return out


Expand Down Expand Up @@ -82,8 +87,18 @@ def is_leaftype(x):
return all(map(is_leaftype, leaves))


PyTree = _MetaPyTree("PyTree", (), {})
PyTree.__module__ = "jaxtyping"
if TYPE_CHECKING:
# Work around pytype bug #1288
# pytype: skip-file
class PyTree(Protocol[_T]):
pass

else:
PyTree = _MetaPyTree("PyTree", (), {})
if getattr(typing, "GENERATING_DOCUMENTATION", False):
PyTree.__module__ = "builtins"
else:
PyTree.__module__ = "jaxtyping"
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
# instancecheck for PyTree[foo], but we subclassing
# instancecheck for PyTree[foo], but subclassing
# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
29 changes: 27 additions & 2 deletions test/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,22 @@
from jaxtyping import Array, Float, jaxtyped


def test_threading():
class _ErrorableThread(threading.Thread):
def run(self):
try:
super().run()
except Exception as e:
self.exc = e
finally:
del self._target, self._args, self._kwargs

def join(self, timeout=None):
super().join(timeout)
if hasattr(self, "exc"):
raise self.exc


def test_threading_jaxtyped():
@jaxtyped
@typechecked
def add(x: Float[Array, "a b"], y: Float[Array, "a b"]) -> Float[Array, "a b"]:
Expand All @@ -36,6 +51,16 @@ def run():
b = jnp.array([[2.0, 3.0]])
add(a, b)

thread = threading.Thread(target=run)
thread = _ErrorableThread(target=run)
thread.start()
thread.join()


def test_threading_nojaxtyped():
def run():
a = jnp.array([[1.0, 2.0]])
assert isinstance(a, Float[Array, "..."])

thread = _ErrorableThread(target=run)
thread.start()
thread.join()

0 comments on commit 39439c2

Please sign in to comment.