Skip to content

Commit

Permalink
Better import hook (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger authored Sep 26, 2022
1 parent 1650657 commit d246e21
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ class Array:
from .pytree_type import PyTree


__version__ = "0.2.6"
__version__ = "0.2.7"
34 changes: 25 additions & 9 deletions jaxtyping/import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _call_with_frames_removed(f, *args, **kwargs):


def _optimized_cache_from_source(path, debug_override=None):
return cache_from_source(path, debug_override, optimization="jaxtyping")
# Version 2: change the position of the `@jaxtyped` decorator, so need a
# different name to avoid hitting old __pycache__
return cache_from_source(path, debug_override, optimization="jaxtyping2")


class _JaxtypingTransformer(ast.NodeVisitor):
Expand Down Expand Up @@ -99,10 +101,16 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
has_annotated_args = any(arg for arg in node.args.args if arg.annotation)
has_annotated_return = bool(node.returns)
if has_annotated_args or has_annotated_return:
# Place at the start of the decorator list, in case a typechecking
# annotation has been manually applied; we need to be above that.
node.decorator_list.insert(
0,
# Place at the end of the decorator list, as otherwise we wrap e.g.
# `jax.custom_{jvp,vjp}` and lose the ability to `defjvp` etc.
#
# Note that the counter-argument here is that we'd like to place this
# at the start of the decorator list, in case a typechecking annotation
# has been manually applied, and we'd need to be above that. In this
# case we're just going to have to need to ask the user to remove their
# typechecking annotation (and let this decorator do it instead).
# It's more important we be compatible with normal JAX code.
node.decorator_list.append(
ast.Attribute(
ast.Name(id="jaxtyping", ctx=ast.Load()), "jaxtyped", ast.Load()
),
Expand Down Expand Up @@ -230,8 +238,16 @@ def install_import_hook(
- `typechecker`: the module and function of the typechecker you want to use, as a
2-tuple of strings. For example `typechecker=("typeguard", "typechecked")` or
`typechecker=("beartype", "beartype")`. You may pass `typechecker=None` if you
do not want to automatically decorate with a typechecker as well; e.g. if you
have a codebase that already has these decorators.
do not want to automatically decorate with a typechecker as well.
If the function already has any decorators on it, then both the `@jaxtyped` and the
typechecker decorators will go at the bottom of the decorator list, e.g.
```python
@some_other_decorator
@jaxtyped
@beartype.beartype
def foo(...): ...
```
**Returns:**
Expand All @@ -243,8 +259,8 @@ def install_import_hook(
```python
# entry_point.py
from jaxtyped import install_import_hook
install_import_hook("main", ("beartype", "beartype"))
import main
with install_import_hook("main", ("beartype", "beartype"))
import main
... # do whatever you're doing
# main.py
Expand Down

0 comments on commit d246e21

Please sign in to comment.