Skip to content

Commit

Permalink
ENH Array API support for euclidean_distances and rbf_kernel (scikit-…
Browse files Browse the repository at this point in the history
…learn#29433)

Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
OmarManzoor and ogrisel authored Jul 11, 2024
1 parent 2b2e290 commit e7af195
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 46 deletions.
18 changes: 18 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ Metrics
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
- :func:`sklearn.metrics.pairwise.chi2_kernel`
- :func:`sklearn.metrics.pairwise.cosine_similarity`
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`

Expand Down Expand Up @@ -172,6 +174,8 @@ automatically skipped. Therefore it's important to run the tests with the
pip install array-api-compat # and other libraries as needed
pytest -k "array_api" -v

.. _mps_support:

Note on MPS device support
--------------------------

Expand All @@ -191,3 +195,17 @@ To enable the MPS support in PyTorch, set the environment variable

At the time of writing all scikit-learn tests should pass, however, the
computational speed is not necessarily better than with the CPU device.

.. _device_support_for_float64:

Note on device support for ``float64``
--------------------------------------

Certain operations within scikit-learn will automatically perform operations
on floating-point values with `float64` precision to prevent overflows and ensure
correctness (e.g., :func:`metrics.pairwise.euclidean_distances`). However,
certain combinations of array namespaces and devices, such as `PyTorch on MPS`
(see :ref:`mps_support`) do not support the `float64` data type. In these cases,
scikit-learn will revert to using the `float32` data type instead. This can result in
different behavior (typically numerically unstable results) compared to not using array
API dispatching or using a device with `float64` support.
4 changes: 3 additions & 1 deletion doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ See :ref:`array_api` for more details.
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`.
- :func:`sklearn.metrics.pairwise.euclidean_distances` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`;
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.rbf_kernel` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`.

**Classes:**

Expand Down
8 changes: 4 additions & 4 deletions sklearn/decomposition/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy import linalg

from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
from ..utils._array_api import _add_to_diagonal, device, get_namespace
from ..utils._array_api import _fill_or_add_to_diagonal, device, get_namespace
from ..utils.validation import check_is_fitted


Expand Down Expand Up @@ -47,7 +47,7 @@ def get_covariance(self):
xp.asarray(0.0, device=device(exp_var)),
)
cov = (components_.T * exp_var_diff) @ components_
_add_to_diagonal(cov, self.noise_variance_, xp)
_fill_or_add_to_diagonal(cov, self.noise_variance_, xp)
return cov

def get_precision(self):
Expand Down Expand Up @@ -89,10 +89,10 @@ def get_precision(self):
xp.asarray(0.0, device=device(exp_var)),
)
precision = components_ @ components_.T / self.noise_variance_
_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
_fill_or_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
precision = components_.T @ linalg_inv(precision) @ components_
precision /= -(self.noise_variance_**2)
_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
_fill_or_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
return precision

@abstractmethod
Expand Down
71 changes: 46 additions & 25 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
gen_even_slices,
)
from ..utils._array_api import (
_fill_or_add_to_diagonal,
_find_matching_floating_dtype,
_is_numpy_namespace,
_max_precision_float_dtype,
_modify_in_place_if_numpy,
get_namespace,
get_namespace_and_device,
)
from ..utils._chunking import get_chunk_n_rows
from ..utils._mask import _get_mask
Expand Down Expand Up @@ -335,13 +339,14 @@ def euclidean_distances(
array([[1. ],
[1.41421356]])
"""
xp, _ = get_namespace(X, Y)
X, Y = check_pairwise_arrays(X, Y)

if X_norm_squared is not None:
X_norm_squared = check_array(X_norm_squared, ensure_2d=False)
original_shape = X_norm_squared.shape
if X_norm_squared.shape == (X.shape[0],):
X_norm_squared = X_norm_squared.reshape(-1, 1)
X_norm_squared = xp.reshape(X_norm_squared, (-1, 1))
if X_norm_squared.shape == (1, X.shape[0]):
X_norm_squared = X_norm_squared.T
if X_norm_squared.shape != (X.shape[0], 1):
Expand All @@ -354,7 +359,7 @@ def euclidean_distances(
Y_norm_squared = check_array(Y_norm_squared, ensure_2d=False)
original_shape = Y_norm_squared.shape
if Y_norm_squared.shape == (Y.shape[0],):
Y_norm_squared = Y_norm_squared.reshape(1, -1)
Y_norm_squared = xp.reshape(Y_norm_squared, (1, -1))
if Y_norm_squared.shape == (Y.shape[0], 1):
Y_norm_squared = Y_norm_squared.T
if Y_norm_squared.shape != (1, Y.shape[0]):
Expand All @@ -375,24 +380,25 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
float32, norms needs to be recomputed on upcast chunks.
TODO: use a float64 accumulator in row_norms to avoid the latter.
"""
if X_norm_squared is not None and X_norm_squared.dtype != np.float32:
XX = X_norm_squared.reshape(-1, 1)
elif X.dtype != np.float32:
XX = row_norms(X, squared=True)[:, np.newaxis]
xp, _, device_ = get_namespace_and_device(X, Y)
if X_norm_squared is not None and X_norm_squared.dtype != xp.float32:
XX = xp.reshape(X_norm_squared, (-1, 1))
elif X.dtype != xp.float32:
XX = row_norms(X, squared=True)[:, None]
else:
XX = None

if Y is X:
YY = None if XX is None else XX.T
else:
if Y_norm_squared is not None and Y_norm_squared.dtype != np.float32:
YY = Y_norm_squared.reshape(1, -1)
elif Y.dtype != np.float32:
YY = row_norms(Y, squared=True)[np.newaxis, :]
if Y_norm_squared is not None and Y_norm_squared.dtype != xp.float32:
YY = xp.reshape(Y_norm_squared, (1, -1))
elif Y.dtype != xp.float32:
YY = row_norms(Y, squared=True)[None, :]
else:
YY = None

if X.dtype == np.float32 or Y.dtype == np.float32:
if X.dtype == xp.float32 or Y.dtype == xp.float32:
# To minimize precision issues with float32, we compute the distance
# matrix on chunks of X and Y upcast to float64
distances = _euclidean_distances_upcast(X, XX, Y, YY)
Expand All @@ -401,14 +407,22 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
distances = -2 * safe_sparse_dot(X, Y.T, dense_output=True)
distances += XX
distances += YY
np.maximum(distances, 0, out=distances)

xp_zero = xp.asarray(0, device=device_, dtype=distances.dtype)
distances = _modify_in_place_if_numpy(
xp, xp.maximum, distances, xp_zero, out=distances
)

# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
if X is Y:
np.fill_diagonal(distances, 0)
_fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)

return distances if squared else np.sqrt(distances, out=distances)
if squared:
return distances

distances = _modify_in_place_if_numpy(xp, xp.sqrt, distances, out=distances)
return distances


@validate_params(
Expand Down Expand Up @@ -552,15 +566,20 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
X and Y are upcast to float64 by chunks, which size is chosen to limit
memory increase by approximately 10% (at least 10MiB).
"""
xp, _, device_ = get_namespace_and_device(X, Y)
n_samples_X = X.shape[0]
n_samples_Y = Y.shape[0]
n_features = X.shape[1]

distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)
distances = xp.empty((n_samples_X, n_samples_Y), dtype=xp.float32, device=device_)

if batch_size is None:
x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1
x_density = (
X.nnz / xp.prod(X.shape) if issparse(X) else xp.asarray(1, device=device_)
)
y_density = (
Y.nnz / xp.prod(Y.shape) if issparse(Y) else xp.asarray(1, device=device_)
)

# Allow 10% more memory than X, Y and the distance matrix take (at
# least 10MiB)
Expand All @@ -580,15 +599,15 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
# xd=x_density and yd=y_density
tmp = (x_density + y_density) * n_features
batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
batch_size = (-tmp + xp.sqrt(tmp**2 + 4 * maxmem)) / 2
batch_size = max(int(batch_size), 1)

x_batches = gen_batches(n_samples_X, batch_size)

xp_max_float = _max_precision_float_dtype(xp=xp, device=device_)
for i, x_slice in enumerate(x_batches):
X_chunk = X[x_slice].astype(np.float64)
X_chunk = xp.astype(X[x_slice], xp_max_float)
if XX is None:
XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis]
XX_chunk = row_norms(X_chunk, squared=True)[:, None]
else:
XX_chunk = XX[x_slice]

Expand All @@ -601,17 +620,17 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
d = distances[y_slice, x_slice].T

else:
Y_chunk = Y[y_slice].astype(np.float64)
Y_chunk = xp.astype(Y[y_slice], xp_max_float)
if YY is None:
YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :]
YY_chunk = row_norms(Y_chunk, squared=True)[None, :]
else:
YY_chunk = YY[:, y_slice]

d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
d += XX_chunk
d += YY_chunk

distances[x_slice, y_slice] = d.astype(np.float32, copy=False)
distances[x_slice, y_slice] = xp.astype(d, xp.float32, copy=False)

return distances

Expand Down Expand Up @@ -1549,13 +1568,15 @@ def rbf_kernel(X, Y=None, gamma=None):
array([[0.71..., 0.51...],
[0.51..., 0.71...]])
"""
xp, _ = get_namespace(X, Y)
X, Y = check_pairwise_arrays(X, Y)
if gamma is None:
gamma = 1.0 / X.shape[1]

K = euclidean_distances(X, Y, squared=True)
K *= -gamma
np.exp(K, K) # exponentiate K in-place
# exponentiate K in-place when using numpy
K = _modify_in_place_if_numpy(xp, xp.exp, K, out=K)
return K


Expand Down
4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
additive_chi2_kernel,
chi2_kernel,
cosine_similarity,
euclidean_distances,
paired_cosine_distances,
rbf_kernel,
)
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
Expand Down Expand Up @@ -2014,6 +2016,8 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
mean_gamma_deviance: [check_array_api_regression_metric],
max_error: [check_array_api_regression_metric],
chi2_kernel: [check_array_api_metric_pairwise],
euclidean_distances: [check_array_api_metric_pairwise],
rbf_kernel: [check_array_api_metric_pairwise],
}


Expand Down
83 changes: 68 additions & 15 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ def __eq__(self, other):
def isdtype(self, dtype, kind):
return isdtype(dtype, kind, xp=self._namespace)

def maximum(self, x1, x2):
# TODO: Remove when `maximum` is made compatible in `array_api_compat`,
# based on the `2023.12` specification.
# https://github.com/data-apis/array-api-compat/issues/127
x1_np = _convert_to_numpy(x1, xp=self._namespace)
x2_np = _convert_to_numpy(x2, xp=self._namespace)
x_max = numpy.maximum(x1_np, x2_np)
return self._namespace.asarray(x_max, device=device(x1, x2))


def _check_device_cpu(device): # noqa
if device not in {"cpu", None}:
Expand Down Expand Up @@ -566,7 +575,28 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):


def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)):
"""Combination into one single function of `get_namespace` and `device`."""
"""Combination into one single function of `get_namespace` and `device`.
Parameters
----------
*array_list : array objects
Array objects.
remove_none : bool, default=True
Whether to ignore None objects passed in arrays.
remove_types : tuple or list, default=(str,)
Types to ignore in the arrays.
Returns
-------
namespace : module
Namespace shared by array objects. If any of the `arrays` are not arrays,
the namespace defaults to NumPy.
is_array_api_compliant : bool
True if the arrays are containers that implement the Array API spec.
Always False when array_api_dispatch=False.
device : device
`device` object (see the "Device Support" section of the array API spec).
"""
array_list = _remove_non_arrays(
*array_list, remove_none=remove_none, remove_types=remove_types
)
Expand All @@ -592,21 +622,36 @@ def _expit(X, xp=None):
return 1.0 / (1.0 + xp.exp(-X))


def _add_to_diagonal(array, value, xp):
# Workaround for the lack of support for xp.reshape(a, shape, copy=False) in
# numpy.array_api: https://github.com/numpy/numpy/issues/23410
value = xp.asarray(value, dtype=array.dtype)
if _is_numpy_namespace(xp):
array_np = numpy.asarray(array)
array_np.flat[:: array.shape[0] + 1] += value
return xp.asarray(array_np)
elif value.ndim == 1:
for i in range(array.shape[0]):
array[i, i] += value[i]
def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False):
"""Implementation to facilitate adding or assigning specified values to the
diagonal of a 2-d array.
If ``add_value`` is `True` then the values will be added to the diagonal
elements otherwise the values will be assigned to the diagonal elements.
By default, ``add_value`` is set to `True. This is currently only
supported for 2-d arrays.
The implementation is taken from the `numpy.fill_diagonal` function:
https://github.com/numpy/numpy/blob/v2.0.0/numpy/lib/_index_tricks_impl.py#L799-L929
"""
if array.ndim != 2:
raise ValueError(
f"array should be 2-d. Got array with shape {tuple(array.shape)}"
)

value = xp.asarray(value, dtype=array.dtype, device=device(array))
end = None
# Explicit, fast formula for the common case. For 2-d arrays, we
# accept rectangular ones.
step = array.shape[1] + 1
if not wrap:
end = array.shape[1] * array.shape[1]

array_flat = xp.reshape(array, (-1,))
if add_value:
array_flat[:end:step] += value
else:
# scalar value
for i in range(array.shape[0]):
array[i, i] += value
array_flat[:end:step] = value


def _max_precision_float_dtype(xp, device):
Expand Down Expand Up @@ -1000,3 +1045,11 @@ def _count_nonzero(X, xp, device, axis=None, sample_weight=None):

zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)


def _modify_in_place_if_numpy(xp, func, *args, out=None, **kwargs):
if _is_numpy_namespace(xp):
func(*args, out=out, **kwargs)
else:
out = func(*args, **kwargs)
return out
Loading

0 comments on commit e7af195

Please sign in to comment.