Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try lowering aten._scaled_dot_product_flash_attention #569

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/lowering/misc/test_scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch
import torch_ttnn
import ttnn

from tests.utils import assert_with_pcc


class ScaledDotProductAttentionModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs)


@pytest.mark.parametrize(
"input_shape, is_causal",
(
((1, 16, 197, 64), False),
((1, 12, 197, 64), False),
((1, 16, 50, 64), False),
((1, 8, 4096, 40), False),
((1, 8, 1024, 80), False),
((1, 8, 256, 160), False),
((1, 8, 64, 160), False),
((1, 12, 50, 64), False),
((1, 16, 1370, 80), False),
((1, 12, 1, 64), False),
((1, 12, 4, 64), True),
Copy link
Contributor Author

@jdh8 jdh8 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inferred 0 batch size 🤔

FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape0-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 16, 197, 64] vs list(actual_pytorch_result.shape)=[0, 16, 197, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape1-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 197, 64] vs list(actual_pytorch_result.shape)=[0, 12, 197, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape2-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 16, 50, 64] vs list(actual_pytorch_result.shape)=[0, 16, 50, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape3-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 4096, 40] vs list(actual_pytorch_result.shape)=[0, 8, 4096, 40]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape4-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 1024, 80] vs list(actual_pytorch_result.shape)=[0, 8, 1024, 80]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape5-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 256, 160] vs list(actual_pytorch_result.shape)=[0, 8, 256, 160]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape6-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 64, 160] vs list(actual_pytorch_result.shape)=[0, 8, 64, 160]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape7-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 50, 64] vs list(actual_pytorch_result.shape)=[0, 12, 50, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape8-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 16, 1370, 80] vs list(actual_pytorch_result.shape)=[0, 16, 1370, 80]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape9-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 1, 64] vs list(actual_pytorch_result.shape)=[0, 12, 1, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape10-True] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 4, 64] vs list(actual_pytorch_result.shape)=[0, 12, 4, 64]
=================================================================== 11 failed in 14.22s ====================================================================
                 Device | INFO     | Closing user mode device drivers
    def assert_with_pcc(expected_pytorch_result, actual_pytorch_result, pcc=0.999):
>       assert list(expected_pytorch_result.shape) == list(
            actual_pytorch_result.shape
        ), f"list(expected_pytorch_result.shape)={list(expected_pytorch_result.shape)} vs list(actual_pytorch_result.shape)={list(actual_pytorch_result.shape)}"
E       AssertionError: list(expected_pytorch_result.shape)=[1, 12, 4, 64] vs list(actual_pytorch_result.shape)=[0, 12, 4, 64]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just confirmed that this issue still lingers. I filed tenstorrent/tt-metal#16021 to keep track on this.

),
)
def test_sdpa(device, input_shape, is_causal):
module = ScaledDotProductAttentionModule()
query = torch.rand(input_shape, dtype=torch.bfloat16)
key = torch.rand(input_shape, dtype=torch.bfloat16)
value = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = module.forward(query, key, value, is_causal=is_causal)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False)
# The compilation is lazy, so we need to run forward once to trigger the compilation
module = torch.compile(module, backend=torch_ttnn.backend, options=option)
result_after = module.forward(query, key, value, is_causal=is_causal)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert torch.ops.aten._scaled_dot_product_flash_attention.default not in nodes
assert nodes.count(ttnn.transformer.scaled_dot_product_attention) == 1
assert_with_pcc(result_before, result_after)
1 change: 1 addition & 0 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def is_tt_compute(node) -> bool:
ttnn.sum,
ttnn.typecast,
ttnn.argmax,
ttnn.transformer.scaled_dot_product_attention,
]
)

