Skip to content

Commit

Permalink
Simplify the workaround for tenstorrent/tt-metal#12671
Browse files Browse the repository at this point in the history
  • Loading branch information
jdh8 committed Dec 19, 2024
1 parent 3a04e8c commit 09535cd
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,21 +418,19 @@ def lower_binary_eltwise(fn, args):
if node.target in TTNN_POINTWISE_UNARY_OPS:
code = TTNN_POINTWISE_UNARY_OPS[node.target]

def unsqueeze_to_2d(code, args=args, kwargs=kwargs):
ndims = len(node.meta["val"].size())
# NOTE(jdh8): Workaround for tenstorrent/tt-metal#12671
# Passing a tensor shaped `(N,)` to the kernel results in `(1, N)`.
# Reshape the tensor back to get the correct shape.
def reshape_1d(code, args=args, kwargs=kwargs):
shape = node.meta["val"].size()
result = g.call_function(code, args, kwargs)

if ndims == 1:
result = g.call_function(ttnn.to_layout, (result, TtnnRowMajorLayout()))
result = g.call_function(ttnn.squeeze, (result, 0))

return result
return result if len(shape) > 1 else g.call_function(ttnn.reshape, (result, shape))

if node.target in TTNN_POINTWISE_UNARY_OPS:
return unsqueeze_to_2d(TTNN_POINTWISE_UNARY_OPS[node.target])
return reshape_1d(TTNN_POINTWISE_UNARY_OPS[node.target])

if node.target == torch.ops.aten.round.default:
return unsqueeze_to_2d(ttnn.round, (args[0],), {"decimals": 0})
return reshape_1d(ttnn.round, (args[0],), {"decimals": 0})

if node.target == torch.ops.aten.clone.default:
arg_metadata = node.meta["val"]
Expand Down

0 comments on commit 09535cd

Please sign in to comment.