diff --git a/tests/lowering/tensor_manipulation/test_unsqueeze.py b/tests/lowering/tensor_manipulation/test_unsqueeze.py index cf825c4dd..06387d4e9 100644 --- a/tests/lowering/tensor_manipulation/test_unsqueeze.py +++ b/tests/lowering/tensor_manipulation/test_unsqueeze.py @@ -17,7 +17,14 @@ def forward(self, x, y): @pytest.mark.parametrize( "input_shape, dim", - [((5, 2, 4, 3), 1)], + [ + ((5, 2, 4, 3), 1), + pytest.param( + (50, 1, 3, 1024), + 0, + marks=pytest.mark.xfail(reason="Fails if ouput is > 4D, using TILE_LAYOUT, and W dim is >= 32."), + ), + ], ) def test_unsqueeze1(device, input_shape, dim): mod = UnsqueezeModule() diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 3e3472828..dc8c76801 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -638,7 +638,8 @@ def rewrite_node(node): output_size = list(node.meta["val"].size()) - if output_size[-1] == input_size[-1] and len(output_size) <= 4: + # FIXME: Cannot reshape a 4D tensor if size[-1] >= 32. + if output_size[-1] == input_size[-1] and (input_size[-1] < 32): return g.call_function(ttnn.reshape, args=(args[0], output_size)) return None