Skip to content

Commit

Permalink
Optimize out aten.clone ops
Browse files Browse the repository at this point in the history
Since we do not alter its layout, it looks good to skip `aten.clone` ops.
  • Loading branch information
jdh8 committed Sep 19, 2024
1 parent 9b717ea commit 47cd4b4
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 19 deletions.
7 changes: 1 addition & 6 deletions tests/lowering/creation/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def test_clone_from_arg(device, input_shapes):
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(torch_ttnn.target_wrappers.clone) == 1
# Check inference result
assert torch.allclose(result_before, result_after)

Expand All @@ -63,8 +60,6 @@ def test_clone_from_node(device, input_shapes):
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(torch_ttnn.target_wrappers.clone) == 1
clone_arg_0 = nodes[target.index(torch_ttnn.target_wrappers.clone)].args[0].target
assert isinstance(clone_arg_0, ttnn.decorators.FastOperation) or isinstance(clone_arg_0, ttnn.decorators.Operation)
assert target.count("call_function") == 0
# Check inference result
assert torch.allclose(result_before, result_after)
3 changes: 1 addition & 2 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def is_function_call(node) -> bool:
ttnn.atan,
ttnn.atan2, # binary
ttnn.atanh,
# ttnn.clone, in target_wrappers
ttnn.cos,
ttnn.cosh,
ttnn.erf,
Expand Down Expand Up @@ -124,7 +123,7 @@ def is_function_call(node) -> bool:
ttnn.to_layout,
]

TTNN_TARGET_WRAPPERS = [target_wrappers.clone, target_wrappers.repeat]
TTNN_TARGET_WRAPPERS = [target_wrappers.repeat]

TTNN_NORM_OPS = [
ttnn.group_norm,
Expand Down
5 changes: 0 additions & 5 deletions torch_ttnn/passes/lowering/target_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
import torch


@torch.fx.wrap
def clone(t):
return ttnn.clone(t, memory_config=t.memory_config(), dtype=t.dtype)


@torch.fx.wrap
def repeat(t, sizes):
return ttnn.repeat(t, ttnn.Shape(sizes))
9 changes: 3 additions & 6 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def call_function(self, target, args, kwargs):
############################################################
# Data movement
############################################################
if target == torch.ops.aten.clone.default:
return args[0]

if target == torch.ops.aten.permute.default:
return self.call_function_prop_meta(ttnn.permute, args, kwargs)

Expand Down Expand Up @@ -364,12 +367,6 @@ def rewrite_node(node):
args = node.args
kwargs = node.kwargs

if node.target == torch.ops.aten.clone.default:
arg_metadata = node.meta["val"]
ttnn_dtype = torch_dtype_to_ttnn_dtype(arg_metadata.dtype)
# Add additional logic to choose the appropriate memory_config type: DRAM or L1
return g.call_function(target_wrappers.clone, args=(args[0],))

if node.target == torch.ops.aten.native_layer_norm.default:
new_node = g.call_function(
ttnn.layer_norm,
Expand Down

0 comments on commit 47cd4b4

Please sign in to comment.