diff --git a/tests/lowering/misc/test_scaled_dot_product_attention.py b/tests/lowering/misc/test_scaled_dot_product_attention.py new file mode 100644 index 000000000..83e714aee --- /dev/null +++ b/tests/lowering/misc/test_scaled_dot_product_attention.py @@ -0,0 +1,50 @@ +import pytest +import torch +import torch_ttnn +import ttnn + +from tests.utils import assert_with_pcc + + +class ScaledDotProductAttentionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs) + + +@pytest.mark.parametrize( + "input_shape, is_causal", + ( + ((1, 16, 197, 64), False), + ((1, 12, 197, 64), False), + ((1, 16, 50, 64), False), + ((1, 8, 4096, 40), False), + ((1, 8, 1024, 80), False), + ((1, 8, 256, 160), False), + ((1, 8, 64, 160), False), + ((1, 12, 50, 64), False), + ((1, 16, 1370, 80), False), + ((1, 12, 1, 64), False), + ((1, 12, 4, 64), True), + ), +) +def test_sdpa(device, input_shape, is_causal): + module = ScaledDotProductAttentionModule() + query = torch.rand(input_shape, dtype=torch.bfloat16) + key = torch.rand(input_shape, dtype=torch.bfloat16) + value = torch.rand(input_shape, dtype=torch.bfloat16) + result_before = module.forward(query, key, value, is_causal=is_causal) + + option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False) + # The compilation is lazy, so we need to run forward once to trigger the compilation + module = torch.compile(module, backend=torch_ttnn.backend, options=option) + result_after = module.forward(query, key, value, is_causal=is_causal) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = [node.target for node in option._out_fx_graphs[0].nodes] + assert torch.ops.aten._scaled_dot_product_flash_attention.default not in nodes + assert nodes.count(ttnn.transformer.scaled_dot_product_attention) == 1 + assert_with_pcc(result_before, result_after) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index c55797545..398ae6d0e 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -199,6 +199,7 @@ def is_tt_compute(node) -> bool: ttnn.sum, ttnn.typecast, ttnn.argmax, + ttnn.transformer.scaled_dot_product_attention, ] ) diff --git a/torch_ttnn/passes/lowering/to_tt_guard_autogen.py b/torch_ttnn/passes/lowering/to_tt_guard_autogen.py index e2f6dd8ba..f5598f766 100644 --- a/torch_ttnn/passes/lowering/to_tt_guard_autogen.py +++ b/torch_ttnn/passes/lowering/to_tt_guard_autogen.py @@ -62,25 +62,6 @@ ["Tensor<[1, 16, 1, 60]> self = ?", "Tensor<[]> other = ?"], ] aten__log_softmax_default_blocklist = [["Tensor<[19, 256008]> self = ?", "int dim = 1", "bool half_to_float = False"]] -aten__scaled_dot_product_flash_attention_default_blocklist = [ - ["Tensor<[1, 16, 197, 64]> query = ?", "Tensor<[1, 16, 197, 64]> key = ?", "Tensor<[1, 16, 197, 64]> value = ?"], - ["Tensor<[1, 12, 197, 64]> query = ?", "Tensor<[1, 12, 197, 64]> key = ?", "Tensor<[1, 12, 197, 64]> value = ?"], - ["Tensor<[1, 16, 50, 64]> query = ?", "Tensor<[1, 16, 50, 64]> key = ?", "Tensor<[1, 16, 50, 64]> value = ?"], - ["Tensor<[1, 8, 4096, 40]> query = ?", "Tensor<[1, 8, 4096, 40]> key = ?", "Tensor<[1, 8, 4096, 40]> value = ?"], - ["Tensor<[1, 8, 1024, 80]> query = ?", "Tensor<[1, 8, 9, 80]> key = ?", "Tensor<[1, 8, 9, 80]> value = ?"], - ["Tensor<[1, 8, 256, 160]> query = ?", "Tensor<[1, 8, 256, 160]> key = ?", "Tensor<[1, 8, 256, 160]> value = ?"], - ["Tensor<[1, 8, 64, 160]> query = ?", "Tensor<[1, 8, 64, 160]> key = ?", "Tensor<[1, 8, 64, 160]> value = ?"], - ["Tensor<[1, 12, 50, 64]> query = ?", "Tensor<[1, 12, 50, 64]> key = ?", "Tensor<[1, 12, 50, 64]> value = ?"], - ["Tensor<[1, 16, 1370, 80]> query = ?", "Tensor<[1, 16, 1370, 80]> key = ?", "Tensor<[1, 16, 1370, 80]> value = ?"], - ["Tensor<[1, 12, 1, 64]> query = ?", "Tensor<[1, 12, 1, 64]> key = ?", "Tensor<[1, 12, 1, 64]> value = ?"], - [ - "Tensor<[1, 12, 4, 64]> query = ?", - "Tensor<[1, 12, 4, 64]> key = ?", - "Tensor<[1, 12, 4, 64]> value = ?", - "float dropout_p = 0.0", - "bool is_causal = True", - ], -] aten_div_Tensor_blocklist = [ ["Tensor<[]> self = ?", "Tensor<[]> other = ?"], ["Tensor<[1, 23, 40, 1]> self = ?", "Tensor<[128]> other = ?"], @@ -413,9 +394,6 @@ def guard_aten(blocklist, node): torch.ops.aten.clamp.default: partial(guard_aten, aten_clamp_default_blocklist), torch.ops.aten.maximum.default: partial(guard_aten, aten_maximum_default_blocklist), torch.ops.aten._log_softmax.default: partial(guard_aten, aten__log_softmax_default_blocklist), - torch.ops.aten._scaled_dot_product_flash_attention.default: partial( - guard_aten, aten__scaled_dot_product_flash_attention_default_blocklist - ), torch.ops.aten.div.Tensor: partial(guard_aten, aten_div_Tensor_blocklist), torch.ops.aten.native_layer_norm.default: partial(guard_aten, aten_native_layer_norm_default_blocklist), torch.ops.aten.exp.default: partial(guard_aten, aten_exp_default_blocklist), diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 0117f323d..999c00c74 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -323,6 +323,21 @@ def __init__(self, target, args, kwargs): # TODO: no ttnn op can convert _adaptive_avg_pool2d return self.call_function_prop_meta(target, args, kwargs) + if target == torch.ops.aten._scaled_dot_product_flash_attention.default: + + def select(dropout_p=0.0, is_causal=False): + # TODO(jdh8): Add suuport for training mode + if dropout_p > 0.0: + return self.call_function_prop_meta(target, args, kwargs) + + return self.call_function_prop_meta( + ttnn.transformer.scaled_dot_product_attention, + args[:3], + {"is_causal": is_causal}, + ) + + return select(*args[3:]) + return self.call_function_prop_meta(target, args, kwargs) @@ -1149,6 +1164,36 @@ def reshape_1d(code, args=args, kwargs=kwargs): input_shape = list(tensor.meta["val"].size()) return g.call_function(target_wrappers.roll, (tensor, input_shape, shifts, dims)) + if node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + query, key, value = args + query_shape = query.meta["val"].size() + key_shape = key.meta["val"].size() + value_shape = value.meta["val"].size() + + attn_mask = kwargs.get("attn_mask") + dropout_p = kwargs.get("dropout_p", 0.0) + scale = kwargs.get("scale", 1.0 / math.sqrt(query_shape[-1])) + + if kwargs.get("is_causal", False): + attn_mask = torch.ones(query_shape[-2], key_shape[-2], dtype=torch.bool).tril() + + key_perm = [*range(len(key_shape))] + key_perm[-2], key_perm[-1] = key_perm[-1], key_perm[-2] + key = g.call_function(ttnn.permute, (key, key_perm)) + + attn_weight = g.call_function(ttnn.matmul, (query, key)) + attn_weight = g.call_function(ttnn.mul, (attn_weight, scale)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_weight = g.call_function(ttnn.where, (attn_mask, attn_weight, -math.inf)) + else: + attn_weight = g.call_function(ttnn.add, (attn_weight, attn_mask)) + + attn_weight = g.call_function(ttnn.softmax, (attn_weight,), {"dim": -1, "numeric_stable": True}) + attn_weight = g.call_function(ttnn.dropout, (attn_weight,), {"p": dropout_p}) + return g.call_function(ttnn.matmul, (attn_weight, value)) + # PEP 8 suggests this explicit statement return None