Expand Down
22 changes: 0 additions & 22 deletions torch_ttnn/passes/lowering/to_tt_guard_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,6 @@
["Tensor<[1, 16, 1, 60]> self = ?", "Tensor<[]> other = ?"],
]
aten__log_softmax_default_blocklist = [["Tensor<[19, 256008]> self = ?", "int dim = 1", "bool half_to_float = False"]]
aten__scaled_dot_product_flash_attention_default_blocklist = [
["Tensor<[1, 16, 197, 64]> query = ?", "Tensor<[1, 16, 197, 64]> key = ?", "Tensor<[1, 16, 197, 64]> value = ?"],
["Tensor<[1, 12, 197, 64]> query = ?", "Tensor<[1, 12, 197, 64]> key = ?", "Tensor<[1, 12, 197, 64]> value = ?"],
["Tensor<[1, 16, 50, 64]> query = ?", "Tensor<[1, 16, 50, 64]> key = ?", "Tensor<[1, 16, 50, 64]> value = ?"],
["Tensor<[1, 8, 4096, 40]> query = ?", "Tensor<[1, 8, 4096, 40]> key = ?", "Tensor<[1, 8, 4096, 40]> value = ?"],
["Tensor<[1, 8, 1024, 80]> query = ?", "Tensor<[1, 8, 9, 80]> key = ?", "Tensor<[1, 8, 9, 80]> value = ?"],
["Tensor<[1, 8, 256, 160]> query = ?", "Tensor<[1, 8, 256, 160]> key = ?", "Tensor<[1, 8, 256, 160]> value = ?"],
["Tensor<[1, 8, 64, 160]> query = ?", "Tensor<[1, 8, 64, 160]> key = ?", "Tensor<[1, 8, 64, 160]> value = ?"],
["Tensor<[1, 12, 50, 64]> query = ?", "Tensor<[1, 12, 50, 64]> key = ?", "Tensor<[1, 12, 50, 64]> value = ?"],
["Tensor<[1, 16, 1370, 80]> query = ?", "Tensor<[1, 16, 1370, 80]> key = ?", "Tensor<[1, 16, 1370, 80]> value = ?"],
["Tensor<[1, 12, 1, 64]> query = ?", "Tensor<[1, 12, 1, 64]> key = ?", "Tensor<[1, 12, 1, 64]> value = ?"],
[
"Tensor<[1, 12, 4, 64]> query = ?",
"Tensor<[1, 12, 4, 64]> key = ?",
"Tensor<[1, 12, 4, 64]> value = ?",
"float dropout_p = 0.0",
"bool is_causal = True",
],
]
aten_div_Tensor_blocklist = [
["Tensor<[]> self = ?", "Tensor<[]> other = ?"],
["Tensor<[1, 23, 40, 1]> self = ?", "Tensor<[128]> other = ?"],
Expand Down Expand Up @@ -413,9 +394,6 @@ def guard_aten(blocklist, node):
torch.ops.aten.clamp.default: partial(guard_aten, aten_clamp_default_blocklist),
torch.ops.aten.maximum.default: partial(guard_aten, aten_maximum_default_blocklist),
torch.ops.aten._log_softmax.default: partial(guard_aten, aten__log_softmax_default_blocklist),
torch.ops.aten._scaled_dot_product_flash_attention.default: partial(
guard_aten, aten__scaled_dot_product_flash_attention_default_blocklist
),
torch.ops.aten.div.Tensor: partial(guard_aten, aten_div_Tensor_blocklist),
torch.ops.aten.native_layer_norm.default: partial(guard_aten, aten_native_layer_norm_default_blocklist),
torch.ops.aten.exp.default: partial(guard_aten, aten_exp_default_blocklist),
Expand Down
45 changes: 45 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,21 @@ def __init__(self, target, args, kwargs):
# TODO: no ttnn op can convert _adaptive_avg_pool2d
return self.call_function_prop_meta(target, args, kwargs)

if target == torch.ops.aten._scaled_dot_product_flash_attention.default:

def select(dropout_p=0.0, is_causal=False):
# TODO(jdh8): Add suuport for training mode
if dropout_p > 0.0:
return self.call_function_prop_meta(target, args, kwargs)

return self.call_function_prop_meta(
ttnn.transformer.scaled_dot_product_attention,
args[:3],
{"is_causal": is_causal},
)

return select(*args[3:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to fire a ticket for unsupported cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm marking this issue as a feature request tenstorrent/tt-metal#16022. No input variation now has nonzero dropout_p yet. It's still good to keep an eye.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this logic drop the attention mask, which must be provided if is_causal == False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten._scaled_dot_product_flash_attention does not provide attention mask as far as I know. I have not yet found better documentation. Please correct me if I am wrong.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I am unfamiliar with the aten API. My understanding of the op comes from the functional API https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html


return self.call_function_prop_meta(target, args, kwargs)


Expand Down Expand Up @@ -1149,6 +1164,36 @@ def reshape_1d(code, args=args, kwargs=kwargs):
input_shape = list(tensor.meta["val"].size())
return g.call_function(target_wrappers.roll, (tensor, input_shape, shifts, dims))

if node.target == torch.ops.aten._scaled_dot_product_flash_attention.default:
query, key, value = args
query_shape = query.meta["val"].size()
key_shape = key.meta["val"].size()
value_shape = value.meta["val"].size()

attn_mask = kwargs.get("attn_mask")
dropout_p = kwargs.get("dropout_p", 0.0)
scale = kwargs.get("scale", 1.0 / math.sqrt(query_shape[-1]))

if kwargs.get("is_causal", False):
attn_mask = torch.ones(query_shape[-2], key_shape[-2], dtype=torch.bool).tril()

key_perm = [*range(len(key_shape))]
key_perm[-2], key_perm[-1] = key_perm[-1], key_perm[-2]
key = g.call_function(ttnn.permute, (key, key_perm))

attn_weight = g.call_function(ttnn.matmul, (query, key))
attn_weight = g.call_function(ttnn.mul, (attn_weight, scale))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weight = g.call_function(ttnn.where, (attn_mask, attn_weight, -math.inf))
else:
attn_weight = g.call_function(ttnn.add, (attn_weight, attn_mask))

attn_weight = g.call_function(ttnn.softmax, (attn_weight,), {"dim": -1, "numeric_stable": True})
attn_weight = g.call_function(ttnn.dropout, (attn_weight,), {"p": dropout_p})
return g.call_function(ttnn.matmul, (attn_weight, value))

# PEP 8 suggests this explicit statement
return None

Expand Down
Loading