Skip to content

Commit

Permalink
Expose kmeans to python
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Feb 26, 2025
1 parent 49298b2 commit 7ce62f9
Show file tree
Hide file tree
Showing 7 changed files with 598 additions and 4 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$<BOOL:${CUVS_NVTX}>:NVTX_ENAB
add_library(
cuvs_c SHARED
src/core/c_api.cpp
src/cluster/kmeans_c.cpp
src/neighbors/brute_force_c.cpp
src/neighbors/ivf_flat_c.cpp
src/neighbors/ivf_pq_c.cpp
Expand Down
201 changes: 201 additions & 0 deletions cpp/include/cuvs/cluster/kmeans.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuvs/core/c_api.h>
#include <cuvs/distance/distance.h>
#include <dlpack/dlpack.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

enum cuvsKMeansInitMethod {
/**
* Sample the centroids using the kmeans++ strategy
*/
KMeansPlusPlus,

/**
* Sample the centroids uniformly at random
*/
Random,

/**
* User provides the array of initial centroids
*/
Array
};

/**
* @brief Hyper-parameters for the kmeans algorithm
*/
struct cuvsKMeansParams {
cuvsDistanceType metric;

/**
* The number of clusters to form as well as the number of centroids to generate (default:8).
*/
int n_clusters;

/**
* Method for initialization, defaults to k-means++:
* - cuvsKMeansInitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm
* to select the initial cluster centers.
* - cuvsKMeansInitMethod::Random (random): Choose 'n_clusters' observations (rows) at
* random from the input data for the initial centroids.
* - cuvsKMeansInitMethod::Array (ndarray): Use 'centroids' as initial cluster centers.
*/
cuvsKMeansInitMethod init;

/**
* Maximum number of iterations of the k-means algorithm for a single run.
*/
int max_iter;

/**
* Relative tolerance with regards to inertia to declare convergence.
*/
double tol;

/**
* Number of instance k-means algorithm will be run with different seeds.
*/
int n_init;

/**
* Oversampling factor for use in the k-means|| algorithm
*/
double oversampling_factor;

/**
* batch_samples and batch_centroids are used to tile 1NN computation which is
* useful to optimize/control the memory footprint
* Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0
* then don't tile the centroids
*/
int batch_samples;

/**
* if 0 then batch_centroids = n_clusters
*/
int batch_centroids;

bool inertia_check;

// TODO: handle balanced kmeans
};

typedef struct cuvsKMeansParams* cuvsKMeansParams_t;

/**
* @brief Allocate Scalar Quantizer params, and populate with default values
*
* @param[in] params cuvsKMeansParams_t to allocate
* @return cuvsError_t
*/
cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params);

/**
* @brief De-allocate Scalar Quantizer params
*
* @param[in] params
* @return cuvsError_t
*/
cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t params);

/**
* @brief Find clusters with k-means algorithm.
*
* Initial centroids are chosen with k-means++ algorithm. Empty
* clusters are reinitialized by choosing new centroids with
* k-means++ algorithm.
*
* @param[in] res opaque C handle
* @param[in] params Parameters for KMeans model.
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation in X.
* [len = n_samples]
* @param[inout] centroids [in] When init is InitMethod::Array, use
* centroids as the initial cluster centers.
* [out] The generated centroids from the
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
* @param[out] n_iter Number of iterations run.
*/
cuvsError_t cuvsKMeansFit(cuvsResources_t res,
cuvsKMeansParams_t params,
DLManagedTensor* X,
DLManagedTensor* sample_weight,
DLManagedTensor* centroids,
double* inertia,
int* n_iter);

/**
* @brief Predict the closest cluster each sample in X belongs to.
*
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model.
* @param[in] X New data to predict.
* [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation in X.
* [len = n_samples]
* @param[in] centroids Cluster centroids. The data must be in
* row-major format.
* [dim = n_clusters x n_features]
* @param[in] normalize_weight True if the weights should be normalized
* @param[out] labels Index of the cluster each sample in X
* belongs to.
* [len = n_samples]
* @param[out] inertia Sum of squared distances of samples to
* their closest cluster center.
*/
cuvsError_t cuvsKMeansPredict(cuvsResources_t res,
cuvsKMeansParams_t params,
DLManagedTensor* X,
DLManagedTensor* sample_weight,
DLManagedTensor* centroids,
DLManagedTensor* labels,
bool normalize_weight,
double* inertia);

