Skip to content

Commit

Permalink
Fix unsqueeze
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwuTT committed Nov 11, 2024
1 parent 6979151 commit 02252cb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
9 changes: 8 additions & 1 deletion tests/lowering/tensor_manipulation/test_unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ def forward(self, x, y):

@pytest.mark.parametrize(
"input_shape, dim",
[((5, 2, 4, 3), 1)],
[
((5, 2, 4, 3), 1),
pytest.param(
(50, 1, 3, 1024),
0,
marks=pytest.mark.xfail(reason="Fails if ouput is > 4D, using TILE_LAYOUT, and W dim is >= 32."),
),
],
)
def test_unsqueeze1(device, input_shape, dim):
mod = UnsqueezeModule()
Expand Down
3 changes: 2 additions & 1 deletion torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ def rewrite_node(node):

output_size = list(node.meta["val"].size())

if output_size[-1] == input_size[-1] and len(output_size) <= 4:
# FIXME: Cannot reshape a 4D tensor if size[-1] >= 32.
if output_size[-1] == input_size[-1] and (input_size[-1] < 32):
return g.call_function(ttnn.reshape, args=(args[0], output_size))
return None

Expand Down

0 comments on commit 02252cb

Please sign in to comment.