diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index a51ee60e47e04..9afedeb7ccecb 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -123,7 +123,9 @@ Metrics - :func:`sklearn.metrics.pairwise.additive_chi2_kernel` - :func:`sklearn.metrics.pairwise.chi2_kernel` - :func:`sklearn.metrics.pairwise.cosine_similarity` +- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`) - :func:`sklearn.metrics.pairwise.paired_cosine_distances` +- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`) - :func:`sklearn.metrics.r2_score` - :func:`sklearn.metrics.zero_one_loss` @@ -172,6 +174,8 @@ automatically skipped. Therefore it's important to run the tests with the pip install array-api-compat # and other libraries as needed pytest -k "array_api" -v +.. _mps_support: + Note on MPS device support -------------------------- @@ -191,3 +195,17 @@ To enable the MPS support in PyTorch, set the environment variable At the time of writing all scikit-learn tests should pass, however, the computational speed is not necessarily better than with the CPU device. + +.. _device_support_for_float64: + +Note on device support for ``float64`` +-------------------------------------- + +Certain operations within scikit-learn will automatically perform operations +on floating-point values with `float64` precision to prevent overflows and ensure +correctness (e.g., :func:`metrics.pairwise.euclidean_distances`). However, +certain combinations of array namespaces and devices, such as `PyTorch on MPS` +(see :ref:`mps_support`) do not support the `float64` data type. In these cases, +scikit-learn will revert to using the `float32` data type instead. This can result in +different behavior (typically numerically unstable results) compared to not using array +API dispatching or using a device with `float64` support. diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index d7d3a71eba636..3971f60eb5f4b 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -43,7 +43,9 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko `; - :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko `; - :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati `; -- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati `. +- :func:`sklearn.metrics.pairwise.euclidean_distances` :pr:`29433` by :user:`Omar Salman `; +- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati `; +- :func:`sklearn.metrics.pairwise.rbf_kernel` :pr:`29433` by :user:`Omar Salman `. **Classes:** diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index f2d0ad663569a..970294efe0184 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -9,7 +9,7 @@ from scipy import linalg from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin -from ..utils._array_api import _add_to_diagonal, device, get_namespace +from ..utils._array_api import _fill_or_add_to_diagonal, device, get_namespace from ..utils.validation import check_is_fitted @@ -47,7 +47,7 @@ def get_covariance(self): xp.asarray(0.0, device=device(exp_var)), ) cov = (components_.T * exp_var_diff) @ components_ - _add_to_diagonal(cov, self.noise_variance_, xp) + _fill_or_add_to_diagonal(cov, self.noise_variance_, xp) return cov def get_precision(self): @@ -89,10 +89,10 @@ def get_precision(self): xp.asarray(0.0, device=device(exp_var)), ) precision = components_ @ components_.T / self.noise_variance_ - _add_to_diagonal(precision, 1.0 / exp_var_diff, xp) + _fill_or_add_to_diagonal(precision, 1.0 / exp_var_diff, xp) precision = components_.T @ linalg_inv(precision) @ components_ precision /= -(self.noise_variance_**2) - _add_to_diagonal(precision, 1.0 / self.noise_variance_, xp) + _fill_or_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp) return precision @abstractmethod diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 9382d585a5fe7..b7db4d94c4f07 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -22,9 +22,13 @@ gen_even_slices, ) from ..utils._array_api import ( + _fill_or_add_to_diagonal, _find_matching_floating_dtype, _is_numpy_namespace, + _max_precision_float_dtype, + _modify_in_place_if_numpy, get_namespace, + get_namespace_and_device, ) from ..utils._chunking import get_chunk_n_rows from ..utils._mask import _get_mask @@ -335,13 +339,14 @@ def euclidean_distances( array([[1. ], [1.41421356]]) """ + xp, _ = get_namespace(X, Y) X, Y = check_pairwise_arrays(X, Y) if X_norm_squared is not None: X_norm_squared = check_array(X_norm_squared, ensure_2d=False) original_shape = X_norm_squared.shape if X_norm_squared.shape == (X.shape[0],): - X_norm_squared = X_norm_squared.reshape(-1, 1) + X_norm_squared = xp.reshape(X_norm_squared, (-1, 1)) if X_norm_squared.shape == (1, X.shape[0]): X_norm_squared = X_norm_squared.T if X_norm_squared.shape != (X.shape[0], 1): @@ -354,7 +359,7 @@ def euclidean_distances( Y_norm_squared = check_array(Y_norm_squared, ensure_2d=False) original_shape = Y_norm_squared.shape if Y_norm_squared.shape == (Y.shape[0],): - Y_norm_squared = Y_norm_squared.reshape(1, -1) + Y_norm_squared = xp.reshape(Y_norm_squared, (1, -1)) if Y_norm_squared.shape == (Y.shape[0], 1): Y_norm_squared = Y_norm_squared.T if Y_norm_squared.shape != (1, Y.shape[0]): @@ -375,24 +380,25 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared float32, norms needs to be recomputed on upcast chunks. TODO: use a float64 accumulator in row_norms to avoid the latter. """ - if X_norm_squared is not None and X_norm_squared.dtype != np.float32: - XX = X_norm_squared.reshape(-1, 1) - elif X.dtype != np.float32: - XX = row_norms(X, squared=True)[:, np.newaxis] + xp, _, device_ = get_namespace_and_device(X, Y) + if X_norm_squared is not None and X_norm_squared.dtype != xp.float32: + XX = xp.reshape(X_norm_squared, (-1, 1)) + elif X.dtype != xp.float32: + XX = row_norms(X, squared=True)[:, None] else: XX = None if Y is X: YY = None if XX is None else XX.T else: - if Y_norm_squared is not None and Y_norm_squared.dtype != np.float32: - YY = Y_norm_squared.reshape(1, -1) - elif Y.dtype != np.float32: - YY = row_norms(Y, squared=True)[np.newaxis, :] + if Y_norm_squared is not None and Y_norm_squared.dtype != xp.float32: + YY = xp.reshape(Y_norm_squared, (1, -1)) + elif Y.dtype != xp.float32: + YY = row_norms(Y, squared=True)[None, :] else: YY = None - if X.dtype == np.float32 or Y.dtype == np.float32: + if X.dtype == xp.float32 or Y.dtype == xp.float32: # To minimize precision issues with float32, we compute the distance # matrix on chunks of X and Y upcast to float64 distances = _euclidean_distances_upcast(X, XX, Y, YY) @@ -401,14 +407,22 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared distances = -2 * safe_sparse_dot(X, Y.T, dense_output=True) distances += XX distances += YY - np.maximum(distances, 0, out=distances) + + xp_zero = xp.asarray(0, device=device_, dtype=distances.dtype) + distances = _modify_in_place_if_numpy( + xp, xp.maximum, distances, xp_zero, out=distances + ) # Ensure that distances between vectors and themselves are set to 0.0. # This may not be the case due to floating point rounding errors. if X is Y: - np.fill_diagonal(distances, 0) + _fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False) - return distances if squared else np.sqrt(distances, out=distances) + if squared: + return distances + + distances = _modify_in_place_if_numpy(xp, xp.sqrt, distances, out=distances) + return distances @validate_params( @@ -552,15 +566,20 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None): X and Y are upcast to float64 by chunks, which size is chosen to limit memory increase by approximately 10% (at least 10MiB). """ + xp, _, device_ = get_namespace_and_device(X, Y) n_samples_X = X.shape[0] n_samples_Y = Y.shape[0] n_features = X.shape[1] - distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32) + distances = xp.empty((n_samples_X, n_samples_Y), dtype=xp.float32, device=device_) if batch_size is None: - x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1 - y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1 + x_density = ( + X.nnz / xp.prod(X.shape) if issparse(X) else xp.asarray(1, device=device_) + ) + y_density = ( + Y.nnz / xp.prod(Y.shape) if issparse(Y) else xp.asarray(1, device=device_) + ) # Allow 10% more memory than X, Y and the distance matrix take (at # least 10MiB) @@ -580,15 +599,15 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None): # Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem # xd=x_density and yd=y_density tmp = (x_density + y_density) * n_features - batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2 + batch_size = (-tmp + xp.sqrt(tmp**2 + 4 * maxmem)) / 2 batch_size = max(int(batch_size), 1) x_batches = gen_batches(n_samples_X, batch_size) - + xp_max_float = _max_precision_float_dtype(xp=xp, device=device_) for i, x_slice in enumerate(x_batches): - X_chunk = X[x_slice].astype(np.float64) + X_chunk = xp.astype(X[x_slice], xp_max_float) if XX is None: - XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis] + XX_chunk = row_norms(X_chunk, squared=True)[:, None] else: XX_chunk = XX[x_slice] @@ -601,9 +620,9 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None): d = distances[y_slice, x_slice].T else: - Y_chunk = Y[y_slice].astype(np.float64) + Y_chunk = xp.astype(Y[y_slice], xp_max_float) if YY is None: - YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :] + YY_chunk = row_norms(Y_chunk, squared=True)[None, :] else: YY_chunk = YY[:, y_slice] @@ -611,7 +630,7 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None): d += XX_chunk d += YY_chunk - distances[x_slice, y_slice] = d.astype(np.float32, copy=False) + distances[x_slice, y_slice] = xp.astype(d, xp.float32, copy=False) return distances @@ -1549,13 +1568,15 @@ def rbf_kernel(X, Y=None, gamma=None): array([[0.71..., 0.51...], [0.51..., 0.71...]]) """ + xp, _ = get_namespace(X, Y) X, Y = check_pairwise_arrays(X, Y) if gamma is None: gamma = 1.0 / X.shape[1] K = euclidean_distances(X, Y, squared=True) K *= -gamma - np.exp(K, K) # exponentiate K in-place + # exponentiate K in-place when using numpy + K = _modify_in_place_if_numpy(xp, xp.exp, K, out=K) return K diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 6110cbd3d1d13..14e96cc9fcd98 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -55,7 +55,9 @@ additive_chi2_kernel, chi2_kernel, cosine_similarity, + euclidean_distances, paired_cosine_distances, + rbf_kernel, ) from sklearn.preprocessing import LabelBinarizer from sklearn.utils import shuffle @@ -2014,6 +2016,8 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) mean_gamma_deviance: [check_array_api_regression_metric], max_error: [check_array_api_regression_metric], chi2_kernel: [check_array_api_metric_pairwise], + euclidean_distances: [check_array_api_metric_pairwise], + rbf_kernel: [check_array_api_metric_pairwise], } diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index a00d250ab31d2..51caacb71c9e2 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -302,6 +302,15 @@ def __eq__(self, other): def isdtype(self, dtype, kind): return isdtype(dtype, kind, xp=self._namespace) + def maximum(self, x1, x2): + # TODO: Remove when `maximum` is made compatible in `array_api_compat`, + # based on the `2023.12` specification. + # https://github.com/data-apis/array-api-compat/issues/127 + x1_np = _convert_to_numpy(x1, xp=self._namespace) + x2_np = _convert_to_numpy(x2, xp=self._namespace) + x_max = numpy.maximum(x1_np, x2_np) + return self._namespace.asarray(x_max, device=device(x1, x2)) + def _check_device_cpu(device): # noqa if device not in {"cpu", None}: @@ -566,7 +575,28 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)): - """Combination into one single function of `get_namespace` and `device`.""" + """Combination into one single function of `get_namespace` and `device`. + + Parameters + ---------- + *array_list : array objects + Array objects. + remove_none : bool, default=True + Whether to ignore None objects passed in arrays. + remove_types : tuple or list, default=(str,) + Types to ignore in the arrays. + + Returns + ------- + namespace : module + Namespace shared by array objects. If any of the `arrays` are not arrays, + the namespace defaults to NumPy. + is_array_api_compliant : bool + True if the arrays are containers that implement the Array API spec. + Always False when array_api_dispatch=False. + device : device + `device` object (see the "Device Support" section of the array API spec). + """ array_list = _remove_non_arrays( *array_list, remove_none=remove_none, remove_types=remove_types ) @@ -592,21 +622,36 @@ def _expit(X, xp=None): return 1.0 / (1.0 + xp.exp(-X)) -def _add_to_diagonal(array, value, xp): - # Workaround for the lack of support for xp.reshape(a, shape, copy=False) in - # numpy.array_api: https://github.com/numpy/numpy/issues/23410 - value = xp.asarray(value, dtype=array.dtype) - if _is_numpy_namespace(xp): - array_np = numpy.asarray(array) - array_np.flat[:: array.shape[0] + 1] += value - return xp.asarray(array_np) - elif value.ndim == 1: - for i in range(array.shape[0]): - array[i, i] += value[i] +def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False): + """Implementation to facilitate adding or assigning specified values to the + diagonal of a 2-d array. + + If ``add_value`` is `True` then the values will be added to the diagonal + elements otherwise the values will be assigned to the diagonal elements. + By default, ``add_value`` is set to `True. This is currently only + supported for 2-d arrays. + + The implementation is taken from the `numpy.fill_diagonal` function: + https://github.com/numpy/numpy/blob/v2.0.0/numpy/lib/_index_tricks_impl.py#L799-L929 + """ + if array.ndim != 2: + raise ValueError( + f"array should be 2-d. Got array with shape {tuple(array.shape)}" + ) + + value = xp.asarray(value, dtype=array.dtype, device=device(array)) + end = None + # Explicit, fast formula for the common case. For 2-d arrays, we + # accept rectangular ones. + step = array.shape[1] + 1 + if not wrap: + end = array.shape[1] * array.shape[1] + + array_flat = xp.reshape(array, (-1,)) + if add_value: + array_flat[:end:step] += value else: - # scalar value - for i in range(array.shape[0]): - array[i, i] += value + array_flat[:end:step] = value def _max_precision_float_dtype(xp, device): @@ -1000,3 +1045,11 @@ def _count_nonzero(X, xp, device, axis=None, sample_weight=None): zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype) return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis) + + +def _modify_in_place_if_numpy(xp, func, *args, out=None, **kwargs): + if _is_numpy_namespace(xp): + func(*args, out=out, **kwargs) + else: + out = func(*args, **kwargs) + return out diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 717bbed76513b..7b5720473848a 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -178,6 +178,7 @@ def safe_sparse_dot(a, b, *, dense_output=False): [11, 25, 39], [17, 39, 61]]) """ + xp, _ = get_namespace(a, b) if a.ndim > 2 or b.ndim > 2: if sparse.issparse(a): # sparse is always 2D. Implies b is 3D+ @@ -193,7 +194,12 @@ def safe_sparse_dot(a, b, *, dense_output=False): ret = a_2d @ b ret = ret.reshape(*a.shape[:-1], b.shape[1]) else: - ret = np.dot(a, b) + # Alternative for `np.dot` when dealing with a or b having + # more than 2 dimensions, that works with the array api. + # If b is 1-dim then the last axis for b is taken otherwise + # if b is >= 2-dim then the second to last axis is taken. + b_axis = -1 if b.ndim == 1 else -2 + ret = xp.tensordot(a, b, axes=[-1, b_axis]) else: ret = a @ b diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 71f499f7a8dae..707304edacd11 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -15,6 +15,7 @@ _convert_to_numpy, _count_nonzero, _estimator_with_converted_arrays, + _fill_or_add_to_diagonal, _is_numpy_namespace, _isin, _max_precision_float_dtype, @@ -112,6 +113,26 @@ def test_array_api_wrapper_astype(): assert X_converted.dtype == xp.float32 +def test_array_api_wrapper_maximum(): + """Test _ArrayAPIWrapper `maximum` for ArrayAPIs other than NumPy. + + This is mainly used to test for `cupy.array_api` but since that is + not available on our coverage-enabled PR CI, we resort to using + `array-api-strict`. + """ + array_api_strict = pytest.importorskip("array_api_strict") + xp_ = _AdjustableNameAPITestWrapper(array_api_strict, "array_api_strict") + xp = _ArrayAPIWrapper(xp_) + + x1 = xp.asarray(([[1, 2, 3], [3, 9, 5]]), dtype=xp.int64) + x2 = xp.asarray(([[0, 1, 6], [8, 4, 5]]), dtype=xp.int64) + result = xp.asarray([[1, 2, 6], [8, 9, 5]], dtype=xp.int64) + + x_max = xp.maximum(x1, x2) + assert x_max.dtype == x1.dtype + assert xp.all(xp.equal(x_max, result)) + + @pytest.mark.parametrize("array_api", ["numpy", "array_api_strict"]) def test_asarray_with_order(array_api): """Test _asarray_with_order passes along order for NumPy arrays.""" @@ -624,3 +645,16 @@ def test_count_nonzero( # NumPy 2.0 has a problem with the device attribute of scalar arrays: # https://github.com/numpy/numpy/issues/26850 assert device(array_xp) == device(result) + + +@pytest.mark.parametrize( + "array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations() +) +@pytest.mark.parametrize("wrap", [True, False]) +def test_fill_or_add_to_diagonal(array_namespace, device_, dtype_name, wrap): + xp = _array_api_for_tests(array_namespace, device_) + array_np = numpy.zeros((5, 4), dtype=numpy.int64) + array_xp = xp.asarray(array_np) + _fill_or_add_to_diagonal(array_xp, value=1, xp=xp, add_value=False, wrap=wrap) + numpy.fill_diagonal(array_np, val=1, wrap=wrap) + assert_array_equal(_convert_to_numpy(array_xp, xp=xp), array_np)