Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert aten.cumsum to ttnn.moreh_cumsum #370

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/lowering/misc/test_cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch_ttnn
import pytest

from tests.utils import assert_with_pcc


class CumsumModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim):
return torch.ops.aten.cumsum.default(input, dim=dim)


@pytest.mark.parametrize(
"input_shapes, dim",
[
((1, 32), -1),
((1, 45), -1),
((1, 59), 1),
((1, 5), -1),
((1, 60), 1),
((1, 10), 1),
((4, 32, 32), 0),
((1, 4, 32, 32), 1),
((4, 4, 32, 32), 0),
((1, 23, 40), 1),
((4, 32), 0),
pytest.param(
(1, 1, 32, 32),
3,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
pytest.param(
(1, 23, 40),
2,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
],
)
def test_cumsum(device, input_shapes, dim):
m = CumsumModule()
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
result_before = m.forward(inputs, dim)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

result_after = m.forward(inputs, dim)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert nodes.count(torch.ops.aten.cumsum.default) == 0
assert_with_pcc(result_before, result_after, pcc=0.99)
9 changes: 3 additions & 6 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,6 @@ def is_function_call(node) -> bool:
)


def can_be_tilized(node):
size = node.meta["val"].size()
return len(size) >= 2 and size[-1] % 32 == 0 and size[-2] % 32 == 0


# For operations limitations
# See https://github.com/tenstorrent-metal/tt-metal/blob/main/ttnn/README.md?plain=1#L19
def is_tt_compute(node) -> bool:
Expand Down Expand Up @@ -178,6 +173,7 @@ def is_tt_compute(node) -> bool:
ttnn.squeeze,
ttnn.full,
ttnn.as_tensor,
ttnn.moreh_cumsum,
]
)

Expand Down Expand Up @@ -335,7 +331,8 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node, device) -> to
need_from_device = False
need_to_layout = False
need_to_device = False
if dst_node.target in TTNN_LAYOUT_CHANGE_OPS and dst_idx == 0 and is_tt(src_node) and not can_be_tilized(dst_node):
# TODO(#372): #322 will enable tile layout for more layout change ops
if dst_node.target in TTNN_LAYOUT_CHANGE_OPS and dst_idx == 0 and is_tt(src_node):
need_from_device = True
need_to_layout = True

Expand Down
23 changes: 23 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,29 @@ def rewrite_node(node):
else:
return None

if node.target == torch.ops.aten.cumsum.default:
tensor, dim = args
input_shape = tensor.meta["val"].size()
rank = len(input_shape)
if rank > 4:
return None
dim = (dim + rank) % rank
# Unsqueeze input tensor to 4D for cumsum
# TODO(#367): Special case if dim is inner-most 2 dim. Unsqueeze (x, y) to (x, y, 1, 1) as cumsum currently only support N and C
if (dim - rank) >= -2:
if rank <= 2:
input_4d_shape = (1,) * (2 - rank) + (*input_shape, 1, 1)
elif rank == 3 and dim == 1:
input_4d_shape = (*input_shape, 1)
else:
return None
else:
input_4d_shape = (1,) * (4 - rank) + input_shape
dim += 4 - rank
input_4d = g.call_function(ttnn.reshape, (tensor, input_4d_shape))
output_4d = g.call_function(ttnn.moreh_cumsum, (input_4d, dim), kwargs)
return g.call_function(ttnn.reshape, (output_4d, input_shape))

with g.inserting_before(node):
new_node = rewrite_node(node)
if new_node is not None:
Expand Down
Loading