-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tensor Subclasses] Trace transfom to interpret __torch_dispatch__
#1394
base: main
Are you sure you want to change the base?
Conversation
fa05c82
to
519e813
Compare
21c2af8
to
11fea26
Compare
__torch_dispatch__
and get the correct output type. Depends on 1393
5e4b00d
to
4cd7a3d
Compare
__torch_dispatch__
and get the correct output type. Depends on 1393__torch_dispatch__
and get the correct output type. based on 1393
4cd7a3d
to
bed021e
Compare
To my mind, we would get better error handling (eg stack traces) if we resolved the |
This comment was marked as outdated.
This comment was marked as outdated.
3fa8e2d
to
d5fb9fe
Compare
d5fb9fe
to
15c8d12
Compare
15c8d12
to
70dc6ba
Compare
70dc6ba
to
fc6d8a9
Compare
__torch_dispatch__
and get the correct output type. based on 1393__torch_dispatch__
fc6d8a9
to
ce3edbc
Compare
There's a lot going on with this PR, and it's pretty complicated. Maybe we should schedule an online sync, @crcrpar and @IvanYashchuk, to see if we can make it more incremental? |
…ass` lookaside Signed-off-by: Masaki Kozuki <[email protected]>
no `__torch_dispatch__` support at all. Signed-off-by: Masaki Kozuki <[email protected]>
somehow, apparently Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
ce3edbc
to
e9027ba
Compare
I would still prefer to do the flattening during interpretation for the benefit of getting error messages with backtraces. |
We don't have a nice time slots that work all of us. EDIT: #1583, #1584, and #1585 are a sequence of PRs that cover this one and #1415.
I embarrassingly am not familiar with interpretation implementation at all. Could you give me some pointers to add flattening so that it happens during interpretation? EDIT: |
OK; we can try to work asynchronously. For a first incremental PR, would you create a PR adding support for aten operators? In particular, if someone were to call something like |
So far in my implementation there's an optimistic mapping from core aten ops to ltorch ops: lightning-thunder/thunder/transforms/tensor_wrapper_subclass.py Lines 339 to 348 in 515d425
By the way, if we're to cover core aten ops, then I'd say it'd be worth thinking of using thunder as a custom backend after AOTAutograd.
Currently lightning-thunder/thunder/torch/__init__.py Lines 172 to 173 in 9d79b8d
lightning-thunder/thunder/torch/__init__.py Line 4700 in 9d79b8d
lightning-thunder/thunder/torch/__init__.py Line 4308 in 9d79b8d
thunder/torch/__init__.py would be fair.
|
OK; expanding thunder/torch/init.py sounds good for now. Let's not "optimistically" try to map ATen operations to torch operations for the moment, but just treat them like different operations. Would you submit a PR adding torch.ops.aten.add to the torch operations? EDITED BELOW. As a follow-up PR to that, what about working with a program like
Where the initial trace shows the tensor subclass and its flattened information, and the prologue validates the subclass and its flattening. Then I'd be curious to see addition with that tensor, like this:
This can be translated for execution by PyTorch, but I think working through this will be interesting. Then the follow-up question is what the grad transform for it looks like, and how this operation should be translated for execution by nvFuser. |
IMHO, it'd sound more natural to me to register core aten ops to Then registration comes after the aforementioned PR, before #1584 and #1585, followed by some refinement of prologue and how traces with tensor subclasses look accompanied by #1584.
With the experience of #1585, I do think we'd have to let the trace get split into forward and backward before interpreting |
Maybe? I guess it just seems like an easy first step to understand the application of aten operators.
Could you elaborate on this? It would be helpful to see an example. I was thinking that a program like
Could be auto-differentiated like programs today because
Which I expect would get further flattened into a Because of that I was thinking that we'd want to understand the tensor subclass operations ASAP. What are your thoughts, @crcrpar? |
I choose to apply the trace transform AFTER the forward-backward split after I worked on torchao.float8 linears. Full traces are available on https://gist.github.com/crcrpar/9b69ef83e68e306415af091d025cbf9c. import torch.nn as nn
from torchao.float8 import convert_to_float8_training
def main():
batch_size, in_features, out_features = 16, 32, 64
device, dtype = torch.device("cuda"), torch.float32
fp8_model = convert_to_float8_training(nn.Linear(in_features, out_features).to(device=device, dtype=dtype))
jitted = thunder.jit(fp8_model) The first computation trace has the following lines (= # /path/to/site-packages/torchao/float8/float8_linear.py:57: res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
tensor = manual_float8_matmul_with_args_in_float8_140062104971168_2(input_fp8, weight_fp8_t) # tensor: "cuda:0 f32[16, 64]"
# t91 = ltorch.reshape(input_fp8, -1, 32) # t91: "cuda:0 f32[16, 32]"
# t91 = prims.reshape(input_fp8, (16, 32)) # t91: "cuda:0 f32[16, 32]"
# tensor = ltorch.spmm(t91, weight_fp8_t) # tensor: "cuda:0 f32[16, 64]" The corresponding lines are https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L32-L58, which is the forward of Here,
These lines are rewritten to # [3/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
t214 = ltorch.core_aten_t(t204) # t214: "cuda:0 f8_e4m3fn[64, 32]"
# t214 = ltorch.transpose(t204, 0, 1) # t214: "cuda:0 f8_e4m3fn[64, 32]"
# t214 = prims.transpose(t204, (1, 0)) # t214: "cuda:0 f8_e4m3fn[64, 32]"
# [4/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
t215 = ltorch.core_aten_clone(t214, memory_format=None) # t215: "cuda:0 f8_e4m3fn[64, 32]"
# t215 = prims.clone(t214) # t215: "cuda:0 f8_e4m3fn[64, 32]"
# [5/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
t216 = ltorch.core_aten_t(t215) # t216: "cuda:0 f8_e4m3fn[32, 64]"
# t216 = ltorch.transpose(t215, 0, 1) # t216: "cuda:0 f8_e4m3fn[32, 64]"
# t216 = prims.transpose(t215, (1, 0)) # t216: "cuda:0 f8_e4m3fn[32, 64]"
# [6/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
t217 = ltorch.core_aten_reciprocal(scale) # t217: "cuda:0 f32[]"
# t217 = prims.reciprocal(scale) # t217: "cuda:0 f32[]"
# [7/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
t218 = ltorch.core_aten_reciprocal(weight_scale) # t218: "cuda:0 f32[]"
# t218 = prims.reciprocal(weight_scale) # t218: "cuda:0 f32[]"
# [8/8] unrolled `__torch_dispatch__` of `torch.spmm(t212, t205)`
t219 = ltorch.core_aten_scaled_mm(t211, t216, t217, t218, None, None, torch.float32, True) # t219: "cuda:0 f32[16, 64]" (quote of https://gist.github.com/crcrpar/9b69ef83e68e306415af091d025cbf9c#file-1_1_fwd_trace_torchao_fp8tensor_flattened-py-L140-L158) Backward traces are also available in the linked gist. |
Interesting!
Following a previous comment I made, it would be great if they printed an annotation that showed they were tensor subclasses.
I was thinking it would be a good thing to get ops like |
I already recreated a longer sequence of PRs: #1583, #1584, #1585, #1591, and #1592 and the last one improves the
I don't think so. Some core aten ops don't define their backward nor decompositions like |
How does PyTorch generate the backward for those operations on the subclass, then? |
The
autograd would not see |
Interesting! OK; so to return to the sample the initial computation would look like:
And if that was differentiated then we'd ask the executors if they have a grad rule for Like if we had
And then we'd have a new transform, like "add torch dispatch decompositions" that would add the ATen operations under these? So, assuming the subclass maps
And then we could let executors claim these operations after that? This ordering makes sense to me, I think. Is this what you had in mind, @crcrpar? One thing that's unfortunate about this logic is that it doesn't let executors automatically apply their custom differentiation transformation to operations they could ultimately execute. For example, let's say someone writes a tensor subclass that maps to an aten operator that transformer engine could execute. If transformer engine had a custom fwd+bwd for that operation it can no longer apply it, because autodifferentiation has already happened. A fix for this is to register the actual tensor subclass operation with transformer engine. Does that make sense? Does that fit your thinking, @kshitij12345, @IvanYashchuk? |
From my perspective, it would not make a lot of sense to treat def computation(x):
# x: MyTensorSubclass[...]
t0 = ltorch.add(x, 1) # t0 could be `torch.Tensor` depending on `MyTensorSubclass.__torch_dispatch__`
return t0 . Then let thunder generate the backward for computation trace before my trace transform unrolls the behavior of
I don't find
Could you elaborate this sentence? I'm not quite following what it means
I think a new transform here would be close to what I've implemented in 1584.
The trace transform for tensor wrapper subclasses does what it does after forward-backward split, before
If one is content with a tensor subclass that has only one actual data tensor, then I don't think writing a tensor wrapper subclass (= |
Some questions about this approach:
I'm not so sure about this, per my questions above. For particular tensor subclasses it seems like the add could be very different? Like for dtensor or a low precision dtype tensor or a quantized tensor? I think it's possible that an executor would like to control the grad formula for those operations, just like we support controlling the grad formula for torch operations today.
Let's work on the earlier questions first and then we can come back to this.
Same here, let's figure out the sequence of operations together and then we can think about implementations separately.
I'm not super worried about this part for now, it's pretty niche. |
I think it is quite reasonable to expect the executor that defines a different grad to also define a forward. In that scenario, the executor should register a lookaside on the TensorSubclass.add method (and we should check that) and all is good. To my mind, it would be a lot easier if we moved the subclasses to interpretation time, just like we do for autograd functions. We have all the instruments, we just need to use them. |
That could work OK; I'd lean towards working with subclass operations being like working with any other operations if possible, but adding a new extensibility pattern just for them wouldn't be the end of the world, either. Determining the ATen operations at interpretation time might be the thing to do, vs. making it a later transform. |
IMHO it doesn't sound like a legit usage/combination of tensor wrapper subclasses with
What's the metadata here? For an instance of tensor wrapper subclass with |
I understand that PyTorch doesn't support this feature, but I think we will want to. Especially for important subclasses like DTensor. We could also implement DTensor separate from general subclass support, and it would just be a special class that does allow for intercepting the tensor subclass calls?
So when we do |
Currently grad rule table's key is
I'm not sure about how much we should expect that kind of deviation, though then what I can do with this pull request is to apply the trace transform to the initial trace to correct the output of ops including tensor subclasses first, before the forward-backward split. |
Interesting! I thought it would be nice to make them part of the same extensibility point, but I can understand if they're different.
I worry about trying to "fix" the metadata of the ops after the tracing has happened. What if the program conditions on the dtype of such a tensor, for example? Then the interpreter will have recorded an invalid program. I think we will have to decompose into ATen operations to determine the appropriate metadata as the trace is being constructed. |
Now this pull request is split into #1583, #1584, #1585, #1591, #1592
torch.Tensor._make_wrapper_subclass
in their__new__
and define their own__torch_dispatch__
BoundSymbol
s of a trace one by one so that we could make a trace free from actual tensor subclass objects as possible and write out the actual behavior that is defined by__torch_dispatch__
in a trace