Skip to content

Commit

Permalink
Merge pull request #372 from daskol:switch-unittest-api
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714902802
  • Loading branch information
ChexDev committed Jan 13, 2025
2 parents 8af2c9e + 057142a commit 6f59425
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions chex/_src/variants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
DEFAULT_NDARRAY_PARAMS_SHAPE = (5, 7)
DEFAULT_NAMED_PARAMS = (('case_0', 1, 2, 1), ('case_1', 4, 6, 2))

make_suite = unittest.defaultTestLoader.loadTestsFromTestCase


# Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
def setUpModule():
Expand Down Expand Up @@ -256,7 +258,7 @@ def setUp(self):
super().setUp()
self.chex_info = str(variants.ChexVariantType.WITHOUT_JIT)
self.res = unittest.TestResult()
ts = unittest.makeSuite(self.FailedTest) # pytype: disable=module-attr
ts = make_suite(self.FailedTest) # pytype: disable=module-attr
ts.run(self.res)

def test_useful_failures(self):
Expand Down Expand Up @@ -291,7 +293,7 @@ def test_useful_failure(self):
unexpected_info = str(variants.ChexVariantType.WITH_DEVICE)

res = unittest.TestResult()
ts = unittest.makeSuite(self.MaybeFailedTest) # pytype: disable=module-attr
ts = make_suite(self.MaybeFailedTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.failures, 1)

Expand All @@ -311,7 +313,7 @@ def test_failure(self):

def test_wrong_base_class(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts = make_suite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.errors, 1)

Expand Down Expand Up @@ -347,7 +349,7 @@ def test_inheritance(self, base_classes):
test_class = self.generate_test_class(*base_classes)
for base_class in base_classes:
self.assertTrue(issubclass(test_class, base_class))
ts = unittest.makeSuite(test_class) # pytype: disable=module-attr
ts = make_suite(test_class) # pytype: disable=module-attr
ts.run(res)
self.assertEqual(res.testsRun, 8)
self.assertEmpty(res.errors or res.failures)
Expand All @@ -366,7 +368,7 @@ def test_should_pass(self, arg_0, arg_1, expected):

def test_should_pass(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts = make_suite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertEqual(res.testsRun, 8)
self.assertEmpty(res.errors or res.failures)
Expand Down Expand Up @@ -423,7 +425,7 @@ def test_noop(self):

def test_unused_variant(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts = make_suite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.errors, 4)
for _, msg in res.errors:
Expand Down Expand Up @@ -456,7 +458,7 @@ def test_arg(self):

def test_unknown_argument(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts = make_suite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.errors, 4)
for _, msg in res.errors:
Expand All @@ -475,7 +477,7 @@ def test_var_type(self):
self.var_types.add(self.variant.type)

def test_var_type_fetch(self):
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts = make_suite(self.InnerTest) # pytype: disable=module-attr
ts.run(unittest.TestResult())
expected_types = set(variants.ChexVariantType)
if jax.device_count() == 1:
Expand Down Expand Up @@ -521,7 +523,7 @@ def test_4(self):

def test_counters(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts = make_suite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)

active_pmap = int(jax.device_count() > 1)
Expand Down

0 comments on commit 6f59425

Please sign in to comment.