Skip to content

Commit

Permalink
Add strict argument to assert_trees_all_equal.
Browse files Browse the repository at this point in the history
This enables using the 'strict' mode of `numpy.testing.assert_array_equal`. This
mode (disabled by default) has stricter handling of scalar values, as described
in
https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html. The
argument is `False` by default, so this changes allows users to enable the
stricter equality checking, if they wish.

The jittable implementation of `assert_trees_all_equal` has always been 'strict'
in the `numpy.testing` sense. I have not implemented `strict=False` for this
function, although it would probably be possible to do so.

PiperOrigin-RevId: 543676343
  • Loading branch information
stompchicken authored and ChexDev committed Jun 27, 2023
1 parent 149068a commit 7eaff96
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
21 changes: 17 additions & 4 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,8 +1603,9 @@ def _assert_tree_all_finite_jittable(


@_static_assertion
def _assert_trees_all_equal_static(*trees: ArrayTree,
ignore_nones: bool = False) -> None:
def _assert_trees_all_equal_static(
*trees: ArrayTree, ignore_nones: bool = False, strict: bool = False
) -> None:
"""Checks that all trees have leaves with *exactly* equal values.
If you are comparing floating point numbers, an exact equality check may not
Expand All @@ -1613,6 +1614,8 @@ def _assert_trees_all_equal_static(*trees: ArrayTree,
Args:
*trees: A sequence of (at least 2) trees with array leaves.
ignore_nones: Whether to ignore `None` in the trees.
strict: If True, disable special scalar handling as described in
`np.testing.assert_array_equals` notes section.
Raises:
AssertionError: If the leaf values actual and desired are not exactly equal,
Expand All @@ -1623,7 +1626,8 @@ def assert_fn(arr_1, arr_2):
np.testing.assert_array_equal(
_ai.jnp_to_np_array(arr_1),
_ai.jnp_to_np_array(arr_2),
err_msg="Error in value equality check: Values not exactly equal")
err_msg="Error in value equality check: Values not exactly equal",
strict=strict)

def cmp_fn(arr_1, arr_2) -> bool:
try:
Expand All @@ -1646,9 +1650,18 @@ def err_msg_fn(arr_1, arr_2) -> str:


def _assert_trees_all_equal_jittable(
*trees: ArrayTree, ignore_nones: bool = False
*trees: ArrayTree, ignore_nones: bool = False, strict: bool = True,
) -> Array:
"""A jittable version of `_assert_trees_all_equal_static`."""
if not strict:
raise NotImplementedError(
"`strict=False` is not implemented by"
" `_assert_trees_all_equal_jittable`. This is a feature of"
" `np.testing.assert_array_equal` used in the static implementation of"
" `assert_trees_all_equal` that we do not implement in the jittable"
" version."
)

if not ignore_nones:
assert_tree_no_nones(trees)

Expand Down
24 changes: 24 additions & 0 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,30 @@ def test_assert_trees_all_equal_nones(self):
with self.assertRaisesRegex(AssertionError, err_regex):
asserts._assert_trees_all_equal_jittable(tree, tree, ignore_nones=False)

def test_assert_trees_all_equal_strict_mode(self):
# See 'notes' section of
# https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html
# for details about the 'strict' mode of `numpy.testing.assert_array_equal`.
# tldr; it has special handling for scalar values (by default).
tree1 = {'a': jnp.array([1.0], dtype=jnp.float32), 'b': 0.0}
tree2 = {'a': jnp.array(1.0, dtype=jnp.float32), 'b': 0.0}

asserts.assert_trees_all_equal(tree1, tree2)
asserts.assert_trees_all_equal(tree1, tree2, strict=False)
err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'a\'')
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_equal(tree1, tree2, strict=True)

err_regex = r'Trees 0 and 1 differ in leaves'
with self.assertRaisesRegex(ValueError, err_regex):
asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=True)

# We do not implement this special scalar handling in the jittable
# assertion (it's possible, but doesn't seem worth the effort).
err_regex = r'`strict=False` is not implemented'
with self.assertRaisesRegex(NotImplementedError, err_regex):
asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=False)

def test_assert_trees_all_close_passes_same_tree(self):
tree = {
'a': [jnp.zeros((1,))],
Expand Down

0 comments on commit 7eaff96

Please sign in to comment.