Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ #1394

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 3, 2024

Now this pull request is split into #1583, #1584, #1585, #1591, #1592

  • add a generic proxy class to represent Tensor Wrapper Subclasses that call torch.Tensor._make_wrapper_subclass in their __new__ and define their own __torch_dispatch__
  • add a trace transform that evaluate BoundSymbols of a trace one by one so that we could make a trace free from actual tensor subclass objects as possible and write out the actual behavior that is defined by __torch_dispatch__ in a trace

@crcrpar crcrpar changed the title [do not review] ops with subclass support [do not review] ops with subclass support, on top of 1393 Nov 4, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from fa05c82 to 519e813 Compare November 5, 2024 08:00
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
@crcrpar crcrpar force-pushed the crpa/subclss-tensor-init branch from 21c2af8 to 11fea26 Compare November 6, 2024 08:50
@crcrpar crcrpar changed the base branch from crpa/subclss-tensor-init to main November 6, 2024 13:51
@crcrpar crcrpar changed the title [do not review] ops with subclass support, on top of 1393 [Tensor Subclasses] [do not review] Trace transfom to interpret __torch_dispatch__ and get the correct output type. Depends on 1393 Nov 6, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 5e4b00d to 4cd7a3d Compare November 6, 2024 15:20
@crcrpar crcrpar marked this pull request as ready for review November 6, 2024 15:20
@crcrpar crcrpar changed the title [Tensor Subclasses] [do not review] Trace transfom to interpret __torch_dispatch__ and get the correct output type. Depends on 1393 [Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ and get the correct output type. based on 1393 Nov 6, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 4cd7a3d to bed021e Compare November 7, 2024 06:14
@t-vi
Copy link
Collaborator

t-vi commented Nov 7, 2024

To my mind, we would get better error handling (eg stack traces) if we resolved the __torch_dispatch__ earlier by checking for subclasses in the torchsymbol handling logic. Also, I don't think we have much information about the outputs - e.g. are they subclasses again or the original class, shapes etc. - without doing so, so we cannot evaluate even metadata in control flow.

@crcrpar

This comment was marked as outdated.

@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 3fa8e2d to d5fb9fe Compare November 19, 2024 06:41
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from d5fb9fe to 15c8d12 Compare November 26, 2024 07:22
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 15c8d12 to 70dc6ba Compare November 28, 2024 12:31
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 70dc6ba to fc6d8a9 Compare December 7, 2024 07:22
@crcrpar crcrpar changed the title [Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ and get the correct output type. based on 1393 [Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ Dec 9, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from fc6d8a9 to ce3edbc Compare December 12, 2024 23:23
@mruberry
Copy link
Collaborator

There's a lot going on with this PR, and it's pretty complicated. Maybe we should schedule an online sync, @crcrpar and @IvanYashchuk, to see if we can make it more incremental?

no `__torch_dispatch__` support at all.

Signed-off-by: Masaki Kozuki <[email protected]>
somehow, apparently

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from ce3edbc to e9027ba Compare December 21, 2024 07:28
@t-vi
Copy link
Collaborator

t-vi commented Dec 21, 2024

I would still prefer to do the flattening during interpretation for the benefit of getting error messages with backtraces.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 22, 2024

There's a lot going on with this PR, and it's pretty complicated. Maybe we should schedule an online sync, @crcrpar and @IvanYashchuk, to see if we can make it more incremental?

We don't have a nice time slots that work all of us.
What would have you find PRs incremental?
Would it look better to you to split this into two, one to add a proxy and the other to add trace transforms?

EDIT: #1583, #1584, and #1585 are a sequence of PRs that cover this one and #1415.

I would still prefer to do the flattening during interpretation for the benefit of getting error messages with backtraces.

I embarrassingly am not familiar with interpretation implementation at all. Could you give me some pointers to add flattening so that it happens during interpretation?

EDIT: __torch_dispatch__ overrides the behavior of ops used in backward and the customized behavior could use torch ops that do not have their backward definition. Thus I do think we'd need to put the flattening AFTER forward-backward split as in #1415

@mruberry
Copy link
Collaborator

There's a lot going on with this PR, and it's pretty complicated. Maybe we should schedule an online sync, @crcrpar and @IvanYashchuk, to see if we can make it more incremental?

We don't have a nice time slots that work all of us. What would have you find PRs incremental? Would it look better to you to split this into two, one to add a proxy and the other to add trace transforms?

OK; we can try to work asynchronously. For a first incremental PR, would you create a PR adding support for aten operators? In particular, if someone were to call something like torch.ops.aten.add then how would that work? Can that operator be added to the torch operations (and be in the torch language context), or should it be added to its own aten language context that has a separate file (or set of files)?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 23, 2024

if someone were to call something like torch.ops.aten.add then how would that work?

So far in my implementation there's an optimistic mapping from core aten ops to ltorch ops:

for node in list_of_function_call_node:
if not hasattr(ltorch, node.target._opname):
msg = (
f"`thunder.torch` does not have corresponding op for {node.target._opname}. "
"Think about adding it to thunder/torch/default_torch_ops.py"
f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}"
f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}"
)
raise RuntimeError(msg)
ltorch_ops_for_node_of_ops.append(getattr(ltorch, node.target._opname))
.

By the way, if we're to cover core aten ops, then I'd say it'd be worth thinking of using thunder as a custom backend after AOTAutograd.

Can that operator be added to the torch operations (and be in the torch language context), or should it be added to its own aten language context that has a separate file (or set of files)?

Currently torchsymbol has some for core aten ops, apparently

elif hasattr(torch.ops.aten, name):
id = f"torch.ops.aten.{name}"

@torchsymbol(torch.ops.aten.embedding_backward)
and
torch.ops.aten._adaptive_avg_pool2d_backward,
are a core aten op. So I think extending thunder/torch/__init__.py would be fair.

@mruberry
Copy link
Collaborator

mruberry commented Dec 23, 2024

if someone were to call something like torch.ops.aten.add then how would that work?

So far in my implementation there's an optimistic mapping from core aten ops to ltorch ops:

for node in list_of_function_call_node:
if not hasattr(ltorch, node.target._opname):
msg = (
f"`thunder.torch` does not have corresponding op for {node.target._opname}. "
"Think about adding it to thunder/torch/default_torch_ops.py"
f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}"
f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}"
)
raise RuntimeError(msg)
ltorch_ops_for_node_of_ops.append(getattr(ltorch, node.target._opname))

