From e9874e6552a10ed1c369a9ffe0a56819e6bc7c8d Mon Sep 17 00:00:00 2001 From: Tapasvi Patel Date: Mon, 13 Jan 2025 18:13:06 +0000 Subject: [PATCH] Added TT_TORCH_ENABLE_IR_DUMP env var to save all intermediate mlir dialect dumps - torch.FX, stablehlo, ttir, ttnn --- README.md | 1 + tt_torch/dynamo/backend.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/README.md b/README.md index 1fd8f03d..6d3bd191 100644 --- a/README.md +++ b/README.md @@ -40,3 +40,4 @@ You can use the following environment variables to override default behaviour: | TT_TORCH_CONSTEVAL_PARAMETERS | Extends consteval to include parameters (e.g., model weights) as well as embedded constants. | False | | TT_TORCH_EMBEDDEDD_CONSTANTS | Remove embedded constants from the Torch FX graph and convert them to constant inputs | False | | TT_TORCH_ENABLE_IR_PRINTING | Enables printing MLIR for all conversion steps from StableHLO to TTNN. Be warned, this forces single core compile, so is much slower. | False | +| TT_TORCH_ENABLE_IR_DUMP | Enables dumping MLIR for all dialect conversion steps : [torch.FX, stablehlo, ttir, ttnn] | False | diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 6c43e78c..cbc12bef 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -470,6 +470,18 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): dump_intermediates = os.environ.get("TT_TORCH_ENABLE_IR_PRINTING") dump_intermediates = dump_intermediates and int(dump_intermediates) + save_intermediates = os.environ.get('TT_TORCH_ENABLE_IR_DUMP') + save_intermediates = save_intermediates and int(save_intermediates) + + if save_intermediates: + with open("gm_graph.txt", 'w') as file: + original_stdout = sys.stdout + sys.stdout = file + try: + gm.graph.print_tabular() + finally: + sys.stdout = original_stdout + module = import_graph(gm.graph) if dump_intermediates: module.dump() @@ -483,6 +495,10 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): if dump_intermediates: module.dump() + if save_intermediates: + with open("stablehlo.mlir", "w") as file: + print(module, file=file) + if compiler_config.profile_ops: compiler_config.set_stablehlo_mlir_module(module.operation.get_asm()) if compiler_config.compile_depth == CompileDepth.STABLEHLO: @@ -491,10 +507,19 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): ttir = tt_mlir.compile_stable_hlo_to_ttir(module.operation.get_asm()) if dump_intermediates: print(ttir) + + if save_intermediates: + with open("ttir.mlir", "w") as file: + print(ttir, file=file) + binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir) if dump_intermediates: print(ttnn) + if save_intermediates: + with open("ttnn.mlir", "w") as file: + print(ttnn, file=file) + executor.set_binary(binary) return executor