Skip to content

Commit

Permalink
Workaround for ttnn.moreh_cumsum
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrysky3 committed Nov 1, 2024
1 parent c84d65d commit b50cf63
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 24 deletions.
31 changes: 23 additions & 8 deletions tests/lowering/misc/test_cumsum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch_ttnn
import pytest
import ttnn

from tests.utils import assert_with_pcc

Expand All @@ -17,17 +16,35 @@ def forward(self, input, dim):
@pytest.mark.parametrize(
"input_shapes, dim",
[
((1, 32), 1),
((1, 32), -1),
((1, 45), -1),
((1, 59), 1),
((1, 5), -1),
((1, 60), 1),
((1, 10), 1),
((4, 32, 32), 0),
((1, 4, 32, 32), 1),
((4, 4, 32, 32), 0),
((1, 23, 40), 1),
((4, 32), 0),
pytest.param(
(1, 1, 32, 32),
3,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
pytest.param(
(1, 23, 40),
2,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
],
)
def test_cumsum(device, input_shapes, dim):
m = CumsumModule()
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
result_before = m.forward(inputs, dim)

option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = False

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

Expand All @@ -36,7 +53,5 @@ def test_cumsum(device, input_shapes, dim):

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert nodes.count(ttnn.moreh_cumsum) == 1

# Check inference result
assert nodes.count(torch.ops.aten.cumsum.default) == 0
assert_with_pcc(result_before, result_after, pcc=0.99)
21 changes: 6 additions & 15 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,6 @@ def is_function_call(node) -> bool:
)


def can_be_tilized(node):
size = node.meta["val"].size()
return len(size) >= 2 and size[-1] % 32 == 0 and size[-2] % 32 == 0


# For operations limitations
# See https://github.com/tenstorrent-metal/tt-metal/blob/main/ttnn/README.md?plain=1#L19
def is_tt_compute(node) -> bool:
Expand Down Expand Up @@ -292,8 +287,9 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
new_nodes = list()
with g.inserting_before(dst_node):
kwargs = {}
# TODO(#322): #322 will enable tile layout for more layout change ops
if (
(dst_node.target in TTNN_LAYOUT_CHANGE_OPS and not can_be_tilized(dst_node))
dst_node.target in TTNN_LAYOUT_CHANGE_OPS
or dst_node.target == ttnn.embedding
or dst_node.target == ttnn.zeros_like
or dst_node.target == target_wrappers.repeat
Expand All @@ -307,9 +303,8 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
else:
kwargs["dtype"] = TtnnBfloat16()

if (is_tt_compute(dst_node) and dst_node.target not in TTNN_LAYOUT_CHANGE_OPS) or (
dst_node.target in TTNN_LAYOUT_CHANGE_OPS and HasValidPageSize(src_node.meta["val"].size(), strict=True)
):
# TODO(#322): #322 will enable device tensor for more layout change ops
if is_tt_compute(dst_node) and dst_node.target not in TTNN_LAYOUT_CHANGE_OPS:
kwargs["device"] = device

new_nodes.append(g.call_function(ttnn.from_torch, (src_node,), kwargs))
Expand All @@ -324,12 +319,8 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node) -> torch.fx.n
return None
if not is_function_call(dst_node):
return None
if (
dst_node.target not in TTNN_LAYOUT_CHANGE_OPS
or dst_idx != 0
or not is_tt(src_node)
or (dst_node.target in TTNN_LAYOUT_CHANGE_OPS and can_be_tilized(dst_node))
):
# TODO(#322): #322 will enable tile layout for more layout change ops
if dst_node.target not in TTNN_LAYOUT_CHANGE_OPS or dst_idx != 0 or not is_tt(src_node):
return None

g = dst_node.graph
Expand Down
18 changes: 17 additions & 1 deletion torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,24 @@ def rewrite_node(node):
tensor, dim = args
input_shape = tensor.meta["val"].size()
rank = len(input_shape)
if rank > 4:
return None
dim = (dim + rank) % rank
return g.call_function(ttnn.moreh_cumsum, (tensor, dim), kwargs)
# Unsqueeze input tensor to 4D for cumsum
# TODO(#367): Special case if dim is inner-most 2 dim. Unsqueeze (x, y) to (x, y, 1, 1) as cumsum currently only support N and C
if (dim - rank) >= -2:
if rank <= 2:
input_4d_shape = (1,) * (2 - rank) + (*input_shape, 1, 1)
elif rank == 3 and dim == 1:
input_4d_shape = (*input_shape, 1)
else:
return None
else:
input_4d_shape = (1,) * (4 - rank) + input_shape
dim += 4 - rank
input_4d = g.call_function(ttnn.reshape, (tensor, input_4d_shape))
output_4d = g.call_function(ttnn.moreh_cumsum, (input_4d, dim), kwargs)
return g.call_function(ttnn.reshape, (output_4d, input_shape))

with g.inserting_before(node):
new_node = rewrite_node(node)
Expand Down

0 comments on commit b50cf63

Please sign in to comment.