diff --git a/tests/lowering/misc/test_cumsum.py b/tests/lowering/misc/test_cumsum.py index b979a5835..c1a666e68 100644 --- a/tests/lowering/misc/test_cumsum.py +++ b/tests/lowering/misc/test_cumsum.py @@ -1,7 +1,6 @@ import torch import torch_ttnn import pytest -import ttnn from tests.utils import assert_with_pcc @@ -17,7 +16,27 @@ def forward(self, input, dim): @pytest.mark.parametrize( "input_shapes, dim", [ - ((1, 32), 1), + ((1, 32), -1), + ((1, 45), -1), + ((1, 59), 1), + ((1, 5), -1), + ((1, 60), 1), + ((1, 10), 1), + ((4, 32, 32), 0), + ((1, 4, 32, 32), 1), + ((4, 4, 32, 32), 0), + ((1, 23, 40), 1), + ((4, 32), 0), + pytest.param( + (1, 1, 32, 32), + 3, + marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"), + ), + pytest.param( + (1, 23, 40), + 2, + marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"), + ), ], ) def test_cumsum(device, input_shapes, dim): @@ -25,9 +44,7 @@ def test_cumsum(device, input_shapes, dim): inputs = torch.rand(input_shapes, dtype=torch.bfloat16) result_before = m.forward(inputs, dim) - option = torch_ttnn.TorchTtnnOption(device=device) - option.gen_graphviz = False - + option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False) # The compilation is lazy, so we need to run forward once to trigger the compilation m = torch.compile(m, backend=torch_ttnn.backend, options=option) @@ -36,7 +53,5 @@ def test_cumsum(device, input_shapes, dim): # Check the graph has be rewritten and contain ttnn ops nodes = [node.target for node in option._out_fx_graphs[0].nodes] - assert nodes.count(ttnn.moreh_cumsum) == 1 - - # Check inference result + assert nodes.count(torch.ops.aten.cumsum.default) == 0 assert_with_pcc(result_before, result_after, pcc=0.99) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 2942ffc55..882a8a304 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -141,11 +141,6 @@ def is_function_call(node) -> bool: ) -def can_be_tilized(node): - size = node.meta["val"].size() - return len(size) >= 2 and size[-1] % 32 == 0 and size[-2] % 32 == 0 - - # For operations limitations # See https://github.com/tenstorrent-metal/tt-metal/blob/main/ttnn/README.md?plain=1#L19 def is_tt_compute(node) -> bool: @@ -292,8 +287,9 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N new_nodes = list() with g.inserting_before(dst_node): kwargs = {} + # TODO(#322): #322 will enable tile layout for more layout change ops if ( - (dst_node.target in TTNN_LAYOUT_CHANGE_OPS and not can_be_tilized(dst_node)) + dst_node.target in TTNN_LAYOUT_CHANGE_OPS or dst_node.target == ttnn.embedding or dst_node.target == ttnn.zeros_like or dst_node.target == target_wrappers.repeat @@ -307,9 +303,8 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N else: kwargs["dtype"] = TtnnBfloat16() - if (is_tt_compute(dst_node) and dst_node.target not in TTNN_LAYOUT_CHANGE_OPS) or ( - dst_node.target in TTNN_LAYOUT_CHANGE_OPS and HasValidPageSize(src_node.meta["val"].size(), strict=True) - ): + # TODO(#322): #322 will enable device tensor for more layout change ops + if is_tt_compute(dst_node) and dst_node.target not in TTNN_LAYOUT_CHANGE_OPS: kwargs["device"] = device new_nodes.append(g.call_function(ttnn.from_torch, (src_node,), kwargs)) @@ -324,12 +319,8 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node) -> torch.fx.n return None if not is_function_call(dst_node): return None - if ( - dst_node.target not in TTNN_LAYOUT_CHANGE_OPS - or dst_idx != 0 - or not is_tt(src_node) - or (dst_node.target in TTNN_LAYOUT_CHANGE_OPS and can_be_tilized(dst_node)) - ): + # TODO(#322): #322 will enable tile layout for more layout change ops + if dst_node.target not in TTNN_LAYOUT_CHANGE_OPS or dst_idx != 0 or not is_tt(src_node): return None g = dst_node.graph diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 5fc7397bc..e01356c90 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -760,8 +760,24 @@ def rewrite_node(node): tensor, dim = args input_shape = tensor.meta["val"].size() rank = len(input_shape) + if rank > 4: + return None dim = (dim + rank) % rank - return g.call_function(ttnn.moreh_cumsum, (tensor, dim), kwargs) + # Unsqueeze input tensor to 4D for cumsum + # TODO(#367): Special case if dim is inner-most 2 dim. Unsqueeze (x, y) to (x, y, 1, 1) as cumsum currently only support N and C + if (dim - rank) >= -2: + if rank <= 2: + input_4d_shape = (1,) * (2 - rank) + (*input_shape, 1, 1) + elif rank == 3 and dim == 1: + input_4d_shape = (*input_shape, 1) + else: + return None + else: + input_4d_shape = (1,) * (4 - rank) + input_shape + dim += 4 - rank + input_4d = g.call_function(ttnn.reshape, (tensor, input_4d_shape)) + output_4d = g.call_function(ttnn.moreh_cumsum, (input_4d, dim), kwargs) + return g.call_function(ttnn.reshape, (output_4d, input_shape)) with g.inserting_before(node): new_node = rewrite_node(node)