diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py new file mode 100644 index 0000000000..3069352be8 --- /dev/null +++ b/thunder/core/trace_interpreter.py @@ -0,0 +1,76 @@ +from typing import Any + +from thunder.core import prims +from thunder.core.pytree import tree_map +from thunder.core.trace import VariableInterface +from thunder.core.utils import safe_map_flat, sequencify + + +# TODO: Currently we use trace.args and trace.kwargs to get the arguments +# Maybe we should use these instead +trace_interpreter_skip_list = ( + prims.PrimIDs.UNPACK_EMPTY_DICT, + prims.PrimIDs.UNPACK_KEY, + prims.PrimIDs.UNPACK_SEQUENCE, + prims.PrimIDs.UNPACK_TRIVIAL, + prims.PrimIDs.RETURN, +) + + +def interpret_trace(trace, *args, symbol_mapper=None, with_env=False, **kwargs): + """Interpret a trace. + + Args: + trace: trace to interpret + *args: arguments to interpret the trace with + symbol_mapper: function that redirects the evaluation of a BoundSymbol to a different function + with_env: whether to return the environment after interpreting the trace. Environment is a dictionary + that maps VariableInterface objects to their values. + **kwargs: keyword arguments to interpret the trace with + + Returns: + result of interpreting the trace, optionally with the environment that saves all intermediate values + """ + env = {} + + def read(x: VariableInterface | Any) -> Any: + if isinstance(x, VariableInterface): + return env[x.name] + else: + return x + + def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None: + if not isinstance(v, VariableInterface): + return + # Duplicates are allowed and overwritten + if v.name in env: + if allow_duplicates: + return + raise ValueError(f"Variable {v.name} is being overwritten this is not allowed") + env[v.name] = val + + safe_map_flat(write, list(trace.args), list(args)) + safe_map_flat(write, list(trace.kwargs.values()), list(kwargs.values())) + + for symbol in trace.bound_symbols: + if symbol.sym.id in trace_interpreter_skip_list: + continue + args = tree_map(read, symbol.args) + kwargs = tree_map(read, symbol.kwargs) + prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym + if prim_func is None: + continue + result = prim_func(*args, **kwargs) + try: + safe_map_flat(write, list(sequencify(symbol.output)), list(sequencify(result))) + except AssertionError as e: + raise RuntimeError( + f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {symbol}." + " This is likely due to a mismatch in the number of outputs." + f" The original symbol has {len(symbol.output)} outputs and the dispatched function has {len(sequencify(result))} outputs." + ) from e + + if with_env: + return tree_map(read, trace.output), env + + return tree_map(read, trace.output) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index ff7e31d299..5b056da8b2 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -19,6 +19,7 @@ import thunder.core.utils as utils from thunder.core import dtypes, prims from thunder.core.devices import cpu, Device +from thunder.core.trace_interpreter import interpret_trace as eval_trace, trace_interpreter_skip_list from thunder.core.proxies import ( NumberProxy, Proxy, @@ -1445,74 +1446,6 @@ def symbol_to_eval(bound_symbol): return bound_symbol.sym -# TODO: Currently we use trace.args and trace.kwargs to get the arguments -# Maybe we should use these instead -transform_skip_list = ( - prims.PrimIDs.UNPACK_EMPTY_DICT, - prims.PrimIDs.UNPACK_KEY, - prims.PrimIDs.UNPACK_SEQUENCE, - prims.PrimIDs.UNPACK_TRIVIAL, - prims.PrimIDs.RETURN, -) - - -def eval_trace(trace, *args, symbol_mapper=symbol_to_eval, with_env=False, **kwargs): - """Evaluate a trace. - - Args: - trace: trace to evaluate - *args: arguments to evaluate the trace with - symbol_mapper: function that maps a symbol to a function that evaluates it - **kwargs: keyword arguments to evaluate the trace with - - Returns: - result of evaluating the trace - """ - env = {} - - def read(x: Variable): - if isinstance(x, Variable): - return env[x.name] - else: - return x - - def write(v: Variable, val: Any, allow_duplicates=False) -> None: - if not isinstance(v, Variable): - return - # Duplicates are allowed and overwritten - if v.name in env: - if allow_duplicates: - return - raise ValueError(f"Variable {v.name} is being overwritten this is not allowed") - env[v.name] = val - - safe_map_flat(write, list(trace.args), list(args)) - safe_map_flat(write, list(trace.kwargs.values()), list(kwargs.values())) - - for symbol in trace.bound_symbols: - if symbol.sym.id in transform_skip_list: - continue - args = tree_map(read, symbol.args) - kwargs = tree_map(read, symbol.kwargs) - prim_func = symbol_mapper(symbol) - if prim_func is None: - continue - result = prim_func(*args, **kwargs) - try: - safe_map_flat(write, list(sequencify(symbol.output)), list(sequencify(result))) - except AssertionError as e: - raise RuntimeError( - f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {symbol}." - " This is likely due to a mismatch in the number of outputs." - f" The original symbol has {len(symbol.output)} outputs and the dispatched function has {len(sequencify(result))} outputs." - ) from e - - if with_env: - return tree_map(read, trace.output), env - - return tree_map(read, trace.output) - - def unwrap_one_level_of_subsymbols(trace): new_symbols_iter = ( bound_symbol.subsymbols if len(bound_symbol.subsymbols) > 0 else [bound_symbol] @@ -2266,7 +2199,7 @@ def iter_bound_symbols(bound_symbols): infrastructure """ for symbol in bound_symbols: - if symbol.sym.id in transform_skip_list: + if symbol.sym.id in trace_interpreter_skip_list: continue elif symbol.output is None: continue @@ -2559,7 +2492,7 @@ def check_bsym_for_vjp(bsym): bool: True if the bound symbol is supported by vjp, False otherwise. """ - if bsym.sym.id in transform_skip_list: + if bsym.sym.id in trace_interpreter_skip_list: return True if bsym.sym.id in backward_impls and bsym.sym.id in augmented_forward_impls: diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index c14f7065f0..3c964f5a62 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -99,7 +99,7 @@ def _generate_supported_op_list(checker): Returns: generator: A generator of operator info objects that support vjp. """ - from thunder.core.transforms import transform_skip_list + from thunder.core.transforms import trace_interpreter_skip_list for opinfo in opinfos: if opinfo not in tensor_creation_ops and opinfo.name not in op_skip: @@ -108,7 +108,7 @@ def _generate_supported_op_list(checker): samples = iter(opinfo.sample_inputs("cpu", dtypes.float64, requires_grad=True)) while (sample := next(samples, None)) is not None: trc = thunder.trace()(opinfo.op, *sample.args, **sample.kwargs) - all_skipped = all(s.sym.id in transform_skip_list for s in trc.bound_symbols) + all_skipped = all(s.sym.id in trace_interpreter_skip_list for s in trc.bound_symbols) if all_skipped: continue all_supported = all(checker(s) for s in trc.bound_symbols)