-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Try lowering aten._scaled_dot_product_flash_attention
#569
base: main
Are you sure you want to change the base?
Conversation
ttnn seems to already have the corresponding API https://docs.tenstorrent.com/ttnn/latest/ttnn/api/ttnn.transformer.scaled_dot_product_attention.html, can it be used here? |
@jdh8 please let us know whatever question you got, I will help to connect with the right stakeholder to resolve this fast. |
0f79229
to
7df6ee4
Compare
((1, 12, 50, 64), False), | ||
((1, 16, 1370, 80), False), | ||
((1, 12, 1, 64), False), | ||
((1, 12, 4, 64), True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inferred 0 batch size 🤔
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape0-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 16, 197, 64] vs list(actual_pytorch_result.shape)=[0, 16, 197, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape1-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 197, 64] vs list(actual_pytorch_result.shape)=[0, 12, 197, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape2-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 16, 50, 64] vs list(actual_pytorch_result.shape)=[0, 16, 50, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape3-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 4096, 40] vs list(actual_pytorch_result.shape)=[0, 8, 4096, 40]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape4-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 1024, 80] vs list(actual_pytorch_result.shape)=[0, 8, 1024, 80]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape5-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 256, 160] vs list(actual_pytorch_result.shape)=[0, 8, 256, 160]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape6-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 8, 64, 160] vs list(actual_pytorch_result.shape)=[0, 8, 64, 160]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape7-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 50, 64] vs list(actual_pytorch_result.shape)=[0, 12, 50, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape8-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 16, 1370, 80] vs list(actual_pytorch_result.shape)=[0, 16, 1370, 80]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape9-False] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 1, 64] vs list(actual_pytorch_result.shape)=[0, 12, 1, 64]
FAILED tests/lowering/misc/test_scaled_dot_product_attention.py::test_sdpa[input_shape10-True] - AssertionError: list(expected_pytorch_result.shape)=[1, 12, 4, 64] vs list(actual_pytorch_result.shape)=[0, 12, 4, 64]
=================================================================== 11 failed in 14.22s ====================================================================
Device | INFO | Closing user mode device drivers
def assert_with_pcc(expected_pytorch_result, actual_pytorch_result, pcc=0.999):
> assert list(expected_pytorch_result.shape) == list(
actual_pytorch_result.shape
), f"list(expected_pytorch_result.shape)={list(expected_pytorch_result.shape)} vs list(actual_pytorch_result.shape)={list(actual_pytorch_result.shape)}"
E AssertionError: list(expected_pytorch_result.shape)=[1, 12, 4, 64] vs list(actual_pytorch_result.shape)=[0, 12, 4, 64]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just confirmed that this issue still lingers. I filed tenstorrent/tt-metal#16021 to keep track on this.
{"is_causal": is_causal}, | ||
) | ||
|
||
return select(*args[3:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to fire a ticket for unsupported cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm marking this issue as a feature request tenstorrent/tt-metal#16022. No input variation now has nonzero dropout_p
yet. It's still good to keep an eye.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this logic drop the attention mask, which must be provided if is_causal == False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten._scaled_dot_product_flash_attention
does not provide attention mask as far as I know. I have not yet found better documentation. Please correct me if I am wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I am unfamiliar with the aten API. My understanding of the op comes from the functional API https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
7df6ee4
to
ea1841f
Compare
ea1841f
to
866703f
Compare
…ult` by removing its blocklist
Ticket
Problem description
Convert
aten._scaled_dot_product_flash_attention
to a series of ops. Future goal might be implementing it as a composite kernel op instead.The source op is functionally equivalent to its high-level counterpart:
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
What's changed