diff --git a/tests/conftest.py b/tests/conftest.py index 692fe7e8..f75cda95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ @pytest.fixture(autouse=True) def run_around_tests(): + torch._dynamo.config.inline_inbuilt_nn_modules = False torch.manual_seed(0) yield torch._dynamo.reset() diff --git a/tt_torch/dynamo/passes.py b/tt_torch/dynamo/passes.py index 312513ea..3fc345f2 100644 --- a/tt_torch/dynamo/passes.py +++ b/tt_torch/dynamo/passes.py @@ -77,19 +77,35 @@ def apply_decompositions( if decompositions is None: return gm - gm = make_fx( - functionalize(gm), - decomposition_table=decompositions, - )(*example_inputs) + with torch.no_grad(): + decompositions = get_decompositions(decompose_ops) + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(example_inputs) + fake_tensor_mode.allow_non_fake_inputs = True + gm = make_fx( + gm, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + decomposition_table=decompositions, + )(*example_inputs) return gm +def bypass_redundant_getitem(gm): + for node in gm.graph.nodes: + if node.op == "call_function" and "getitem" in node.name: + if isinstance(node.args[0], tuple): + idx = node.args[1] + if isinstance(idx, int): + node.replace_all_uses_with(node.args[0][idx]) + return gm + + def constant_fold(gm, example_inputs): gm = const_fold.split_const_subgraphs(gm) + gm.run_folding() + graph_constants = [] - # run the module to generate the consteval constants - _ = gm(*example_inputs) for node in gm.graph.nodes: if node.op == "get_attr" and node.name == "_fx_const_folded_attrs": gm.graph.inserting_before(node) @@ -97,14 +113,20 @@ def constant_fold(gm, example_inputs): if isinstance(gm._FX_CONST_FOLDED_ATTRS, torch.Tensor): placeholder = gm.graph.placeholder(f"_fx_const_folded_attrs") node.replace_all_uses_with(placeholder) + if len(gm._FX_CONST_FOLDED_ATTRS.data.shape) == 0: + graph_constants = ( + gm._FX_CONST_FOLDED_ATTRS.data + ) = gm._FX_CONST_FOLDED_ATTRS.data.unsqueeze(0) graph_constants = [gm._FX_CONST_FOLDED_ATTRS.data] else: - for idx, (key, value) in enumerate(node.users.items()): + for idx, key in enumerate(node.users.keys()): placeholder = gm.graph.placeholder(f"_fx_const_folded_attrs_{idx}") key.replace_all_uses_with(placeholder) + for param in gm._FX_CONST_FOLDED_ATTRS: + if len(param.data.shape) == 0: + param.data = param.data.unsqueeze(0) graph_constants = [param.data for param in gm._FX_CONST_FOLDED_ATTRS] - gm.graph.eliminate_dead_code() return gm, graph_constants @@ -115,6 +137,7 @@ def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config): gm, graph_constants = constant_fold(gm, example_inputs) else: graph_constants = [] + gm = bypass_redundant_getitem(gm) reduce_graph(gm) run_shape_prop(gm, example_inputs + graph_constants) return gm, graph_constants