From 7eaff9691143f3934011cede595b5d9ca8223d86 Mon Sep 17 00:00:00 2001 From: Stephen Spencer Date: Tue, 27 Jun 2023 01:59:18 -0700 Subject: [PATCH] Add `strict` argument to `assert_trees_all_equal`. This enables using the 'strict' mode of `numpy.testing.assert_array_equal`. This mode (disabled by default) has stricter handling of scalar values, as described in https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html. The argument is `False` by default, so this changes allows users to enable the stricter equality checking, if they wish. The jittable implementation of `assert_trees_all_equal` has always been 'strict' in the `numpy.testing` sense. I have not implemented `strict=False` for this function, although it would probably be possible to do so. PiperOrigin-RevId: 543676343 --- chex/_src/asserts.py | 21 +++++++++++++++++---- chex/_src/asserts_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index c6f73fd3..86e20d95 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1603,8 +1603,9 @@ def _assert_tree_all_finite_jittable( @_static_assertion -def _assert_trees_all_equal_static(*trees: ArrayTree, - ignore_nones: bool = False) -> None: +def _assert_trees_all_equal_static( + *trees: ArrayTree, ignore_nones: bool = False, strict: bool = False +) -> None: """Checks that all trees have leaves with *exactly* equal values. If you are comparing floating point numbers, an exact equality check may not @@ -1613,6 +1614,8 @@ def _assert_trees_all_equal_static(*trees: ArrayTree, Args: *trees: A sequence of (at least 2) trees with array leaves. ignore_nones: Whether to ignore `None` in the trees. + strict: If True, disable special scalar handling as described in + `np.testing.assert_array_equals` notes section. Raises: AssertionError: If the leaf values actual and desired are not exactly equal, @@ -1623,7 +1626,8 @@ def assert_fn(arr_1, arr_2): np.testing.assert_array_equal( _ai.jnp_to_np_array(arr_1), _ai.jnp_to_np_array(arr_2), - err_msg="Error in value equality check: Values not exactly equal") + err_msg="Error in value equality check: Values not exactly equal", + strict=strict) def cmp_fn(arr_1, arr_2) -> bool: try: @@ -1646,9 +1650,18 @@ def err_msg_fn(arr_1, arr_2) -> str: def _assert_trees_all_equal_jittable( - *trees: ArrayTree, ignore_nones: bool = False + *trees: ArrayTree, ignore_nones: bool = False, strict: bool = True, ) -> Array: """A jittable version of `_assert_trees_all_equal_static`.""" + if not strict: + raise NotImplementedError( + "`strict=False` is not implemented by" + " `_assert_trees_all_equal_jittable`. This is a feature of" + " `np.testing.assert_array_equal` used in the static implementation of" + " `assert_trees_all_equal` that we do not implement in the jittable" + " version." + ) + if not ignore_nones: assert_tree_no_nones(trees) diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index cac51e5c..da41f628 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -943,6 +943,30 @@ def test_assert_trees_all_equal_nones(self): with self.assertRaisesRegex(AssertionError, err_regex): asserts._assert_trees_all_equal_jittable(tree, tree, ignore_nones=False) + def test_assert_trees_all_equal_strict_mode(self): + # See 'notes' section of + # https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html + # for details about the 'strict' mode of `numpy.testing.assert_array_equal`. + # tldr; it has special handling for scalar values (by default). + tree1 = {'a': jnp.array([1.0], dtype=jnp.float32), 'b': 0.0} + tree2 = {'a': jnp.array(1.0, dtype=jnp.float32), 'b': 0.0} + + asserts.assert_trees_all_equal(tree1, tree2) + asserts.assert_trees_all_equal(tree1, tree2, strict=False) + err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'a\'') + with self.assertRaisesRegex(AssertionError, err_regex): + asserts.assert_trees_all_equal(tree1, tree2, strict=True) + + err_regex = r'Trees 0 and 1 differ in leaves' + with self.assertRaisesRegex(ValueError, err_regex): + asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=True) + + # We do not implement this special scalar handling in the jittable + # assertion (it's possible, but doesn't seem worth the effort). + err_regex = r'`strict=False` is not implemented' + with self.assertRaisesRegex(NotImplementedError, err_regex): + asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=False) + def test_assert_trees_all_close_passes_same_tree(self): tree = { 'a': [jnp.zeros((1,))],