Skip to content

Commit

Permalink
add path of SubclassTensorProxy in tensorproxy
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 5, 2024
1 parent 4e05dae commit 519e813
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
3 changes: 3 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -580,6 +581,8 @@ def get_computation_and_inputs(*args, **kwargs):
if len(tensor_args_consumed_by_inplace_grouped_by_numel) > 1:
vanilla_tensor_args = set(tensor_indices)

computation_trc = flatten_tensor_subclasses(computation_trc)

if epilogue_trc is not None:
epilogue_traces = [epilogue_trc]
else:
Expand Down
27 changes: 25 additions & 2 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2104,8 +2104,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
else:
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
shape = tuple(t.shape)
return TensorProxy(
name,
ctor_kwargs = dict(
name=name,
shape=tuple(shape),
device=device,
dtype=dtype,
Expand All @@ -2116,6 +2116,29 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
)

if type(t) not in (torch.Tensor, torch.nn.Parameter) and isinstance(t, torch.Tensor):
baseutils.check(
hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"),
lambda: f"{t=} seems to be a tensor subclass but not traceable",
)
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = [getattr(t, name) for name in tensor_attr_names]
ctor_kwargs.update(
{
"tensors": tensors,
"non_tensors": list(metadata.values()),
"subclass_type": type(t),
}
)
p = SubclassTensorProxy(**ctor_kwargs)
for name, tensor in zip(tensor_attr_names, tensors):
setattr(p, name, tensor)
for name, value in metadata.items():
setattr(p, name, value)
return p
else:
return TensorProxy(**ctor_kwargs)


def futuretensorproxy(
t: torch.Tensor | TensorProxy | FutureTensorProxy, /, *, name: None | str, history: None | tuple = None
Expand Down

0 comments on commit 519e813

Please sign in to comment.