Skip to content

Commit

Permalink
Enable TT_TORCH_IR_LOG_LEVEL to dump intermediate IRs when compiling …
Browse files Browse the repository at this point in the history
…onnx models.
  • Loading branch information
LPanosTT committed Feb 27, 2025
1 parent 186ad28 commit f726e7a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
16 changes: 14 additions & 2 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,15 +599,27 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
verify_ir(module)

if dump_info:
print("Torch module", file=sys.stderr)
print("Torch FX module", file=sys.stderr)
module.dump()

if compiler_config.profile_ops:
compiler_config.set_torch_mlir_module(module.operation.get_asm())
if compiler_config.compile_depth == CompileDepth.TORCH_MLIR:
return executor

lower_to_stable_hlo(module, enable_ir_printing=dump_debug)
run_pipeline_with_repro_report(
module,
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline)",
"Lowering TorchFX IR -> Torch Backend IR",
dump_debug,
)

if dump_info:
print("Torch Backend module", file=sys.stderr)
module.dump()

lower_mlir_module(False, OutputType.STABLEHLO, module)

if dump_info:
print("StableHLO module", file=sys.stderr)
module.dump()
Expand Down
39 changes: 38 additions & 1 deletion tt_torch/onnx_compile/onnx_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import tt_mlir
from torch_mlir.ir import Context
from torch_mlir.dialects import torch as torch_dialect
import os
import sys

from torch_mlir.compiler_utils import (
OutputType,
Expand All @@ -25,10 +27,45 @@ def compile_onnx(module: onnx.ModelProto):
imp = onnx_importer.NodeImporter.define_function(module_info.main_graph, module)
imp.import_all()

dump_intermediates = os.environ.get("TT_TORCH_IR_LOG_LEVEL")
dump_info = False
dump_debug = False
if dump_intermediates:
dump_debug = dump_intermediates == "DEBUG"
dump_info = dump_debug or dump_intermediates == "INFO"

# Setting large_elements_limit to 0 so the console does not get flooded with the data of large tensors
if dump_info:
print("ONNX module", file=sys.stderr)
module.print(large_elements_limit=0)

run_pipeline_with_repro_report(
module,
"builtin.module(torch-onnx-to-torch-backend-pipeline)",
"Lowering Torch Onnx IR -> Torch Backend IR",
)

if dump_info:
print("Torch Backend module", file=sys.stderr)
module.print(large_elements_limit=0)

lower_mlir_module(False, OutputType.STABLEHLO, module)
return tt_mlir.compile(module.operation.get_asm())

if dump_info:
print("StableHLO module", file=sys.stderr)
module.print(large_elements_limit=0)

# Need to set enable_debug_info=True to get the location information for the ops in the asm string
ttir = tt_mlir.compile_stable_hlo_to_ttir(
module.operation.get_asm(enable_debug_info=True)
)
if dump_info:
print("TTIR module", file=sys.stderr)
print(ttir, file=sys.stderr)

binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
if dump_info:
print("TTNN module", file=sys.stderr)
print(ttnn, file=sys.stderr)

return binary

0 comments on commit f726e7a

Please sign in to comment.