From c84d65d53da4281d31be357abd7920cd3e421e84 Mon Sep 17 00:00:00 2001 From: Po-Sheng Chang Date: Thu, 31 Oct 2024 07:06:59 +0000 Subject: [PATCH 1/2] Add torch.ops.aten.cumsum.default lowering to ttnn.moreh_cumsum --- tests/lowering/misc/test_cumsum.py | 42 +++++++++++++++++++ .../passes/lowering/add_data_move_pass.py | 1 + torch_ttnn/passes/lowering/to_tt_pass.py | 7 ++++ 3 files changed, 50 insertions(+) create mode 100644 tests/lowering/misc/test_cumsum.py diff --git a/tests/lowering/misc/test_cumsum.py b/tests/lowering/misc/test_cumsum.py new file mode 100644 index 000000000..b979a5835 --- /dev/null +++ b/tests/lowering/misc/test_cumsum.py @@ -0,0 +1,42 @@ +import torch +import torch_ttnn +import pytest +import ttnn + +from tests.utils import assert_with_pcc + + +class CumsumModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, dim): + return torch.ops.aten.cumsum.default(input, dim=dim) + + +@pytest.mark.parametrize( + "input_shapes, dim", + [ + ((1, 32), 1), + ], +) +def test_cumsum(device, input_shapes, dim): + m = CumsumModule() + inputs = torch.rand(input_shapes, dtype=torch.bfloat16) + result_before = m.forward(inputs, dim) + + option = torch_ttnn.TorchTtnnOption(device=device) + option.gen_graphviz = False + + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + + result_after = m.forward(inputs, dim) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = [node.target for node in option._out_fx_graphs[0].nodes] + assert nodes.count(ttnn.moreh_cumsum) == 1 + + # Check inference result + assert_with_pcc(result_before, result_after, pcc=0.99) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 0fee1b44a..2942ffc55 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -177,6 +177,7 @@ def is_tt_compute(node) -> bool: ttnn.squeeze, ttnn.full, ttnn.as_tensor, + ttnn.moreh_cumsum, ] ) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 87bef6647..5fc7397bc 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -756,6 +756,13 @@ def rewrite_node(node): else: return None + if node.target == torch.ops.aten.cumsum.default: + tensor, dim = args + input_shape = tensor.meta["val"].size() + rank = len(input_shape) + dim = (dim + rank) % rank + return g.call_function(ttnn.moreh_cumsum, (tensor, dim), kwargs) + with g.inserting_before(node): new_node = rewrite_node(node) if new_node is not None: From 7b4ec8be56421e77439fb6844a1b12831935b9e9 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 1 Nov 2024 02:40:47 +0000 Subject: [PATCH 2/2] Workaround for ttnn.moreh_cumsum --- tests/lowering/misc/test_cumsum.py | 31 ++++++++++++++----- .../passes/lowering/add_data_move_pass.py | 21 ++++--------- torch_ttnn/passes/lowering/to_tt_pass.py | 18 ++++++++++- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/tests/lowering/misc/test_cumsum.py b/tests/lowering/misc/test_cumsum.py index b979a5835..c1a666e68 100644 --- a/tests/lowering/misc/test_cumsum.py +++ b/tests/lowering/misc/test_cumsum.py @@ -1,7 +1,6 @@ import torch import torch_ttnn import pytest -import ttnn from tests.utils import assert_with_pcc @@ -17,7 +16,27 @@ def forward(self, input, dim): @pytest.mark.parametrize( "input_shapes, dim", [ - ((1, 32), 1), + ((1, 32), -1), + ((1, 45), -1), + ((1, 59), 1), + ((1, 5), -1), + ((1, 60), 1), + ((1, 10), 1), + ((4, 32, 32), 0), + ((1, 4, 32, 32), 1), + ((4, 4, 32, 32), 0), + ((1, 23, 40), 1), + ((4, 32), 0), + pytest.param( + (1, 1, 32, 32), + 3, + marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"), + ), + pytest.param( + (1, 23, 40), + 2, + marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"), + ), ], ) def test_cumsum(device, input_shapes, dim): @@ -25,9 +44,7 @@ def test_cumsum(device, input_shapes, dim): inputs = torch.rand(input_shapes, dtype=torch.bfloat16) result_before = m.forward(inputs, dim) - option = torch_ttnn.TorchTtnnOption(device=device) - option.gen_graphviz = False - + option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False) # The compilation is lazy, so we need to run forward once to trigger the compilation m = torch.compile(m, backend=torch_ttnn.backend, options=option) @@ -36,7 +53,5 @@ def test_cumsum(device, input_shapes, dim): # Check the graph has be rewritten and contain ttnn ops nodes = [node.target for node in option._out_fx_graphs[0].nodes] - assert nodes.count(ttnn.moreh_cumsum) == 1 - - # Check inference result + assert nodes.count(torch.ops.aten.cumsum.default) == 0 assert_with_pcc(result_before, result_after, pcc=0.99) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 2942ffc55..a5de100a0 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -141,11 +141,6 @@ def is_function_call(node) -> bool: ) -def can_be_tilized(node): - size = node.meta["val"].size() - return len(size) >= 2 and size[-1] % 32 == 0 and size[-2] % 32 == 0 - - # For operations limitations # See https://github.com/tenstorrent-metal/tt-metal/blob/main/ttnn/README.md?plain=1#L19 def is_tt_compute(node) -> bool: @@ -292,8 +287,9 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N new_nodes = list() with g.inserting_before(dst_node): kwargs = {} + # TODO(#372): #322 will enable tile layout for more layout change ops if ( - (dst_node.target in TTNN_LAYOUT_CHANGE_OPS and not can_be_tilized(dst_node)) + dst_node.target in TTNN_LAYOUT_CHANGE_OPS or dst_node.target == ttnn.embedding or dst_node.target == ttnn.zeros_like or dst_node.target == target_wrappers.repeat @@ -307,9 +303,8 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N else: kwargs["dtype"] = TtnnBfloat16() - if (is_tt_compute(dst_node) and dst_node.target not in TTNN_LAYOUT_CHANGE_OPS) or ( - dst_node.target in TTNN_LAYOUT_CHANGE_OPS and HasValidPageSize(src_node.meta["val"].size(), strict=True) - ): + # TODO(#372): #322 will enable device tensor for more layout change ops + if is_tt_compute(dst_node) and dst_node.target not in TTNN_LAYOUT_CHANGE_OPS: kwargs["device"] = device new_nodes.append(g.call_function(ttnn.from_torch, (src_node,), kwargs)) @@ -324,12 +319,8 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node) -> torch.fx.n return None if not is_function_call(dst_node): return None - if ( - dst_node.target not in TTNN_LAYOUT_CHANGE_OPS - or dst_idx != 0 - or not is_tt(src_node) - or (dst_node.target in TTNN_LAYOUT_CHANGE_OPS and can_be_tilized(dst_node)) - ): + # TODO(#372): #322 will enable tile layout for more layout change ops + if dst_node.target not in TTNN_LAYOUT_CHANGE_OPS or dst_idx != 0 or not is_tt(src_node): return None g = dst_node.graph diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 5fc7397bc..e01356c90 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -760,8 +760,24 @@ def rewrite_node(node): tensor, dim = args input_shape = tensor.meta["val"].size() rank = len(input_shape) + if rank > 4: + return None dim = (dim + rank) % rank - return g.call_function(ttnn.moreh_cumsum, (tensor, dim), kwargs) + # Unsqueeze input tensor to 4D for cumsum + # TODO(#367): Special case if dim is inner-most 2 dim. Unsqueeze (x, y) to (x, y, 1, 1) as cumsum currently only support N and C + if (dim - rank) >= -2: + if rank <= 2: + input_4d_shape = (1,) * (2 - rank) + (*input_shape, 1, 1) + elif rank == 3 and dim == 1: + input_4d_shape = (*input_shape, 1) + else: + return None + else: + input_4d_shape = (1,) * (4 - rank) + input_shape + dim += 4 - rank + input_4d = g.call_function(ttnn.reshape, (tensor, input_4d_shape)) + output_4d = g.call_function(ttnn.moreh_cumsum, (input_4d, dim), kwargs) + return g.call_function(ttnn.reshape, (output_4d, input_shape)) with g.inserting_before(node): new_node = rewrite_node(node)