Skip to content

Commit

Permalink
Move eval_trace from transforms.py and rename it to interpret_trace (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk authored Sep 7, 2024
1 parent f7bdd5a commit 92a3099
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 72 deletions.
76 changes: 76 additions & 0 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any

from thunder.core import prims
from thunder.core.pytree import tree_map
from thunder.core.trace import VariableInterface
from thunder.core.utils import safe_map_flat, sequencify


# TODO: Currently we use trace.args and trace.kwargs to get the arguments
# Maybe we should use these instead
trace_interpreter_skip_list = (
prims.PrimIDs.UNPACK_EMPTY_DICT,
prims.PrimIDs.UNPACK_KEY,
prims.PrimIDs.UNPACK_SEQUENCE,
prims.PrimIDs.UNPACK_TRIVIAL,
prims.PrimIDs.RETURN,
)


def interpret_trace(trace, *args, symbol_mapper=None, with_env=False, **kwargs):
"""Interpret a trace.
Args:
trace: trace to interpret
*args: arguments to interpret the trace with
symbol_mapper: function that redirects the evaluation of a BoundSymbol to a different function
with_env: whether to return the environment after interpreting the trace. Environment is a dictionary
that maps VariableInterface objects to their values.
**kwargs: keyword arguments to interpret the trace with
Returns:
result of interpreting the trace, optionally with the environment that saves all intermediate values
"""
env = {}

def read(x: VariableInterface | Any) -> Any:
if isinstance(x, VariableInterface):
return env[x.name]
else:
return x

def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
if not isinstance(v, VariableInterface):
return
# Duplicates are allowed and overwritten
if v.name in env:
if allow_duplicates:
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
env[v.name] = val

safe_map_flat(write, list(trace.args), list(args))
safe_map_flat(write, list(trace.kwargs.values()), list(kwargs.values()))

for symbol in trace.bound_symbols:
if symbol.sym.id in trace_interpreter_skip_list:
continue
args = tree_map(read, symbol.args)
kwargs = tree_map(read, symbol.kwargs)
prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym
if prim_func is None:
continue
result = prim_func(*args, **kwargs)
try:
safe_map_flat(write, list(sequencify(symbol.output)), list(sequencify(result)))
except AssertionError as e:
raise RuntimeError(
f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {symbol}."
" This is likely due to a mismatch in the number of outputs."
f" The original symbol has {len(symbol.output)} outputs and the dispatched function has {len(sequencify(result))} outputs."
) from e

if with_env:
return tree_map(read, trace.output), env

return tree_map(read, trace.output)
73 changes: 3 additions & 70 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import thunder.core.utils as utils
from thunder.core import dtypes, prims
from thunder.core.devices import cpu, Device
from thunder.core.trace_interpreter import interpret_trace as eval_trace, trace_interpreter_skip_list
from thunder.core.proxies import (
NumberProxy,
Proxy,
Expand Down Expand Up @@ -1445,74 +1446,6 @@ def symbol_to_eval(bound_symbol):
return bound_symbol.sym


# TODO: Currently we use trace.args and trace.kwargs to get the arguments
# Maybe we should use these instead
transform_skip_list = (
prims.PrimIDs.UNPACK_EMPTY_DICT,
prims.PrimIDs.UNPACK_KEY,
prims.PrimIDs.UNPACK_SEQUENCE,
prims.PrimIDs.UNPACK_TRIVIAL,
prims.PrimIDs.RETURN,
)


def eval_trace(trace, *args, symbol_mapper=symbol_to_eval, with_env=False, **kwargs):
"""Evaluate a trace.
Args:
trace: trace to evaluate
*args: arguments to evaluate the trace with
symbol_mapper: function that maps a symbol to a function that evaluates it
**kwargs: keyword arguments to evaluate the trace with
Returns:
result of evaluating the trace
"""
env = {}

def read(x: Variable):
if isinstance(x, Variable):
return env[x.name]
else:
return x

def write(v: Variable, val: Any, allow_duplicates=False) -> None:
if not isinstance(v, Variable):
return
# Duplicates are allowed and overwritten
if v.name in env:
if allow_duplicates:
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
env[v.name] = val

safe_map_flat(write, list(trace.args), list(args))
safe_map_flat(write, list(trace.kwargs.values()), list(kwargs.values()))

for symbol in trace.bound_symbols:
if symbol.sym.id in transform_skip_list:
continue
args = tree_map(read, symbol.args)
kwargs = tree_map(read, symbol.kwargs)
prim_func = symbol_mapper(symbol)
if prim_func is None:
continue
result = prim_func(*args, **kwargs)
try:
safe_map_flat(write, list(sequencify(symbol.output)), list(sequencify(result)))
except AssertionError as e:
raise RuntimeError(
f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {symbol}."
" This is likely due to a mismatch in the number of outputs."
f" The original symbol has {len(symbol.output)} outputs and the dispatched function has {len(sequencify(result))} outputs."
) from e

if with_env:
return tree_map(read, trace.output), env

return tree_map(read, trace.output)


def unwrap_one_level_of_subsymbols(trace):
new_symbols_iter = (
bound_symbol.subsymbols if len(bound_symbol.subsymbols) > 0 else [bound_symbol]
Expand Down Expand Up @@ -2266,7 +2199,7 @@ def iter_bound_symbols(bound_symbols):
infrastructure
"""
for symbol in bound_symbols:
if symbol.sym.id in transform_skip_list:
if symbol.sym.id in trace_interpreter_skip_list:
continue
elif symbol.output is None:
continue
Expand Down Expand Up @@ -2559,7 +2492,7 @@ def check_bsym_for_vjp(bsym):
bool: True if the bound symbol is supported by vjp, False otherwise.
"""

if bsym.sym.id in transform_skip_list:
if bsym.sym.id in trace_interpreter_skip_list:
return True

if bsym.sym.id in backward_impls and bsym.sym.id in augmented_forward_impls:
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _generate_supported_op_list(checker):
Returns:
generator: A generator of operator info objects that support vjp.
"""
from thunder.core.transforms import transform_skip_list
from thunder.core.transforms import trace_interpreter_skip_list

for opinfo in opinfos:
if opinfo not in tensor_creation_ops and opinfo.name not in op_skip:
Expand All @@ -108,7 +108,7 @@ def _generate_supported_op_list(checker):
samples = iter(opinfo.sample_inputs("cpu", dtypes.float64, requires_grad=True))
while (sample := next(samples, None)) is not None:
trc = thunder.trace()(opinfo.op, *sample.args, **sample.kwargs)
all_skipped = all(s.sym.id in transform_skip_list for s in trc.bound_symbols)
all_skipped = all(s.sym.id in trace_interpreter_skip_list for s in trc.bound_symbols)
if all_skipped:
continue
all_supported = all(checker(s) for s in trc.bound_symbols)
Expand Down

0 comments on commit 92a3099

Please sign in to comment.