Skip to content

Commit

Permalink
Improve test id parsing (#1085)
Browse files Browse the repository at this point in the history
Support valid negative numbers in test id
Fix math fidelity HiFi4 for duplicated tests
  • Loading branch information
vbrkicTT authored Jan 23, 2025
1 parent fbec928 commit 4383a7d
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions forge/test/operators/utils/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,22 @@ def test_id_to_test_vector(cls, test_id: str) -> TestVector:

test_id = test_id.replace("no_device-", "")

# Split by '-' but not by ' -'
parts = re.split(r"(?<! )-", test_id)
# Split by '-' but not by ' -' and not by '(-'
# Explanation: Valid negative numbers can appear in kwargs or shapes or potentially other tuples
# * Space ' ' before '-' is for separating parameters
# Example: reshape-FROM_HOST-{'shape': (8, -1)}-(2, 2, 2, 2)-None-None
# * Open bracket '(' before '-' is for opening tuples
# Example: reshape-FROM_HOST-{'shape': (-1, 15)}-(3, 4, 5)-None-None

# Replace - with |
test_id = test_id.replace("-", "|")
# Replace ' |' with ' -' (revert previous replacement for valid negative numbers)
test_id = test_id.replace(" |", " -")
# Replace '(|' with '(-' (revert previous replacement for valid negative numbers)
test_id = test_id.replace("(|", "(-")

parts = test_id.split("|")

assert len(parts) == 6 or len(parts) == 7, f"Invalid test id: {test_id} / {parts}"
if len(parts) == 6:
dev_data_format_index = 4
Expand All @@ -644,11 +658,8 @@ def test_id_to_test_vector(cls, test_id: str) -> TestVector:
math_fidelity_part = parts[math_fidelity_index]
if math_fidelity_part == "None":
math_fidelity_part = None
# TODO remove hardcoded values here
if math_fidelity_part in (
"HiFi40",
"HiFi41",
):
# As last parameter in test id is math fidelity, in case of duplicated tests numeric suffix should be ignored
if math_fidelity_part is not None and math_fidelity_part.startswith("HiFi4"):
math_fidelity_part = "HiFi4"
math_fidelity = eval(f"forge._C.{math_fidelity_part}") if math_fidelity_part is not None else None

Expand Down

0 comments on commit 4383a7d

Please sign in to comment.