diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index ba2d0759d6..d118c7990b 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -730,8 +730,9 @@ jobs: shell: bash run: | source env/activate - export TT_EXPLORER_GENERATED_TEST_DIR=${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN - pytest tools/explorer/test/run_tests.py + export TT_EXPLORER_GENERATED_MLIR_TEST_DIRS=${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/n150/perf,${{ steps.strings.outputs.build-output-dir }}/test/python/golden/ttnn + export TT_EXPLORER_GENERATED_TTNN_TEST_DIRS=${{ steps.strings.outputs.build-output-dir }}/test/python/golden/ttnn + pytest -svv tools/explorer/test/run_tests.py # collect results diff --git a/python/test_infra/test_utils.py b/python/test_infra/test_utils.py index d0d97a7bf0..6f0bc349d6 100644 --- a/python/test_infra/test_utils.py +++ b/python/test_infra/test_utils.py @@ -22,6 +22,8 @@ TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "") +# Default output to the current directory from where this module is being invoked +OUTPUT_PATH = "" # ----- Static helpers used in this file only ----- @@ -32,6 +34,25 @@ def _dump_module(module: Module) -> None: # ----- General Purpose Helpers - Could Be Used In Other Files ----- +def set_output_path(path): + global OUTPUT_PATH + if not os.path.exists(path): + raise ValueError(f"The provided path '{path}' is not a valid path.") + OUTPUT_PATH = path + + +def get_ttnn_path(filename): + ttnn_dir = os.path.join(OUTPUT_PATH, "ttnn") + if not os.path.exists(ttnn_dir): + os.makedirs(ttnn_dir) + return os.path.join(ttnn_dir, filename) + + +def get_ttmetal_path(filename): + ttmetal_dir = os.path.join(OUTPUT_PATH, "ttmetal") + if not os.path.exists(ttmetal_dir): + os.makedirs(ttmetal_dir) + return os.path.join(ttmetal_dir, filename) def compile_as_mlir_module( @@ -179,6 +200,7 @@ def ttir_to_ttnn( # Optionally dump to file. if dump_to_file: + output_file_name = get_ttnn_path(output_file_name) with open(output_file_name, "w") as f: f.write(str(module)) @@ -224,6 +246,7 @@ def ttir_to_ttmetal( # Optionally dump to file. if dump_to_file: + output_file_name = get_ttmetal_path(output_file_name) with open(output_file_name, "w") as f: f.write(str(module)) @@ -239,6 +262,8 @@ def ttnn_to_flatbuffer( """ # Convert to flatbuffer file. + # Take the output_file_name and prefix with the ttnn directory + output_file_name = get_ttnn_path(output_file_name) if module_log: ttnn_to_flatbuffer_file( module, output_file_name, builder.get_golden_map(), module_log @@ -260,6 +285,8 @@ def ttmetal_to_flatbuffer( """ # Convert to flatbuffer file. + # Take the output_file_name and prefix with ttm directory + output_file_name = get_ttmetal_path(output_file_name) ttmetal_to_flatbuffer_file(module, output_file_name, builder.get_golden_map()) print("`ttmetal_to_flatbuffer_file` passed successfully.") diff --git a/runtime/tools/python/ttrt/common/callback.py b/runtime/tools/python/ttrt/common/callback.py index 93a9af267b..627ad17823 100644 --- a/runtime/tools/python/ttrt/common/callback.py +++ b/runtime/tools/python/ttrt/common/callback.py @@ -65,7 +65,7 @@ def save_memory_report(self, memory_report_path): def check_pcc(self): for loc, golden_data in self.golden_report.items(): if golden_data["actual_pcc"] < golden_data["expected_pcc"]: - raise Exception( + raise PCCErrorException( f"Failed: golden comparison failed, actual_pcc={golden_data['actual_pcc']} < expected_pcc={golden_data['expected_pcc']}" ) diff --git a/runtime/tools/python/ttrt/common/perf.py b/runtime/tools/python/ttrt/common/perf.py index 18791a777b..fd347c2f12 100644 --- a/runtime/tools/python/ttrt/common/perf.py +++ b/runtime/tools/python/ttrt/common/perf.py @@ -528,12 +528,17 @@ def signal_handler(sig, frame): for result in test_result: if result["result"] != "pass": + if result["result"] == "test_error": + raise TTRTTestException(str(result["exception"])) raise Exception(f'{result["exception"]}') except Exception as e: + result = "error" + if isinstance(e, TTRTTestException): + result = "test_error" test_result = { "file_path": bin.file_path, - "result": "error", + "result": result, "exception": str(e), "log_file": self.logger.file_name, "artifacts": self.artifacts.artifacts_folder_path, @@ -543,7 +548,7 @@ def signal_handler(sig, frame): f"ERROR: test={bin.file_path} experienced an error with exception={str(e)}" ) self.results.add_result(test_result) - bin.test_result = "error" + bin.test_result = result traceback.print_exc() continue diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index 84a06a9c4a..e830bdadfc 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -691,9 +691,13 @@ def convert_input_layouts(device, inputs, fbb, program_index): callback_runtime_config.check_memory_leak() except Exception as e: + result = "error" + if isinstance(e, TTRTTestException): + result = "test_error" + test_result = { "file_path": bin.file_path, - "result": "error", + "result": result, "exception": str(e), "log_file": self.logger.file_name, "artifacts": self.artifacts.artifacts_folder_path, @@ -703,7 +707,7 @@ def convert_input_layouts(device, inputs, fbb, program_index): f"ERROR: test={bin.file_path} experienced an error with exception={str(e)}" ) self.results.add_result(test_result) - bin.test_result = "error" + bin.test_result = result continue finally: ttrt.runtime.close_device(device) diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 77c558a760..00e4499664 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -698,6 +698,22 @@ def __init__(self, logger, file_manager, file_path): self.test_result = "pass" +class TTRTTestException(Exception): + """ "Base class for all "Test Specific" Errors in TTRT""" + + pass + + +class PCCErrorException(TTRTTestException): + """Class to store PCC Comparison Errors""" + + pass + + +# Define a constant TTRT_TEST_ERROR_RETURN_CODE +TTRT_TEST_EXCEPTION_RETURN_CODE = 42 + + class Results: def __init__(self, logger, file_manager): self.logger = logger @@ -750,11 +766,17 @@ def save_results(self, file_name="results.json"): tree.write(xml_file_path, encoding="utf-8", xml_declaration=True) def get_result_code(self): + return_code = 0 for entry in self.results: + res = entry.get("result") if entry.get("result") != "pass": - return 1 + if res == "test_error": + return_code = TTRT_TEST_EXCEPTION_RETURN_CODE + else: + # Prioritize severity of return_code 1 if any non-test errors are encountered + return 1 - return 0 + return return_code def get_results(self): return self.results diff --git a/test/python/golden/test_ttir_models.py b/test/python/golden/test_ttir_models.py index 33f49b76f4..a516d39d1b 100644 --- a/test/python/golden/test_ttir_models.py +++ b/test/python/golden/test_ttir_models.py @@ -6,7 +6,7 @@ import inspect -from ttmlir.test_utils import compile_to_flatbuffer +from ttmlir.test_utils import compile_to_flatbuffer, set_output_path from ttmlir.ttir_builder import Operand, TTIRBuilder @@ -139,6 +139,21 @@ def test_llama_attention( if __name__ == "__main__": + import argparse, os + + parser = argparse.ArgumentParser(description="Run TTIR Builder Model tests") + parser.add_argument( + "--path", + type=str, + help="Optional output path for the flatbuffer. Creates path if supplied path doesn't exist", + ) + args = parser.parse_args() + + if args.path and os.path.exists(args.path): + if not os.path.exists(args.path): + os.makedirs(args.path) + set_output_path(args.path) + test_functions = inspect.getmembers( inspect.getmodule(inspect.currentframe()), inspect.isfunction ) diff --git a/test/python/golden/test_ttir_ops.py b/test/python/golden/test_ttir_ops.py index 8ecb1434ec..8e7bffb5b1 100644 --- a/test/python/golden/test_ttir_ops.py +++ b/test/python/golden/test_ttir_ops.py @@ -7,18 +7,22 @@ import inspect import torch -from ttmlir.test_utils import compile_to_flatbuffer +from ttmlir.test_utils import compile_to_flatbuffer, set_output_path from ttmlir.ttir_builder import Operand, TTIRBuilder, Attribute - +# NOTE: This test is not valid for TTRT Perf due to weird issues with perf collection +""" @compile_to_flatbuffer([(1, 128, 128, 1)], targets=["ttnn"]) def test_squeeze(in0: Operand, builder: TTIRBuilder): return builder.squeeze(in0, 0) +""" - +# NOTE: Same as Squeeze, this Op is not valid for TTRT Perf. +""" @compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_unsqueeze(in0: Operand, builder: TTIRBuilder): return builder.unsqueeze(in0, 0) +""" @compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) @@ -53,9 +57,11 @@ def test_logical_not(in0: Operand, builder: TTIRBuilder): # NOTE: The generated flatbuffer will currently fail to run due to only floats # being supported by the runtime. See issue #1775 for tracking +""" @compile_to_flatbuffer([(128, 128)], inputs_types=[torch.int8], targets=["ttnn"]) def test_bitwise_not(in0: Operand, builder: TTIRBuilder): return builder.bitwise_not(in0) +""" @compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) @@ -217,6 +223,8 @@ def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder): # NOTE: The generated flatbuffer will currently fail to run due to only floats # being supported by the runtime. See issue #1775 for tracking + +""" @compile_to_flatbuffer( [ (64, 64), @@ -227,10 +235,12 @@ def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder): ) def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.bitwise_and(in0, in1) - +""" # NOTE: The generated flatbuffer will currently fail to run due to only floats # being supported by the runtime. See issue #1775 for tracking + +""" @compile_to_flatbuffer( [ (64, 64), @@ -241,10 +251,12 @@ def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder): ) def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.bitwise_or(in0, in1) - +""" # NOTE: The generated flatbuffer will currently fail to run due to only floats # being supported by the runtime. See issue #1775 for tracking + +""" @compile_to_flatbuffer( [ (64, 64), @@ -255,6 +267,7 @@ def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder): ) def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.bitwise_xor(in0, in1) +""" @compile_to_flatbuffer( @@ -450,6 +463,21 @@ def test_arbitrary_op_chain( if __name__ == "__main__": + import argparse, os + + parser = argparse.ArgumentParser(description="Run TTIR Builder Op tests") + parser.add_argument( + "--path", + type=str, + help="Optional output path for the flatbuffer. Creates path if supplied path doesn't exist", + ) + args = parser.parse_args() + + if args.path and os.path.exists(args.path): + if not os.path.exists(args.path): + os.makedirs(args.path) + set_output_path(args.path) + test_functions = inspect.getmembers( inspect.getmodule(inspect.currentframe()), inspect.isfunction ) diff --git a/tools/explorer/test/run_tests.py b/tools/explorer/test/run_tests.py index 9fea9ee9cd..d834444eef 100644 --- a/tools/explorer/test/run_tests.py +++ b/tools/explorer/test/run_tests.py @@ -9,47 +9,64 @@ import pytest import glob import os +import logging HOST = "localhost" PORT = 8002 COMMAND_URL = "http://" + HOST + ":" + str(PORT) + "/apipost/v1/send_command" TEST_LOAD_MODEL_PATHS = [ - "test/ttmlir/Dialect/TTNN/optimizer/mnist_sharding.mlir", "test/ttmlir/Explorer/**/*.mlir", - "test/ttmlir/Silicon/TTNN/**/*.mlir", + "test/ttmlir/Silicon/TTNN/n150/perf/**/*.mlir", ] MNIST_SHARDING_PATH = "test/ttmlir/Silicon/TTNN/n150/optimizer/mnist_sharding.mlir" TEST_EXECUTE_MODEL_PATHS = [ MNIST_SHARDING_PATH, ] -if "TT_EXPLORER_GENERATED_TEST_DIR" in os.environ: - TEST_LOAD_MODEL_PATHS.append( - os.environ["TT_EXPLORER_GENERATED_TEST_DIR"] + "/**/*.mlir" - ) +if "TT_EXPLORER_GENERATED_MLIR_TEST_DIRS" in os.environ: + for path in os.environ["TT_EXPLORER_GENERATED_MLIR_TEST_DIRS"].split(","): + if os.path.exists(path): + TEST_LOAD_MODEL_PATHS.append(path + "/**/*.mlir") + else: + logging.error( + "Path %s provided in TT_EXPLORER_GENERED_MLIR_TEST_DIRS doesn't exist. Tests not added.", + path, + ) + +if "TT_EXPLORER_GENERATED_TTNN_TEST_DIRS" in os.environ: + for path in os.environ["TT_EXPLORER_GENERATED_TTNN_TEST_DIRS"].split(","): + if os.path.exists(path): + TEST_LOAD_MODEL_PATHS.append(path + "/**/*.ttnn") + TEST_EXECUTE_MODEL_PATHS.append(path + "/**/*.ttnn") + else: + logging.error( + "Path %s provided in TT_EXPLORER_GENERED_TTNN_TEST_DIRS doesn't exist. Tests not added.", + path, + ) + +FILTERED_TESTS = [ + # This test is way too large to fit reasonably in CI. + "test_llama_attention.ttnn", +] def get_test_files(paths): files = [] for path in paths: files.extend(glob.glob(path, recursive=True)) - return files + files = [ + file for file in files if all(not file.endswith(x) for x in FILTERED_TESTS) + ] -def execute_command(model_path, settings): - cmd = { - "extensionId": "tt_adapter", - "cmdId": "execute", - "modelPath": model_path, - "deleteAfterConversion": False, - "settings": settings, - } + return files - result = requests.post(COMMAND_URL, json=cmd) - assert result.ok - if "error" in result.json(): - print(result.json()) - assert False + +def GET_TTNN_TEST(): + for test in get_test_files(TEST_EXECUTE_MODEL_PATHS): + if test.endswith("test_mnist.ttnn"): + return test + return None @pytest.fixture(scope="function", autouse=True) @@ -88,13 +105,6 @@ def server_shutdown(): request.addfinalizer(server_shutdown) -def get_test_files(paths): - files = [] - for path in paths: - files.extend(glob.glob(path)) - return files - - def send_command(command, model_path, settings={}): cmd = { "extensionId": "tt_adapter", @@ -157,7 +167,7 @@ def test_load_model(model_path): @pytest.mark.parametrize("model_path", get_test_files(TEST_EXECUTE_MODEL_PATHS)) def test_execute_model(model_path): execute_command_and_wait( - model_path, {"optimizationPolicy": "DF Sharding"}, timeout=300 + model_path, {"optimizationPolicy": "Optimizer Disabled"}, timeout=300 ) convert_command_and_assert(model_path) @@ -171,10 +181,10 @@ def test_execute_mnist_l1_interleaved(): convert_command_and_assert(MNIST_SHARDING_PATH) -def test_execute_mnist_optimizer_disabled(): +def test_execute_mnist_df_sharding(): execute_command_and_wait( MNIST_SHARDING_PATH, - {"optimizationPolicy": "Optimizer Disabled"}, + {"optimizationPolicy": "DF Sharding"}, timeout=300, ) convert_command_and_assert(MNIST_SHARDING_PATH) @@ -208,7 +218,7 @@ def test_execute_and_check_perf_data_exists(): timeout=300, ) result = convert_command_and_assert(MNIST_SHARDING_PATH) - assert "perf_data" in result["graphs"][0] + assert "perf_data" in result["graphs"][0]["overlays"] def test_execute_model_invalid_policy(): @@ -218,3 +228,20 @@ def test_execute_model_invalid_policy(): {"optimizationPolicy": "Invalid Policy"}, timeout=300, ) + + +def test_execute_and_check_accuracy_data_exists(): + # Get the test_mnist path + test_mnist_path = GET_TTNN_TEST() + + # Key Decision: Make Test Fail or just provide error message and skip? + assert ( + test_mnist_path is not None + ), "Couldn't find test_mnist.ttnn in GENERATED_TTNN_TEST_DIRS" + execute_command_and_wait( + test_mnist_path, {"optimizationPolicy": "Optimizer Disabled"}, timeout=300 + ) + result = convert_command_and_assert(test_mnist_path) + if "accuracy_data" not in result["graphs"][0]["overlays"]: + print(result) + assert "accuracy_data" in str(result) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/main.py b/tools/explorer/tt_adapter/src/tt_adapter/main.py index 3c5371ebbf..50128ff5be 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/main.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/main.py @@ -95,15 +95,16 @@ def convert( logging.info(f"Using optimized model: {optimized_model_path}") # Get performance results. perf_trace = self.model_runner.get_perf_trace(model_path) + golden_results = self.model_runner.get_golden_results(model_path) with open(optimized_model_path, "r") as model_file: module = utils.parse_mlir_str(model_file.read()) # Convert TTIR to Model Explorer Graphs and Display/Return - graph, perf_data = mlir.build_graph(module, perf_trace) - if perf_data: - # TODO(odjuricic) We should replace the perf_data with overlays once this is fixed on FE. - graph = utils.add_to_dataclass(graph, "perf_data", perf_data.graphsData) + graph, overlays = mlir.build_graph(module, perf_trace, golden_results) + + if overlays: + graph = utils.add_to_dataclass(graph, "overlays", overlays) if overrides := self.model_runner.get_overrides(model_path): graph = utils.add_to_dataclass(graph, "overrides", overrides) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py index 7050ab94a4..3396f61768 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -305,14 +305,16 @@ def parse_memory_config(attr): value="x".join(map(str, memory_config.shard_spec.shard_shape.shape)), ) ) + result.append( graph_builder.KeyValue( key="tensor-memory-layout", value=str( - ttnn.TensorMemoryLayout(memory_config.tensor_memory_layout.value) + ttnn.TensorMemoryLayout(int(memory_config.tensor_memory_layout.value)) ), ) ) + return result @@ -580,7 +582,7 @@ def make_constant_node(self, constant_name): ] -def build_graph(module, perf_trace=None): +def build_graph(module, perf_trace=None, golden_results=None): output_connections = defaultdict(int) graph = graph_builder.Graph(id="tt-graph") @@ -598,6 +600,17 @@ def build_graph(module, perf_trace=None): if loc: loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"] + # Parse Golden Results for Overlay + accuracy_node_data = {} + loc_to_accuracy = {} + if golden_results is not None: + for loc, res in golden_results.items(): + loc = parse_loc_string(loc) + assert loc not in loc_to_accuracy + if loc: + # Store the full result here, just need to parse the loc accordingly= + loc_to_accuracy[loc] = res + module_op = OpHandler(module.operation) module_attrs = module_op.get_attributes() module_attrs = dict((attr.key, attr.value) for attr in module_attrs) @@ -622,6 +635,17 @@ def build_graph(module, perf_trace=None): loc_to_perf[operation.named_location] ) + if ( + operation.named_location in loc_to_accuracy + and operation.op.name not in EMPTY_OPS + ): + accuracy_node_data[ + operation.id + ] = node_data_builder.NodeDataResult( + loc_to_accuracy[operation.named_location]["actual_pcc"] + - loc_to_accuracy[operation.named_location]["expected_pcc"] + ) + if op.name not in FILTERED_OPS and op.name in EMPTY_OPS: append_later.append(graph_node) elif op.name not in FILTERED_OPS: @@ -708,8 +732,8 @@ def build_graph(module, perf_trace=None): ) output_connections[source_node.id] += 1 + overlays = {} # Add performance data to the graph color overlay, if it exists - overlay_data = None if perf_node_data: gradient = [ node_data_builder.GradientItem(stop=0, bgColor="yellow"), @@ -718,10 +742,24 @@ def build_graph(module, perf_trace=None): graph_node_data = node_data_builder.GraphNodeData( results=perf_node_data, gradient=gradient ) - overlay_data = node_data_builder.ModelNodeData( + overlays["perf_data"] = node_data_builder.ModelNodeData( graphsData={"tt-graph": graph_node_data} + ).graphsData + + if accuracy_node_data: + thres = [ + # Show Red if ActualPCC - ExpectedPCC is 0 and below (ActualPCC < ExpectedPCC) + node_data_builder.ThresholdItem(value=0, bgColor="red"), + # Show Green if ActualPCC - ExpectedPCC is 1 and below (Actual PCC >= ExpectedPCC) + node_data_builder.ThresholdItem(value=1, bgColor="green"), + ] + graph_node_data = node_data_builder.GraphNodeData( + results=accuracy_node_data, thresholds=thres ) + overlays["accuracy_data"] = node_data_builder.ModelNodeData( + graphsData={"tt-graph": graph_node_data} + ).graphsData graph.groupNodeAttributes = group_node_attrs OpHandler.schedule = 0 - return graph, overlay_data + return graph, overlays diff --git a/tools/explorer/tt_adapter/src/tt_adapter/runner.py b/tools/explorer/tt_adapter/src/tt_adapter/runner.py index e388c820ed..767e1b4b79 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/runner.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/runner.py @@ -13,6 +13,7 @@ import pandas as pd import threading import queue +import json class ExplorerRunException(Exception): @@ -140,6 +141,17 @@ def get_perf_trace(self, model_path): return pd.read_csv(op_perf_file) + def get_golden_results(self, model_path): + accuracy_res = f"{self.model_state[model_path].model_output_dir}/run/program_0/golden_results.json" + + if not os.path.exists(accuracy_res): + raise FileNotFoundError(f"Golden results not found @ {accuracy_res}") + + with open(accuracy_res, "r") as f: + res = json.load(f) + + return res + def run_in_subprocess(self, command): self.log(f"Running command:\n{' '.join(command)}\n") @@ -304,9 +316,16 @@ def compile_and_run(self, model_path, overrides_string): ttrt_process = self.run_in_subprocess(ttrt_perf_command) if ttrt_process.returncode != 0: - error = "Error while running TTRT perf" - self.log(error, severity=logging.error) - raise ExplorerRunException(error) + # 42 is the specific code for a test error instead of ttrt + if ttrt_process.returncode == 42: + error = ( + "Error while running TTRT Tests... Continuing Explorer Execution" + ) + self.log(error, severity=logging.error) + else: + error = "Error while running TTRT perf" + self.log(error, severity=logging.error) + raise ExplorerRunException(error) perf = self.get_perf_trace(model_path) columns = [