diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index d2e3d8503..7e28d9f8d 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -418,21 +418,19 @@ def lower_binary_eltwise(fn, args): if node.target in TTNN_POINTWISE_UNARY_OPS: code = TTNN_POINTWISE_UNARY_OPS[node.target] - def unsqueeze_to_2d(code, args=args, kwargs=kwargs): - ndims = len(node.meta["val"].size()) + # NOTE(jdh8): Workaround for tenstorrent/tt-metal#12671 + # Passing a tensor shaped `(N,)` to the kernel results in `(1, N)`. + # Reshape the tensor back to get the correct shape. + def reshape_1d(code, args=args, kwargs=kwargs): + shape = node.meta["val"].size() result = g.call_function(code, args, kwargs) - - if ndims == 1: - result = g.call_function(ttnn.to_layout, (result, TtnnRowMajorLayout())) - result = g.call_function(ttnn.squeeze, (result, 0)) - - return result + return result if len(shape) > 1 else g.call_function(ttnn.reshape, (result, shape)) if node.target in TTNN_POINTWISE_UNARY_OPS: - return unsqueeze_to_2d(TTNN_POINTWISE_UNARY_OPS[node.target]) + return reshape_1d(TTNN_POINTWISE_UNARY_OPS[node.target]) if node.target == torch.ops.aten.round.default: - return unsqueeze_to_2d(ttnn.round, (args[0],), {"decimals": 0}) + return reshape_1d(ttnn.round, (args[0],), {"decimals": 0}) if node.target == torch.ops.aten.clone.default: arg_metadata = node.meta["val"]