-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Try lowering aten._scaled_dot_product_flash_attention
#569
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import pytest | ||
import torch | ||
import torch_ttnn | ||
import ttnn | ||
|
||
from tests.utils import assert_with_pcc | ||
|
||
|
||
class ScaledDotProductAttentionModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, *args, **kwargs): | ||
return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, is_causal", | ||
( | ||
((1, 16, 197, 64), False), | ||
((1, 12, 197, 64), False), | ||
((1, 16, 50, 64), False), | ||
((1, 8, 4096, 40), False), | ||
((1, 8, 1024, 80), False), | ||
((1, 8, 256, 160), False), | ||
((1, 8, 64, 160), False), | ||
((1, 12, 50, 64), False), | ||
((1, 16, 1370, 80), False), | ||
((1, 12, 1, 64), False), | ||
((1, 12, 4, 64), True), | ||
), | ||
) | ||
def test_sdpa(device, input_shape, is_causal): | ||
module = ScaledDotProductAttentionModule() | ||
query = torch.rand(input_shape, dtype=torch.bfloat16) | ||
key = torch.rand(input_shape, dtype=torch.bfloat16) | ||
value = torch.rand(input_shape, dtype=torch.bfloat16) | ||
result_before = module.forward(query, key, value, is_causal=is_causal) | ||
|
||
option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False) | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
module = torch.compile(module, backend=torch_ttnn.backend, options=option) | ||
result_after = module.forward(query, key, value, is_causal=is_causal) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = [node.target for node in option._out_fx_graphs[0].nodes] | ||
assert torch.ops.aten._scaled_dot_product_flash_attention.default not in nodes | ||
assert nodes.count(ttnn.transformer.scaled_dot_product_attention) == 1 | ||
assert_with_pcc(result_before, result_after) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -323,6 +323,21 @@ def __init__(self, target, args, kwargs): | |
# TODO: no ttnn op can convert _adaptive_avg_pool2d | ||
return self.call_function_prop_meta(target, args, kwargs) | ||
|
||
if target == torch.ops.aten._scaled_dot_product_flash_attention.default: | ||
|
||
def select(dropout_p=0.0, is_causal=False): | ||
# TODO(jdh8): Add suuport for training mode | ||
if dropout_p > 0.0: | ||
return self.call_function_prop_meta(target, args, kwargs) | ||
|
||
return self.call_function_prop_meta( | ||
ttnn.transformer.scaled_dot_product_attention, | ||
args[:3], | ||
{"is_causal": is_causal}, | ||
) | ||
|
||
return select(*args[3:]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to fire a ticket for unsupported cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm marking this issue as a feature request tenstorrent/tt-metal#16022. No input variation now has nonzero There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this logic drop the attention mask, which must be provided if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I am unfamiliar with the aten API. My understanding of the op comes from the functional API https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html |
||
|
||
return self.call_function_prop_meta(target, args, kwargs) | ||
|
||
|
||
|
@@ -1149,6 +1164,36 @@ def reshape_1d(code, args=args, kwargs=kwargs): | |
input_shape = list(tensor.meta["val"].size()) | ||
return g.call_function(target_wrappers.roll, (tensor, input_shape, shifts, dims)) | ||
|
||
if node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: | ||
query, key, value = args | ||
query_shape = query.meta["val"].size() | ||
key_shape = key.meta["val"].size() | ||
value_shape = value.meta["val"].size() | ||
|
||
attn_mask = kwargs.get("attn_mask") | ||
dropout_p = kwargs.get("dropout_p", 0.0) | ||
scale = kwargs.get("scale", 1.0 / math.sqrt(query_shape[-1])) | ||
|
||
if kwargs.get("is_causal", False): | ||
attn_mask = torch.ones(query_shape[-2], key_shape[-2], dtype=torch.bool).tril() | ||
|
||
key_perm = [*range(len(key_shape))] | ||
key_perm[-2], key_perm[-1] = key_perm[-1], key_perm[-2] | ||
key = g.call_function(ttnn.permute, (key, key_perm)) | ||
|
||
attn_weight = g.call_function(ttnn.matmul, (query, key)) | ||
attn_weight = g.call_function(ttnn.mul, (attn_weight, scale)) | ||
|
||
if attn_mask is not None: | ||
if attn_mask.dtype == torch.bool: | ||
attn_weight = g.call_function(ttnn.where, (attn_mask, attn_weight, -math.inf)) | ||
else: | ||
attn_weight = g.call_function(ttnn.add, (attn_weight, attn_mask)) | ||
|
||
attn_weight = g.call_function(ttnn.softmax, (attn_weight,), {"dim": -1, "numeric_stable": True}) | ||
attn_weight = g.call_function(ttnn.dropout, (attn_weight,), {"p": dropout_p}) | ||
return g.call_function(ttnn.matmul, (attn_weight, value)) | ||
|
||
# PEP 8 suggests this explicit statement | ||
return None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inferred 0 batch size 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just confirmed that this issue still lingers. I filed tenstorrent/tt-metal#16021 to keep track on this.