Skip to content

Commit

Permalink
Merge pull request #2641 from pytorch/attention_converter_cherry_pick
Browse files Browse the repository at this point in the history
cherry-pick: Attention converter and linting fixes
  • Loading branch information
gs-olive authored Feb 5, 2024
2 parents c189b4c + 4d11385 commit 5eb323f
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 24.1.1
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2243,7 +2243,14 @@ def tensorrt_scaled_dot_product_attention(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2]
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
kwargs.get("scale", None),
)


Expand Down
29 changes: 20 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def scaled_dot_product_attention(
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
scale: Optional[float],
) -> TRTTensor:
mm = impl.matmul.matrix_multiply(
ctx,
Expand All @@ -27,16 +28,26 @@ def scaled_dot_product_attention(
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
div = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
math.sqrt(query.shape[-1]),
)
if scale is None:
scaled = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
math.sqrt(query.shape[-1]),
)
else:
scaled = impl.elementwise.mul(
ctx,
target,
source_ir,
name + "_scale",
mm,
scale,
)
softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", div, -1
ctx, target, source_ir, name + "_softmax", scaled, -1
)
out = impl.matmul.matrix_multiply(
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_efficient_attention import lower_efficient_attention
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -18,7 +18,7 @@
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
lower_efficient_attention,
lower_scaled_dot_product_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
import operator
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)
REPLACEABLE_ATEN_OPS = {
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
}


def lower_scaled_dot_product_attention(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace specific versions of scaled_dot_product_attention with an equivalent
implementation which can be easily converted to TRT
"""
original_fns, replacement = scaled_dot_product_attention_replacement()
replaced_nodes = []

# For each original function, search for it in the graph and replace
for original in original_fns:
replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
gm,
original,
replacement,
ignore_literals=True,
)

if replaced_nodes:
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
for match in replaced_nodes:
attention_node_replaced = None
# Seek the attention operator being replaced
for node in match.nodes_map:
if node.target in REPLACEABLE_ATEN_OPS:
attention_node_replaced = match.nodes_map[node]
break

assert attention_node_replaced is not None

# If the attention operator had keyword-args, copy them to the new node
if attention_node_replaced.kwargs:
assert len(match.replacements) == 1
new_attention_node = match.replacements[0]
assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)
new_attention_node.kwargs = {**attention_node_replaced.kwargs}

gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")

return gm


def scaled_dot_product_attention_replacement() -> Tuple[
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for efficient attention"""

# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q,
k,
v,
None,
False,
)
out = operator.getitem(outputs, 0)
return out

# Flash Attention original graph
def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
)
out = operator.getitem(outputs, 0)
return out

# Efficient Attention w/Scale original graph
def efficient_scale(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q,
k,
v,
None,
False,
scale=1.0,
)
out = operator.getitem(outputs, 0)
return out

# Flash Attention w/Scale original graph
def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
scale=1.0,
)
out = operator.getitem(outputs, 0)
return out

# Replacement graph consists of the functional version of scaled_dot_product_attention
def replacement(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(query, key, value)

return (efficient, flash, efficient_scale, flash_scale), replacement
117 changes: 117 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,123 @@ def forward(self, q, k, v):
torch._dynamo.reset()


class TestLowerFlashAttention(TestCase):
def test_lower_flash_attention(self):
class FlashAttention(torch.nn.Module):
def forward(self, q, k, v):
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
scale=0.15,
)
return attn[0]

inputs = [
torch.rand(8, 4, 16, 8).half().cuda(),
torch.rand(8, 4, 16, 8).half().cuda(),
torch.rand(8, 4, 16, 8).half().cuda(),
]

fx_graph = torch.fx.symbolic_trace(FlashAttention())
expected_ops = {torch.nn.functional.scaled_dot_product_attention}
unexpected_ops = {torch.ops.aten._scaled_dot_product_flash_attention.default}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
# Remove 1 decimal from the requirement for FP16
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT - 1,
msg=f"FlashAttention TRT outputs don't match with the original model.",
)
torch._dynamo.reset()

def test_flash_attention_converter(self):
class FlashAttention(torch.nn.Module):
def forward(self, q, k, v):
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
scale=0.25,
)
return attn[0]

inputs = [
torch.rand(1, 3, 6, 8).half().cuda(),
torch.rand(1, 3, 2, 8).half().cuda(),
torch.rand(1, 3, 2, 8).half().cuda(),
]

fx_graph = torch.fx.symbolic_trace(FlashAttention())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
# Remove 1 decimal from the requirement for FP16
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT - 1,
msg=f"FlashAttention TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


class TestLowerLinear(TestCase):
def test_lower_linear(self):
class Linear(torch.nn.Module):
Expand Down

0 comments on commit 5eb323f

Please sign in to comment.