Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jameszianxu/runtime intermediates #383

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tt_torch/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ PYBIND11_MODULE(tt_mlir, m) {
.def_readonly("stride", &tt::runtime::TensorDesc::stride)
.def_readonly("itemsize", &tt::runtime::TensorDesc::itemsize)
.def_readonly("dataType", &tt::runtime::TensorDesc::dataType);

py::class_<tt::runtime::Tensor>(m, "Tensor");

m.def("get_op_output_tensor", &tt::runtime::getOpOutputTensor);
m.def("get_op_debug_str", &tt::runtime::getOpDebugString,
"Get the debug string of the op");
Expand Down
160 changes: 147 additions & 13 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
import operator
import pdb

from tt_torch.dynamo.passes import pass_pipeline
from tt_torch.tools.utils import (
Expand All @@ -18,6 +19,7 @@
calculate_atol,
calculate_pcc,
)
from tt_torch.tools.verify import verify_against_golden

import tt_mlir
from tt_mlir import is_runtime_debug_enabled
Expand Down Expand Up @@ -57,6 +59,50 @@ def get_node_location(self, node: torch.fx.Node) -> Optional[Location]:
return Location.name(node.name, context=self._c)


class RuntimeIntermediate:
def __init__(self, node: torch.fx.Node, golden):
self.node = node
self.golden = golden # may be a tuple of tensors

# each fxnode can be decomposed into multiple ttnn ops.
# we store all their intermediate outputs here
# TODO - Need a way to uniquely reference ttnn intermediates
self.decomposed_intermediate_outputs = []
self.pcc = None
self.atol = None
self.passed_pcc = False
self.passed_atol = False

def calculate_metrics(self):
# calculate the metrics for the golden tensor after all decomposition steps done

if (
len(self.decomposed_intermediate_outputs) == 0
and self.node.op == "call_function"
):
return # getitem_4 has no intermediates? - if there are no intermediates for a call_function node; what to do
assert False, f"No decomposed intermediates found for {self.node.name}"

final_decomposed_output = self.decomposed_intermediate_outputs[
-1
] # could be a tuple of tensors

# verify_against_golden expects a tuple of tensors as inputs. need to preprocess
if not isinstance(final_decomposed_output, tuple):
final_decomposed_output = (final_decomposed_output,)
if not isinstance(self.golden, tuple):
self.golden = (self.golden,)

(
self.passed_pcc,
self.passed_atol,
_,
_,
self.pcc,
self.atol,
) = verify_against_golden(self.golden, final_decomposed_output, 0.99, 1e-2)


def import_graph(graph: torch.fx.GraphModule):
context = Context()
torch_dialect.register_dialect(context)
Expand Down Expand Up @@ -157,6 +203,10 @@ def __init__(
self.stderror_redirected = False
self.file_stderr = None

# Dictionary to track the intermediate golden values for each torchFX op
# string(fxnode_name) : runtimeCacheEntry
self.runtime_intermediate_cache = {}

def register_intermediate_callback(self, callback):
if not is_runtime_debug_enabled():
raise RuntimeError(
Expand Down Expand Up @@ -433,7 +483,7 @@ def run_op(self, binary, *inputs):

return outputs, stderr_data

def run_gm_op_by_op(self, *inputs):
def run_gm_op_by_op(self, *inputs, cache_intermediate_goldens=False):
node_to_tensor = {}
input_index = 0
outputs = []
Expand Down Expand Up @@ -472,6 +522,10 @@ def run_gm_op_by_op(self, *inputs):
binary = None
print(f"Failed to compile {idx}/{num_nodes}: {node.target}: {e}")

if cache_intermediate_goldens:
golden = node.target(*args, **node.kwargs)
cache_entry = RuntimeIntermediate(node, golden)
self.runtime_intermediate_cache[node.name] = cache_entry
if (
self.compiler_config.compile_depth == CompileDepth.EXECUTE_OP_BY_OP
and binary is not None
Expand All @@ -488,7 +542,9 @@ def run_gm_op_by_op(self, *inputs):
op.compilation_status = OpCompilationStatus.EXECUTED
tensor = node.target(*args, **node.kwargs)
if self.compiler_config.verify_op_by_op:
atol = calculate_atol(calculated, tensor)
atol = calculate_atol(
calculated, tensor
) # how does this work? the tensor/calculated must be unpacked properly.
op.atol = atol
if atol > self.required_atol:
print(f"atol too high for {idx}: {atol}")
Expand All @@ -504,6 +560,7 @@ def run_gm_op_by_op(self, *inputs):
else:
tensor = node.target(*args, **node.kwargs)
node_to_tensor[node] = tensor

elif node.op == "output":
args = node.args[0]
output_tensors = [node_to_tensor[arg] for arg in args]
Expand All @@ -529,6 +586,21 @@ def run_gm_op_by_op(self, *inputs):

return outputs

def verify_intermediates_after_execution(self):
pdb.set_trace()
for _, intermediate in self.runtime_intermediate_cache.items():
intermediate.calculate_metrics()
print(
f"Metrics for {intermediate.node.name}: pcc {intermediate.pcc}\tatol {intermediate.atol}"
)

if not intermediate.passed_atol:
print(
f"atol too high for {intermediate.node.name}: {intermediate.atol}"
)
if not intermediate.passed_pcc:
print(f"pcc too low for {intermediate.node.name}: {intermediate.pcc}")

def __call__(self, *inputs):
new_inputs = ()
for input in inputs:
Expand All @@ -553,26 +625,86 @@ def __call__(self, *inputs):

inputs = new_inputs

if self.compiler_config._enable_intermediate_verification:
# prepopulate golden map
self.run_gm_op_by_op(
*(inputs + self.graph_constants), cache_intermediate_goldens=True
)

if self.compiler_config.compile_depth == CompileDepth.EXECUTE:
assert self.binary is not None, "Binary must be set for EXECUTE mode"
return tt_mlir.run(inputs + self.graph_constants, self.binary)
ret = tt_mlir.run(inputs + self.graph_constants, self.binary)
elif self.compiler_config.compile_depth in (
CompileDepth.EXECUTE_OP_BY_OP,
CompileDepth.COMPILE_OP_BY_OP,
):
return self.run_gm_op_by_op(*(inputs + self.graph_constants))
ret = self.run_gm_op_by_op(*(inputs + self.graph_constants))
else:
return self.gm(*inputs)
ret = self.gm(*inputs)

if self.compiler_config._enable_intermediate_verification:
# run post-execution intermediate verification
self.verify_intermediates_after_execution()

return ret


def create_verify_golden_callback(executor: Executor):
# Closure to capture external state in the callback.

def verify_golden_callback(binary, callback_context, op_context):
# Using these parameters, we should be able to query information
# about the op described by op_context, and its output. I.e. location:
location = tt_mlir.get_op_loc_info(
op_context
) # Do we care about other context?
output_intermediate_tensor: Tensor = None # Grab runtime tensor and inject here

# format 'loc("<torchfx node UID>")'
print(f"location = {location}")

if location == "loc(unknown)":
return

location = location.split('"')[1]

def verify_golden_callback(binary, callback_context, op_context):
# Using these parameters, we should be able to query information
# about the op described by op_context, and its output. I.e. location:
location = tt_mlir.get_op_loc_info(op_context)
# ...
print("torchfx node UID raw location =", location)

# We will need to provide the bindings necesarry in this frontend.
# Those bindings will interact with the runtime API
intermediate_data: RuntimeIntermediate = (
executor.runtime_intermediate_cache.get(location, None)
)

if intermediate_data is not None:
print(f"Found golden for op @ {intermediate_data.node.name} == {location}")

# TESTING ONLY - Add the golden tensor as a fake output, to check that
# verification can get PCC = 1 for all cases.
intermediate_data.decomposed_intermediate_outputs.append(
output_intermediate_tensor
) # actual output
intermediate_data.decomposed_intermediate_outputs.append(
intermediate_data.golden
) # fake output

# if intermediate_data.golden != None:
print(
f"Decomposition added fake tensor too. Total ct {len(intermediate_data.decomposed_intermediate_outputs)}"
)

# pdb.set_trace()

# atol = calculate_atol(calculated, golden)
# if atol > executor.required_atol:
# print(f"atol too high for {location}: {atol}")
# pcc = calculate_pcc(calculated, golden)
# if pcc < executor.required_pcc:
# print(f"pcc too low for {location}: {pcc}")

# pdb.set_trace()
# We will need to provide the bindings necesarry in this frontend.
# Those bindings will interact with the runtime API

return verify_golden_callback


def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
Expand All @@ -595,6 +727,8 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
dump_debug = dump_intermediates == "DEBUG"
dump_info = dump_debug or dump_intermediates == "INFO"

gm.graph.print_tabular() # run in full mode to print the entire graph (names for each op)

module = import_graph(gm.graph)
verify_ir(module)

Expand Down Expand Up @@ -626,7 +760,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
print(ttir, file=sys.stderr)

if compiler_config.enable_intermediate_verification:
executor.register_intermediate_callback(verify_golden_callback)
executor.register_intermediate_callback(create_verify_golden_callback(executor))

binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
if dump_info:
Expand Down
3 changes: 2 additions & 1 deletion tt_torch/tools/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import tt_mlir
from tt_torch.onnx_compile import compile_onnx
from tt_torch.dynamo.backend import backend
from tt_torch.tools.utils import calculate_atol, calculate_pcc


Expand Down Expand Up @@ -118,6 +117,8 @@ def _verify_torch_module(
compiler_config,
do_assert,
):
from tt_torch.dynamo.backend import backend # avoid circular import

if input_data_types is None:
input_data_types = [torch.float32] * (
len(input_shapes) if input_shapes is not None else len(inputs)
Expand Down
Loading