Skip to content

Commit

Permalink
Added TT_TORCH_ENABLE_IR_DUMP env var to save all intermediate mlir d…
Browse files Browse the repository at this point in the history
…ialect dumps - torch.FX, stablehlo, ttir, ttnn
  • Loading branch information
tapspatel committed Jan 13, 2025
1 parent 4853ccb commit e9874e6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ You can use the following environment variables to override default behaviour:
| TT_TORCH_CONSTEVAL_PARAMETERS | Extends consteval to include parameters (e.g., model weights) as well as embedded constants. | False |
| TT_TORCH_EMBEDDEDD_CONSTANTS | Remove embedded constants from the Torch FX graph and convert them to constant inputs | False |
| TT_TORCH_ENABLE_IR_PRINTING | Enables printing MLIR for all conversion steps from StableHLO to TTNN. Be warned, this forces single core compile, so is much slower. | False |
| TT_TORCH_ENABLE_IR_DUMP | Enables dumping MLIR for all dialect conversion steps : [torch.FX, stablehlo, ttir, ttnn] | False |
25 changes: 25 additions & 0 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,18 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
dump_intermediates = os.environ.get("TT_TORCH_ENABLE_IR_PRINTING")
dump_intermediates = dump_intermediates and int(dump_intermediates)

save_intermediates = os.environ.get('TT_TORCH_ENABLE_IR_DUMP')
save_intermediates = save_intermediates and int(save_intermediates)

if save_intermediates:
with open("gm_graph.txt", 'w') as file:
original_stdout = sys.stdout
sys.stdout = file
try:
gm.graph.print_tabular()
finally:
sys.stdout = original_stdout

module = import_graph(gm.graph)
if dump_intermediates:
module.dump()
Expand All @@ -483,6 +495,10 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
if dump_intermediates:
module.dump()

if save_intermediates:
with open("stablehlo.mlir", "w") as file:
print(module, file=file)

if compiler_config.profile_ops:
compiler_config.set_stablehlo_mlir_module(module.operation.get_asm())
if compiler_config.compile_depth == CompileDepth.STABLEHLO:
Expand All @@ -491,10 +507,19 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
ttir = tt_mlir.compile_stable_hlo_to_ttir(module.operation.get_asm())
if dump_intermediates:
print(ttir)

if save_intermediates:
with open("ttir.mlir", "w") as file:
print(ttir, file=file)

binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
if dump_intermediates:
print(ttnn)

if save_intermediates:
with open("ttnn.mlir", "w") as file:
print(ttnn, file=file)

executor.set_binary(binary)
return executor

Expand Down

0 comments on commit e9874e6

Please sign in to comment.