diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 8329bb5a..a0b16d53 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -599,7 +599,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): verify_ir(module) if dump_info: - print("Torch module", file=sys.stderr) + print("Torch FX module", file=sys.stderr) module.dump() if compiler_config.profile_ops: @@ -607,7 +607,19 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): if compiler_config.compile_depth == CompileDepth.TORCH_MLIR: return executor - lower_to_stable_hlo(module, enable_ir_printing=dump_debug) + run_pipeline_with_repro_report( + module, + f"builtin.module(torchdynamo-export-to-torch-backend-pipeline)", + "Lowering TorchFX IR -> Torch Backend IR", + dump_debug, + ) + + if dump_info: + print("Torch Backend module", file=sys.stderr) + module.dump() + + lower_mlir_module(False, OutputType.STABLEHLO, module) + if dump_info: print("StableHLO module", file=sys.stderr) module.dump() diff --git a/tt_torch/onnx_compile/onnx_compile.py b/tt_torch/onnx_compile/onnx_compile.py index da1024bb..953d86c5 100644 --- a/tt_torch/onnx_compile/onnx_compile.py +++ b/tt_torch/onnx_compile/onnx_compile.py @@ -6,6 +6,8 @@ import tt_mlir from torch_mlir.ir import Context from torch_mlir.dialects import torch as torch_dialect +import os +import sys from torch_mlir.compiler_utils import ( OutputType, @@ -25,10 +27,45 @@ def compile_onnx(module: onnx.ModelProto): imp = onnx_importer.NodeImporter.define_function(module_info.main_graph, module) imp.import_all() + dump_intermediates = os.environ.get("TT_TORCH_IR_LOG_LEVEL") + dump_info = False + dump_debug = False + if dump_intermediates: + dump_debug = dump_intermediates == "DEBUG" + dump_info = dump_debug or dump_intermediates == "INFO" + + # Setting large_elements_limit to 0 so the console does not get flooded with the data of large tensors + if dump_info: + print("ONNX module", file=sys.stderr) + module.print(large_elements_limit=0) + run_pipeline_with_repro_report( module, "builtin.module(torch-onnx-to-torch-backend-pipeline)", "Lowering Torch Onnx IR -> Torch Backend IR", ) + + if dump_info: + print("Torch Backend module", file=sys.stderr) + module.print(large_elements_limit=0) + lower_mlir_module(False, OutputType.STABLEHLO, module) - return tt_mlir.compile(module.operation.get_asm()) + + if dump_info: + print("StableHLO module", file=sys.stderr) + module.print(large_elements_limit=0) + + # Need to set enable_debug_info=True to get the location information for the ops in the asm string + ttir = tt_mlir.compile_stable_hlo_to_ttir( + module.operation.get_asm(enable_debug_info=True) + ) + if dump_info: + print("TTIR module", file=sys.stderr) + print(ttir, file=sys.stderr) + + binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir) + if dump_info: + print("TTNN module", file=sys.stderr) + print(ttnn, file=sys.stderr) + + return binary