Skip to content

Commit

Permalink
Custom ContextCache for setting location info
Browse files Browse the repository at this point in the history
Some bindings for runtime API

Ensure that location information is dumped into the module string if
TT_RUNTIME_DEBUG=ON

Add cleanup bindings to ensure well ordered destruction

register callbacks via Executor
  • Loading branch information
LPanosTT committed Jan 27, 2025
1 parent 2256f17 commit 5a3d70c
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 10 deletions.
3 changes: 2 additions & 1 deletion docs/src/controlling.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ You can use the following environment variables to override default behaviour:
| Environment Variable | Behaviour | Default |
| -------------------- | --------- | --------
| TT_TORCH_COMPILE_DEPTH | Sets the maximum compile depth, see `tt_torch/tools/utils.py` for options. | `EXECUTE` |
| TT_TORCH_VERIFY_INTERMEDIATES | Sets whether to verify intermediate tensors against pytorch when running with compile depth `EXECUTE_OP_BY_OP`. | False |
| TT_TORCH_VERIFY_OP_BY_OP | Sets whether to verify the output of each compiled op against pytorch when running with compile depth `EXECUTE_OP_BY_OP`. | False |
| TT_TORCH_VERIFY_INTERMEDIATES | Sets whether to verify runtime intermediates during execution. | False |
| TT_TORCH_CONSTEVAL | Enables evaluation of constant expressions (consteval) in the Torch FX graph prior to compilation. | False |
| TT_TORCH_CONSTEVAL_PARAMETERS | Extends consteval to include parameters (e.g., model weights) as well as embedded constants. | False |
| TT_TORCH_EMBEDDEDD_CONSTANTS | Remove embedded constants from the Torch FX graph and convert them to constant inputs | False |
Expand Down
5 changes: 5 additions & 0 deletions tt_torch/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ set(CMAKE_PREFIX_PATH ${TORCH_INSTALL_PREFIX})
find_package(Torch REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
set(TARGET_NAME tt_mlir)

pybind11_add_module(${TARGET_NAME} bindings.cpp)
add_compile_definitions(TTMLIR_ENABLE_STABLEHLO=1)
if (${TT_RUNTIME_DEBUG} MATCHES "ON")
add_compile_definitions(TT_RUNTIME_DEBUG=1)
endif()

add_dependencies(${TARGET_NAME}
TT_TORCH_MLIR
Expand Down
51 changes: 51 additions & 0 deletions tt_torch/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
// SPDX-License-Identifier: Apache-2.0

#include "tt-mlir-interface.hpp"
#include "tt/runtime/types.h"
#include <optional>
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
#include "tt/runtime/detail/debug.h"
#include "tt/runtime/runtime.h"
#endif
#include <pybind11/pybind11.h>
#include <torch/extension.h>

Expand Down Expand Up @@ -217,4 +223,49 @@ PYBIND11_MODULE(tt_mlir, m) {
"Get the number of available devices");
m.def("bytestream_to_json", &bytestream_to_json,
"Convert the bytestream to json");

#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
py::class_<tt::runtime::CallbackContext>(m, "CallbackContext");
py::class_<tt::runtime::OpContext>(m, "OpContext");
py::class_<tt::runtime::TensorDesc>(m, "TensorDesc")
.def_readonly("shape", &tt::runtime::TensorDesc::shape)
.def_readonly("stride", &tt::runtime::TensorDesc::stride)
.def_readonly("itemsize", &tt::runtime::TensorDesc::itemsize)
.def_readonly("dataType", &tt::runtime::TensorDesc::dataType);
m.def("get_op_output_tensor", &tt::runtime::getOpOutputTensor);
m.def("get_op_debug_str", &tt::runtime::getOpDebugString,
"Get the debug string of the op");
m.def("get_op_loc_info", &tt::runtime::getOpLocInfo,
"Get the location info of the op");
py::class_<tt::runtime::debug::Hooks>(m, "DebugHooks")
.def_static(
"get_debug_hooks",
[](py::function func) {
return tt::runtime::debug::Hooks::get(
[func](tt::runtime::Binary binary,
tt::runtime::CallbackContext programContext,
tt::runtime::OpContext opContext) {
func(binary, programContext, opContext);
});
},
"Get the debug hooks")
.def("__str__", [](const tt::runtime::debug::Hooks &hooks) {
std::stringstream os;
os << hooks;
return os.str();
});

/**
* Cleanup code to force a well ordered destruction w.r.t. the GIL
*/
auto cleanup_callback = []() {
tt::runtime::debug::Hooks::get().unregisterHooks();
};
m.add_object("_cleanup", py::capsule(cleanup_callback));
m.def("unregister_hooks",
[]() { tt::runtime::debug::Hooks::get().unregisterHooks(); });
m.def("is_runtime_debug_enabled", []() -> bool { return true; });
#else
m.def("is_runtime_debug_enabled", []() -> bool { return false; });
#endif
}
5 changes: 2 additions & 3 deletions tt_torch/csrc/tt-mlir-interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "stablehlo/transforms/Passes.h" // from @stablehlo

#define TTMLIR_ENABLE_STABLEHLO
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
Expand Down Expand Up @@ -92,7 +91,7 @@ std::string compileStableHLOToTTIR(std::string_view code) {
}
std::string buffer;
llvm::raw_string_ostream os(buffer);
mlir_module.get()->print(os);
mlir_module.get()->print(os, mlir::OpPrintingFlags().enableDebugInfo());
os.flush();

return buffer;
Expand Down Expand Up @@ -150,7 +149,7 @@ compileTTIRToTTNN(std::string_view code) {

std::string buffer;
llvm::raw_string_ostream os(buffer);
mlir_module->print(os);
mlir_module->print(os, mlir::OpPrintingFlags().enableDebugInfo());
os.flush();

return std::make_tuple(binary, buffer);
Expand Down
45 changes: 40 additions & 5 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
)

import tt_mlir
from torch_mlir.ir import Context
from torch_mlir.extras.fx_importer import FxImporter
from tt_mlir import is_runtime_debug_enabled
import torch_mlir
from torch_mlir.ir import Context, Location
from torch_mlir.extras.fx_importer import FxImporter, ContextCache

from torch_mlir.dialects import torch as torch_dialect

Expand All @@ -29,7 +31,7 @@
run_pipeline_with_repro_report,
lower_mlir_module,
)
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Optional
import os
import multiprocessing as mp
import time
Expand All @@ -39,10 +41,18 @@
import tempfile


