diff --git a/thunder/__init__.py b/thunder/__init__.py index 104af7734f..5af42c14ed 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -72,6 +72,7 @@ from thunder.core.interpreter import print_interpreter_log, print_to_log from thunder.core.jit_ext import thunder_general_jit from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction +from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses # NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this import torch as pytorch @@ -580,6 +581,8 @@ def get_computation_and_inputs(*args, **kwargs): if len(tensor_args_consumed_by_inplace_grouped_by_numel) > 1: vanilla_tensor_args = set(tensor_indices) + computation_trc = flatten_tensor_subclasses(computation_trc) + if epilogue_trc is not None: epilogue_traces = [epilogue_trc] else: diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 8cf179e263..3c1ae4024e 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -2104,8 +2104,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = else: # NOTE Without tuple(t.shape) then the shape would be a torch.Size object shape = tuple(t.shape) - return TensorProxy( - name, + ctor_kwargs = dict( + name=name, shape=tuple(shape), device=device, dtype=dtype, @@ -2116,6 +2116,29 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = thunder_fsdp_padding_size=_thunder_fsdp_padding_size, ) + if type(t) not in (torch.Tensor, torch.nn.Parameter) and isinstance(t, torch.Tensor): + baseutils.check( + hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"), + lambda: f"{t=} seems to be a tensor subclass but not traceable", + ) + tensor_attr_names, metadata = t.__tensor_flatten__() + tensors = [getattr(t, name) for name in tensor_attr_names] + ctor_kwargs.update( + { + "tensors": tensors, + "non_tensors": list(metadata.values()), + "subclass_type": type(t), + } + ) + p = SubclassTensorProxy(**ctor_kwargs) + for name, tensor in zip(tensor_attr_names, tensors): + setattr(p, name, tensor) + for name, value in metadata.items(): + setattr(p, name, value) + return p + else: + return TensorProxy(**ctor_kwargs) + def futuretensorproxy( t: torch.Tensor | TensorProxy | FutureTensorProxy, /, *, name: None | str, history: None | tuple = None