Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate ttir_builder generated TTNNs as Load/Execute Tests in Explorer #2123

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
fd0dc12
Added ModuleCacher to cache modules using Action hooks.
vprajapati-tt Jan 15, 2025
435d869
Moved to Conversion/Passes.h
vprajapati-tt Jan 15, 2025
25a0bfd
Clone Op
vprajapati-tt Jan 16, 2025
b119b72
Switched to strings instead of Op Clones
vprajapati-tt Jan 16, 2025
ce1294d
Moved from map to vector
vprajapati-tt Jan 16, 2025
693e5a9
Added ModuleCache to FB Schema
vprajapati-tt Jan 17, 2025
2ca9747
Debugging ModuleCacher
vprajapati-tt Jan 17, 2025
b71cb5f
Stupid bug fix
vprajapati-tt Jan 17, 2025
df7062a
Interim
vprajapati-tt Jan 21, 2025
11ebc9d
Suggested Changes + Starting Explorer Integration
vprajapati-tt Jan 22, 2025
57e9477
Added a thing
vprajapati-tt Jan 23, 2025
d6ed7ba
small fix
vprajapati-tt Jan 24, 2025
0d650ea
TestString in Schema
vprajapati-tt Jan 24, 2025
49b90af
TestString in Schema + CPP
vprajapati-tt Jan 24, 2025
fd7a5a1
Reverted to old schema
vprajapati-tt Jan 27, 2025
a3c0fa7
Rename
vprajapati-tt Jan 27, 2025
2dbefd0
Reorder and Pray
vprajapati-tt Jan 27, 2025
5a28cc5
Added MLIR Stages support in TTIR Builder
vprajapati-tt Jan 28, 2025
2566bef
Fixed execution so Golden Results correctly propagate after run
vprajapati-tt Jan 29, 2025
98548fc
Merge branch 'main' into vprajapati/issue-1234
vprajapati-tt Jan 29, 2025
12c4c74
Removed redundant print
vprajapati-tt Jan 29, 2025
8b29e8d
Removed redundant loc formatting
vprajapati-tt Jan 29, 2025
784df87
Removed Golden Opaque types
vprajapati-tt Jan 30, 2025
b21279a
Separated Logic between FB and MLIR loading
vprajapati-tt Jan 30, 2025
3723de5
Removed PassedModuleCache
vprajapati-tt Jan 30, 2025
95c0f5c
Fixed TypeID discrepancy - Thanks Nick
vprajapati-tt Jan 31, 2025
2c19924
Merge branch 'main' into vprajapati/issue-1234
vprajapati-tt Feb 4, 2025
c42ff18
Removed redundant debug prints & assert
vprajapati-tt Feb 4, 2025
c36a949
Merge branch 'main' into vprajapati/issue-1234
vprajapati-tt Feb 5, 2025
5acfac2
Added Variable Path in test_ttir_ops
vprajapati-tt Feb 5, 2025
4934690
Merge branch 'main' into vprajapati/issue-2079
vprajapati-tt Feb 5, 2025
d86d702
Added TTNN Tests + Env Var for Explorer CI
vprajapati-tt Feb 5, 2025
9ac2c06
Merge branch 'vprajapati/issue-1234' into vprajapati/issue-2079
vprajapati-tt Feb 5, 2025
ac0edcc
Added TTNN Tests into Explorer, fixed small bugs
vprajapati-tt Feb 5, 2025
e7e9e95
Interim
vprajapati-tt Feb 7, 2025
e2ee3b4
Merge branch 'main' into vprajapati/issue-2079
vprajapati-tt Feb 7, 2025
f5fb306
Added Accuracy Overlay + Some Tests
vprajapati-tt Feb 14, 2025
bb8533e
Merge branch 'main' into vprajapati/issue-2079
vprajapati-tt Feb 14, 2025
83ad2da
Clean
vprajapati-tt Feb 14, 2025
5fc7f83
Merge branch 'main' into vprajapati/issue-2079
vprajapati-tt Feb 14, 2025
afef378
Style Changes + Added Accuracy Test
vprajapati-tt Feb 21, 2025
0351e93
Merge branch 'main' into vprajapati/issue-2079
vprajapati-tt Feb 21, 2025
81053f6
Merge branch 'main' into vprajapati/issue-2079
vprajapati-tt Feb 24, 2025
1b883fb
Working Changes to Tests
vprajapati-tt Feb 24, 2025
d34cc60
Finally got this working, please don't break in CI
vprajapati-tt Feb 24, 2025
ea13fa6
It didn't work.
vprajapati-tt Feb 24, 2025
b84f042
Removed Squeeze/Unsqueeze tests while bug is figured out
vprajapati-tt Feb 25, 2025
967b166
I give up, CI debug printing
vprajapati-tt Feb 26, 2025
bd23b89
Merge branch 'main' into vprajapati/issue-2079
vprajapati-tt Feb 26, 2025
3009b57
Made test more verbose
vprajapati-tt Feb 26, 2025
3a3c9df
small changes + removed llama test
vprajapati-tt Feb 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,9 @@ jobs:
shell: bash
run: |
source env/activate
export TT_EXPLORER_GENERATED_TEST_DIR=${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN
pytest tools/explorer/test/run_tests.py
export TT_EXPLORER_GENERATED_MLIR_TEST_DIRS=${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/n150/perf,${{ steps.strings.outputs.build-output-dir }}/test/python/golden/ttnn
export TT_EXPLORER_GENERATED_TTNN_TEST_DIRS=${{ steps.strings.outputs.build-output-dir }}/test/python/golden/ttnn
pytest -svv tools/explorer/test/run_tests.py
# collect results


Expand Down
27 changes: 27 additions & 0 deletions python/test_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "")

# Default output to the current directory from where this module is being invoked
OUTPUT_PATH = ""

# ----- Static helpers used in this file only -----

Expand All @@ -32,6 +34,25 @@ def _dump_module(module: Module) -> None:


# ----- General Purpose Helpers - Could Be Used In Other Files -----
def set_output_path(path):
global OUTPUT_PATH
if not os.path.exists(path):
raise ValueError(f"The provided path '{path}' is not a valid path.")
OUTPUT_PATH = path


def get_ttnn_path(filename):
ttnn_dir = os.path.join(OUTPUT_PATH, "ttnn")
if not os.path.exists(ttnn_dir):
os.makedirs(ttnn_dir)
return os.path.join(ttnn_dir, filename)


def get_ttmetal_path(filename):
ttmetal_dir = os.path.join(OUTPUT_PATH, "ttmetal")
if not os.path.exists(ttmetal_dir):
os.makedirs(ttmetal_dir)
return os.path.join(ttmetal_dir, filename)


def compile_as_mlir_module(
Expand Down Expand Up @@ -179,6 +200,7 @@ def ttir_to_ttnn(

# Optionally dump to file.
if dump_to_file:
output_file_name = get_ttnn_path(output_file_name)
with open(output_file_name, "w") as f:
f.write(str(module))

Expand Down Expand Up @@ -224,6 +246,7 @@ def ttir_to_ttmetal(

# Optionally dump to file.
if dump_to_file:
output_file_name = get_ttmetal_path(output_file_name)
with open(output_file_name, "w") as f:
f.write(str(module))

Expand All @@ -239,6 +262,8 @@ def ttnn_to_flatbuffer(
"""

# Convert to flatbuffer file.
# Take the output_file_name and prefix with the ttnn directory
output_file_name = get_ttnn_path(output_file_name)
if module_log:
ttnn_to_flatbuffer_file(
module, output_file_name, builder.get_golden_map(), module_log
Expand All @@ -260,6 +285,8 @@ def ttmetal_to_flatbuffer(
"""

# Convert to flatbuffer file.
# Take the output_file_name and prefix with ttm directory
output_file_name = get_ttmetal_path(output_file_name)
ttmetal_to_flatbuffer_file(module, output_file_name, builder.get_golden_map())

print("`ttmetal_to_flatbuffer_file` passed successfully.")
Expand Down
2 changes: 1 addition & 1 deletion runtime/tools/python/ttrt/common/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def save_memory_report(self, memory_report_path):
def check_pcc(self):
for loc, golden_data in self.golden_report.items():
if golden_data["actual_pcc"] < golden_data["expected_pcc"]:
raise Exception(
raise PCCErrorException(
f"Failed: golden comparison failed, actual_pcc={golden_data['actual_pcc']} < expected_pcc={golden_data['expected_pcc']}"
)

Expand Down
9 changes: 7 additions & 2 deletions runtime/tools/python/ttrt/common/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,12 +528,17 @@ def signal_handler(sig, frame):

for result in test_result:
if result["result"] != "pass":
if result["result"] == "test_error":
raise TTRTTestException(str(result["exception"]))
raise Exception(f'{result["exception"]}')

except Exception as e:
result = "error"
if isinstance(e, TTRTTestException):
result = "test_error"
test_result = {
"file_path": bin.file_path,
"result": "error",
"result": result,
"exception": str(e),
"log_file": self.logger.file_name,
"artifacts": self.artifacts.artifacts_folder_path,
Expand All @@ -543,7 +548,7 @@ def signal_handler(sig, frame):
f"ERROR: test={bin.file_path} experienced an error with exception={str(e)}"
)
self.results.add_result(test_result)
bin.test_result = "error"
bin.test_result = result
traceback.print_exc()
continue

Expand Down
8 changes: 6 additions & 2 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,13 @@ def convert_input_layouts(device, inputs, fbb, program_index):
callback_runtime_config.check_memory_leak()

except Exception as e:
result = "error"
if isinstance(e, TTRTTestException):
result = "test_error"

test_result = {
"file_path": bin.file_path,
"result": "error",
"result": result,
"exception": str(e),
"log_file": self.logger.file_name,
"artifacts": self.artifacts.artifacts_folder_path,
Expand All @@ -703,7 +707,7 @@ def convert_input_layouts(device, inputs, fbb, program_index):
f"ERROR: test={bin.file_path} experienced an error with exception={str(e)}"
)
self.results.add_result(test_result)
bin.test_result = "error"
bin.test_result = result
continue
finally:
ttrt.runtime.close_device(device)
Expand Down
26 changes: 24 additions & 2 deletions runtime/tools/python/ttrt/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,22 @@ def __init__(self, logger, file_manager, file_path):
self.test_result = "pass"


class TTRTTestException(Exception):
""" "Base class for all "Test Specific" Errors in TTRT"""

pass


class PCCErrorException(TTRTTestException):
"""Class to store PCC Comparison Errors"""

pass


# Define a constant TTRT_TEST_ERROR_RETURN_CODE
TTRT_TEST_EXCEPTION_RETURN_CODE = 42


class Results:
def __init__(self, logger, file_manager):
self.logger = logger
Expand Down Expand Up @@ -750,11 +766,17 @@ def save_results(self, file_name="results.json"):
tree.write(xml_file_path, encoding="utf-8", xml_declaration=True)

def get_result_code(self):
return_code = 0
for entry in self.results:
res = entry.get("result")
if entry.get("result") != "pass":
return 1
if res == "test_error":
return_code = TTRT_TEST_EXCEPTION_RETURN_CODE
else:
# Prioritize severity of return_code 1 if any non-test errors are encountered
return 1

return 0
return return_code

def get_results(self):
return self.results
17 changes: 16 additions & 1 deletion test/python/golden/test_ttir_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import inspect

from ttmlir.test_utils import compile_to_flatbuffer
from ttmlir.test_utils import compile_to_flatbuffer, set_output_path
from ttmlir.ttir_builder import Operand, TTIRBuilder


Expand Down Expand Up @@ -139,6 +139,21 @@ def test_llama_attention(


if __name__ == "__main__":
import argparse, os

parser = argparse.ArgumentParser(description="Run TTIR Builder Model tests")
parser.add_argument(
"--path",
type=str,
help="Optional output path for the flatbuffer. Creates path if supplied path doesn't exist",
)
args = parser.parse_args()

if args.path and os.path.exists(args.path):
if not os.path.exists(args.path):
os.makedirs(args.path)
set_output_path(args.path)

test_functions = inspect.getmembers(
inspect.getmodule(inspect.currentframe()), inspect.isfunction
)
Expand Down
38 changes: 33 additions & 5 deletions test/python/golden/test_ttir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,22 @@
import inspect
import torch

from ttmlir.test_utils import compile_to_flatbuffer
from ttmlir.test_utils import compile_to_flatbuffer, set_output_path
from ttmlir.ttir_builder import Operand, TTIRBuilder, Attribute


# NOTE: This test is not valid for TTRT Perf due to weird issues with perf collection
"""
@compile_to_flatbuffer([(1, 128, 128, 1)], targets=["ttnn"])
def test_squeeze(in0: Operand, builder: TTIRBuilder):
return builder.squeeze(in0, 0)
"""


# NOTE: Same as Squeeze, this Op is not valid for TTRT Perf.
"""
@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_unsqueeze(in0: Operand, builder: TTIRBuilder):
return builder.unsqueeze(in0, 0)
"""


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
Expand Down Expand Up @@ -53,9 +57,11 @@ def test_logical_not(in0: Operand, builder: TTIRBuilder):

# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
"""
@compile_to_flatbuffer([(128, 128)], inputs_types=[torch.int8], targets=["ttnn"])
def test_bitwise_not(in0: Operand, builder: TTIRBuilder):
return builder.bitwise_not(in0)
"""


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
Expand Down Expand Up @@ -217,6 +223,8 @@ def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):

# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking

"""
@compile_to_flatbuffer(
[
(64, 64),
Expand All @@ -227,10 +235,12 @@ def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
)
def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_and(in0, in1)

"""

# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking

"""
@compile_to_flatbuffer(
[
(64, 64),
Expand All @@ -241,10 +251,12 @@ def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
)
def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_or(in0, in1)

"""

# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking

"""
@compile_to_flatbuffer(
[
(64, 64),
Expand All @@ -255,6 +267,7 @@ def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
)
def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_xor(in0, in1)
"""


@compile_to_flatbuffer(
Expand Down Expand Up @@ -450,6 +463,21 @@ def test_arbitrary_op_chain(


if __name__ == "__main__":
import argparse, os

parser = argparse.ArgumentParser(description="Run TTIR Builder Op tests")
parser.add_argument(
"--path",
type=str,
help="Optional output path for the flatbuffer. Creates path if supplied path doesn't exist",
)
args = parser.parse_args()

if args.path and os.path.exists(args.path):
if not os.path.exists(args.path):
os.makedirs(args.path)
set_output_path(args.path)

test_functions = inspect.getmembers(
inspect.getmodule(inspect.currentframe()), inspect.isfunction
)
Expand Down
Loading
Loading