From 520daaef4f277b2111a0089765a2f5f34433b29b Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:56:33 +0100 Subject: [PATCH] Half compatibility with typeguard v4. --- docs/api/runtime-type-checking.md | 2 +- jaxtyping/_array_types.py | 27 +++++++++++++++++++++++++++ jaxtyping/_decorator.py | 24 +++++++++++++++++++----- test/requirements.txt | 2 +- test/test_import_hook.py | 20 -------------------- 5 files changed, 48 insertions(+), 27 deletions(-) diff --git a/docs/api/runtime-type-checking.md b/docs/api/runtime-type-checking.md index ec0cd86..88fa1c0 100644 --- a/docs/api/runtime-type-checking.md +++ b/docs/api/runtime-type-checking.md @@ -6,7 +6,7 @@ Runtime type checking **synergises beautifully with `jax.jit`!** All shape check There are two approaches: either use [`jaxtyping.jaxtyped`][] to typecheck a single function, or [`jaxtyping.install_import_hook`][] to typecheck a whole codebase. -In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions -- `3` and `4` -- have some known issues.) +In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). !!! warning diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index b6d4c53..2a27c91 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -19,6 +19,7 @@ import enum import functools as ft +import importlib.metadata import importlib.util import re import sys @@ -738,6 +739,28 @@ def __init_subclass__(cls, **kwargs): _complex128 = "complex128" +# Workaround a longstanding bug in typeguard v4, by monkeypatching their internals. +# https://stackoverflow.com/questions/79201839/hello-world-for-jaxtyping/79205145#79205145 +# https://github.com/patrick-kidger/jaxtyping/issues/80 +# https://github.com/agronholm/typeguard/issues/353 +# This is as robust as I can make it to future changes in typeguard, I think. +typeguard_v4_compat = False +try: + typeguard_distribution = importlib.metadata.distribution("typeguard") +except importlib.metadata.PackageNotFoundError: + pass +else: + if typeguard_distribution.version.split(".", 1)[0] == "4": + if importlib.util.find_spec("typeguard._transformer") is not None: + import typeguard._transformer + + if hasattr(typeguard._transformer, "annotated_names"): + annotated_names = typeguard._transformer.annotated_names + if type(annotated_names) is tuple: + if all(type(x) is str for x in annotated_names): + typeguard_v4_compat = True + + def _make_dtype(_dtypes, name): class _Cls(AbstractDtype): dtypes = _dtypes @@ -748,6 +771,10 @@ class _Cls(AbstractDtype): _Cls.__module__ = "builtins" else: _Cls.__module__ = "jaxtyping" + if typeguard_v4_compat: + typeguard._transformer.annotated_names = ( + typeguard._transformer.annotated_names + (f"jaxtyping.{name}",) + ) return _Cls diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index f382bc5..ee1b123 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -430,11 +430,13 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore module = getattr(fn, "__module__", "") # Use the same name so that typeguard warnings look correct. + # Set the line number so that typeguard v4 finds us. + lineno = getattr(getattr(fn, "__code__", None), "co_firstlineno", 1) full_fn, output_name = _make_fn_with_signature( - name, qualname, module, full_signature, output=True + name, qualname, module, full_signature, output=True, lineno=lineno ) param_fn = _make_fn_with_signature( - name, qualname, module, param_signature, output=False + name, qualname, module, param_signature, output=False, lineno=lineno ) full_fn = _apply_typechecker(typechecker, full_fn) param_fn = _apply_typechecker(typechecker, param_fn) @@ -616,13 +618,19 @@ def _check_dataclass_annotations(self, typechecker): self.__class__.__module__, signature, output=False, + lineno=1, ) f = jaxtyped(f, typechecker=typechecker) f(self, **values) def _make_fn_with_signature( - name: str, qualname: str, module: str, signature: inspect.Signature, output: bool + name: str, + qualname: str, + module: str, + signature: inspect.Signature, + output: bool, + lineno: int, ): """Dynamically creates a function `fn` with name `name` and signature `signature`. @@ -740,7 +748,8 @@ def _make_fn_with_signature( else: retstr = f"-> {name_to_annotation['return']}" - fnstr = f"def {name}({argstr}){retstr}:\n {outstr}" + newlines = "\n" * (lineno - 1) + fnstr = f"{newlines}def {name}({argstr}){retstr}:\n {outstr}" exec(fnstr, scope) fn = scope[name] del scope[name] # Avoids introducing a reference cycle. @@ -802,7 +811,12 @@ def _get_problem_arg( assert keep_annotation is not sentinel new_signature = inspect.Signature(new_parameters) fn = _make_fn_with_signature( - "check_single_arg", "check_single_arg", module, new_signature, output=False + "check_single_arg", + "check_single_arg", + module, + new_signature, + output=False, + lineno=1, ) fn = _apply_typechecker( typechecker, fn diff --git a/test/requirements.txt b/test/requirements.txt index 44fc63b..ee4b5d8 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -7,4 +7,4 @@ numpy<2 pytest pytest-asyncio tensorflow -typeguard<3 +typeguard diff --git a/test/test_import_hook.py b/test/test_import_hook.py index 29abb8d..28ef31c 100644 --- a/test/test_import_hook.py +++ b/test/test_import_hook.py @@ -32,26 +32,6 @@ _here = pathlib.Path(__file__).parent -try: - typeguard_version = importlib.metadata.version("typeguard") -except Exception as e: - raise ImportError("Could not find typeguard version") from e -else: - try: - major, _, _ = typeguard_version.split(".") - major = int(major) - except Exception as e: - raise ImportError( - f"Unexpected typeguard version {typeguard_version}; not formatted as " - "`major.minor.patch`" - ) from e -if major != 2: - raise ImportError( - "jaxtyping's tests required typeguard version 2. (Versions 3 and 4 are both " - "known to have bugs.)" - ) - - assert not hasattr(jaxtyping, "_test_import_hook_counter") jaxtyping._test_import_hook_counter = 0