Skip to content

Commit

Permalink
fix(ops): Fix ops.argmin() handling of subnormal float values in Kera…
Browse files Browse the repository at this point in the history
…s backends (#20812)

- Update JAX and NumPy backends to handle subnormal float comparisons

- Add test case to verify subnormal float value handling
  • Loading branch information
harshaljanjani authored Jan 26, 2025
1 parent aee7dce commit 734cd03
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
15 changes: 14 additions & 1 deletion keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,20 @@ def argmax(x, axis=None, keepdims=False):


def argmin(x, axis=None, keepdims=False):
return jnp.argmin(x, axis=axis, keepdims=keepdims)
x_64 = jnp.asarray(x, dtype=jnp.float64)
if axis is not None:
min_mask = x_64 == jnp.min(x_64, axis=axis, keepdims=True)
indices = jnp.argmin(
jnp.where(min_mask, x_64, jnp.inf), axis=axis, keepdims=keepdims
).astype("int32")
else:
min_mask = (x_64 < x_64.min()) | (
(x_64 == x_64.min()) & (jnp.signbit(x_64))
)
indices = jnp.argmin(
jnp.where(min_mask, x_64, jnp.inf), axis=axis, keepdims=keepdims
).astype("int32")
return indices


def argsort(x, axis=-1):
Expand Down
15 changes: 14 additions & 1 deletion keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,20 @@ def argmax(x, axis=None, keepdims=False):

def argmin(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32")
x_64 = np.asarray(x, dtype=np.float64)
if axis is not None:
min_mask = x_64 == np.min(x_64, axis=axis, keepdims=True)
indices = np.argmin(
np.where(min_mask, x_64, np.inf), axis=axis, keepdims=keepdims
).astype("int32")
else:
min_mask = (x_64 < x_64.min()) | (
(x_64 == x_64.min()) & (np.signbit(x_64))
)
indices = np.argmin(
np.where(min_mask, x_64, np.inf), axis=axis, keepdims=keepdims
).astype("int32")
return indices


def argsort(x, axis=-1):
Expand Down
22 changes: 22 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,28 @@ def test_argmax_negative_zero(self):
)
self.assertEqual(knp.argmax(input_data), 2)

@pytest.mark.skipif(
keras.config.backend() == "openvino"
or keras.config.backend() == "tensorflow",
reason="""
OpenVINO and TensorFlow don't support this
change, TensorFlow behavior for this case is under
evaluation and may change within this PR
""",
)
def test_argmin_negative_zero(self):
input_data = np.array(
[
0.0,
1.1754943508222875e-38,
-1.401298464324817e-45,
0.0,
459367.0,
],
dtype=np.float32,
)
self.assertEqual(knp.argmin(input_data), 2)

def test_argmin(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.argmin(x).shape, ())
Expand Down

0 comments on commit 734cd03

Please sign in to comment.