diff --git a/tt_torch/csrc/bindings.cpp b/tt_torch/csrc/bindings.cpp index 9b61b409..87ebe80f 100644 --- a/tt_torch/csrc/bindings.cpp +++ b/tt_torch/csrc/bindings.cpp @@ -239,6 +239,9 @@ PYBIND11_MODULE(tt_mlir, m) { .def_readonly("stride", &tt::runtime::TensorDesc::stride) .def_readonly("itemsize", &tt::runtime::TensorDesc::itemsize) .def_readonly("dataType", &tt::runtime::TensorDesc::dataType); + + py::class_(m, "Tensor"); + m.def("get_op_output_tensor", &tt::runtime::getOpOutputTensor); m.def("get_op_debug_str", &tt::runtime::getOpDebugString, "Get the debug string of the op"); diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 8329bb5a..3a61a475 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -7,6 +7,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch._functorch.compile_utils import strip_overloads import operator +import pdb from tt_torch.dynamo.passes import pass_pipeline from tt_torch.tools.utils import ( @@ -18,6 +19,7 @@ calculate_atol, calculate_pcc, ) +from tt_torch.tools.verify import verify_against_golden import tt_mlir from tt_mlir import is_runtime_debug_enabled @@ -57,6 +59,50 @@ def get_node_location(self, node: torch.fx.Node) -> Optional[Location]: return Location.name(node.name, context=self._c) +class RuntimeIntermediate: + def __init__(self, node: torch.fx.Node, golden): + self.node = node + self.golden = golden # may be a tuple of tensors + + # each fxnode can be decomposed into multiple ttnn ops. + # we store all their intermediate outputs here + # TODO - Need a way to uniquely reference ttnn intermediates + self.decomposed_intermediate_outputs = [] + self.pcc = None + self.atol = None + self.passed_pcc = False + self.passed_atol = False + + def calculate_metrics(self): + # calculate the metrics for the golden tensor after all decomposition steps done + + if ( + len(self.decomposed_intermediate_outputs) == 0 + and self.node.op == "call_function" + ): + return # getitem_4 has no intermediates? - if there are no intermediates for a call_function node; what to do + assert False, f"No decomposed intermediates found for {self.node.name}" + + final_decomposed_output = self.decomposed_intermediate_outputs[ + -1 + ] # could be a tuple of tensors + + # verify_against_golden expects a tuple of tensors as inputs. need to preprocess + if not isinstance(final_decomposed_output, tuple): + final_decomposed_output = (final_decomposed_output,) + if not isinstance(self.golden, tuple): + self.golden = (self.golden,) + + ( + self.passed_pcc, + self.passed_atol, + _, + _, + self.pcc, + self.atol, + ) = verify_against_golden(self.golden, final_decomposed_output, 0.99, 1e-2) + + def import_graph(graph: torch.fx.GraphModule): context = Context() torch_dialect.register_dialect(context) @@ -157,6 +203,10 @@ def __init__( self.stderror_redirected = False self.file_stderr = None + # Dictionary to track the intermediate golden values for each torchFX op + # string(fxnode_name) : runtimeCacheEntry + self.runtime_intermediate_cache = {} + def register_intermediate_callback(self, callback): if not is_runtime_debug_enabled(): raise RuntimeError( @@ -433,7 +483,7 @@ def run_op(self, binary, *inputs): return outputs, stderr_data - def run_gm_op_by_op(self, *inputs): + def run_gm_op_by_op(self, *inputs, cache_intermediate_goldens=False): node_to_tensor = {} input_index = 0 outputs = [] @@ -472,6 +522,10 @@ def run_gm_op_by_op(self, *inputs): binary = None print(f"Failed to compile {idx}/{num_nodes}: {node.target}: {e}") + if cache_intermediate_goldens: + golden = node.target(*args, **node.kwargs) + cache_entry = RuntimeIntermediate(node, golden) + self.runtime_intermediate_cache[node.name] = cache_entry if ( self.compiler_config.compile_depth == CompileDepth.EXECUTE_OP_BY_OP and binary is not None @@ -488,7 +542,9 @@ def run_gm_op_by_op(self, *inputs): op.compilation_status = OpCompilationStatus.EXECUTED tensor = node.target(*args, **node.kwargs) if self.compiler_config.verify_op_by_op: - atol = calculate_atol(calculated, tensor) + atol = calculate_atol( + calculated, tensor + ) # how does this work? the tensor/calculated must be unpacked properly. op.atol = atol if atol > self.required_atol: print(f"atol too high for {idx}: {atol}") @@ -504,6 +560,7 @@ def run_gm_op_by_op(self, *inputs): else: tensor = node.target(*args, **node.kwargs) node_to_tensor[node] = tensor + elif node.op == "output": args = node.args[0] output_tensors = [node_to_tensor[arg] for arg in args] @@ -529,6 +586,21 @@ def run_gm_op_by_op(self, *inputs): return outputs + def verify_intermediates_after_execution(self): + pdb.set_trace() + for _, intermediate in self.runtime_intermediate_cache.items(): + intermediate.calculate_metrics() + print( + f"Metrics for {intermediate.node.name}: pcc {intermediate.pcc}\tatol {intermediate.atol}" + ) + + if not intermediate.passed_atol: + print( + f"atol too high for {intermediate.node.name}: {intermediate.atol}" + ) + if not intermediate.passed_pcc: + print(f"pcc too low for {intermediate.node.name}: {intermediate.pcc}") + def __call__(self, *inputs): new_inputs = () for input in inputs: @@ -553,26 +625,86 @@ def __call__(self, *inputs): inputs = new_inputs + if self.compiler_config._enable_intermediate_verification: + # prepopulate golden map + self.run_gm_op_by_op( + *(inputs + self.graph_constants), cache_intermediate_goldens=True + ) + if self.compiler_config.compile_depth == CompileDepth.EXECUTE: assert self.binary is not None, "Binary must be set for EXECUTE mode" - return tt_mlir.run(inputs + self.graph_constants, self.binary) + ret = tt_mlir.run(inputs + self.graph_constants, self.binary) elif self.compiler_config.compile_depth in ( CompileDepth.EXECUTE_OP_BY_OP, CompileDepth.COMPILE_OP_BY_OP, ): - return self.run_gm_op_by_op(*(inputs + self.graph_constants)) + ret = self.run_gm_op_by_op(*(inputs + self.graph_constants)) else: - return self.gm(*inputs) + ret = self.gm(*inputs) + + if self.compiler_config._enable_intermediate_verification: + # run post-execution intermediate verification + self.verify_intermediates_after_execution() + + return ret + + +def create_verify_golden_callback(executor: Executor): + # Closure to capture external state in the callback. + + def verify_golden_callback(binary, callback_context, op_context): + # Using these parameters, we should be able to query information + # about the op described by op_context, and its output. I.e. location: + location = tt_mlir.get_op_loc_info( + op_context + ) # Do we care about other context? + output_intermediate_tensor: Tensor = None # Grab runtime tensor and inject here + + # format 'loc("")' + print(f"location = {location}") + + if location == "loc(unknown)": + return + location = location.split('"')[1] -def verify_golden_callback(binary, callback_context, op_context): - # Using these parameters, we should be able to query information - # about the op described by op_context, and its output. I.e. location: - location = tt_mlir.get_op_loc_info(op_context) - # ... + print("torchfx node UID raw location =", location) - # We will need to provide the bindings necesarry in this frontend. - # Those bindings will interact with the runtime API + intermediate_data: RuntimeIntermediate = ( + executor.runtime_intermediate_cache.get(location, None) + ) + + if intermediate_data is not None: + print(f"Found golden for op @ {intermediate_data.node.name} == {location}") + + # TESTING ONLY - Add the golden tensor as a fake output, to check that + # verification can get PCC = 1 for all cases. + intermediate_data.decomposed_intermediate_outputs.append( + output_intermediate_tensor + ) # actual output + intermediate_data.decomposed_intermediate_outputs.append( + intermediate_data.golden + ) # fake output + + # if intermediate_data.golden != None: + print( + f"Decomposition added fake tensor too. Total ct {len(intermediate_data.decomposed_intermediate_outputs)}" + ) + + # pdb.set_trace() + + # atol = calculate_atol(calculated, golden) + # if atol > executor.required_atol: + # print(f"atol too high for {location}: {atol}") + # pcc = calculate_pcc(calculated, golden) + # if pcc < executor.required_pcc: + # print(f"pcc too low for {location}: {pcc}") + + # pdb.set_trace() + # We will need to provide the bindings necesarry in this frontend. + # Those bindings will interact with the runtime API + + return verify_golden_callback def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): @@ -595,6 +727,8 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): dump_debug = dump_intermediates == "DEBUG" dump_info = dump_debug or dump_intermediates == "INFO" + gm.graph.print_tabular() # run in full mode to print the entire graph (names for each op) + module = import_graph(gm.graph) verify_ir(module) @@ -626,7 +760,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): print(ttir, file=sys.stderr) if compiler_config.enable_intermediate_verification: - executor.register_intermediate_callback(verify_golden_callback) + executor.register_intermediate_callback(create_verify_golden_callback(executor)) binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir) if dump_info: diff --git a/tt_torch/tools/verify.py b/tt_torch/tools/verify.py index 0fa225d4..f8a67f9b 100644 --- a/tt_torch/tools/verify.py +++ b/tt_torch/tools/verify.py @@ -7,7 +7,6 @@ import numpy as np import tt_mlir from tt_torch.onnx_compile import compile_onnx -from tt_torch.dynamo.backend import backend from tt_torch.tools.utils import calculate_atol, calculate_pcc @@ -118,6 +117,8 @@ def _verify_torch_module( compiler_config, do_assert, ): + from tt_torch.dynamo.backend import backend # avoid circular import + if input_data_types is None: input_data_types = [torch.float32] * ( len(input_shapes) if input_shapes is not None else len(inputs)