Skip to content

Commit

Permalink
Logging atol and rtol in case AllCloseValueCheck
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT authored and kmilanovicTT committed Feb 3, 2025
1 parent 25fa915 commit abdcb6b
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions forge/test/operators/pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,21 @@ def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runne

if report.when == "call" or (report.when == "setup" and report.skipped):
try:
log_test_vector_properties(item, report, xfail_reason)
log_test_vector_properties(
item=item,
report=report,
xfail_reason=xfail_reason,
exception=call.excinfo.value if call.excinfo is not None else None,
)
except Exception as e:
logger.error(f"Failed to log test vector properties: {e}")
logger.exception(e)
pass


def log_test_vector_properties(item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str):
def log_test_vector_properties(
item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str, exception: Exception
):
original_name = item.originalname
test_id = item.name
test_id = test_id.replace(f"{original_name}[", "")
Expand All @@ -91,3 +98,21 @@ def log_test_vector_properties(item: _pytest.python.Function, report: _pytest.re
if xfail_reason is not None:
item.user_properties.append(("xfail_reason", xfail_reason))
item.user_properties.append(("outcome", report.outcome))

if exception is not None:
error_message = f"{exception}"

if "Observed maximum relative diff" in error_message:
error_message_lines = error_message.split("\n")
observed_error_lines = [line for line in error_message_lines if "Observed maximum relative diff" in line]
if observed_error_lines:
observed_error_line = observed_error_lines[0]
# Example: "- Observed maximum relative diff: 0.0008770461427047849, maximum absolute diff: 0.0009063482284545898"
rtol = float(observed_error_line.split(",")[0].split(":")[1].strip())
atol = float(observed_error_line.split(",")[1].split(":")[1].strip())
else:
logger.error(f"Error parsing 'Observed maximum relative diff' from the exception: {error_message}")
rtol = None
atol = None
item.user_properties.append(("all_close_rtol", rtol))
item.user_properties.append(("all_close_atol", atol))

0 comments on commit abdcb6b

Please sign in to comment.