diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4738ea80be..f4ac2ab9e2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,7 @@ repos: hooks: - id: ruff - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 24.1.1 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0592160a4c..69cb6cecb1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2243,7 +2243,14 @@ def tensorrt_scaled_dot_product_attention( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.attention.scaled_dot_product_attention( - ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2] + ctx, + target, + SourceIR.TORCHTRT_LOWERED, + name, + args[0], + args[1], + args[2], + kwargs.get("scale", None), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py index 6221357ca2..7b8c99fe44 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/attention.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -17,6 +17,7 @@ def scaled_dot_product_attention( query: TRTTensor, key: TRTTensor, value: TRTTensor, + scale: Optional[float], ) -> TRTTensor: mm = impl.matmul.matrix_multiply( ctx, @@ -27,16 +28,26 @@ def scaled_dot_product_attention( key, other_matrix_op=trt.MatrixOperation.TRANSPOSE, ) - div = impl.elementwise.div( - ctx, - target, - source_ir, - name + "_scale", - mm, - math.sqrt(query.shape[-1]), - ) + if scale is None: + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + math.sqrt(query.shape[-1]), + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) softmax = impl.normalization.softmax( - ctx, target, source_ir, name + "_softmax", div, -1 + ctx, target, source_ir, name + "_softmax", scaled, -1 ) out = impl.matmul.matrix_multiply( ctx, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index d6e12f5215..489805cb43 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -5,8 +5,8 @@ from .constant_folding import constant_fold from .fuse_prims_broadcast import fuse_prims_broadcast -from .lower_efficient_attention import lower_efficient_attention from .lower_linear import lower_linear +from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output @@ -18,7 +18,7 @@ remove_input_alias_fixing_clones, constant_fold, repair_input_as_output, - lower_efficient_attention, + lower_scaled_dot_product_attention, lower_linear, fuse_prims_broadcast, replace_max_pool_with_indices, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py deleted file mode 100644 index 9bcf38d30c..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging -import operator -from typing import Callable, Sequence, Tuple - -import torch -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) - - -def lower_efficient_attention( - gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] -) -> torch.fx.GraphModule: - """Replace a specific version of scaled_dot_product_attention with an equivalent - implementation which can be easily converted to TRT - """ - orig, replacement = efficient_attention_replacement() - - if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after lowering _scaled_dot_product_efficient_attention:\n{gm.graph}" - ) - - return gm - - -def efficient_attention_replacement() -> Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], -]: - """Constructs the original and replacement functions for efficient attention""" - - # Original graph - def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( - q, k, v, None, False - ) - out = operator.getitem(outputs, 0) - return out - - # Replacement graph consists of the functional version of scaled_dot_product_attention - def replacement( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - return torch.nn.functional.scaled_dot_product_attention(query, key, value) - - return orig, replacement diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py new file mode 100644 index 0000000000..161dbbe9df --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -0,0 +1,123 @@ +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + assert len(match.replacements) == 1 + new_attention_node = match.replacements[0] + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 9070c8373f..f9a84f2db6 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -267,6 +267,123 @@ def forward(self, q, k, v): torch._dynamo.reset() +class TestLowerFlashAttention(TestCase): + def test_lower_flash_attention(self): + class FlashAttention(torch.nn.Module): + def forward(self, q, k, v): + attn = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=0.15, + ) + return attn[0] + + inputs = [ + torch.rand(8, 4, 16, 8).half().cuda(), + torch.rand(8, 4, 16, 8).half().cuda(), + torch.rand(8, 4, 16, 8).half().cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(FlashAttention()) + expected_ops = {torch.nn.functional.scaled_dot_product_attention} + unexpected_ops = {torch.ops.aten._scaled_dot_product_flash_attention.default} + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + # Remove 1 decimal from the requirement for FP16 + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT - 1, + msg=f"FlashAttention TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_flash_attention_converter(self): + class FlashAttention(torch.nn.Module): + def forward(self, q, k, v): + attn = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=0.25, + ) + return attn[0] + + inputs = [ + torch.rand(1, 3, 6, 8).half().cuda(), + torch.rand(1, 3, 2, 8).half().cuda(), + torch.rand(1, 3, 2, 8).half().cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(FlashAttention()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + # Remove 1 decimal from the requirement for FP16 + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT - 1, + msg=f"FlashAttention TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + class TestLowerLinear(TestCase): def test_lower_linear(self): class Linear(torch.nn.Module):