Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up imports of jit_ext.py #1588

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 37 additions & 83 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,77 @@
import thunder
from __future__ import annotations
import math
from typing import Any, Optional, Dict, Tuple, Literal
import builtins
from typing import Any
import collections
from collections.abc import ValuesView, Iterable, Iterator
from collections.abc import Callable, Sequence
import dataclasses
import weakref
import random
from functools import partial, wraps, reduce
import linecache
import operator
import copy
from functools import wraps
import contextvars
from contextlib import contextmanager
import dis
import warnings
from enum import Enum, auto
from io import StringIO
import inspect
import time

from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data
import thunder.clang as clang
import thunder.core.transforms
import thunder.core.baseutils as baseutils
import thunder.core.codeutils as codeutils
from thunder.core.baseutils import run_once

from types import (
BuiltinMethodType,
CellType,
ClassMethodDescriptorType,
CodeType,
CoroutineType,
FrameType,
FunctionType,
MethodType,
GetSetDescriptorType,
MethodDescriptorType,
MethodType,
ModuleType,
NoneType,
BuiltinFunctionType,
BuiltinMethodType,
MethodWrapperType,
WrapperDescriptorType,
TracebackType,
CellType,
ModuleType,
CodeType,
BuiltinFunctionType,
FunctionType,
MethodType,
GetSetDescriptorType,
UnionType,
WrapperDescriptorType,
)

import torch
import torch.utils.checkpoint

import thunder
from thunder.core.compile_data import get_cache_option, get_compile_data
import thunder.clang as clang
import thunder.core.transforms
import thunder.core.baseutils as baseutils
import thunder.core.codeutils as codeutils
from thunder.core.proxies import (
DistParallelType,
proxy,
AnyProxy,
NumberProxy,
Proxy,
ProxyInterface,
ProxyTag,
AnyProxy,
NumberProxy,
StringProxy,
TensorProxy,
FutureTensorProxy,
Variable,
variableify,
unvariableify,
is_proxy_name_available,
proxy,
unvariableify,
variableify,
)
from thunder.core.trace import set_tracectx, reset_tracectx, tracectx, from_trace
from thunder.core.trace import tracectx, from_trace
from thunder.core.interpreter import (
InterpreterLogItem,
InterpreterFrame,
interpret,
_interpret_call,
CapsuleType,
default_callbacks,
INTERPRETER_CALLBACKS,
INTERPRETER_SIGNALS,
default_opcode_interpreter,
_default_lookaside_map,
InterpreterRuntimeCtx,
ProvenanceRecord,
PseudoInst,
WrappedValue,
_interpret_call,
default_callbacks,
default_lookaside,
do_raise,
get_interpreterruntimectx,
InterpreterRuntimeCtx,
interpret,
interpreter_needs_wrap,
is_opaque,
Py_NULL,
member_descriptor,
WrappedValue,
unwrap,
wrap,
wrap_const,
PseudoInst,
ProvenanceRecord,
interpreter_needs_wrap,
ThunderInterpreterObject,
)
from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language
from thunder.core.baseutils import extract_callable_name
from thunder.core.codeutils import get_siginfo, SigInfo
from thunder.core.codeutils import SigInfo
import thunder.core.prims as prims
from thunder.common import transform_for_execution
from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions
from thunder.core.symbol import Symbol, BoundSymbol, is_traceable

from thunder.extend import Executor
from thunder.common import CompileData, CompileStats
from thunder.core.symbol import Symbol
from thunder.core.trace import TraceCtx, TraceResults
from thunder.torch import _torch_to_thunder_function_map
from thunder.clang import _clang_fn_set
from thunder.core.pytree import tree_map, tree_iter
from thunder.core.compile_data import compile_data_and_stats

#
# jit_ext.py implements extensions of thunder's interpreter
Expand Down Expand Up @@ -554,7 +511,7 @@ def _general_jit_object_setattr_lookaside(obj: Any, name: str, value: Any):
return d
d.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
ud = unwrap(d)
assert type(ud) == dict
assert type(ud) is dict
res = _interpret_call(ud.__setitem__, name, value)
return res

Expand All @@ -565,7 +522,6 @@ def _general_jit_setattr_lookaside(obj: Any, name: str, value: Any):
assert setattr_lookaside is not None

uobj = unwrap(obj)
uname = unwrap(name)

if isinstance(uobj, torch.nn.Module):
# 1) populate the wrappeers for the member dicts
Expand Down Expand Up @@ -705,8 +661,6 @@ def _convert_pytorchfunc_to_thundertrace(
trace.bound_symbols.extend(bsyms)
func_result = unwrap(wrapped_func_result)
if shallow_copy_output and not bsyms:
from thunder.core.baseutils import sequencify

out_to_shallow_copy: dict[Variable, TensorProxy] = {}
for a in sequencify(func_result):
shallow_copy_of_a = prims.shallow_copy.meta(a)
Expand Down Expand Up @@ -1360,7 +1314,7 @@ def _general_jit_wrap_callback(value):
value.provenance.ext_flag |= EXT_FLAG_IS_MODULE
elif isinstance(uvalue, torch.Tensor):
# we always want to proxy torch.Tensor, even const
p = ctx.proxify(value)
ctx.proxify(value)
elif value.provenance.inst is PseudoInst.CONSTANT:
value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED
elif callable(uvalue):
Expand All @@ -1376,7 +1330,7 @@ def _general_jit_wrap_callback(value):
value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED
value.provenance.ext_flag |= EXT_FLAG_IS_CONSTRAINABLE_INPUT
# we follow the caching mechanisms of the eager_unpack_interpreter
p = ctx.proxify(value)
ctx.proxify(value)
else:
return _general_jit_sharp_edge(
f"We are using a (non-const) value of type {type(uvalue).__name__}, which is not identified as an input.",
Expand Down Expand Up @@ -1845,14 +1799,14 @@ def process_recorded_modifications(ctx, epilogue_trace):
and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
and modified_object.provenance.inputs[1].value == "_buffers"
):
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) ## todo: better criterion
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) # todo: better criterion
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
modified_object.provenance.inputs[0]
)
assert typ == "_modules"
root_module_proxy = root_for_provenances.get(root_module_provenance)
if root_module_proxy is None:
## we want this to created in the compute trace context for namespace...
# we want this to created in the compute trace context for namespace...
root_module_proxy = Proxy(history=root_module_provenance)
epilogue_trace.add_name(root_module_proxy.name)
root_for_provenances[root_module_provenance] = root_module_proxy
Expand All @@ -1869,7 +1823,7 @@ def process_recorded_modifications(ctx, epilogue_trace):
name = k
setattr_obj_provenance = modified_object.provenance.inputs[0]
if hasattr(setattr_obj_provenance, "proxy"):
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) ## todo: better criterion
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) # todo: better criterion
setattr_obj_proxy = setattr_obj_provenance.proxy
with tracectx(epilogue_trace):
bsym = prims.pack_attr.bind(setattr_obj_proxy, name, value.value, output=None)
Expand All @@ -1879,7 +1833,7 @@ def process_recorded_modifications(ctx, epilogue_trace):
else:
raise NotImplementedError(f"Modifications {inst} on dicts are not supported")
else:
raise NotImplementedError(f"Modifications of {type(uvalue).__name__} objects are not supported")
raise NotImplementedError(f"Modifications of {type(umodified_object).__name__} objects are not supported")


def bind_inputs(name, trace, input_vars, input_proxies):
Expand Down
Loading