class TTContextCache(ContextCache):
def get_node_location(self, node: torch.fx.Node) -> Optional[Location]:
return Location.name(node.name, context=self._c)


def import_graph(graph: torch.fx.GraphModule):
context = Context()
torch_dialect.register_dialect(context)
importer = FxImporter(context=context)
importer._cc = TTContextCache(
importer._c, py_attr_tracker=importer._py_attr_tracker
)
importer.import_stateless_graph(graph)
return importer.module

Expand Down Expand Up @@ -105,6 +115,14 @@ def __init__(
# Dictionary to keep track of the type conversion for unsupported hardware
# types and use it to convert the input arguments to supported types.
self.type_conversion = {torch.bool: torch.bfloat16}
self.intermediate_callbacks = {}

def register_intermediate_callback(self, callback):
if not is_runtime_debug_enabled():
raise RuntimeError(
"Runtime debug is required to use intermediate callbacks. Please recompile this project with -DTT_RUNTIME_DEBUG=ON."
)
tt_mlir.DebugHooks.get_debug_hooks(callback)

def set_binary(self, binary):
self.binary = binary
Expand Down Expand Up @@ -383,7 +401,7 @@ def run_gm_op_by_op(self, *inputs):
raise ValueError("Failed to execute")
op.compilation_status = OpCompilationStatus.EXECUTED
tensor = node.target(*args, **node.kwargs)
if self.compiler_config.enable_intermediate_verification:
if self.compiler_config.verify_op_by_op:
atol = calculate_atol(calculated, tensor)
op.atol = atol
if atol > self.required_atol:
Expand Down Expand Up @@ -454,6 +472,16 @@ def __call__(self, *inputs):
return self.gm(*inputs)


def verify_golden_callback(binary, callback_context, op_context):
# Using these parameters, we should be able to query information
# about the op described by op_context, and its output. I.e. location:
location = tt_mlir.get_op_loc_info(op_context)
# ...

# We will need to provide the bindings necesarry in this frontend.
# Those bindings will interact with the runtime API


def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
# Apply environment overrides at start of compilation to allow overriding what was set in the test
compiler_config.apply_environment_overrides()
Expand Down Expand Up @@ -492,12 +520,19 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
if compiler_config.compile_depth == CompileDepth.STABLEHLO:
return executor

ttir = tt_mlir.compile_stable_hlo_to_ttir(module.operation.get_asm())
# Need to set enable_debug_info=True to get the location information for the ops in the asm string
ttir = tt_mlir.compile_stable_hlo_to_ttir(
module.operation.get_asm(enable_debug_info=True)
)
if dump_intermediates:
print("TTIR module", file=sys.stderr)
print(ttir, file=sys.stderr)

if compiler_config.enable_intermediate_verification:
executor.register_intermediate_callback(verify_golden_callback)

binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)

