Skip to content

Commit

Permalink
Refine conversion to ttnn.repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
jdh8 committed Sep 13, 2024
1 parent d74ae89 commit c9e292a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
15 changes: 8 additions & 7 deletions tests/lowering/tensor_manipulation/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ def forward(self, x, sizes):
@pytest.mark.parametrize(
"input_shape, sizes",
(
pytest.param(
(4, 4),
(3, 2),
marks=pytest.mark.xfail(
reason="Current repeat implementation requires aligned last dim when repeating on last dim"
),
),
((1, 1, 1), (1, 1, 1)),
((1, 1, 2048, 2048), (1, 1, 1, 1)),
((1, 1, 256), (1, 1, 1)),
Expand All @@ -38,6 +31,14 @@ def forward(self, x, sizes):
((6, 2), (4, 1)),
((6, 2), (400, 1)),
((6, 2), (9, 1)),
pytest.param(
(4, 4),
(3, 2),
marks=pytest.mark.xfail(
reason="Current repeat implementation requires aligned last dim when repeating on last dim"
),
),
((5, 16), (2, 3)),
),
)
def test_repeat(device, input_shape, sizes):
Expand Down
8 changes: 4 additions & 4 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,11 +560,11 @@ def rewrite_node(node):
if np.prod(sizes) == 1:
return tensor

# Repeat fails if last dimension of input is 1
if shape[-1] == 1:
return None
# Current repeat implementation requires aligned last dim when repeating on last dim
if sizes[-1] == 1 or shape[-1] % 16 == 0:
return g.call_function(target_wrappers.repeat, args)

return g.call_function(target_wrappers.repeat, args=(tensor, sizes))
return None

if node.target == torch.ops.aten.unsqueeze.default:
output_size = node.meta["val"].size()
Expand Down

0 comments on commit c9e292a

Please sign in to comment.