Skip to content

Commit

Permalink
finish cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Jan 21, 2025
1 parent cf53463 commit 24b6bae
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@
import torch.utils.checkpoint

import thunder
from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data
from thunder.core.compile_data import get_cache_option, get_compile_data
import thunder.clang as clang
import thunder.core.transforms
import thunder.core.baseutils as baseutils
import thunder.core.codeutils as codeutils
from thunder.core.proxies import (
AnyProxy,
DistParallelType,
NumberProxy,
Proxy,
ProxyInterface,
Expand All @@ -45,14 +44,13 @@
unvariableify,
variableify,
)
from thunder.core.trace import set_tracectx, reset_tracectx, tracectx, from_trace
from thunder.core.trace import tracectx, from_trace
from thunder.core.interpreter import (
INTERPRETER_CALLBACKS,
INTERPRETER_SIGNALS,
InterpreterRuntimeCtx,
ProvenanceRecord,
PseudoInst,
ThunderInterpreterObject,
WrappedValue,
_interpret_call,
default_callbacks,
Expand All @@ -66,7 +64,6 @@
wrap,
wrap_const,
)
from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language
from thunder.core.codeutils import SigInfo
import thunder.core.prims as prims
from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions
Expand Down Expand Up @@ -514,7 +511,7 @@ def _general_jit_object_setattr_lookaside(obj: Any, name: str, value: Any):
return d
d.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
ud = unwrap(d)
assert type(ud) == dict
assert type(ud) is dict
res = _interpret_call(ud.__setitem__, name, value)
return res

Expand All @@ -525,7 +522,6 @@ def _general_jit_setattr_lookaside(obj: Any, name: str, value: Any):
assert setattr_lookaside is not None

uobj = unwrap(obj)
uname = unwrap(name)

if isinstance(uobj, torch.nn.Module):
# 1) populate the wrappeers for the member dicts
Expand Down Expand Up @@ -665,8 +661,6 @@ def _convert_pytorchfunc_to_thundertrace(
trace.bound_symbols.extend(bsyms)
func_result = unwrap(wrapped_func_result)
if shallow_copy_output and not bsyms:
from thunder.core.baseutils import sequencify

out_to_shallow_copy: dict[Variable, TensorProxy] = {}
for a in sequencify(func_result):
shallow_copy_of_a = prims.shallow_copy.meta(a)
Expand Down Expand Up @@ -1320,7 +1314,7 @@ def _general_jit_wrap_callback(value):
value.provenance.ext_flag |= EXT_FLAG_IS_MODULE
elif isinstance(uvalue, torch.Tensor):
# we always want to proxy torch.Tensor, even const
p = ctx.proxify(value)
ctx.proxify(value)
elif value.provenance.inst is PseudoInst.CONSTANT:
value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED
elif callable(uvalue):
Expand All @@ -1336,7 +1330,7 @@ def _general_jit_wrap_callback(value):
value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED
value.provenance.ext_flag |= EXT_FLAG_IS_CONSTRAINABLE_INPUT
# we follow the caching mechanisms of the eager_unpack_interpreter
p = ctx.proxify(value)
ctx.proxify(value)
else:
return _general_jit_sharp_edge(
f"We are using a (non-const) value of type {type(uvalue).__name__}, which is not identified as an input.",
Expand Down Expand Up @@ -1805,14 +1799,14 @@ def process_recorded_modifications(ctx, epilogue_trace):
and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
and modified_object.provenance.inputs[1].value == "_buffers"
):
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) ## todo: better criterion
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) # todo: better criterion
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
modified_object.provenance.inputs[0]
)
assert typ == "_modules"
root_module_proxy = root_for_provenances.get(root_module_provenance)
if root_module_proxy is None:
## we want this to created in the compute trace context for namespace...
# we want this to created in the compute trace context for namespace...
root_module_proxy = Proxy(history=root_module_provenance)
epilogue_trace.add_name(root_module_proxy.name)
root_for_provenances[root_module_provenance] = root_module_proxy
Expand All @@ -1829,7 +1823,7 @@ def process_recorded_modifications(ctx, epilogue_trace):
name = k
setattr_obj_provenance = modified_object.provenance.inputs[0]
if hasattr(setattr_obj_provenance, "proxy"):
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) ## todo: better criterion
assert isinstance(value.value, (Proxy, int, tuple, NoneType)) # todo: better criterion
setattr_obj_proxy = setattr_obj_provenance.proxy
with tracectx(epilogue_trace):
bsym = prims.pack_attr.bind(setattr_obj_proxy, name, value.value, output=None)
Expand All @@ -1839,7 +1833,7 @@ def process_recorded_modifications(ctx, epilogue_trace):
else:
raise NotImplementedError(f"Modifications {inst} on dicts are not supported")
else:
raise NotImplementedError(f"Modifications of {type(uvalue).__name__} objects are not supported")
raise NotImplementedError(f"Modifications of {type(umodified_object).__name__} objects are not supported")


def bind_inputs(name, trace, input_vars, input_proxies):
Expand Down

0 comments on commit 24b6bae

Please sign in to comment.