if dump_intermediates:
print("TTNN module", file=sys.stderr)
print(ttnn, file=sys.stderr)
Expand Down
40 changes: 39 additions & 1 deletion tt_torch/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import math
import sys

from tt_mlir import is_runtime_debug_enabled


class CompileDepth(Enum):
TORCH_FX = 1
Expand Down Expand Up @@ -129,14 +131,47 @@ def __init__(self):
self.model_name = ""
self.results_path = "results/models/"
self.single_op_timeout = 5
self.enable_intermediate_verification = False
self.enable_consteval = False
self.remove_embedded_constants = False
self._consteval_parameters = False
self._enable_intermediate_verification = False
self._verify_op_by_op = False

self.apply_environment_overrides()
self.post_init()

@property
def verify_op_by_op(self):
return self._verify_op_by_op

@verify_op_by_op.setter
def verify_op_by_op(self, value):
assert isinstance(
value, bool
), "enable_intermediate_verification must be a boolean"
if value and self.compile_depth != CompileDepth.EXECUTE_OP_BY_OP:
print(
"WARNING: Setting verify_op_by_op to True but compile_depth is not set to EXECUTE_OP_BY_OP. This CompilerConfig flag will have no effect."
)
self._verify_op_by_op = value

@property
def enable_intermediate_verification(self):
return self._enable_intermediate_verification

@enable_intermediate_verification.setter
def enable_intermediate_verification(self, value):
assert isinstance(
value, bool
), "enable_intermediate_verification must be a boolean"

if value and not is_runtime_debug_enabled():
raise RuntimeError(
"attempting to set enable_intermediate_verification to True but tt_mlir was not built with runtime debug enabled. Rebuild this project with -DTT_RUNTIME_DEBUG=ON if you wish to verify intermediate results."
)

self._enable_intermediate_verification = True

@property
def consteval_parameters(self):
return self._consteval_parameters
Expand All @@ -150,6 +185,9 @@ def apply_environment_overrides(self):
compile_depth = os.environ.get("TT_TORCH_COMPILE_DEPTH")
if compile_depth:
self.compile_depth = CompileDepth[compile_depth]
verify_op_by_op = os.environ.get("TT_TORCH_VERIFY_OP_BY_OP")
if verify_op_by_op and int(verify_op_by_op):
self.verify_op_by_op = True
verify_intermediates = os.environ.get("TT_TORCH_VERIFY_INTERMEDIATES")
if verify_intermediates and int(verify_intermediates):
self.enable_intermediate_verification = True
Expand Down

0 comments on commit 5a3d70c

Please sign in to comment.