From 47cd4b41005efeaafaae439394baca5707b3d5b8 Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Thu, 19 Sep 2024 20:59:43 +0000 Subject: [PATCH] Optimize out `aten.clone` ops Since we do not alter its layout, it looks good to skip `aten.clone` ops. --- tests/lowering/creation/test_clone.py | 7 +------ torch_ttnn/passes/lowering/add_data_move_pass.py | 3 +-- torch_ttnn/passes/lowering/target_wrappers.py | 5 ----- torch_ttnn/passes/lowering/to_tt_pass.py | 9 +++------ 4 files changed, 5 insertions(+), 19 deletions(-) diff --git a/tests/lowering/creation/test_clone.py b/tests/lowering/creation/test_clone.py index 30caadd54..33fb72703 100644 --- a/tests/lowering/creation/test_clone.py +++ b/tests/lowering/creation/test_clone.py @@ -38,9 +38,6 @@ def test_clone_from_arg(device, input_shapes): result_after = m.forward(*inputs) option._out_fx_graphs[0].print_tabular() - # Check the graph has be rewritten and contain ttnn ops - nodes = list(option._out_fx_graphs[0].nodes) - assert [node.target for node in nodes].count(torch_ttnn.target_wrappers.clone) == 1 # Check inference result assert torch.allclose(result_before, result_after) @@ -63,8 +60,6 @@ def test_clone_from_node(device, input_shapes): # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) target = [node.target for node in nodes] - assert target.count(torch_ttnn.target_wrappers.clone) == 1 - clone_arg_0 = nodes[target.index(torch_ttnn.target_wrappers.clone)].args[0].target - assert isinstance(clone_arg_0, ttnn.decorators.FastOperation) or isinstance(clone_arg_0, ttnn.decorators.Operation) + assert target.count("call_function") == 0 # Check inference result assert torch.allclose(result_before, result_after) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 7add48c36..0cd70d174 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -33,7 +33,6 @@ def is_function_call(node) -> bool: ttnn.atan, ttnn.atan2, # binary ttnn.atanh, - # ttnn.clone, in target_wrappers ttnn.cos, ttnn.cosh, ttnn.erf, @@ -124,7 +123,7 @@ def is_function_call(node) -> bool: ttnn.to_layout, ] -TTNN_TARGET_WRAPPERS = [target_wrappers.clone, target_wrappers.repeat] +TTNN_TARGET_WRAPPERS = [target_wrappers.repeat] TTNN_NORM_OPS = [ ttnn.group_norm, diff --git a/torch_ttnn/passes/lowering/target_wrappers.py b/torch_ttnn/passes/lowering/target_wrappers.py index 77f26921b..581562dc7 100644 --- a/torch_ttnn/passes/lowering/target_wrappers.py +++ b/torch_ttnn/passes/lowering/target_wrappers.py @@ -2,11 +2,6 @@ import torch -@torch.fx.wrap -def clone(t): - return ttnn.clone(t, memory_config=t.memory_config(), dtype=t.dtype) - - @torch.fx.wrap def repeat(t, sizes): return ttnn.repeat(t, ttnn.Shape(sizes)) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index b0bb2c8f2..d9b4e247b 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -292,6 +292,9 @@ def call_function(self, target, args, kwargs): ############################################################ # Data movement ############################################################ + if target == torch.ops.aten.clone.default: + return args[0] + if target == torch.ops.aten.permute.default: return self.call_function_prop_meta(ttnn.permute, args, kwargs) @@ -364,12 +367,6 @@ def rewrite_node(node): args = node.args kwargs = node.kwargs - if node.target == torch.ops.aten.clone.default: - arg_metadata = node.meta["val"] - ttnn_dtype = torch_dtype_to_ttnn_dtype(arg_metadata.dtype) - # Add additional logic to choose the appropriate memory_config type: DRAM or L1 - return g.call_function(target_wrappers.clone, args=(args[0],)) - if node.target == torch.ops.aten.native_layer_norm.default: new_node = g.call_function( ttnn.layer_norm,