diff --git a/tests/lowering/creation/test_clone.py b/tests/lowering/creation/test_clone.py index 30caadd54..33fb72703 100644 --- a/tests/lowering/creation/test_clone.py +++ b/tests/lowering/creation/test_clone.py @@ -38,9 +38,6 @@ def test_clone_from_arg(device, input_shapes): result_after = m.forward(*inputs) option._out_fx_graphs[0].print_tabular() - # Check the graph has be rewritten and contain ttnn ops - nodes = list(option._out_fx_graphs[0].nodes) - assert [node.target for node in nodes].count(torch_ttnn.target_wrappers.clone) == 1 # Check inference result assert torch.allclose(result_before, result_after) @@ -63,8 +60,6 @@ def test_clone_from_node(device, input_shapes): # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) target = [node.target for node in nodes] - assert target.count(torch_ttnn.target_wrappers.clone) == 1 - clone_arg_0 = nodes[target.index(torch_ttnn.target_wrappers.clone)].args[0].target - assert isinstance(clone_arg_0, ttnn.decorators.FastOperation) or isinstance(clone_arg_0, ttnn.decorators.Operation) + assert target.count("call_function") == 0 # Check inference result assert torch.allclose(result_before, result_after) diff --git a/tests/lowering/creation/test_to_copy.py b/tests/lowering/creation/test_to_copy.py index 29546e679..53cef0832 100644 --- a/tests/lowering/creation/test_to_copy.py +++ b/tests/lowering/creation/test_to_copy.py @@ -18,8 +18,8 @@ class ToCopyWithOpAfterModule(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x): - to = x.to(torch.bfloat16) + def forward(self, x, dtype): + to = x.to(dtype) return torch.add(to, to) @@ -52,23 +52,29 @@ def test_to_copy(device, input_shapes): # If there is a ttnn.from_torch that follows aten._to_copy and is casting to bfloat, then convert. @pytest.mark.parametrize( "input_shapes", - [[(4, 4)]], + [(4, 4)], ) -def test_to_copy_with_op_after(device, input_shapes): +@pytest.mark.parametrize("dtype", ((torch.bfloat16), (torch.int64))) +def test_to_copy_with_op_after(device, input_shapes, dtype): m = ToCopyWithOpAfterModule() - inputs = [torch.rand(shape) for shape in input_shapes] - result_before = m.forward(*inputs) + inputs = torch.rand(input_shapes) + result_before = m.forward(inputs, dtype) option = torch_ttnn.TorchTtnnOption(device=device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation m = torch.compile(m, backend=torch_ttnn.backend, options=option) - result_after = m.forward(*inputs) + result_after = m.forward(inputs, dtype) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) target = [node.target for node in nodes] - assert target.count(torch.ops.aten._to_copy.default) == 0 + # try_add_data_move_out: ttnn.to_torch will be followed by a to_copy + if dtype == torch.bfloat16: + count = 0 + else: + count = 2 + assert target.count(torch.ops.aten._to_copy.default) == count assert target.count(ttnn.add) == 1 # Check inference result assert torch.allclose(result_before, result_after, rtol=0.2) @@ -78,9 +84,9 @@ class ToCopyViewModule(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x, y, target_shape): + def forward(self, x, y, target_shape, dtype): view = torch.ops.aten.view.default(x, target_shape) - _to_copy = torch.ops.aten._to_copy.default(view, dtype=torch.bfloat16) + _to_copy = torch.ops.aten._to_copy.default(view, dtype=dtype) abs = torch.abs(y) return torch.add(_to_copy, abs) @@ -92,9 +98,9 @@ class ToCopyExpand(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x, y, target_shape): + def forward(self, x, y, target_shape, dtype): expand = torch.ops.aten.expand.default(x, target_shape) - _to_copy = torch.ops.aten._to_copy.default(expand, dtype=torch.bfloat16) + _to_copy = torch.ops.aten._to_copy.default(expand, dtype=dtype) abs = torch.abs(y) return torch.add(_to_copy, abs) @@ -109,23 +115,29 @@ def input_shapes(self): (ToCopyExpand(), torch_ttnn.target_wrappers.repeat), ], ) -def test_reshape_test1(device, module, ttnn_op): +@pytest.mark.parametrize("dtype", ((torch.bfloat16), (torch.int64))) +def test_reshape_test1(device, module, ttnn_op, dtype): m = module input_shape1, input_shape2, target_shape = m.input_shapes() x = torch.rand(input_shape1, dtype=torch.bfloat16) y = torch.rand(input_shape2, dtype=torch.bfloat16) - result_before = m.forward(x, y, target_shape) + result_before = m.forward(x, y, target_shape, dtype) option = torch_ttnn.TorchTtnnOption(device=device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation m = torch.compile(m, backend=torch_ttnn.backend, options=option) - result_after = m.forward(x, y, target_shape) + result_after = m.forward(x, y, target_shape, dtype) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) target = [node.target for node in nodes] - assert target.count(torch.ops.aten._to_copy.default) == 0 + # try_add_data_move_out: ttnn.to_torch will be followed by a to_copy + if dtype == torch.bfloat16: + count = 0 + else: + count = 2 + assert target.count(torch.ops.aten._to_copy.default) == count assert [node.target for node in nodes].count(ttnn_op) == 1 # Check inference result assert_with_pcc(result_before, result_after, 0.99) diff --git a/tests/lowering/eltwise/binary/test_div.py b/tests/lowering/eltwise/binary/test_div.py index d6fd32629..6521a9b4a 100644 --- a/tests/lowering/eltwise/binary/test_div.py +++ b/tests/lowering/eltwise/binary/test_div.py @@ -13,19 +13,21 @@ def forward(self, numerator, denominator): return torch.div(numerator, denominator) +# ttnn.div does not support broadcasting some combination of shapes. Fallback to reciprocal and multiply. @pytest.mark.parametrize( - "input_shapes", + "input_shapes, use_ttnn_div", ( - ((32, 32), (32, 32)), - ((64,), (32, 64)), - ((64, 32), (64, 1)), + (((32, 32), (32, 32)), True), + (((64,), (32, 64)), False), + (((64, 32), (64, 1)), False), pytest.param( ((64, 1), (1, 64)), + False, marks=pytest.mark.xfail(reason="broadcasting issues (#64)"), ), ), ) -def test_div(device, input_shapes): +def test_div(device, input_shapes, use_ttnn_div): m = DivModule() inputs = [torch.randint(1, 15, shape).to(torch.bfloat16) for shape in input_shapes] result_before = m.forward(*inputs) @@ -39,10 +41,13 @@ def test_div(device, input_shapes): # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) target = [node.target for node in nodes] - assert target.count(ttnn.reciprocal) == 1 - assert target.count(ttnn.mul) == 1 - assert target.index(ttnn.reciprocal) < target.index(ttnn.mul) - assert nodes[target.index(ttnn.mul)].args[1].target == ttnn.reciprocal + if use_ttnn_div: + assert target.count(ttnn.div) == 1 + else: + assert target.count(ttnn.reciprocal) == 1 + assert target.count(ttnn.mul) == 1 + assert target.index(ttnn.reciprocal) < target.index(ttnn.mul) + assert nodes[target.index(ttnn.mul)].args[1].target == ttnn.reciprocal # Check inference result assert_with_pcc(result_before, result_after) @@ -50,7 +55,7 @@ def test_div(device, input_shapes): @pytest.mark.parametrize( "input_shapes", - [[(4, 4)], [(32, 32)]], + [[(4, 4)], [(32, 32)], [(1, 197, 1024)]], ) def test_div_scalar_denom(device, input_shapes): m = DivModule() @@ -66,15 +71,5 @@ def test_div_scalar_denom(device, input_shapes): # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) target = [node.target for node in nodes] - assert target.count(ttnn.full) == 1 - assert target.count(ttnn.reciprocal) == 1 - assert target.count(ttnn.mul) == 1 - assert target.index(ttnn.full) < target.index(ttnn.reciprocal) - assert target.index(ttnn.reciprocal) < target.index(ttnn.mul) - assert nodes[target.index(ttnn.mul)].args[1].target == ttnn.reciprocal - # Intermediate node meta check if preserved - for node in nodes: - if node.target == ttnn.full or node.target == ttnn.reciprocal: - assert node.meta["val"].size() == input_shapes[0] - # Check inference result + assert target.count(ttnn.div) == 1 assert_with_pcc(result_before, result_after) diff --git a/tests/lowering/eltwise/binary/test_sub.py b/tests/lowering/eltwise/binary/test_sub.py index d7c1e1672..9a6ead808 100644 --- a/tests/lowering/eltwise/binary/test_sub.py +++ b/tests/lowering/eltwise/binary/test_sub.py @@ -118,12 +118,8 @@ def test_rsub_scalar(device, input_shapes): # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) target = [node.target for node in nodes] - assert target.count(ttnn.full) == 1 - assert target.count(ttnn.sub) == 1 - assert target.index(ttnn.full) < target.index(ttnn.sub) - # Intermediate node meta check if preserved - for node in nodes: - if node.target == ttnn.full: - assert node.meta["val"].size() == input_shapes[0] + assert target.count(ttnn.neg) == 1 + assert target.count(ttnn.add) == 1 + assert target.index(ttnn.neg) < target.index(ttnn.add) # Check inference result assert_with_pcc(result_before, result_after, 0.998) diff --git a/tests/lowering/eltwise/unary/test_remainder.py b/tests/lowering/eltwise/unary/test_remainder.py index 69d383475..97c22d486 100644 --- a/tests/lowering/eltwise/unary/test_remainder.py +++ b/tests/lowering/eltwise/unary/test_remainder.py @@ -12,6 +12,7 @@ def forward(self, input, mod): return input % mod +@pytest.mark.skip_platform("grayskull") @pytest.mark.parametrize( "input_shape, mod, converted", [ diff --git a/tests/lowering/pool/test_avg_pool_2d.py b/tests/lowering/pool/test_avg_pool_2d.py index 02e739cbb..39c8db27a 100644 --- a/tests/lowering/pool/test_avg_pool_2d.py +++ b/tests/lowering/pool/test_avg_pool_2d.py @@ -14,6 +14,7 @@ def forward(self, input): return torch._adaptive_avg_pool2d(input, (1, 1)) +@pytest.mark.skip() @pytest.mark.parametrize( "input_shapes", [(1, 2048, 7, 7)], diff --git a/tests/lowering/tensor_manipulation/test_slice.py b/tests/lowering/tensor_manipulation/test_slice.py index d5aa0b16d..b75960466 100644 --- a/tests/lowering/tensor_manipulation/test_slice.py +++ b/tests/lowering/tensor_manipulation/test_slice.py @@ -103,6 +103,10 @@ def forward(self, input, dim, start, end): ((1, 4251, 192), 0, 0, END_MAX), ((1, 4251, 192), 1, -100, END_MAX), ((1, 4251, 192), 1, 1, -100), + # Hardnet (train) + ((1, 782, 7, 7), 1, 0, 160), + # Clip + ((1, 77), 1, 0, 7), ), ) def test_aten_slice(device, input_shape, dim, start, end, module): diff --git a/tests/lowering/tensor_manipulation/test_squeeze.py b/tests/lowering/tensor_manipulation/test_squeeze.py index ce720b619..e1e5d3a2b 100644 --- a/tests/lowering/tensor_manipulation/test_squeeze.py +++ b/tests/lowering/tensor_manipulation/test_squeeze.py @@ -17,8 +17,8 @@ def forward(self, input, dim): [ ((1, 32, 16), 0), ((1, 256, 1), -1), - ((33, 44, 1, 32, 16), 1), - ((33, 44, 1, 32, 16), 2), + pytest.param((33, 44, 1, 32, 16), 1, marks=pytest.mark.xfail(reason="Cannot reshape from 5D to 4D.")), + pytest.param((33, 44, 1, 32, 16), 2, marks=pytest.mark.xfail(reason="Cannot reshape from 5D to 4D.")), ], ) def test_squeeze_dim(device, input_shape, dim): @@ -53,10 +53,10 @@ def forward(self, input): @pytest.mark.parametrize( "input_shape", [ - ((64, 1, 32, 16, 1, 32, 32)), - ((1, 1, 55, 23, 44, 32, 32)), - ((22, 1, 55, 23, 44, 32, 1)), - ((1, 1, 55, 1, 1, 1, 1)), + pytest.param((64, 1, 32, 16, 1, 32, 32), marks=pytest.mark.xfail(reason="Does not support TILE_LAYOUT.")), + pytest.param((1, 1, 55, 23, 44, 32, 32), marks=pytest.mark.xfail(reason="Does not support TILE_LAYOUT.")), + pytest.param((22, 1, 55, 23, 44, 32, 1), marks=pytest.mark.xfail(reason="Does not support TILE_LAYOUT.")), + pytest.param((1, 1, 55, 1, 1, 1, 1), marks=pytest.mark.xfail(reason="Does not support TILE_LAYOUT.")), ], ) def test_squeeze_none_dim(device, input_shape): diff --git a/tests/lowering/tensor_manipulation/test_transpose.py b/tests/lowering/tensor_manipulation/test_transpose.py index fe60e51b0..b33f9bc41 100644 --- a/tests/lowering/tensor_manipulation/test_transpose.py +++ b/tests/lowering/tensor_manipulation/test_transpose.py @@ -20,6 +20,7 @@ def forward(self, x, dim0, dim1): # If not, this runtime error will be thrown: # RuntimeError: TT_FATAL @ ../tt_metal/impl/buffers/buffer.cpp:41: page_size % sizeof(uint32_t) == 0 ((5, 3, 2), 0, 2), + ((1, 4150, 192), 1, 2), ((5, 3, 1), 0, 2), ((5, 3, 1), 1, 2), ((5, 3, 1), 0, 1), diff --git a/tests/lowering/tensor_manipulation/test_unsqueeze.py b/tests/lowering/tensor_manipulation/test_unsqueeze.py index cf825c4dd..903c70f10 100644 --- a/tests/lowering/tensor_manipulation/test_unsqueeze.py +++ b/tests/lowering/tensor_manipulation/test_unsqueeze.py @@ -17,7 +17,18 @@ def forward(self, x, y): @pytest.mark.parametrize( "input_shape, dim", - [((5, 2, 4, 3), 1)], + [ + pytest.param( + (5, 2, 4, 3), + 1, + marks=pytest.mark.xfail(reason="Fails if output is > 4D, using TILE_LAYOUT, and W dim is >= 32."), + ), + pytest.param( + (50, 1, 3, 1024), + 0, + marks=pytest.mark.xfail(reason="Fails if output is > 4D, using TILE_LAYOUT, and W dim is >= 32."), + ), + ], ) def test_unsqueeze1(device, input_shape, dim): mod = UnsqueezeModule() @@ -64,7 +75,13 @@ def test_unsqueeze2(device, input_shape, dim): @pytest.mark.parametrize( "input_shape, dim", - [((5, 2, 4, 3), -2)], + [ + pytest.param( + (5, 2, 4, 3), + -2, + marks=pytest.mark.xfail(reason="Fails if output is > 4D, using TILE_LAYOUT, and W dim is >= 32."), + ) + ], ) def test_unsqueeze3(device, input_shape, dim): mod = UnsqueezeModule() diff --git a/tests/lowering/tensor_manipulation/test_view.py b/tests/lowering/tensor_manipulation/test_view.py index 67725c765..eeb6939e3 100644 --- a/tests/lowering/tensor_manipulation/test_view.py +++ b/tests/lowering/tensor_manipulation/test_view.py @@ -80,6 +80,32 @@ def forward(self, x, new_shape): ((256, 4096), (1, 256, 4096)), ((1, 32, 16, 96), (1, 32, 1536)), ((1, 192, 4150), (1, 192, 50, 83)), + ((1, 100, 192), (100, 192)), + ((1, 1445, 192), (1, 1445, 3, 64)), + ((1, 1445, 192), (1445, 192)), + ((1, 1445, 3, 64), (1, 1445, 192)), + ((1, 1445, 768), (1445, 768)), + ((1, 192, 32, 42), (1, 192, 1344)), + ((1, 192, 4150), (1, 192, 50, 83)), + ((1, 3, 1445, 1445), (3, 1445, 1445)), + ((1, 3, 1445, 64), (3, 1445, 64)), + ((1, 3, 64, 1445), (3, 64, 1445)), + ((100, 192), (1, 100, 192)), + ((100, 4), (1, 100, 4)), + ((100, 92), (1, 100, 92)), + ((1445, 192), (1, 1445, 192)), + ((1445, 768), (1, 1445, 768)), + ((192), (1, 192, 1, 1)), + ((1), (1, 1, 1, 1)), + ((3, 1445, 1445), (1, 3, 1445, 1445)), + ((3, 1445, 64), (1, 3, 1445, 64)), + ((32), (1, 1, 32, 1)), + ((42), (1, 1, 1, 42)), + pytest.param( + (1, 10), + (10,), + marks=pytest.mark.xfail(reason="Does not support TILE_LAYOUT."), + ), ], ) def test_reshape(device, input_shape, new_shape, module): diff --git a/tools/collect_metrics.py b/tools/collect_metrics.py index f0e369819..89b396088 100644 --- a/tools/collect_metrics.py +++ b/tools/collect_metrics.py @@ -221,8 +221,6 @@ def _join_br(str_list: list): opname = op["opname"] inputs = _join_br(op["inputs"]) self[opname][inputs] - if opname == "aten.cat.default": - print(op) # If exist, map converted ops to the original op if compiled_schema_metrics: # Hold ops that require revisiting the original dict to determine the status @@ -232,8 +230,6 @@ def _join_br(str_list: list): opname = op["opname"] original_opname = op["original_inputs"]["opname"] original_inputs = _join_br(op["original_inputs"]["inputs"]) - if opname == "aten.cat.default": - print(op) # NOTE(kevinwuTT): Some ttnn ops are wrapped, so they have no `ttnn` prefix. Should this be more strict? if opname != original_opname: # Some aten ops are converted to other aten ops diff --git a/torch_ttnn/handle_input_aliasing.py b/torch_ttnn/handle_input_aliasing.py index 8cb87c2f4..29ae78abf 100644 --- a/torch_ttnn/handle_input_aliasing.py +++ b/torch_ttnn/handle_input_aliasing.py @@ -1,6 +1,6 @@ import torch from typing import List -from torch_ttnn.utils import GraphCleanup +from torch_ttnn.utils import graph_cleanup """ AOT Autograd has an optimization where if it determines that the storage of the @@ -25,22 +25,23 @@ # torch.fx defines a placeholder node as a function input -def get_input_nodes(gm: torch.fx.GraphModule) -> List[torch.fx.Node]: - input_nodes = [node for node in gm.graph.nodes if (node.op == "placeholder")] +def get_nodes_with_op(gm: torch.fx.GraphModule, op: str) -> List[torch.fx.Node]: + input_nodes = [node for node in gm.graph.nodes if (node.op == op)] return input_nodes # Insert aten.clone nodes after every input to prevent input aliasing def insert_clones_for_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - input_nodes = get_input_nodes(gm) + placeholders = get_nodes_with_op(gm, "placeholder") modified = False - for node in input_nodes: - """TODO(kevinwuTT): This does not work if inserting right after the node itself. - Only works if inserting after all of the input_nodes. - TypeError: forward() missing `n` required positional arguments - Somehow the argument list will get truncated. + + for node in placeholders: + """NOTE: Torch assumes placeholder nodes are laid out consecutively at this stage. + If we insert nodes in between, the list of input arguments will be truncated. + We will get this error: `TypeError: forward() missing `n` required positional arguments`. + Workaround is to insert nodes after the last placeholder node. """ - with gm.graph.inserting_after(input_nodes[-1]): + with gm.graph.inserting_after(placeholders[-1]): clone_node = gm.graph.call_function(torch.ops.aten.clone.default, args=(node,)) node.replace_all_uses_with( clone_node, @@ -48,8 +49,19 @@ def insert_clones_for_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.Graph ) modified = True + get_attrs = get_nodes_with_op(gm, "get_attr") + for node in get_attrs: + example_value = node.meta.get("example_value") + if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): + with gm.graph.inserting_after(node): + clone_node = gm.graph.call_function(torch.ops.aten.clone.default, args=(node,)) + node.replace_all_uses_with( + clone_node, + delete_user_cb=lambda node: node != clone_node, + ) + modified = True if modified: - gm = GraphCleanup(gm) + gm = graph_cleanup(gm) return gm @@ -76,6 +88,6 @@ def remove_clones_for_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.Graph modified = True if modified: - gm = GraphCleanup(gm) + gm = graph_cleanup(gm) return gm diff --git a/torch_ttnn/metrics.py b/torch_ttnn/metrics.py index 253a06a91..3224e6438 100644 --- a/torch_ttnn/metrics.py +++ b/torch_ttnn/metrics.py @@ -125,6 +125,8 @@ def collect_input_variation(target, args, kwargs): inputs.append(Inputs(dtype, name, shape, value)) return InputVariation(target, inputs) + elif hasattr(target, "python_fully_qualified_name"): + return InputVariation(target.python_fully_qualified_name, []) else: return None diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 0050ff3a3..3edbbb1f2 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -6,7 +6,7 @@ TtnnDevice, TtnnBfloat16, TtnnUint32, - HasValidPageSize, + get_shape, ) @@ -73,6 +73,7 @@ def is_function_call(node) -> bool: TTNN_POINTWISE_BINARY_OPS = [ ttnn.add, + ttnn.div, ttnn.eqz, ttnn.gez, ttnn.ge, @@ -146,8 +147,8 @@ def is_function_call(node) -> bool: TTNN_LAYOUT_CHANGE_OPS = set( [ - ttnn.reshape, ttnn.slice, + ttnn.full, ] ) @@ -291,8 +292,8 @@ def is_target_a_user_of_curr_node(curr_node, target): if curr_node.target == target: return True - # Only trace certain nodes that support different layouts - if curr_node.target not in TTNN_LAYOUT_CHANGE_OPS: + # Only trace certain nodes that support different layouts, including reshape + if curr_node.target not in TTNN_LAYOUT_CHANGE_OPS.union(set([ttnn.reshape])): return False for user in list(curr_node.users.keys()): @@ -328,7 +329,13 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N g = dst_node.graph new_nodes = list() with g.inserting_before(dst_node): - kwargs = {"layout": TtnnTileLayout(), "device": device} + kwargs = {} + if dst_node.target == ttnn.embedding: + kwargs["layout"] = TtnnRowMajorLayout() + else: + kwargs["layout"] = TtnnTileLayout() + + kwargs["device"] = device if is_target_a_user_of_curr_node(dst_node, ttnn.embedding) and dst_idx == 0: kwargs["dtype"] = TtnnUint32() @@ -351,7 +358,6 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node, device) -> to need_from_device = False need_to_layout = False need_to_device = False - # TODO(#372): #322 will enable tile layout for more layout change ops if dst_node.target in TTNN_LAYOUT_CHANGE_OPS and dst_idx == 0 and is_tt(src_node): need_from_device = True need_to_layout = True @@ -366,6 +372,7 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node, device) -> to return None g = dst_node.graph + new_nodes = [] with g.inserting_before(dst_node): new_nodes = [src_node] if need_from_device: @@ -384,23 +391,32 @@ def try_add_layout_change_after_node(src_node, dst_idx, dst_node, device) -> tor # Consider src_node is ttnn.repeat, and dst_node should be any tt_compute node that uses ttnn.repeat if not is_function_call(src_node): return None - if ( - src_node.target not in TTNN_LAYOUT_CHANGE_OPS.union(set([target_wrappers.repeat])) - or not is_tt_compute(dst_node) - or dst_node.target == ttnn.embedding - ): + + need_from_device = False + need_to_layout = False + need_to_device = False + if src_node.target in TTNN_LAYOUT_CHANGE_OPS and is_tt(dst_node): + need_to_device = True + need_to_layout = True + + # These nodes use ROW_MAJOR_LAYOUT to create tensors + if src_node.target in [ttnn.ones, target_wrappers.repeat]: + need_to_layout = True + + if not any((need_from_device, need_to_layout, need_to_device)): return None g = dst_node.graph new_nodes = [] with g.inserting_before(dst_node): - if dst_node.target != target_wrappers.repeat: - new_nodes.append( - g.call_function(ttnn.to_layout, (new_nodes[-1] if new_nodes else src_node, TtnnTileLayout())) - ) - new_nodes.append( - g.call_function(ttnn.to_device, (new_nodes[-1] if new_nodes else src_node,), {"device": device}) - ) + new_nodes = [src_node] + if need_from_device: + new_nodes.append(g.call_function(ttnn.from_device, (new_nodes[-1],))) + if need_to_layout: + new_nodes.append(g.call_function(ttnn.to_layout, (new_nodes[-1], TtnnTileLayout()))) + if need_to_device: + new_nodes.append(g.call_function(ttnn.to_device, (new_nodes[-1],), {"device": device})) + new_nodes = new_nodes[1:] insert_node_between(src_node, dst_idx, dst_node, new_nodes) diff --git a/torch_ttnn/passes/lowering/to_tt_guard.py b/torch_ttnn/passes/lowering/to_tt_guard.py index 06e5a7526..ff306e4ef 100644 --- a/torch_ttnn/passes/lowering/to_tt_guard.py +++ b/torch_ttnn/passes/lowering/to_tt_guard.py @@ -214,6 +214,30 @@ aten_mul_Tensor_blocklist += [["Tensor<[1, 1]> self = ?", "Tensor other = 50258"]] +# albert inputs that do not support TILE_LAYOUT +aten_view_default_blocklist += [ + # albert-base-v2-eval, albert-large-v2-eval + ["Tensor<[9, 30000]> self = ?", "List[int] size = [1, 9, 30000]"], + # albert-xlarge-v2-eval + ["Tensor<[9, 8192]> self = ?", "List[int] size = [1, 9, 8192]"], + ["Tensor<[1, 9, 8192]> self = ?", "List[int] size = [9, 8192]"], + # albert-xxlarge-v2-eval + ["Tensor<[9, 16384]> self = ?", "List[int] size = [1, 9, 16384]"], + ["Tensor<[1, 9, 16384]> self = ?", "List[int] size = [9, 16384]"], +] + +# Falcon inputs that do not support TILE_LAYOUT +aten_view_default_blocklist += [ + ["Tensor<[7, 18176]> self = ?", "List[int] size = [1, 7, 18176]"], + ["Tensor<[1, 7, 18176]> self = ?", "List[int] size = [7, 18176]"], +] + +# XGLM inputs that do not support TILE_LAYOUT +aten_view_default_blocklist += [ + ["Tensor<[19, 256008]> self = ?", "List[int] size = [1, 19, 256008]"], + ["Tensor<[1, 19, 256008]> self = ?", "List[int] size = [-1, 256008]"], +] + ############################################################ # EXTRA BLOCKLIST OF GPTNeo diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 906f0cc9b..2934e8959 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -2,12 +2,13 @@ import ttnn import math from torch_ttnn.utils import ( - GraphCleanup, - HasValidPageSize, + graph_cleanup, TtnnBfloat16, TtnnDevice, TtnnL1MemoryConfig, TtnnRowMajorLayout, + has_valid_page_size, + get_shape, TtnnTileLayout, ) import numpy as np @@ -138,6 +139,9 @@ def __init__(self, target, args, kwargs): pseudo_node = PseudoNode(target, args, kwargs) if not can_lowering_to_ttnn(pseudo_node): + # Fallback: aten.reshape is more stable if the input nodes have changed + if target == torch.ops.aten.view.default or target == torch.ops.aten._unsafe_view.default: + target = torch.ops.aten.reshape.default return self.call_function_prop_meta(target, args, kwargs) if are_args_from_int_output_ops(args) or is_target_incompatible_with_grayskull(target, self.device): @@ -394,6 +398,9 @@ def is_zero_dim(meta): # assumes output size is (1, 1) return self.call_function_prop_meta(ttnn.global_avg_pool2d, (args[0],), kwargs) + if target == torch.ops.aten.clone.default: + return args[0] + return self.call_function_prop_meta(target, args, kwargs) @@ -439,15 +446,6 @@ def rewrite_node(node): args = node.args kwargs = node.kwargs - if node.target == torch.ops.aten.clone.default: - arg_metadata = node.meta["val"] - try: - ttnn_dtype = torch_dtype_to_ttnn_dtype(arg_metadata.dtype) - except: - return None - # Add additional logic to choose the appropriate memory_config type: DRAM or L1 - return g.call_function(target_wrappers.clone, args=(args[0],)) - if node.target == torch.ops.aten.native_layer_norm.default: new_node = g.call_function( ttnn.layer_norm, @@ -490,13 +488,13 @@ def rewrite_node(node): # Instead, fill a tensor with the same size as args[0] with the scalar value using ttnn.full # NOTE(jdh8): after broadcasting support is complete, we should fill a (1,) tensor arg_metadata = node.meta["val"] - if HasValidPageSize(arg_metadata.size(), strict=True): + if has_valid_page_size(arg_metadata.size(), strict=True): new_kwargs = { "fill_value": args[1], "device": TtnnDevice(), - "layout": TtnnTileLayout(), } full_node = g.call_function(ttnn.full, args=(arg_metadata.size(),), kwargs=new_kwargs) + full_node = g.call_function(ttnn.to_layout, (full_node, TtnnTileLayout()), {}) return g.call_function( relational_scalar_ops[node.target], args=(args[0], full_node), @@ -520,9 +518,9 @@ def rewrite_node(node): new_kwargs = { "fill_value": args[1], "device": TtnnDevice(), - "layout": TtnnTileLayout(), } - return g.call_function(ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs) + full = g.call_function(ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs) + return g.call_function(ttnn.to_layout, (full, TtnnTileLayout()), {}) # Replace op with scalar for eltwise ops # TODO: Generalize this to support all eltwise ops node_users = list(node.users.keys()) @@ -569,52 +567,17 @@ def rewrite_node(node): return g.call_function(ttnn.log, (softmax_node,), kwargs) if node.target == torch.ops.aten.rsub.Scalar: - # NOTE(kevinwuTT): ttnn.sub shows error if passing a literal scalar as the first argument. - # Instead, fill a tensor with the same size as args[0] with the scalar value using ttnn.full - node_metadata = node.meta["val"] - shape = node_metadata.size() - # If last dim == 1, then the follow error will appear: - # Page size must be divisible by sizeof(uint32_t) because buffers hold uint32_t values - if shape[-1] != 1 and HasValidPageSize(shape): - # NOTE(kevinwuTT): Only bfloat16 seems to work for now - # TODO(kevinwuTT): Use ttnn.full instead of aten - new_kwargs = { - "fill_value": args[1], - "device": TtnnDevice(), - } - full = g.call_function( - ttnn.full, - args=(tuple(shape),), - kwargs=new_kwargs, - ) - to_layout = g.call_function(ttnn.to_layout, (full,), {"layout": TtnnTileLayout()}) - return g.call_function(ttnn.sub, args=(to_layout, args[0]), kwargs={}) - return None + # aten.rsub(tensor, scalar) = sub(scalar, tensor) + # However, ttnn.sub does not support scalar as the first argument + # Instead: ttnn.add(ttnn.negate(tensor), scalar)) + ttnn_neg = g.call_function(ttnn.neg, (args[0],)) + return g.call_function(ttnn.add, (ttnn_neg, args[1])) if node.target == torch.ops.aten.div.Tensor: - # ttnn.recip does not support scalars. Call an ttnn.full and pass that to ttnn.recip - # TODO(kevinwuTT): Use a ttnn equivalent - node_metadata = node.meta["val"] - shape = node_metadata.size() - # If last dim == 1, then the follow error will appear: - # Page size must be divisible by sizeof(uint32_t) because buffers hold uint32_t values - if shape[-1] != 1 and HasValidPageSize(shape): - if isinstance(args[1], float): - new_kwargs = { - "fill_value": args[1], - "device": TtnnDevice(), - } - full = g.call_function( - ttnn.full, - args=(tuple(shape),), - kwargs=new_kwargs, - ) - to_layout = g.call_function(ttnn.to_layout, (full,), {"layout": TtnnTileLayout()}) - recip = g.call_function(ttnn.reciprocal, (to_layout,), {}) - else: - recip = g.call_function(ttnn.reciprocal, (args[1],), {}) + if not isinstance(args[1], float) and (get_shape(args[0]) != get_shape(args[1])): + recip = g.call_function(ttnn.reciprocal, (args[1],), {}) return g.call_function(ttnn.mul, (args[0], recip), {}) - return None + return g.call_function(ttnn.div, args, {}) if node.target == torch.ops.aten.expand.default: input_tensor_shape = args[0].meta["val"].size() @@ -691,11 +654,17 @@ def rewrite_node(node): return None if node.target == torch.ops.aten.squeeze.dim or node.target == torch.ops.aten.squeeze.default: + if len(get_shape(args[0])) > 4: + return None if use_less_ttnn_op_types or node.target == torch.ops.aten.squeeze.default: # ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default) # squeezing is the same as reshaping to shape of output tensor of squeeze output_size = list(node.meta["val"].size()) - return g.call_function(ttnn.reshape, args=(args[0], output_size)) + # FIXME: Reshape has issues with 1D outputs for TILE_LAYOUT + if len(output_size) > 1: + return g.call_function(ttnn.reshape, args=(args[0], output_size)) + else: + return None else: return g.call_function(ttnn.squeeze, args=(args[0], args[1])) @@ -704,6 +673,9 @@ def rewrite_node(node): if output_shape_num_element == 0: return args[0] output_size = list(node.meta["val"].size()) + # FIXME: ttnn.reshape does not support > 4D outputs with TILE_LAYOUT currently + if len(output_size) > 4: + return None return g.call_function(ttnn.reshape, args=(args[0], output_size)) if node.target in [torch.ops.aten.transpose.int, torch.ops.aten.t.default]: @@ -773,10 +745,18 @@ def rewrite_node(node): return g.call_function(ttnn.pad, args=(input, full_pad, value)) if node.target in [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default]: + # Skip if either dimensions is larger than 5 input_tensor_num_element = args[0].meta["val"].numel() output_shape_num_element = node.meta["val"].numel() - if input_tensor_num_element == 0 or output_shape_num_element == 0: - return None + # Skip if either dimensions is larger than 5 + if ( + input_tensor_num_element == 0 + or output_shape_num_element == 0 + or len(get_shape(args[0])) > 4 + or len(args[1]) > 4 + or len(args[1]) < 2 + ): + return g.call_function(torch.ops.aten.reshape.default, args, kwargs) return g.call_function(ttnn.reshape, (args[0], args[1]), {}) if node.target == torch.ops.aten.split.Tensor: @@ -799,15 +779,16 @@ def rewrite_node(node): return g.call_function(ttnn.split, args=new_args) if node.target == torch.ops.aten._to_copy.default: + src_dtype = node.args[0].meta["val"].dtype + dst_dtype = kwargs["dtype"] # Keep it if casting to bool type(bool may be problematic) - if kwargs["dtype"] in [torch.bool]: + if dst_dtype in [torch.bool]: return None - # Keep it if the graph output uses this op + + # Keep it if the graph output uses this op, unless it's bfloat target_users_ops = [user.target for user in node.users.keys()] if "output" in target_users_ops: return None - src_dtype = node.args[0].meta["val"].dtype - dst_dtype = kwargs["dtype"] # Some aten op need it to cast specific dtype (ex, index_select) # Keep it if casting from int to float or reverse @@ -819,8 +800,9 @@ def rewrite_node(node): if dst_dtype in [torch.int32, torch.int64] and src_dtype not in [torch.int32, torch.int64]: return None - if src_dtype in [torch.int32, torch.int64] and dst_dtype not in [torch.int32, torch.int64]: - return None + # if src_dtype in [torch.int32, torch.int64] and dst_dtype not in [torch.int32, torch.int64]: + # return None + # Essentially remove this op return node.args[0] @@ -839,12 +821,13 @@ def rewrite_node(node): multiplier = np_tensor_shp // np_mask_shp mask_bcst = g.call_function(target_wrappers.repeat, args=(mask, multiplier.tolist())) - kwargs = {"dtype": TtnnBfloat16(), "layout": TtnnTileLayout(), "device": TtnnDevice()} + kwargs = {"dtype": TtnnBfloat16(), "layout": TtnnRowMajorLayout(), "device": TtnnDevice()} ones = g.call_function(ttnn.ones, (tensor_shape,), kwargs) mask_flip = g.call_function(ttnn.subtract, (ones, mask_bcst)) tensor_masked = g.call_function(ttnn.multiply, (tensor, mask_flip)) full = g.call_function(ttnn.full, (tensor_shape, fill_value), kwargs) + full = g.call_function(ttnn.to_layout, (full, TtnnTileLayout()), {}) full_masked = g.call_function(ttnn.multiply, (mask_bcst, full)) masked_fill = g.call_function(ttnn.add, (tensor_masked, full_masked)) @@ -869,7 +852,10 @@ def rewrite_node(node): if len(output_size) == 0: return g.call_function(torch.ops.aten.squeeze.dim, args=(tensor, 0)) - return g.call_function(ttnn.reshape, args=(slice_tensor, list(output_size))) + if len(input_size) > 4 or len(output_size) > 4 or len(output_size) == 1: + return g.call_function(torch.ops.aten.reshape.default, args=(slice_tensor, list(output_size))) + else: + return g.call_function(ttnn.reshape, args=(slice_tensor, list(output_size))) if node.target == torch.ops.aten.cumsum.default: tensor, dim = args @@ -970,7 +956,7 @@ def rewrite_node(node): delete_user_cb=lambda node: node != new_node, ) - gm = GraphCleanup(gm) + gm = graph_cleanup(gm) return gm diff --git a/torch_ttnn/utils.py b/torch_ttnn/utils.py index 7a921ae6a..0e7d04ef0 100644 --- a/torch_ttnn/utils.py +++ b/torch_ttnn/utils.py @@ -1,7 +1,7 @@ import torch -def GraphCleanup(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def graph_cleanup(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() @@ -9,12 +9,23 @@ def GraphCleanup(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return gm +def get_shape(node_or_shape): + if isinstance(node_or_shape, torch.fx.node.Node): + if (val := node_or_shape.meta.get("val", None)) is not None: + return val.size() + elif isinstance(node_or_shape, torch.Size) or isinstance(node_or_shape, list): + return node_or_shape + + return None + + # Certain ops don't support certain shapes and will emit a valid_page_size error # RuntimeError: TT_FATAL @ ../tt_metal/impl/buffers/buffer.cpp:38: valid_page_size # For valid non-interleaved buffers page size 2048 must equal buffer size X. For interleaved-buffers page size should be divisible by buffer size -def HasValidPageSize(shape, strict=False): - if len(shape) >= 2 and shape[-1] > 0: - return shape[-1] % 32 == 0 or (not strict and shape[-1] < 32) +def has_valid_page_size(node_or_shape, strict=False): + if (shape := get_shape(node_or_shape)) is not None: + if len(shape) >= 2 and len(shape) <= 4 and shape[-1] > 0: + return shape[-1] % 32 == 0 or (not strict and shape[-1] < 32) return False