/**
* @brief Compute cluster cost
*
* @param[in] handle The raft handle
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] centroids Cluster centroids. The data must be in
* row-major format.
* [dim = n_clusters x n_features]
* @param[out] cost Resulting cluster cost
*
*/
cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res,
DLManagedTensor* X,
DLManagedTensor* centroids,
double* cost);
#ifdef __cplusplus
}
#endif
47 changes: 43 additions & 4 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ struct params : base_params {
*/
double oversampling_factor = 2.0;

// batch_samples and batch_centroids are used to tile 1NN computation which is
// useful to optimize/control the memory footprint
// Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0
// then don't tile the centroids
/**
* batch_samples and batch_centroids are used to tile 1NN computation which is
* useful to optimize/control the memory footprint
* Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0
* then don't tile the centroids
*/
int batch_samples = 1 << 15;

/**
Expand Down Expand Up @@ -1089,6 +1091,43 @@ void transform(raft::resources const& handle,
raft::device_matrix_view<const double, int> X,
raft::device_matrix_view<const double, int> centroids,
raft::device_matrix_view<double, int> X_new);

/**
* @brief Compute cluster cost
*
* @param[in] handle The raft handle
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] centroids Cluster centroids. The data must be in
* row-major format.
* [dim = n_clusters x n_features]
* @param[out] cost Resulting cluster cost
*
*/
void cluster_cost(const raft::resources& handle,
raft::device_matrix_view<const float, int> X,
raft::device_matrix_view<const float, int> centroids,
raft::host_scalar_view<float> cost);

/**
* @brief Compute cluster cost
*
* @param[in] handle The raft handle
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] centroids Cluster centroids. The data must be in
* row-major format.
* [dim = n_clusters x n_features]
* @param[out] cost Resulting cluster cost
*
*/
void cluster_cost(const raft::resources& handle,
raft::device_matrix_view<const double, int> X,
raft::device_matrix_view<const double, int> centroids,
raft::host_scalar_view<double> cost);

/**
* @}
*/
Expand Down
46 changes: 46 additions & 0 deletions cpp/src/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,52 @@ void min_cluster_distance(raft::resources const& handle,
workspace);
}

template <typename DataT, typename IndexT>
void cluster_cost(raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::host_scalar_view<DataT> cost)
{
auto stream = raft::resource::get_cuda_stream(handle);

auto n_clusters = centroids.extent(0);
auto n_samples = X.extent(0);
auto n_features = X.extent(1);

rmm::device_uvector<char> workspace(n_samples * sizeof(IndexT), stream);

rmm::device_uvector<DataT> x_norms(n_samples, stream);
rmm::device_uvector<DataT> centroid_norms(n_clusters, stream);
raft::linalg::rowNorm(
x_norms.data(), X.data_handle(), n_features, n_samples, raft::linalg::L2Norm, true, stream);
raft::linalg::rowNorm(
centroid_norms.data(), centroids, n_features, n_clusters, raft::linalg::L2Norm, true, stream);

rmm::device_uvector<DataT> min_cluster_distance(n_samples, stream);
rmm::device_uvector<DataT> l2_norm_or_distance_buffer(0, stream);

auto metric = cuvs::distance::DistanceType::L2Expanded;

cuvs::cluster::kmeans::min_cluster_distance(handle,
X,
centroids,
min_cluster_distance,
x_norms,
l2_norm_or_distance_buffer,
metric,
n_samples,
n_clusters,
workspace);

rmm::device_scalar<DataT> device_cost(0, stream);
cuvs::cluster::kmeans::cluster_cost(handle,
min_cluster_distance.view(),
workspace,
raft::make_device_scalar_view<DataT>(device_cost.data()),
raft::add_op{});
raft::update_host(cost.data(), device_cost.data(), 1, stream);
}

/**
* @brief Calculates a <key, value> pair for every sample in input 'X' where key is an
* index of one of the 'centroids' (index of the nearest centroid) and 'value'
Expand Down
Loading

0 comments on commit 7ce62f9

Please sign in to comment.