.
By the way, if we're to cover core aten ops, then I'd say it'd be worth thinking of using thunder as a custom backend after AOTAutograd.

Can that operator be added to the torch operations (and be in the torch language context), or should it be added to its own aten language context that has a separate file (or set of files)?

Currently torchsymbol has some for core aten ops, apparently

elif hasattr(torch.ops.aten, name):
id = f"torch.ops.aten.{name}"

@torchsymbol(torch.ops.aten.embedding_backward)

and

torch.ops.aten._adaptive_avg_pool2d_backward,

are a core aten op. So I think extending thunder/torch/__init__.py would be fair.

OK; expanding thunder/torch/init.py sounds good for now. Let's not "optimistically" try to map ATen operations to torch operations for the moment, but just treat them like different operations.

Would you submit a PR adding torch.ops.aten.add to the torch operations?

EDITED BELOW.

As a follow-up PR to that, what about working with a program like

# Original program
def foo(x):
  return x

# Trace
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  return x

Where the initial trace shows the tensor subclass and its flattened information, and the prologue validates the subclass and its flattening. Then I'd be curious to see addition with that tensor, like this:

# Original program
def foo(x):
  return x + 1

# Trace
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  t0 = MyTensorSubclass.torch.add(x, 1)  # t0: "MyTensorSubclass[cuda:0 f32[12, 12]]"
    # t1 = flatten_tensor_subclass(MyTensorSubclass, x)
    # t2 = torch.ops.aten.add(t1, 1)
      # <decomposition of aten.add into prims would go here> 
    # t0 = unflatten_tensor_subclass(MyTensorSubclass, t2)
  return t0

This can be translated for execution by PyTorch, but I think working through this will be interesting. Then the follow-up question is what the grad transform for it looks like, and how this operation should be translated for execution by nvFuser.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 24, 2024

IMHO, it'd sound more natural to me to register core aten ops to thunder.torch namespace after merging #1583.

Then registration comes after the aforementioned PR, before #1584 and #1585, followed by some refinement of prologue and how traces with tensor subclasses look accompanied by #1584.

the follow-up question is what the grad transform for it looks like, and how this operation should be translated for execution by nvFuser.

With the experience of #1585, I do think we'd have to let the trace get split into forward and backward before interpreting __torch_dispatch__ partly because the extended behavior of certain ops could be dependent on ops without any backward definitions.

@mruberry
Copy link
Collaborator

IMHO, it'd sound more natural to me to register core aten ops to thunder.torch namespace after merging #1583.

Maybe? I guess it just seems like an easy first step to understand the application of aten operators.

Then registration comes after the aforementioned PR, before #1584 and #1585, followed by some refinement of prologue and how traces with tensor subclasses look accompanied by #1584.

the follow-up question is what the grad transform for it looks like, and how this operation should be translated for execution by nvFuser.

With the experience of #1585, I do think we'd have to let the trace get split into forward and backward before interpreting __torch_dispatch__ partly because the extended behavior of certain ops could be dependent on ops without any backward definitions.

Could you elaborate on this? It would be helpful to see an example. I was thinking that a program like

# Trace
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  t0 = MyTensorSubclass.torch.add(x, 1)  # t0: "MyTensorSubclass[cuda:0 f32[12, 12]]"
    # t1 = flatten_tensor_subclass(MyTensorSubclass, x)
    # t2 = torch.ops.aten.add(t1, 1)
      # <decomposition of aten.add into prims would go here> 
    # t0 = unflatten_tensor_subclass(MyTensorSubclass, t2)
  return t0

Could be auto-differentiated like programs today because MyTensorSubclass.torch.add would not define a grad transform, and so the grad transform would "flatten" it to its components, which would be:

# t1 = flatten_tensor_subclass(MyTensorSubclass, x)
    # t2 = torch.ops.aten.add(t1, 1)
      # <decomposition of aten.add into prims would go here> 
    # t0 = unflatten_tensor_subclass(MyTensorSubclass, t2)

Which I expect would get further flattened into a prims.add call. Maybe @IvanYashchuk has a different perspective.

Because of that I was thinking that we'd want to understand the tensor subclass operations ASAP. What are your thoughts, @crcrpar?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 1, 2025

I choose to apply the trace transform AFTER the forward-backward split after I worked on torchao.float8 linears.

Full traces are available on https://gist.github.com/crcrpar/9b69ef83e68e306415af091d025cbf9c.
A program is (quote of https://gist.github.com/crcrpar/9b69ef83e68e306415af091d025cbf9c#file-0_torchao_fp8linear-py)

import torch.nn as nn
from torchao.float8 import convert_to_float8_training

def main():
    batch_size, in_features, out_features = 16, 32, 64
    device, dtype = torch.device("cuda"), torch.float32
    fp8_model = convert_to_float8_training(nn.Linear(in_features, out_features).to(device=device, dtype=dtype))
    jitted = thunder.jit(fp8_model)

The first computation trace has the following lines (= thunder.last_traces(jitted)[0], full trace is https://gist.github.com/crcrpar/9b69ef83e68e306415af091d025cbf9c#file-1_0_first_fwd_trace-py)

  # /path/to/site-packages/torchao/float8/float8_linear.py:57: 	        res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
  tensor = manual_float8_matmul_with_args_in_float8_140062104971168_2(input_fp8, weight_fp8_t)  # tensor: "cuda:0 f32[16, 64]"
    # t91 = ltorch.reshape(input_fp8, -1, 32)  # t91: "cuda:0 f32[16, 32]"
      # t91 = prims.reshape(input_fp8, (16, 32))  # t91: "cuda:0 f32[16, 32]"
    # tensor = ltorch.spmm(t91, weight_fp8_t)  # tensor: "cuda:0 f32[16, 64]"

The corresponding lines are https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L32-L58, which is the forward of manual_float8_matmul_with_args_in_float8(torch.autograd.Function).

Here, t91 and weight_fp8_t are an object of torchao.float8.Float8Tensor.

Float8Tensor.__torch_dispatch__ defines its custom behavior for mm which calls torch._scaled_mm as in https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_python_api.py#L22.

These lines are rewritten to

  # [3/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
  t214 = ltorch.core_aten_t(t204)  # t214: "cuda:0 f8_e4m3fn[64, 32]"
    # t214 = ltorch.transpose(t204, 0, 1)  # t214: "cuda:0 f8_e4m3fn[64, 32]"
      # t214 = prims.transpose(t204, (1, 0))  # t214: "cuda:0 f8_e4m3fn[64, 32]"
  # [4/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
  t215 = ltorch.core_aten_clone(t214, memory_format=None)  # t215: "cuda:0 f8_e4m3fn[64, 32]"
    # t215 = prims.clone(t214)  # t215: "cuda:0 f8_e4m3fn[64, 32]"
  # [5/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
  t216 = ltorch.core_aten_t(t215)  # t216: "cuda:0 f8_e4m3fn[32, 64]"
    # t216 = ltorch.transpose(t215, 0, 1)  # t216: "cuda:0 f8_e4m3fn[32, 64]"
      # t216 = prims.transpose(t215, (1, 0))  # t216: "cuda:0 f8_e4m3fn[32, 64]"
  # [6/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
  t217 = ltorch.core_aten_reciprocal(scale)  # t217: "cuda:0 f32[]"
    # t217 = prims.reciprocal(scale)  # t217: "cuda:0 f32[]"
  # [7/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
  t218 = ltorch.core_aten_reciprocal(weight_scale)  # t218: "cuda:0 f32[]"
    # t218 = prims.reciprocal(weight_scale)  # t218: "cuda:0 f32[]"
  # [8/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
  t219 = ltorch.core_aten_scaled_mm(t211, t216, t217, t218, None, None, torch.float32, True)  # t219: "cuda:0 f32[16, 64]"

(quote of https://gist.github.com/crcrpar/9b69ef83e68e306415af091d025cbf9c#file-1_1_fwd_trace_torchao_fp8tensor_flattened-py-L140-L158)
Here, neither torch._scaled_mm nor torch.ops.aten._scaled_mm have their backward defined in PyTorch. Neither thunder does.
So, this trace transform would better wait on the forward-backward split as at least we know the backward of torch.mm and ltorch.spmm.

Backward traces are also available in the linked gist.

@mruberry
Copy link
Collaborator

mruberry commented Jan 2, 2025

I choose to apply the trace transform AFTER the forward-backward split after I worked on torchao.float8 linears.

Interesting!

  # /path/to/site-packages/torchao/float8/float8_linear.py:57: 	        res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
  tensor = manual_float8_matmul_with_args_in_float8_140062104971168_2(input_fp8, weight_fp8_t)  # tensor: "cuda:0 f32[16, 64]"
    # t91 = ltorch.reshape(input_fp8, -1, 32)  # t91: "cuda:0 f32[16, 32]"
      # t91 = prims.reshape(input_fp8, (16, 32))  # t91: "cuda:0 f32[16, 32]"
    # tensor = ltorch.spmm(t91, weight_fp8_t)  # tensor: "cuda:0 f32[16, 64]"

The corresponding lines are pytorch/ao@fe5f11b/torchao/float8/float8_linear.py#L32-L58, which is the forward of manual_float8_matmul_with_args_in_float8(torch.autograd.Function).

Here, t91 and weight_fp8_t are an object of torchao.float8.Float8Tensor

Following a previous comment I made, it would be great if they printed an annotation that showed they were tensor subclasses.

(quote of gist.github.com/crcrpar/9b69ef83e68e306415af091d025cbf9c#file-1_1_fwd_trace_torchao_fp8tensor_flattened-py-L140-L158) Here, neither torch._scaled_mm nor torch.ops.aten._scaled_mm have their backward defined in PyTorch. Neither thunder does. So, this trace transform would better wait on the forward-backward split as at least we know the backward of torch.mm and ltorch.spmm.

I was thinking it would be a good thing to get ops like torch.ops.aten._scaled_mm and then Thunder would have to define an autograd formula for them or a decomposition for them. Isn't that the point of adding the aten operators?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 2, 2025

I already recreated a longer sequence of PRs: #1583, #1584, #1585, #1591, and #1592 and the last one improves the type_string.

Isn't that the point of adding the aten operators?

I don't think so. Some core aten ops don't define their backward nor decompositions like torch._scaled_mm and torch.ops.aten._scaled_mm.

@mruberry
Copy link
Collaborator

mruberry commented Jan 2, 2025

I already recreated a longer sequence of PRs: #1583, #1584, #1585, #1591, and #1592 and the last one improves the type_string.

Isn't that the point of adding the aten operators?
I don't think so. Some core aten ops don't define their backward nor decompositions like torch._scaled_mm and torch.ops.aten._scaled_mm.

How does PyTorch generate the backward for those operations on the subclass, then?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 3, 2025

The __torch_dispatch__ comes last as in https://pytorch.org/docs/stable/notes/extending.html#extending-torch-native-api:

For the purpose of extending torch, the important subset of the ordering for this discussion is:

vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends

autograd would not see torch._scaled_mm as it's used inside Float8Tensor.__torch_dispatch__.

@mruberry
Copy link
Collaborator

mruberry commented Jan 3, 2025

The __torch_dispatch__ comes last as in pytorch.org/docs/stable/notes/extending.html#extending-torch-native-api:

For the purpose of extending torch, the important subset of the ordering for this discussion is:
vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends

autograd would not see torch._scaled_mm as it's used inside Float8Tensor.__torch_dispatch__.

Interesting!

OK; so to return to the sample the initial computation would look like:

# Trace
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  t0 = MyTensorSubclass.torch.add(x, 1)  # t0: "MyTensorSubclass[cuda:0 f32[12, 12]]"
  return t0

And if that was differentiated then we'd ask the executors if they have a grad rule for MyTensorSubclass.torch.add and if not then we'd differentiate it like torch.add, except all the torch operations would be MyTensorSubclass variations?

Like if we had MyTensorSubclass.torch.mul, it would differentiate into MyTensorSubclass.torch.mul in the forward and MyTensorSubclass.torch.mul in the backward?

# Fwd + Bwd Trace
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  t0 = MyTensorSubclass.torch.mul(x, 1)  # t0: "MyTensorSubclass[cuda:0 f32[12, 12]]"
  return t0
  # grad g for t0 introduced here
  t1 = MyTensorSubclass.torch.mul(g, x)
  return t1

And then we'd have a new transform, like "add torch dispatch decompositions" that would add the ATen operations under these?

So, assuming the subclass maps mul to div (which is silly, but whatever):

# After add torch dispatch decompositions (adds decompositions to the subclass)
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  t0 = MyTensorSubclass.torch.mul(x, 1)  # t0: "MyTensorSubclass[cuda:0 f32[12, 12]]"
    # t0 = aten.div(x, 1)
  return t0
  # grad g for t0 introduced here
  t1 = MyTensorSubclass.torch.mul(g, x) # t1: "MyTensorSubclass[cuda:0 f32[12, 12]]"
    # t1 = aten.div(g, x)
  return t1

And then we could let executors claim these operations after that?

This ordering makes sense to me, I think. Is this what you had in mind, @crcrpar?

One thing that's unfortunate about this logic is that it doesn't let executors automatically apply their custom differentiation transformation to operations they could ultimately execute. For example, let's say someone writes a tensor subclass that maps to an aten operator that transformer engine could execute. If transformer engine had a custom fwd+bwd for that operation it can no longer apply it, because autodifferentiation has already happened. A fix for this is to register the actual tensor subclass operation with transformer engine.

Does that make sense? Does that fit your thinking, @kshitij12345, @IvanYashchuk?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 7, 2025

From my perspective, it would not make a lot of sense to treat torch.add of MyTensorSubclass differently from torch.add of torch.Tensor.
So in my sequence of pull requests (1583, 1584, 1585, 1591, 1592), the computation trace would rather be

def computation(x):
    # x: MyTensorSubclass[...]
    t0 = ltorch.add(x, 1)  # t0 could be `torch.Tensor` depending on `MyTensorSubclass.__torch_dispatch__`
    return t0

. Then let thunder generate the backward for computation trace before my trace transform unrolls the behavior of torch.add that is extended by MyTensorSubclass.__torch_dispatch__.
After thunder generates the backward, forward and backward trace uses a trace transform for tensor wrapper subclasses with __torch_dispatch__ to update themselves so that they become one that free from tensor subclasses in any of its ops, except ctor/dtor/flattening of them.

And if that was differentiated then we'd ask the executors if they have a grad rule for MyTensorSubclass.torch.add and if not then we'd differentiate it like torch.add, except all the torch operations would be MyTensorSubclass variations?

I don't find torch.add of MyTensorSubclass different from torch.add of torch.Tensor until backward generation is done. The trace transform writes down the actual behavior of torch.add of MyTensorSubclass that would be a sequence of torch ops of torch.Tensors. Then executors would claim ops that they want to execute/fuse.

Like if we had MyTensorSubclass.torch.mul, it would differentiate into MyTensorSubclass.torch.mul in the forward and MyTensorSubclass.torch.mul in the backward?

Could you elaborate this sentence? I'm not quite following what it means

And then we'd have a new transform, like "add torch dispatch decompositions" that would add the ATen operations under these?

I think a new transform here would be close to what I've implemented in 1584.

And then we could let executors claim these operations after that?

The trace transform for tensor wrapper subclasses does what it does after forward-backward split, before transform_for_execution. So I think the answer would be yes.

One thing that's unfortunate about this logic is that it doesn't let executors automatically apply their custom differentiation transformation to operations they could ultimately execute.

If one is content with a tensor subclass that has only one actual data tensor, then I don't think writing a tensor wrapper subclass (= torch.Tensor._make_wrapper_subclass inside __init__ and __torch_dispatch__) would be the best or one of the bests.

@mruberry
Copy link
Collaborator

mruberry commented Jan 7, 2025

From my perspective, it would not make a lot of sense to treat torch.add of MyTensorSubclass differently from torch.add of torch.Tensor. So in my sequence of pull requests (1583, 1584, 1585, 1591, 1592), the computation trace would rather be

def computation(x):
    # x: MyTensorSubclass[...]
    t0 = ltorch.add(x, 1)  # t0 could be `torch.Tensor` depending on `MyTensorSubclass.__torch_dispatch__`
    return t0

. Then let thunder generate the backward for computation trace before my trace transform unrolls the behavior of torch.add that is extended by MyTensorSubclass.__torch_dispatch__. After thunder generates the backward, forward and backward trace uses a trace transform for tensor wrapper subclasses with __torch_dispatch__ to update themselves so that they become one that free from tensor subclasses in any of its ops, except ctor/dtor/flattening of them.

Some questions about this approach:

  • What if an executor wants to define a custom grad formula for a subclass tensor operation? For example, if nvFuser would like to define a custom autograd formula for an operation on dtensors?
  • How does the trace known the metadata to continue abstract interpretation with t0 if the original add is not decomposed into aten operations before the grad transform is applied?

And if that was differentiated then we'd ask the executors if they have a grad rule for MyTensorSubclass.torch.add and if not then we'd differentiate it like torch.add, except all the torch operations would be MyTensorSubclass variations?

I don't find torch.add of MyTensorSubclass different from torch.add of torch.Tensor until backward generation is done. The trace transform writes down the actual behavior of torch.add of MyTensorSubclass that would be a sequence of torch ops of torch.Tensors. Then executors would claim ops that they want to execute/fuse.

I'm not so sure about this, per my questions above. For particular tensor subclasses it seems like the add could be very different? Like for dtensor or a low precision dtype tensor or a quantized tensor? I think it's possible that an executor would like to control the grad formula for those operations, just like we support controlling the grad formula for torch operations today.

Like if we had MyTensorSubclass.torch.mul, it would differentiate into MyTensorSubclass.torch.mul in the forward and MyTensorSubclass.torch.mul in the backward?

Could you elaborate this sentence? I'm not quite following what it means

Let's work on the earlier questions first and then we can come back to this.

And then we'd have a new transform, like "add torch dispatch decompositions" that would add the ATen operations under these?

I think a new transform here would be close to what I've implemented in 1584.

Same here, let's figure out the sequence of operations together and then we can think about implementations separately.

And then we could let executors claim these operations after that?

The trace transform for tensor wrapper subclasses does what it does after forward-backward split, before transform_for_execution. So I think the answer would be yes.

One thing that's unfortunate about this logic is that it doesn't let executors automatically apply their custom differentiation transformation to operations they could ultimately execute.

If one is content with a tensor subclass that has only one actual data tensor, then I don't think writing a tensor wrapper subclass (= torch.Tensor._make_wrapper_subclass inside __init__ and __torch_dispatch__) would be the best or one of the bests.

I'm not super worried about this part for now, it's pretty niche.

@t-vi
Copy link
Collaborator

t-vi commented Jan 7, 2025

I'm not so sure about this, per my questions above. For particular tensor subclasses it seems like the add could be very different? Like for dtensor or a low precision dtype tensor or a quantized tensor? I think it's possible that an executor would like to control the grad formula for those operations, just like we support controlling the grad formula for torch operations today.

I think it is quite reasonable to expect the executor that defines a different grad to also define a forward. In that scenario, the executor should register a lookaside on the TensorSubclass.add method (and we should check that) and all is good.

To my mind, it would be a lot easier if we moved the subclasses to interpretation time, just like we do for autograd functions. We have all the instruments, we just need to use them.

@mruberry
Copy link
Collaborator

mruberry commented Jan 8, 2025

I'm not so sure about this, per my questions above. For particular tensor subclasses it seems like the add could be very different? Like for dtensor or a low precision dtype tensor or a quantized tensor? I think it's possible that an executor would like to control the grad formula for those operations, just like we support controlling the grad formula for torch operations today.

I think it is quite reasonable to expect the executor that defines a different grad to also define a forward. In that scenario, the executor should register a lookaside on the TensorSubclass.add method (and we should check that) and all is good.

To my mind, it would be a lot easier if we moved the subclasses to interpretation time, just like we do for autograd functions. We have all the instruments, we just need to use them.

That could work OK; I'd lean towards working with subclass operations being like working with any other operations if possible, but adding a new extensibility pattern just for them wouldn't be the end of the world, either.

Determining the ATen operations at interpretation time might be the thing to do, vs. making it a later transform.

@crcrpar crcrpar marked this pull request as draft January 9, 2025 09:18
@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 14, 2025

What if an executor wants to define a custom grad formula for a subclass tensor operation? For example, if nvFuser would like to define a custom autograd formula for an operation on dtensors?

IMHO it doesn't sound like a legit usage/combination of tensor wrapper subclasses with __torch_dispatch__ and custom grad rule defined in Thunder. The PyTorch doc says, e.g.

This code runs “below all features”. It is thus only responsible, like a regular backend, for generating the output value of each Tensor (and can, and should, ignore all advanced features like autograd, autocast, etc).

How does the trace known the metadata to continue abstract interpretation with t0 if the original add is not decomposed into aten operations before the grad transform is applied?

What's the metadata here? For an instance of tensor wrapper subclass with __torch_dispatch__, we set some major metadata such as dtype, shape, and device. Then each meta op can infer the output metadata of those from those metadata of the instance.

@mruberry
Copy link
Collaborator

What if an executor wants to define a custom grad formula for a subclass tensor operation? For example, if nvFuser would like to define a custom autograd formula for an operation on dtensors?

IMHO it doesn't sound like a legit usage/combination of tensor wrapper subclasses with __torch_dispatch__ and custom grad rule defined in Thunder. The PyTorch doc says, e.g.

This code runs “below all features”. It is thus only responsible, like a regular backend, for generating the output value of each Tensor (and can, and should, ignore all advanced features like autograd, autocast, etc).

I understand that PyTorch doesn't support this feature, but I think we will want to. Especially for important subclasses like DTensor. We could also implement DTensor separate from general subclass support, and it would just be a special class that does allow for intercepting the tensor subclass calls?

How does the trace known the metadata to continue abstract interpretation with t0 if the original add is not decomposed into aten operations before the grad transform is applied?

What's the metadata here? For an instance of tensor wrapper subclass with __torch_dispatch__, we set some major metadata such as dtype, shape, and device. Then each meta op can infer the output metadata of those from those metadata of the instance.

So when we do c = torch.add(a, b), and a is a tensor subclass, so the call is more conceptually like c = MyTensorSubclass.add(a, b), is the assumption that that the metadata of c can be determined by calling the meta function for torch.add on a and b, treating a as a regular tensor and not a subclass? I don't think that's always the case, right? One example that comes to mind is that some tensor subclasses change the type promotion rules for operations.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 17, 2025

Especially for important subclasses like DTensor. We could also implement DTensor separate from general subclass support, and it would just be a special class that does allow for intercepting the tensor subclass calls?

Currently grad rule table's key is Symbol.id thus I think it'll boil down to defining new torchsymbols for ops with those tensor subclasses and new proxy classes for them.
Thus I think it's not on the radar of my work here and the implementation for it would be more intuitive than the trace transform I've written recently.

So when we do c = torch.add(a, b), and a is a tensor subclass, so the call is more conceptually like c = MyTensorSubclass.add(a, b), is the assumption that that the metadata of c can be determined by calling the meta function for torch.add on a and b, treating a as a regular tensor and not a subclass? I don't think that's always the case, right? One example that comes to mind is that some tensor subclasses change the type promotion rules for operations.

I'm not sure about how much we should expect that kind of deviation, though then what I can do with this pull request is to apply the trace transform to the initial trace to correct the output of ops including tensor subclasses first, before the forward-backward split.

@mruberry
Copy link
Collaborator

Especially for important subclasses like DTensor. We could also implement DTensor separate from general subclass support, and it would just be a special class that does allow for intercepting the tensor subclass calls?

Currently grad rule table's key is Symbol.id thus I think it'll boil down to defining new torchsymbols for ops with those tensor subclasses and new proxy classes for them. Thus I think it's not on the radar of my work here and the implementation for it would be more intuitive than the trace transform I've written recently.

Interesting! I thought it would be nice to make them part of the same extensibility point, but I can understand if they're different.

So when we do c = torch.add(a, b), and a is a tensor subclass, so the call is more conceptually like c = MyTensorSubclass.add(a, b), is the assumption that that the metadata of c can be determined by calling the meta function for torch.add on a and b, treating a as a regular tensor and not a subclass? I don't think that's always the case, right? One example that comes to mind is that some tensor subclasses change the type promotion rules for operations.

I'm not sure about how much we should expect that kind of deviation, though then what I can do with this pull request is to apply the trace transform to the initial trace to correct the output of ops including tensor subclasses first, before the forward-backward split.

I worry about trying to "fix" the metadata of the ops after the tracing has happened. What if the program conditions on the dtype of such a tensor, for example? Then the interpreter will have recorded an invalid program. I think we will have to decompose into ATen operations to determine the appropriate metadata as the trace is being constructed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants