Skip to content

Commit

Permalink
array API support for cosine_distances (scikit-learn#29265)
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilyXinyi authored Jul 12, 2024
1 parent cc97b80 commit 1813b4a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Metrics
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
- :func:`sklearn.metrics.pairwise.chi2_kernel`
- :func:`sklearn.metrics.pairwise.cosine_similarity`
- :func:`sklearn.metrics.pairwise.cosine_distances`
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ See :ref:`array_api` for more details.
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.cosine_distances` :pr:`29265` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.pairwise.euclidean_distances` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`;
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.rbf_kernel` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`.
Expand Down
7 changes: 5 additions & 2 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
gen_even_slices,
)
from ..utils._array_api import (
_clip,
_fill_or_add_to_diagonal,
_find_matching_floating_dtype,
_is_numpy_namespace,
Expand Down Expand Up @@ -1139,15 +1140,17 @@ def cosine_distances(X, Y=None):
array([[1. , 1. ],
[0.42..., 0.18...]])
"""
xp, _ = get_namespace(X, Y)

# 1.0 - cosine_similarity(X, Y) without copy
S = cosine_similarity(X, Y)
S *= -1
S += 1
np.clip(S, 0, 2, out=S)
S = _clip(S, 0, 2, xp)
if X is Y or Y is None:
# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
np.fill_diagonal(S, 0.0)
_fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
return S


Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from sklearn.metrics.pairwise import (
additive_chi2_kernel,
chi2_kernel,
cosine_distances,
cosine_similarity,
euclidean_distances,
paired_cosine_distances,
Expand Down Expand Up @@ -2016,6 +2017,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
mean_gamma_deviance: [check_array_api_regression_metric],
max_error: [check_array_api_regression_metric],
chi2_kernel: [check_array_api_metric_pairwise],
cosine_distances: [check_array_api_metric_pairwise],
euclidean_distances: [check_array_api_metric_pairwise],
rbf_kernel: [check_array_api_metric_pairwise],
}
Expand Down
13 changes: 13 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,19 @@ def _nanmax(X, axis=None, xp=None):
return X


def _clip(S, min_val, max_val, xp):
# TODO: remove this method and change all usage once we move to array api 2023.12
# https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.clip.html#clip
if _is_numpy_namespace(xp):
return numpy.clip(S, min_val, max_val)
else:
min_arr = xp.asarray(min_val, dtype=S.dtype)
max_arr = xp.asarray(max_val, dtype=S.dtype)
S = xp.where(S < min_arr, min_arr, S)
S = xp.where(S > max_arr, max_arr, S)
return S


def _asarray_with_order(
array, dtype=None, order=None, copy=None, *, xp=None, device=None
):
Expand Down

0 comments on commit 1813b4a

Please sign in to comment.