Skip to content

Commit

Permalink
Update jax triage / testing environment for greater completeness
Browse files Browse the repository at this point in the history
This includes a triage_jaxtest.py program that prints the total failures,
passing, and known failures (with test counts). This is so that known
failures can highlight which tests they are blocking.
  • Loading branch information
rsuderman committed May 9, 2023
1 parent defbb95 commit fad797b
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 10 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/run_jaxtests_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,22 @@ jobs:
export PJRT_NAMES_AND_LIBRARY_PATHS="iree_cpu:/home/runner/work/openxla-pjrt-plugin/openxla-pjrt-plugin/pjrt_plugins/pjrt_plugin_iree_cpu.so"
JAX_PLATFORMS=iree_cpu python test/test_simple.py
JAX_PLATFORMS=iree_cpu python test/test_jax.py /home/runner/work/openxla-pjrt-plugin/jax/tests/nn_test.py \
JAX_PLATFORMS=iree_cpu python test/test_jax.py \
/home/runner/work/openxla-pjrt-plugin/jax/tests/nn_test.py \
--passing ${PASSING_ARTIFACT} \
--failing ${FAILING_ARTIFACT} \
--expected ${GOLDEN_ARTIFACT}
--expected ${GOLDEN_ARTIFACT} \
--logdir jax_testsuite
# If we passed we can update the golden.
if [ $? -eq 0 ]; then
cp ${PASSING_ARTIFACT} ${GOLDEN_ARTIFACT}
fi
- name: "Run JAX Testsuite Triage"
run: |
python test/triage_jaxtest.py --logdir jax_testsuite
- uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce # v3.1.2
if: always()
with:
Expand Down
18 changes: 10 additions & 8 deletions test/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@


def get_tests(tests):
testlist = []
fulltestlist = []
for test in sorted(tests):
print("Fetching from:", test)
stdout = subprocess.run(PYTEST_CMD + ["--setup-only", test],
capture_output=True)
testlist += re.findall('::[^ ]*::[^ ]*', str(stdout))
testlist = [test + func for func in testlist]
return testlist
testlist = re.findall('::[^ ]*::[^ ]*', str(stdout))
fulltestlist += [test + func for func in testlist]
return fulltestlist


def generate_test_commands(tests, timeout=False):
Expand All @@ -49,10 +50,11 @@ def generate_test_commands(tests, timeout=False):


def exec_test(command):
result = subprocess.run(command,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
sys.stdout.write(".")
result = subprocess.run(command, capture_output=True)
if result.returncode == 0:
sys.stdout.write(".")
else:
sys.stdout.write("f")
sys.stdout.flush()
return result.returncode

Expand Down
130 changes: 130 additions & 0 deletions test/triage_jaxtest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import jaxlib.mlir.ir as mlir_ir
import jax._src.interpreters.mlir as mlir
import os
import re

parser = argparse.ArgumentParser(prog='triage_jaxtest.py',
description='Triage the jax tests')
parser.add_argument('-l', '--logdir', default="/tmp/jaxtest")
parser.add_argument('-d', '--delete', default=False)
args = parser.parse_args()

tests = set(os.listdir(args.logdir))


def filter_to_failures(tests):
failures = list()
for test in tests:
files = os.listdir(f"{args.logdir}/{test}")
if "error.txt" in files or "CRASH_MARKER" in files:
failures.append(test)
failures = sorted(failures)
return failures


def check_custom_call(errortxt, _):
return "mhlo.custom_call" in errortxt


def check_uint(errortxt, _):
return "uint" in errortxt


def check_degenerate_tensor(_, mlirbc):
return "tensor<0x" in mlirbc or "x0x" in mlirbc


def check_complex(errortxt, _):
return "complex<" in errortxt


def check_truncsfbf2(errortxt, _):
return "__truncsfbf2" in errortxt


def check_scatter_i1(errortxt, _):
return "'iree_linalg_ext.scatter' op mismatch in argument 0 of region 'i1' and element type of update value 'i8'" in errortxt


def check_dot_i1(_, mlirbc):
for line in mlirbc.split("\n"):
has_i1 = re.search("tensor<([0-9]*x)*i1>", line)
has_dot = re.search("stablehlo.dot", line)
if has_i1 and has_dot:
return True
return False


KnownChecks = {
"https://github.com/openxla/iree/issues/12410 (custom call)":
check_custom_call,
"https://github.com/openxla/iree/issues/12665 (unsigned) ":
check_uint,
"https://github.com/openxla/iree/issues/13347 (0-length) ":
check_degenerate_tensor,
"https://github.com/openxla/iree/issues/12747 (complex) ":
check_complex,
"https://github.com/openxla/iree/issues/13499 (truncsfbf2) ":
check_truncsfbf2,
"https://github.com/openxla/iree/issues/13427 (scatter i1) ":
check_scatter_i1,
"https://github.com/openxla/iree/issues/13493 (dot i1) ":
check_dot_i1,
"Untriaged":
lambda _, __: True,
}


def filter_error_mapping(tests):
error_mapping = {}
for test in tests:
files = sorted(os.listdir(f'{args.logdir}/{test}'))
# Load the error.txt if it is available.
error = ""
if "error.txt" in files:
with open(f'{args.logdir}/{test}/error.txt') as f:
error = "".join(f.readlines())

# Load the last bytecode file that was attempted to be compiled:
mlirbc_count = len(
[f for f in os.listdir(f'{args.logdir}/{test}') if "mlirbc" in f])
mlirbc_name = f'{mlirbc_count - 1}-program.mlirbc'

with mlir.make_ir_context() as ctx:
with open(f'{args.logdir}/{test}/{mlirbc_name}', 'rb') as f:
mlirbc = f.read()
mlirbc = str(mlir_ir.Module.parse(mlirbc))

error_mapping[test] = "unknown error"
for checkname in KnownChecks:
if KnownChecks[checkname](error, mlirbc):
error_mapping[test] = checkname
break
return error_mapping


def generate_summary(mapping):
summary = {}
for err in KnownChecks.keys():
summary[err] = []
for test in mapping:
summary[mapping[test]].append(test)
return summary


def print_summary(summary):
for error in summary:
print(f'{error} : {len(summary[error])}')


failing = filter_to_failures(tests)
mapping = filter_error_mapping(failing)
summary = generate_summary(mapping)
print_summary(summary)
print("Passing:", len(tests) - len(failing))
print("Failing:", len(failing))

0 comments on commit fad797b

Please sign in to comment.