Skip to content

Commit

Permalink
Runtime errors inside eqx.filter_jit are now very readable.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Aug 18, 2024
1 parent 68cc26a commit 2188e01
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 122 deletions.
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
is_inexact_array_like as is_inexact_array_like,
partition as partition,
)
from ._jit import filter_jit as filter_jit
from ._jit import EquinoxRuntimeError as EquinoxRuntimeError, filter_jit as filter_jit
from ._make_jaxpr import filter_make_jaxpr as filter_make_jaxpr
from ._module import (
field as field,
Expand Down
93 changes: 34 additions & 59 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import numpy as np
from jaxtyping import Array, ArrayLike, Bool, Int, PyTree

from . import _jit
from ._ad import filter_custom_jvp
from ._config import EQX_ON_ERROR, EQX_ON_ERROR_BREAKPOINT_FRAMES
from ._doc_utils import doc_remove_args
from ._filters import combine, is_array, partition
from ._jit import filter_jit
from ._misc import currently_jitting
from ._unvmap import unvmap_any, unvmap_max

Expand Down Expand Up @@ -52,76 +52,50 @@ def _nan_like(x: Union[Array, np.ndarray]) -> Union[Array, np.ndarray]:
"""


_on_error_msg = """
---------------------------------------------------------------------------
An error occurred during the runtime of your JAX program.
---------------------------------------------------------------------------
Traceback:
{stack}
---------------------------------------------------------------------------
Error message:
{msg}
---------------------------------------------------------------------------
You have a few options to try and debug this issue.
1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful
way to debug such errors. This can be interacted with using most of the usual commands
for the Python debugger: `u` and `d` to move up and down frames, the name of a variable
to print its value, etc.
If taking this approach, then it is recommended to also set
`EQX_ON_ERROR_BREAKPOINT_FRAMES=<some number>`, corresponding to the number of frames to
add to the debugger.
If you get trace-time errors from JAX then try reducing the value of
`EQX_ON_ERROR_BREAKPOINT_FRAMES`. See
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.
2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can
(mostly) inspect the state of your program as if it was normal Python.
3) For more suggestions, see `https://docs.kidger.site/equinox/api/debug/`.
"""


_frames_msg = f"""
Opening a breakpoint with {EQX_ON_ERROR_BREAKPOINT_FRAMES} frames.
You can control this value by setting the environment variable
`EQX_ON_ERROR_BREAKPOINT_FRAMES=<some number>`.
-------------------
Note that setting large values of this number may lead to crashes at trace time; see
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.
Opening a breakpoint with {EQX_ON_ERROR_BREAKPOINT_FRAMES} frames. You can control this
value by setting the environment variable `EQX_ON_ERROR_BREAKPOINT_FRAMES=<some value>`.
(Note that setting large values of this number may lead to crashes at trace time; see
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.)
"""


# This is never actually surfaced to an end user -- it always becomes an XlaRuntimeError
class EqxRuntimeError(RuntimeError):
# The name of this is looked for in `_jit.py` in order to determine if we have a
# runtime error -- and if so then the custom reporting will engage.
#
# Note that this is *not* the class that is raised at runtime to a user: this is an
# internal implementation detail of Equinox. It is caught by `equinox.filter_jit` and
# replaced with the actual run time error. (Without any of the misleading baggage that
# XLA would otherwise attach.)
class _EquinoxRuntimeError(RuntimeError):
pass


class EquinoxTracetimeError(RuntimeError):
pass


EquinoxTracetimeError.__module__ = "equinox"


@filter_custom_jvp
def _error(x, pred, index, *, msgs, on_error, stack):
if on_error == "raise":

def raises(_index):
raise EqxRuntimeError(
_on_error_msg.format(stack=stack, msg=msgs[_index.item()])
# Sneakily smuggle out the information about the error. Inspired by
# `sys.last_value`.
_jit.last_msg = msg = msgs[_index.item()]
_jit.last_stack = stack
raise _EquinoxRuntimeError(
f"{msg}\n\n\n"
"--------------------\n"
"An error occurred during the runtime of your JAX program! "
"Unfortunately you do not appear to be using `equinox.filter_jit` "
"(perhaps you are using `jax.jit` instead?) and so further information "
"about the error cannot be displayed. (Probably you are seeing a very "
"large but uninformative error message right now.) Please wrap your "
"program with `equinox.filter_jit`.\n"
"--------------------\n"
)

def tpu_msg(_out, _index):
Expand All @@ -148,7 +122,7 @@ def handle_error(): # pyright: ignore

def display_msg(_index):
print(_frames_msg)
print(msgs[_index.item()])
print("equinox.EquinoxRuntimeError: " + msgs[_index.item()])
return _index

def to_nan(_index):
Expand Down Expand Up @@ -356,11 +330,12 @@ def branched_error_if_impl(
return x

tb = None
frames = list(traceback.walk_stack(None))
for f, lineno in reversed(frames):
for f, lineno in traceback.walk_stack(None):
if f.f_locals.get("__equinox_filter_jit__", False):
break
if traceback_util.include_frame(f):
tb = types.TracebackType(tb, f, f.f_lasti, lineno)
stack = "\n".join(traceback.format_tb(tb))
stack = "".join(traceback.format_tb(tb)).rstrip()
dynamic_x, static_x = partition(x, is_array)
flat = jtu.tree_leaves(dynamic_x)
if len(flat) == 0:
Expand All @@ -373,7 +348,7 @@ def branched_error_if_impl(

# filter_jit does some work to produce nicer runtime error messages.
# We also place it here to ensure a consistent experience when using JAX in eager mode.
branched_error_if_impl_jit = filter_jit(branched_error_if_impl)
branched_error_if_impl_jit = _jit.filter_jit(branched_error_if_impl)


def assert_dce(
Expand Down
154 changes: 92 additions & 62 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools as ft
import inspect
import sys
import warnings
from collections.abc import Callable
from typing import Any, Literal, overload, TypeVar
Expand All @@ -22,6 +23,7 @@
from ._deprecate import deprecated_0_10
from ._doc_utils import doc_remove_args
from ._filters import combine, is_array, partition
from ._misc import currently_jitting
from ._module import field, Module, module_update_wrapper, Partial, Static


Expand Down Expand Up @@ -110,38 +112,54 @@ class XlaRuntimeError(Exception):
pass


def _modify_traceback(e: Exception):
# Remove JAX's UnfilteredStackTrace, with its huge error messages.
e.__cause__ = None
# Remove _JitWrapper.__call__ and _JitWrapper._call and Method.__call__ from the
# traceback
tb = e.__traceback__ = e.__traceback__.tb_next.tb_next.tb_next # pyright: ignore
try:
# See https://github.com/google/jax/blob/69cd3ebe99ce12a9f22e50009c00803a095737c7/jax/_src/traceback_util.py#L190 # noqa: E501
jax.lib.xla_extension.replace_thread_exc_traceback(tb) # pyright: ignore
except AttributeError:
pass
# IPython ignores __tracebackhide__ directives for the frame that actually raises
# the error. We fix that here.
try:
get_ipython() # pyright: ignore
except NameError:
pass
else:
import IPython # pyright: ignore

# Check that IPython supports __tracebackhide__
if IPython.version_info[:2] >= (7, 17): # pyright: ignore
tb_stack = []
while tb is not None:
tb_stack.append(tb)
tb = tb.tb_next
for tb in reversed(tb_stack):
if not tb.tb_frame.f_locals.get("__tracebackhide__", False):
tb.tb_next = None
break
else:
e.__traceback__ = None
# This is the class we use to raise runtime errors from `eqx.error_if`.
class EquinoxRuntimeError(RuntimeError):
pass


# Magic value that means error messages are displayed as `{__qualname__}: ...` rather
# than `{__module__}.{__qualname__}`. (At least, I checked the default Python
# interpreter, the default Python REPL, ptpython, ipython, pdb, and ipdb.)
EquinoxRuntimeError.__module__ = "builtins"
# Note that we don't also override `__name__` or `__qualname__`. Suppressing the
# `equinox._jit` module bit is useful for readability, but we don't want to go so far as
# deleting the name altogether. (Or even e.g. setting it to the 'Above is the stack...'
# first section of our error message below!) The reason is that whilst that gives a
# nicer displayed error in default Python, it doesn't necessarily do as well with other
# tools, e.g. debuggers. So what we have here is a compromise.


last_msg = None
last_stack = None


_on_error_msg = """Above is the stack outside of JIT. Below is the stack inside of JIT:
{stack}
equinox.EquinoxRuntimeError: {msg}
-------------------
An error occurred during the runtime of your JAX program.
1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful
way to debug such errors. This can be interacted with using most of the usual commands
for the Python debugger: `u` and `d` to move up and down frames, the name of a variable
to print its value, etc.
2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can
(mostly) inspect the state of your program as if it was normal Python.
3) See `https://docs.kidger.site/equinox/api/debug/` for more suggestions.
"""


class _FilteredStderr:
def __init__(self, stderr):
self.stderr = stderr

def write(self, data: str):
if "_EquinoxRuntimeError" not in data:
self.stderr.write(data)


class _JitWrapper(Module):
Expand All @@ -160,6 +178,9 @@ def __wrapped__(self):

def _call(self, is_lower, args, kwargs):
__tracebackhide__ = True
# Used by our error messages when figuring out where to stop walking the stack.
if not currently_jitting():
__equinox_filter_jit__ = True # noqa: F841
info = (
self._signature,
self._dynamic_fun,
Expand All @@ -178,49 +199,58 @@ def _call(self, is_lower, args, kwargs):
_postprocess, # pyright: ignore
)
else:
if self.filter_warning:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable*"
)
# Filter stderr to remove our default "you don't seem to be using
# `equinox.filter_jit`" message. (Which also comes with a misleading stack
# trace from XLA.)
stderr = sys.stderr
sys.stderr = _FilteredStderr(stderr)
try:
if self.filter_warning:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable*"
)
out = self._cached(dynamic_donate, dynamic_nodonate, static)
else:
out = self._cached(dynamic_donate, dynamic_nodonate, static)
else:
out = self._cached(dynamic_donate, dynamic_nodonate, static)
except XlaRuntimeError as e:
# Catch Equinox's runtime errors, and re-raise them with actually useful
# information. (By default XlaRuntimeError produces a lot of terrifying
# but useless information.)
if (
last_msg is not None
and last_stack is not None
and "_EquinoxRuntimeError: " in str(e)
):
# We check `last_msg` and `last_stack` just in case. I'm not sure
# what happens in distributed/multiprocess environments. Is the
# callback necessarily executed in the same interpreter as we are in
# here?
raise EquinoxRuntimeError(
_on_error_msg.format(msg=last_msg, stack=last_stack)
) from None
# `from None` to hide the large but uninformative XlaRuntimeError.
else:
raise
finally:
sys.stderr = stderr
return _postprocess(out)

def __call__(self, /, *args, **kwargs):
__tracebackhide__ = True
try:
return self._call(False, args, kwargs)
except XlaRuntimeError as e:
# Catch Equinox's runtime errors, and strip the more intimidating parts of
# the error message.
if len(e.args) != 1 or not isinstance(e.args[0], str):
raise # No idea if this ever happens. But if it does, just bail.
(msg,) = e.args
if "EqxRuntimeError: " in msg:
_, msg = msg.split("EqxRuntimeError: ", 1)
msg, *_ = msg.rsplit("\n\nAt:\n", 1)
e.args = (msg,)
if jax.config.jax_traceback_filtering in ( # pyright: ignore
None,
"auto",
):
_modify_traceback(e)
except EquinoxRuntimeError as e:
# Use a two-part try/except here and in `_call` to delete the
# `raise EquinoxRuntimeError` line from the stack trace.
e.__traceback__ = None
raise
# I considered also catching `Exception`, and removing the terrifying-looking
# JAX exception that occurs by default.
# This ends up being difficult to get working reliably (e.g. KeyError has a
# different __str__ so modifying the `.args` is hard/undefined; JAX errors have
# a different __init__ so overwriting __str__ in a new class ends up requiring
# magic; taking a different approach and overwriting sys.excepthook is ignored
# under IPython, ...)
# All in all, not worth it.

def lower(self, /, *args, **kwargs) -> Lowered:
return self._call(True, args, kwargs)

def __get__(self, instance, owner):
del owner
if instance is None:
return self
return Partial(self, instance)
Expand Down

0 comments on commit 2188e01

Please sign in to comment.