From 7ce62f969c94598c64be7d4da3a0c66a9d1d183c Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 25 Feb 2025 21:35:25 -0800 Subject: [PATCH] Expose kmeans to python --- cpp/CMakeLists.txt | 1 + cpp/include/cuvs/cluster/kmeans.h | 201 ++++++++++++++++++++++ cpp/include/cuvs/cluster/kmeans.hpp | 47 ++++- cpp/src/cluster/kmeans.cuh | 46 +++++ cpp/src/cluster/kmeans_c.cpp | 236 ++++++++++++++++++++++++++ python/cuvs/CMakeLists.txt | 1 + python/cuvs/cuvs/tests/test_kmeans.py | 70 ++++++++ 7 files changed, 598 insertions(+), 4 deletions(-) create mode 100644 cpp/include/cuvs/cluster/kmeans.h create mode 100644 cpp/src/cluster/kmeans_c.cpp create mode 100644 python/cuvs/cuvs/tests/test_kmeans.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 869854847..c8f867bb4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -688,6 +688,7 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$:NVTX_ENAB add_library( cuvs_c SHARED src/core/c_api.cpp + src/cluster/kmeans_c.cpp src/neighbors/brute_force_c.cpp src/neighbors/ivf_flat_c.cpp src/neighbors/ivf_pq_c.cpp diff --git a/cpp/include/cuvs/cluster/kmeans.h b/cpp/include/cuvs/cluster/kmeans.h new file mode 100644 index 000000000..2719f963e --- /dev/null +++ b/cpp/include/cuvs/cluster/kmeans.h @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +enum cuvsKMeansInitMethod { + /** + * Sample the centroids using the kmeans++ strategy + */ + KMeansPlusPlus, + + /** + * Sample the centroids uniformly at random + */ + Random, + + /** + * User provides the array of initial centroids + */ + Array +}; + +/** + * @brief Hyper-parameters for the kmeans algorithm + */ +struct cuvsKMeansParams { + cuvsDistanceType metric; + + /** + * The number of clusters to form as well as the number of centroids to generate (default:8). + */ + int n_clusters; + + /** + * Method for initialization, defaults to k-means++: + * - cuvsKMeansInitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm + * to select the initial cluster centers. + * - cuvsKMeansInitMethod::Random (random): Choose 'n_clusters' observations (rows) at + * random from the input data for the initial centroids. + * - cuvsKMeansInitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. + */ + cuvsKMeansInitMethod init; + + /** + * Maximum number of iterations of the k-means algorithm for a single run. + */ + int max_iter; + + /** + * Relative tolerance with regards to inertia to declare convergence. + */ + double tol; + + /** + * Number of instance k-means algorithm will be run with different seeds. + */ + int n_init; + + /** + * Oversampling factor for use in the k-means|| algorithm + */ + double oversampling_factor; + + /** + * batch_samples and batch_centroids are used to tile 1NN computation which is + * useful to optimize/control the memory footprint + * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 + * then don't tile the centroids + */ + int batch_samples; + + /** + * if 0 then batch_centroids = n_clusters + */ + int batch_centroids; + + bool inertia_check; + + // TODO: handle balanced kmeans +}; + +typedef struct cuvsKMeansParams* cuvsKMeansParams_t; + +/** + * @brief Allocate Scalar Quantizer params, and populate with default values + * + * @param[in] params cuvsKMeansParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params); + +/** + * @brief De-allocate Scalar Quantizer params + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t params); + +/** + * @brief Find clusters with k-means algorithm. + * + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +cuvsError_t cuvsKMeansFit(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter); + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +cuvsError_t cuvsKMeansPredict(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia); + +/** + * @brief Compute cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * + */ +cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, + DLManagedTensor* X, + DLManagedTensor* centroids, + double* cost); +#ifdef __cplusplus +} +#endif diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 64ac813ab..7c6af27e1 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -104,10 +104,12 @@ struct params : base_params { */ double oversampling_factor = 2.0; - // batch_samples and batch_centroids are used to tile 1NN computation which is - // useful to optimize/control the memory footprint - // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 - // then don't tile the centroids + /** + * batch_samples and batch_centroids are used to tile 1NN computation which is + * useful to optimize/control the memory footprint + * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 + * then don't tile the centroids + */ int batch_samples = 1 << 15; /** @@ -1089,6 +1091,43 @@ void transform(raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_matrix_view X_new); + +/** + * @brief Compute cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * + */ +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost); + +/** + * @brief Compute cluster cost + * + * @param[in] handle The raft handle + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[out] cost Resulting cluster cost + * + */ +void cluster_cost(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost); + /** * @} */ diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 5e6d756cc..4115e4abe 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -465,6 +465,52 @@ void min_cluster_distance(raft::resources const& handle, workspace); } +template +void cluster_cost(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost) +{ + auto stream = raft::resource::get_cuda_stream(handle); + + auto n_clusters = centroids.extent(0); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + + rmm::device_uvector workspace(n_samples * sizeof(IndexT), stream); + + rmm::device_uvector x_norms(n_samples, stream); + rmm::device_uvector centroid_norms(n_clusters, stream); + raft::linalg::rowNorm( + x_norms.data(), X.data_handle(), n_features, n_samples, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm( + centroid_norms.data(), centroids, n_features, n_clusters, raft::linalg::L2Norm, true, stream); + + rmm::device_uvector min_cluster_distance(n_samples, stream); + rmm::device_uvector l2_norm_or_distance_buffer(0, stream); + + auto metric = cuvs::distance::DistanceType::L2Expanded; + + cuvs::cluster::kmeans::min_cluster_distance(handle, + X, + centroids, + min_cluster_distance, + x_norms, + l2_norm_or_distance_buffer, + metric, + n_samples, + n_clusters, + workspace); + + rmm::device_scalar device_cost(0, stream); + cuvs::cluster::kmeans::cluster_cost(handle, + min_cluster_distance.view(), + workspace, + raft::make_device_scalar_view(device_cost.data()), + raft::add_op{}); + raft::update_host(cost.data(), device_cost.data(), 1, stream); +} + /** * @brief Calculates a pair for every sample in input 'X' where key is an * index of one of the 'centroids' (index of the nearest centroid) and 'value' diff --git a/cpp/src/cluster/kmeans_c.cpp b/cpp/src/cluster/kmeans_c.cpp new file mode 100644 index 000000000..243c3db87 --- /dev/null +++ b/cpp/src/cluster/kmeans_c.cpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) +{ + auto kmeans_params = cuvs::cluster::kmeans::params(); + kmeans_params.metric = params.metric; + kmeans_params.init = static_cast(params.init); + kmeans_params.n_clusters = params.n_clusters; + kmeans_params.max_iter = params.max_iter; + kmeans_params.tol = params.tol; + kmeans_params.oversampling_factor = params.oversampling_factor; + kmeans_params.batch_samples = params.batch_samples; + kmeans_params.batch_centroids = params.batch_centroids; + kmeans_params.inertia_check = params.inertia_check; + return kmeans_params; +} + +template +void _fit(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + double* inertia, + int* n_iter) +{ + auto X = X_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + + auto kmeans_params = convert_params(params); + + T inertia_temp; + IdxT n_iter_temp; + + if (cuvs::core::is_dlpack_device_compatible(X)) { + using const_mdspan_type = raft::device_matrix_view; + using mdspan_type = raft::device_matrix_view; + + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + sample_weight = + cuvs::core::from_dlpack>(sample_weight_tensor); + } + + cuvs::cluster::kmeans::fit(*res_ptr, + kmeans_params, + cuvs::core::from_dlpack(X_tensor), + sample_weight, + cuvs::core::from_dlpack(centroids_tensor), + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + } else { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + *inertia = inertia_temp; + *n_iter = n_iter_temp; +} + +template +void _predict(cuvsResources_t res, + const cuvsKMeansParams& params, + DLManagedTensor* X_tensor, + DLManagedTensor* sample_weight_tensor, + DLManagedTensor* centroids_tensor, + DLManagedTensor* labels_tensor, + bool normalize_weight, + double* inertia) +{ + auto X = X_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + + auto kmeans_params = convert_params(params); + T inertia_temp; + + if (cuvs::core::is_dlpack_device_compatible(X)) { + using labels_mdspan_type = raft::device_vector_view; + using const_mdspan_type = raft::device_matrix_view; + using mdspan_type = raft::device_matrix_view; + + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + sample_weight = + cuvs::core::from_dlpack>(sample_weight_tensor); + } + + cuvs::cluster::kmeans::predict(*res_ptr, + kmeans_params, + cuvs::core::from_dlpack(X_tensor), + sample_weight, + cuvs::core::from_dlpack(centroids_tensor), + cuvs::core::from_dlpack(labels_tensor), + normalize_weight, + raft::make_host_scalar_view(&inertia_temp)); + } else { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + *inertia = inertia_temp; +} + +template +void _cluster_cost(cuvsResources_t res, + DLManagedTensor* X_tensor, + DLManagedTensor* centroids_tensor, + double* cost) +{ + auto X = X_tensor->dl_tensor; + auto res_ptr = reinterpret_cast(res); + + T cost_temp; + + if (cuvs::core::is_dlpack_device_compatible(X)) { + using mdspan_type = raft::device_matrix_view; + + cuvs::cluster::kmeans::cluster_cost(*res_ptr, + cuvs::core::from_dlpack(X_tensor), + cuvs::core::from_dlpack(centroids_tensor), + raft::make_host_scalar_view(&cost_temp)); + } else { + RAFT_FAIL("X dataset must be accessible on device memory"); + } + + *cost = cost_temp; +} +} // namespace + +extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) +{ + return cuvs::core::translate_exceptions([=] { + cuvs::cluster::kmeans::params cpp_params; + *params = new cuvsKMeansParams{.metric = cpp_params.metric, + .n_clusters = cpp_params.n_clusters, + .init = static_cast(cpp_params.init), + .max_iter = cpp_params.max_iter, + .tol = cpp_params.tol, + .oversampling_factor = cpp_params.oversampling_factor, + .batch_samples = cpp_params.batch_samples, + .inertia_check = cpp_params.inertia_check}; + }); +} + +extern "C" cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t params) +{ + return cuvs::core::translate_exceptions([=] { delete params; }); +} + +extern "C" cuvsError_t cuvsKMeansFit(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsKMeansPredict(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _predict(res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _predict( + res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, + DLManagedTensor* X, + DLManagedTensor* centroids, + double* cost) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _cluster_cost(res, X, centroids, cost); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _cluster_cost(res, X, centroids, cost); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} diff --git a/python/cuvs/CMakeLists.txt b/python/cuvs/CMakeLists.txt index 93946cfdb..91cf0d503 100644 --- a/python/cuvs/CMakeLists.txt +++ b/python/cuvs/CMakeLists.txt @@ -57,6 +57,7 @@ target_include_directories(cuvs::cuvs INTERFACE "$= 1 + assert np.allclose(cluster_cost(X, centroids), inertia, rtol=1e-6) + + +@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_cols", [5, 25]) +@pytest.mark.parametrize("n_clusters", [4, 15]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): + X = np.random.random_sample((n_rows, n_cols)).astype(dtype) + X_device = device_ndarray(X) + + centroids = X[:n_clusters] + centroids_device = device_ndarray(centroids) + + inertia = cluster_cost(X_device, centroids_device) + + # compute the nearest centroid to each sample + distances = pairwise_distance( + X_device, centroids_device, metric="sqeuclidean" + ).copy_to_host() + cluster_ids = np.argmin(distances, axis=1) + + cluster_distances = np.take_along_axis( + distances, cluster_ids[:, None], axis=1 + ) + + # need reduced tolerance for float32 + tol = 1e-3 if dtype == np.float32 else 1e-6 + assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol)