diff --git a/docs/src/controlling.md b/docs/src/controlling.md index aceed17d..38cd10b3 100644 --- a/docs/src/controlling.md +++ b/docs/src/controlling.md @@ -5,7 +5,8 @@ You can use the following environment variables to override default behaviour: | Environment Variable | Behaviour | Default | | -------------------- | --------- | -------- | TT_TORCH_COMPILE_DEPTH | Sets the maximum compile depth, see `tt_torch/tools/utils.py` for options. | `EXECUTE` | -| TT_TORCH_VERIFY_INTERMEDIATES | Sets whether to verify intermediate tensors against pytorch when running with compile depth `EXECUTE_OP_BY_OP`. | False | +| TT_TORCH_VERIFY_OP_BY_OP | Sets whether to verify the output of each compiled op against pytorch when running with compile depth `EXECUTE_OP_BY_OP`. | False | +| TT_TORCH_VERIFY_INTERMEDIATES | Sets whether to verify runtime intermediates during execution. | False | | TT_TORCH_CONSTEVAL | Enables evaluation of constant expressions (consteval) in the Torch FX graph prior to compilation. | False | | TT_TORCH_CONSTEVAL_PARAMETERS | Extends consteval to include parameters (e.g., model weights) as well as embedded constants. | False | | TT_TORCH_EMBEDDEDD_CONSTANTS | Remove embedded constants from the Torch FX graph and convert them to constant inputs | False | diff --git a/tt_torch/csrc/CMakeLists.txt b/tt_torch/csrc/CMakeLists.txt index ca96ad28..b8703c68 100644 --- a/tt_torch/csrc/CMakeLists.txt +++ b/tt_torch/csrc/CMakeLists.txt @@ -106,7 +106,12 @@ set(CMAKE_PREFIX_PATH ${TORCH_INSTALL_PREFIX}) find_package(Torch REQUIRED) find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") set(TARGET_NAME tt_mlir) + pybind11_add_module(${TARGET_NAME} bindings.cpp) +add_compile_definitions(TTMLIR_ENABLE_STABLEHLO=1) +if (${TT_RUNTIME_DEBUG} MATCHES "ON") + add_compile_definitions(TT_RUNTIME_DEBUG=1) +endif() add_dependencies(${TARGET_NAME} TT_TORCH_MLIR diff --git a/tt_torch/csrc/bindings.cpp b/tt_torch/csrc/bindings.cpp index 4ed62afa..71e8af7b 100644 --- a/tt_torch/csrc/bindings.cpp +++ b/tt_torch/csrc/bindings.cpp @@ -3,6 +3,12 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt-mlir-interface.hpp" +#include "tt/runtime/types.h" +#include +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 +#include "tt/runtime/detail/debug.h" +#include "tt/runtime/runtime.h" +#endif #include #include @@ -217,4 +223,49 @@ PYBIND11_MODULE(tt_mlir, m) { "Get the number of available devices"); m.def("bytestream_to_json", &bytestream_to_json, "Convert the bytestream to json"); + +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + py::class_(m, "CallbackContext"); + py::class_(m, "OpContext"); + py::class_(m, "TensorDesc") + .def_readonly("shape", &tt::runtime::TensorDesc::shape) + .def_readonly("stride", &tt::runtime::TensorDesc::stride) + .def_readonly("itemsize", &tt::runtime::TensorDesc::itemsize) + .def_readonly("dataType", &tt::runtime::TensorDesc::dataType); + m.def("get_op_output_tensor", &tt::runtime::getOpOutputTensor); + m.def("get_op_debug_str", &tt::runtime::getOpDebugString, + "Get the debug string of the op"); + m.def("get_op_loc_info", &tt::runtime::getOpLocInfo, + "Get the location info of the op"); + py::class_(m, "DebugHooks") + .def_static( + "get_debug_hooks", + [](py::function func) { + return tt::runtime::debug::Hooks::get( + [func](tt::runtime::Binary binary, + tt::runtime::CallbackContext programContext, + tt::runtime::OpContext opContext) { + func(binary, programContext, opContext); + }); + }, + "Get the debug hooks") + .def("__str__", [](const tt::runtime::debug::Hooks &hooks) { + std::stringstream os; + os << hooks; + return os.str(); + }); + + /** + * Cleanup code to force a well ordered destruction w.r.t. the GIL + */ + auto cleanup_callback = []() { + tt::runtime::debug::Hooks::get().unregisterHooks(); + }; + m.add_object("_cleanup", py::capsule(cleanup_callback)); + m.def("unregister_hooks", + []() { tt::runtime::debug::Hooks::get().unregisterHooks(); }); + m.def("is_runtime_debug_enabled", []() -> bool { return true; }); +#else + m.def("is_runtime_debug_enabled", []() -> bool { return false; }); +#endif } diff --git a/tt_torch/csrc/tt-mlir-interface.cpp b/tt_torch/csrc/tt-mlir-interface.cpp index e7474eb2..1e9b7c84 100644 --- a/tt_torch/csrc/tt-mlir-interface.cpp +++ b/tt_torch/csrc/tt-mlir-interface.cpp @@ -33,7 +33,6 @@ #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/transforms/Passes.h" // from @stablehlo -#define TTMLIR_ENABLE_STABLEHLO #include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" #include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" @@ -92,7 +91,7 @@ std::string compileStableHLOToTTIR(std::string_view code) { } std::string buffer; llvm::raw_string_ostream os(buffer); - mlir_module.get()->print(os); + mlir_module.get()->print(os, mlir::OpPrintingFlags().enableDebugInfo()); os.flush(); return buffer; @@ -150,7 +149,7 @@ compileTTIRToTTNN(std::string_view code) { std::string buffer; llvm::raw_string_ostream os(buffer); - mlir_module->print(os); + mlir_module->print(os, mlir::OpPrintingFlags().enableDebugInfo()); os.flush(); return std::make_tuple(binary, buffer); diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index ce12a750..9385bdd7 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -19,8 +19,10 @@ ) import tt_mlir -from torch_mlir.ir import Context -from torch_mlir.extras.fx_importer import FxImporter +from tt_mlir import is_runtime_debug_enabled +import torch_mlir +from torch_mlir.ir import Context, Location +from torch_mlir.extras.fx_importer import FxImporter, ContextCache from torch_mlir.dialects import torch as torch_dialect @@ -29,7 +31,7 @@ run_pipeline_with_repro_report, lower_mlir_module, ) -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional import os import multiprocessing as mp import time @@ -39,10 +41,18 @@ import tempfile +class TTContextCache(ContextCache): + def get_node_location(self, node: torch.fx.Node) -> Optional[Location]: + return Location.name(node.name, context=self._c) + + def import_graph(graph: torch.fx.GraphModule): context = Context() torch_dialect.register_dialect(context) importer = FxImporter(context=context) + importer._cc = TTContextCache( + importer._c, py_attr_tracker=importer._py_attr_tracker + ) importer.import_stateless_graph(graph) return importer.module @@ -105,6 +115,14 @@ def __init__( # Dictionary to keep track of the type conversion for unsupported hardware # types and use it to convert the input arguments to supported types. self.type_conversion = {torch.bool: torch.bfloat16} + self.intermediate_callbacks = {} + + def register_intermediate_callback(self, callback): + if not is_runtime_debug_enabled(): + raise RuntimeError( + "Runtime debug is required to use intermediate callbacks. Please recompile this project with -DTT_RUNTIME_DEBUG=ON." + ) + tt_mlir.DebugHooks.get_debug_hooks(callback) def set_binary(self, binary): self.binary = binary @@ -383,7 +401,7 @@ def run_gm_op_by_op(self, *inputs): raise ValueError("Failed to execute") op.compilation_status = OpCompilationStatus.EXECUTED tensor = node.target(*args, **node.kwargs) - if self.compiler_config.enable_intermediate_verification: + if self.compiler_config.verify_op_by_op: atol = calculate_atol(calculated, tensor) op.atol = atol if atol > self.required_atol: @@ -454,6 +472,16 @@ def __call__(self, *inputs): return self.gm(*inputs) +def verify_golden_callback(binary, callback_context, op_context): + # Using these parameters, we should be able to query information + # about the op described by op_context, and its output. I.e. location: + location = tt_mlir.get_op_loc_info(op_context) + # ... + + # We will need to provide the bindings necesarry in this frontend. + # Those bindings will interact with the runtime API + + def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): # Apply environment overrides at start of compilation to allow overriding what was set in the test compiler_config.apply_environment_overrides() @@ -492,12 +520,19 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): if compiler_config.compile_depth == CompileDepth.STABLEHLO: return executor - ttir = tt_mlir.compile_stable_hlo_to_ttir(module.operation.get_asm()) + # Need to set enable_debug_info=True to get the location information for the ops in the asm string + ttir = tt_mlir.compile_stable_hlo_to_ttir( + module.operation.get_asm(enable_debug_info=True) + ) if dump_intermediates: print("TTIR module", file=sys.stderr) print(ttir, file=sys.stderr) + if compiler_config.enable_intermediate_verification: + executor.register_intermediate_callback(verify_golden_callback) + binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir) + if dump_intermediates: print("TTNN module", file=sys.stderr) print(ttnn, file=sys.stderr) diff --git a/tt_torch/tools/utils.py b/tt_torch/tools/utils.py index e92cc788..6c968263 100644 --- a/tt_torch/tools/utils.py +++ b/tt_torch/tools/utils.py @@ -11,6 +11,8 @@ import math import sys +from tt_mlir import is_runtime_debug_enabled + class CompileDepth(Enum): TORCH_FX = 1 @@ -129,14 +131,47 @@ def __init__(self): self.model_name = "" self.results_path = "results/models/" self.single_op_timeout = 5 - self.enable_intermediate_verification = False self.enable_consteval = False self.remove_embedded_constants = False self._consteval_parameters = False + self._enable_intermediate_verification = False + self._verify_op_by_op = False self.apply_environment_overrides() self.post_init() + @property + def verify_op_by_op(self): + return self._verify_op_by_op + + @verify_op_by_op.setter + def verify_op_by_op(self, value): + assert isinstance( + value, bool + ), "enable_intermediate_verification must be a boolean" + if value and self.compile_depth != CompileDepth.EXECUTE_OP_BY_OP: + print( + "WARNING: Setting verify_op_by_op to True but compile_depth is not set to EXECUTE_OP_BY_OP. This CompilerConfig flag will have no effect." + ) + self._verify_op_by_op = value + + @property + def enable_intermediate_verification(self): + return self._enable_intermediate_verification + + @enable_intermediate_verification.setter + def enable_intermediate_verification(self, value): + assert isinstance( + value, bool + ), "enable_intermediate_verification must be a boolean" + + if value and not is_runtime_debug_enabled(): + raise RuntimeError( + "attempting to set enable_intermediate_verification to True but tt_mlir was not built with runtime debug enabled. Rebuild this project with -DTT_RUNTIME_DEBUG=ON if you wish to verify intermediate results." + ) + + self._enable_intermediate_verification = True + @property def consteval_parameters(self): return self._consteval_parameters @@ -150,6 +185,9 @@ def apply_environment_overrides(self): compile_depth = os.environ.get("TT_TORCH_COMPILE_DEPTH") if compile_depth: self.compile_depth = CompileDepth[compile_depth] + verify_op_by_op = os.environ.get("TT_TORCH_VERIFY_OP_BY_OP") + if verify_op_by_op and int(verify_op_by_op): + self.verify_op_by_op = True verify_intermediates = os.environ.get("TT_TORCH_VERIFY_INTERMEDIATES") if verify_intermediates and int(verify_intermediates): self.enable_intermediate_verification = True