From 829ae16c5862a985a1077b3c57a71e71ab9895ba Mon Sep 17 00:00:00 2001 From: Micka Date: Wed, 14 Aug 2024 21:59:42 +0200 Subject: [PATCH] [FEA] Support for Cosine distance in IVF-Flat (#179) Authors: - Micka (https://github.com/lowener) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/179 --- cpp/include/cuvs/cluster/kmeans.hpp | 1 + cpp/include/cuvs/neighbors/ivf_flat.hpp | 72 +++++ cpp/src/cluster/detail/kmeans_balanced.cuh | 115 ++++++-- cpp/src/cluster/kmeans_balanced.cuh | 47 ++-- cpp/src/neighbors/ivf_common.cuh | 1 + cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 45 ++- .../ivf_flat/ivf_flat_build_float_int64_t.cu | 69 ----- .../ivf_flat/ivf_flat_build_int8_t_int64_t.cu | 69 ----- .../ivf_flat_build_uint8_t_int64_t.cu | 69 ----- .../ivf_flat/ivf_flat_extend_float_int64_t.cu | 71 ----- .../ivf_flat_extend_int8_t_int64_t.cu | 71 ----- .../ivf_flat_extend_uint8_t_int64_t.cu | 71 ----- .../ivf_flat/ivf_flat_interleaved_scan.cuh | 263 ++++++++++++++---- .../neighbors/ivf_flat/ivf_flat_search.cuh | 32 ++- cpp/src/neighbors/ivf_flat_index.cpp | 1 + cpp/test/neighbors/ann_ivf_flat.cuh | 56 ++++ cpp/test/neighbors/naive_knn.cuh | 14 +- .../cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx | 4 +- python/cuvs/cuvs/test/test_ivf_flat.py | 3 +- 19 files changed, 544 insertions(+), 530 deletions(-) delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index c6fe4d046..75205fa4f 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -120,6 +120,7 @@ struct params : base_params { * Simple object to specify hyper-parameters to the balanced k-means algorithm. * * The following metrics are currently supported in k-means balanced: + * - CosineExpanded * - InnerProduct * - L2Expanded * - L2SqrtExpanded diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index 918fef5af..44502f942 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -304,6 +304,12 @@ struct index : cuvs::neighbors::index { /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -327,6 +333,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -351,6 +363,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -374,6 +392,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -398,6 +422,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -421,6 +451,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -445,6 +481,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -475,6 +517,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -506,6 +554,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -536,6 +590,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -567,6 +627,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -597,6 +663,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index a09b17532..34bb22e85 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -141,6 +142,53 @@ inline std::enable_if_t> predict_core( raft::compose_op, raft::key_op>()); break; } + case cuvs::distance::DistanceType::CosineExpanded: { + auto workspace = raft::make_device_mdarray( + handle, mr, raft::make_extents((sizeof(int)) * n_rows)); + + auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( + handle, mr, raft::make_extents(n_rows)); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + thrust::fill(raft::resource::get_thrust_policy(handle), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + minClusterAndDistance.size(), + initial_value); + + auto centroidsNorm = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); + raft::linalg::rowNorm(centroidsNorm.data_handle(), + centers, + dim, + n_clusters, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + + cuvs::distance::fusedDistanceNNMinReduce, IdxT>( + minClusterAndDistance.data_handle(), + dataset, + centers, + dataset_norm, + centroidsNorm.data_handle(), + n_rows, + n_clusters, + dim, + (void*)workspace.data_handle(), + false, + false, + true, + params.metric, + 0.0f, + stream); + // Copy keys to output labels + thrust::transform(raft::resource::get_thrust_policy(handle), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + n_rows, + labels, + raft::compose_op, raft::key_op>()); + break; + } case cuvs::distance::DistanceType::InnerProduct: { // TODO: pass buffer rmm::device_uvector distances(n_rows * n_clusters, stream, mr); @@ -320,13 +368,14 @@ void calc_centers_and_sizes(const raft::resources& handle, } /** Computes the L2 norm of the dataset, converting to MathT if necessary */ -template +template void compute_norm(const raft::resources& handle, MathT* dataset_norm, const T* dataset, IdxT dim, IdxT n_rows, MappingOpT mapping_op, + FinOpT norm_fin_op, std::optional mr = std::nullopt) { raft::common::nvtx::range fun_scope("compute_norm"); @@ -347,7 +396,7 @@ void compute_norm(const raft::resources& handle, } raft::linalg::rowNorm( - dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); + dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream, norm_fin_op); } /** @@ -394,7 +443,8 @@ void predict(const raft::resources& handle, std::is_same_v ? 0 : max_minibatch_size * dim, stream, mem_res); bool need_compute_norm = dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded); + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded); rmm::device_uvector cur_dataset_norm( need_compute_norm ? max_minibatch_size : 0, stream, mem_res); const MathT* dataset_norm_ptr = nullptr; @@ -411,8 +461,24 @@ void predict(const raft::resources& handle, // Compute the norm now if it hasn't been pre-computed. if (need_compute_norm) { - compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res); + if (params.metric == cuvs::distance::DistanceType::CosineExpanded) + compute_norm(handle, + cur_dataset_norm.data(), + cur_dataset_ptr, + dim, + minibatch_size, + mapping_op, + raft::sqrt_op{}, + mr); + else + compute_norm(handle, + cur_dataset_norm.data(), + cur_dataset_ptr, + dim, + minibatch_size, + mapping_op, + raft::identity_op{}, + mr); dataset_norm_ptr = cur_dataset_norm.data(); } else if (dataset_norm != nullptr) { dataset_norm_ptr = dataset_norm + offset; @@ -904,7 +970,8 @@ auto build_fine_clusters(const raft::resources& handle, cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); if (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded) { thrust::gather(raft::resource::get_thrust_policy(handle), mc_trainset_ids, mc_trainset_ids + k, @@ -963,7 +1030,8 @@ void build_hierarchical(const raft::resources& handle, IdxT n_rows, MathT* cluster_centers, IdxT n_clusters, - MappingOpT mapping_op) + MappingOpT mapping_op, + const MathT* dataset_norm = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); using LabelT = uint32_t; @@ -980,21 +1048,32 @@ void build_hierarchical(const raft::resources& handle, auto [max_minibatch_size, mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - // Precompute the L2 norm of the dataset if relevant. - const MathT* dataset_norm = nullptr; + // Precompute the L2 norm of the dataset if relevant and not yet computed. rmm::device_uvector dataset_norm_buf(0, stream, device_memory); - if (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded)) { dataset_norm_buf.resize(n_rows, stream); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - compute_norm(handle, - dataset_norm_buf.data() + offset, - dataset + dim * offset, - dim, - minibatch_size, - mapping_op, - device_memory); + if (params.metric == cuvs::distance::DistanceType::CosineExpanded) + compute_norm(handle, + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + raft::sqrt_op{}, + device_memory); + else + compute_norm(handle, + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + raft::identity_op{}, + device_memory); } dataset_norm = (const MathT*)dataset_norm_buf.data(); } diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 040d17b36..306989891 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -71,13 +71,15 @@ namespace cuvs::cluster::kmeans_balanced { * @param[out] centroids The generated centroids [dim = n_clusters x n_features] * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic * datatype. If DataT == MathT, this must be the identity. + * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] */ template void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op()) + MappingOpT mapping_op = raft::identity_op(), + std::optional> X_norm = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "Number of features in dataset and centroids are different"); @@ -88,14 +90,16 @@ void fit(const raft::resources& handle, "The number of centroids must be strictly positive and cannot exceed the number of " "points in the training dataset."); - cuvs::cluster::kmeans::detail::build_hierarchical(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.data_handle(), - centroids.extent(0), - mapping_op); + cuvs::cluster::kmeans::detail::build_hierarchical( + handle, + params, + X.extent(1), + X.data_handle(), + X.extent(0), + centroids.data_handle(), + centroids.extent(0), + mapping_op, + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } /** @@ -125,6 +129,7 @@ void fit(const raft::resources& handle, * @param[out] labels The output labels [dim = n_samples] * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic * datatype. If DataT == MathT, this must be the identity. + * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] */ template X, raft::device_matrix_view centroids, raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) + MappingOpT mapping_op = raft::identity_op(), + std::optional> X_norm = std::nullopt) { RAFT_EXPECTS(X.extent(0) == labels.extent(0), "Number of rows in dataset and labels are different"); @@ -149,15 +155,18 @@ void predict(const raft::resources& handle, static_cast(std::numeric_limits::max()), "The chosen label type cannot represent all cluster labels"); - cuvs::cluster::kmeans::detail::predict(handle, - params, - centroids.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - mapping_op); + cuvs::cluster::kmeans::detail::predict( + handle, + params, + centroids.data_handle(), + centroids.extent(0), + X.extent(1), + X.data_handle(), + X.extent(0), + labels.data_handle(), + mapping_op, + raft::resource::get_workspace_resource(handle), + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } /** diff --git a/cpp/src/neighbors/ivf_common.cuh b/cpp/src/neighbors/ivf_common.cuh index 60d43bed6..fb73fb8a9 100644 --- a/cpp/src/neighbors/ivf_common.cuh +++ b/cpp/src/neighbors/ivf_common.cuh @@ -254,6 +254,7 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); } } break; + case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor; if (factor != 1.0) { diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index e8df3e3d6..fb110d810 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -335,6 +335,37 @@ void extend(raft::resources const& handle, if (!index->center_norms().has_value()) { index->allocate_center_norms(handle); if (index->center_norms().has_value()) { + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + } else { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream); + } + RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); + } + } else if (index->center_norms().has_value() && index->adaptive_centers()) { + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + } else { raft::linalg::rowNorm(index->center_norms()->data_handle(), index->centers().data_handle(), dim, @@ -342,16 +373,7 @@ void extend(raft::resources const& handle, raft::linalg::L2Norm, true, stream); - RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } - } else if (index->center_norms().has_value() && index->adaptive_centers()) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - raft::linalg::L2Norm, - true, - stream); RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } } @@ -384,7 +406,8 @@ inline auto build(raft::resources const& handle, "unsupported data type"); RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); - + RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::CosineExpanded || dim > 1, + "Cosine metric requires more than one dim"); index index(handle, params, dim); utils::memzero( index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); @@ -414,7 +437,7 @@ inline auto build(raft::resources const& handle, index.centers().data_handle(), index.n_lists(), index.dim()); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = static_cast(index.metric()); + kmeans_params.metric = index.metric(); cuvs::cluster::kmeans_balanced::fit( handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu deleted file mode 100644 index 56bb71094..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } -CUVS_INST_IVF_FLAT_BUILD(float, int64_t); - -#undef CUVS_INST_IVF_FLAT_BUILD - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu deleted file mode 100644 index 4803868c0..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } -CUVS_INST_IVF_FLAT_BUILD(int8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_BUILD - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu deleted file mode 100644 index e087f94c4..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } -CUVS_INST_IVF_FLAT_BUILD(uint8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_BUILD - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu deleted file mode 100644 index 2636067bf..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } -CUVS_INST_IVF_FLAT_EXTEND(float, int64_t); - -#undef CUVS_INST_IVF_FLAT_EXTEND - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu deleted file mode 100644 index 191cb9f39..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } -CUVS_INST_IVF_FLAT_EXTEND(int8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_EXTEND - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu deleted file mode 100644 index 29b7e7b69..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } -CUVS_INST_IVF_FLAT_EXTEND(uint8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_EXTEND - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index ce29a7e7c..86ef55928 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -104,13 +104,16 @@ __device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit * values) */ -template +template struct loadAndComputeDist { Lambda compute_dist; AccT& dist; + AccT& norm_query; + AccT& norm_data; - __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(AccT& dist, Lambda op, AccT& norm_query, AccT& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -134,6 +137,10 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < Veclen; ++k) { compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query += queryRegs[k] * queryRegs[k]; + norm_data += encV[k] * encV[k]; + } } } } @@ -163,7 +170,12 @@ struct loadAndComputeDist { const int d = (i * kUnroll + j) * Veclen; #pragma unroll for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), encV[k]); + T q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += encV[k] * encV[k]; + } } } } @@ -184,20 +196,28 @@ struct loadAndComputeDist { raft::ldg(enc, data + loadDataIdx); #pragma unroll for (int k = 0; k < Veclen; k++) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), enc[k]); + T q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += enc[k] * enc[k]; + } } } } }; // This handles uint8_t 8, 16 Veclens -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -220,6 +240,10 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < veclen_int; k++) { compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs[k], queryRegs[k], norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -244,7 +268,12 @@ struct loadAndComputeDist { const int d = (i * kUnroll + j) * veclen_int; #pragma unroll for (int k = 0; k < veclen_int; ++k) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), encV[k]); + uint32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -267,6 +296,10 @@ struct loadAndComputeDist { for (int k = 0; k < veclen_int; k++) { uint32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc[k], enc[k], norm_data); + } } } } @@ -274,13 +307,16 @@ struct loadAndComputeDist { // Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while // using above common template of int2/int4 -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -294,6 +330,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, @@ -313,6 +353,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } } @@ -330,17 +374,24 @@ struct loadAndComputeDist { uint32_t enc = reinterpret_cast(data)[lane_id]; uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -354,6 +405,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } @@ -374,6 +429,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } } @@ -391,17 +450,24 @@ struct loadAndComputeDist { uint32_t enc = reinterpret_cast(data)[lane_id]; uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -415,6 +481,10 @@ struct loadAndComputeDist { uint32_t encV = data[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = query_shared[shmemIndex + j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query += queryRegs * queryRegs; + norm_data += encV * encV; + } } } @@ -434,6 +504,10 @@ struct loadAndComputeDist { uint32_t encV = data[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += encV * encV; + } } } } @@ -451,18 +525,25 @@ struct loadAndComputeDist { uint32_t enc = data[lane_id]; uint32_t q = raft::shfl(queryReg, d, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += enc * enc; + } } } }; // This device function is for int8 veclens 4, 8 and 16 -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; int32_t& dist; + int32_t& norm_query; + int32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(int32_t& dist, Lambda op, int32_t& norm_query, int32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -485,6 +566,10 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < veclen_int; k++) { compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs[k], queryRegs[k], norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -513,6 +598,10 @@ struct loadAndComputeDist { for (int k = 0; k < veclen_int; ++k) { int32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -531,17 +620,24 @@ struct loadAndComputeDist { for (int k = 0; k < veclen_int; k++) { int32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); // Here 4 is for 1 - int; compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc[k], enc[k], norm_data); + } } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + int32_t& norm_query; + int32_t& norm_data; + __device__ __forceinline__ + loadAndComputeDist(int32_t& dist, Lambda op, int32_t& norm_query, int32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, @@ -554,6 +650,10 @@ struct loadAndComputeDist { int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } @@ -574,6 +674,10 @@ struct loadAndComputeDist { int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; int32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryReg, queryReg, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } } @@ -588,16 +692,23 @@ struct loadAndComputeDist { int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; int32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + int32_t& norm_query; + int32_t& norm_data; + __device__ __forceinline__ + loadAndComputeDist(int32_t& dist, Lambda op, int32_t& norm_query, int32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -609,6 +720,11 @@ struct loadAndComputeDist { #pragma unroll for (int j = 0; j < kUnroll; ++j) { compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); + if constexpr (ComputeNorm) { + norm_query += int32_t{query_shared[shmemIndex + j]} * int32_t{query_shared[shmemIndex + j]}; + norm_data += int32_t{data[loadIndex + j * kIndexGroupSize]} * + int32_t{data[loadIndex + j * kIndexGroupSize]}; + } } } @@ -625,9 +741,12 @@ struct loadAndComputeDist { for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, - raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize), - data[lane_id + j * kIndexGroupSize]); + int32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + compute_dist(dist, q, data[lane_id + j * kIndexGroupSize]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += data[lane_id + j * kIndexGroupSize] * data[lane_id + j * kIndexGroupSize]; + } } } } @@ -638,7 +757,12 @@ struct loadAndComputeDist { const int loadDim = dimBlocks + lane_id; int32_t queryReg = loadDim < dim ? query[loadDim] : 0; for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - compute_dist(dist, raft::shfl(queryReg, d, raft::WarpSize), data[lane_id]); + int32_t q = raft::shfl(queryReg, d, raft::WarpSize); + compute_dist(dist, q, data[lane_id]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += int32_t{data[lane_id]} * int32_t{data[lane_id]}; + } } } }; @@ -691,6 +815,7 @@ using block_sort_t = typename flat_block_sort::typ template lc(dist, - compute_dist); + // Process first shm_assisted_dim dimensions (always using shared memory) + loadAndComputeDist lc( + dist, compute_dist, norm_query, norm_dataset); for (int pos = 0; pos < shm_assisted_dim; pos += raft::WarpSize, data += kIndexGroupSize * raft::WarpSize) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); } - } - if (dim > query_smem_elems) { - // The default path - using shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += raft::WarpSize) { - lc.runLoadShflAndCompute(data, query, pos, lane_id); - } - lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); - } else { - // when shm_assisted_dim == full_warps_along_dim < dim - if (valid) { - loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); + if (dim > query_smem_elems) { + // The default path - using shfl ops - for dimensions beyond query_smem_elems + loadAndComputeDist lc( + dist, compute_dist, norm_query, norm_dataset); + for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += raft::WarpSize) { + lc.runLoadShflAndCompute(data, query, pos, lane_id); + } + lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); + } else { + // when shm_assisted_dim == full_warps_along_dim < dim + loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT, ComputeNorm> lc( + dist, compute_dist, norm_query, norm_dataset); for (int pos = full_warps_along_dim; pos < dim; pos += Veclen, data += kIndexGroupSize * Veclen) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); @@ -814,7 +940,13 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) } // Enqueue one element per thread - const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; + float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; + + if constexpr (ComputeNorm) { + if (valid) + val = val / (raft::sqrt(static_cast(norm_query)) * + raft::sqrt(static_cast(norm_dataset))); + } if constexpr (kManageLocalTopK) { queue.add(val, sample_offset + vec_id); } else { @@ -864,6 +996,7 @@ uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMem template , raft::identity_op>({}, {}, std::forward(args)...); + case cuvs::distance::DistanceType::CosineExpanded: + // NB: "Ascending" is reversed because the post-processing step is done after that sort + return launch_kernel>( + {}, + raft::compose_op(raft::mul_const_op{-1.0f}, raft::add_const_op{1.0f}), + std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 43111a7de..b7dac3ef8 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -79,6 +79,8 @@ void search_impl(raft::resources const& handle, // also we might need additional storage for select_k rmm::device_uvector indices_tmp_dev(0, stream, search_mr); rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); + auto distance_buffer_dev_view = raft::make_device_matrix_view( + distance_buffer_dev.data(), n_queries, index.n_lists()); size_t float_query_size; if constexpr (std::is_integral_v) { @@ -122,6 +124,19 @@ void search_impl(raft::resources const& handle, RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); break; } + case cuvs::distance::DistanceType::CosineExpanded: { + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + alpha = -1.0f; + beta = 0.0f; + break; + } default: { alpha = 1.0f; beta = 0.0f; @@ -144,12 +159,25 @@ void search_impl(raft::resources const& handle, index.n_lists(), stream); + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + auto n_lists = index.n_lists(); + const auto* q_norm_ptr = query_norm_dev.data(); + const auto* index_center_norm_ptr = index.center_norms()->data_handle(); + raft::linalg::map_offset( + handle, + distance_buffer_dev_view, + [=] __device__(const uint32_t idx, const float dist) { + const auto query = idx / n_lists; + const auto cluster = idx % n_lists; + return dist / (q_norm_ptr[query] * index_center_norm_ptr[cluster]); + }, + raft::make_const_mdspan(distance_buffer_dev_view)); + } RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); cuvs::selection::select_k( handle, - raft::make_device_matrix_view( - distance_buffer_dev.data(), n_queries, index.n_lists()), + raft::make_const_mdspan(distance_buffer_dev_view), std::nullopt, raft::make_device_matrix_view(coarse_distances_dev.data(), n_queries, n_probes), raft::make_device_matrix_view( diff --git a/cpp/src/neighbors/ivf_flat_index.cpp b/cpp/src/neighbors/ivf_flat_index.cpp index b249a9c29..6f7d11e50 100644 --- a/cpp/src/neighbors/ivf_flat_index.cpp +++ b/cpp/src/neighbors/ivf_flat_index.cpp @@ -193,6 +193,7 @@ void index::allocate_center_norms(raft::resources const& res) case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Unexpanded: case cuvs::distance::DistanceType::L2SqrtUnexpanded: + case cuvs::distance::DistanceType::CosineExpanded: center_norms_ = raft::make_device_vector(res, n_lists()); break; default: center_norms_ = std::nullopt; diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index f907baa1f..17ec84097 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -20,7 +20,9 @@ #include "naive_knn.cuh" #include +#include #include +#include #include #include @@ -533,63 +535,109 @@ const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) {1000, 10000, 1, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, // test dims that do not fit into kernel shared memory limits {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2051, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2051, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 2052, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2052, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2053, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2053, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 2056, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2056, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, // various random combinations {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, // host input data {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, // // host input data with prefetching for kernel copy overlapping {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::InnerProduct, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, // test splitting the big query batches (> max gridDim.y) into smaller batches {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, false}, {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::InnerProduct, false}, + {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::CosineExpanded, false}, {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, true}, + {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, true}, // test radix_sort for getting the cluster selection {1000, @@ -608,6 +656,14 @@ const std::vector> inputs = { raft::matrix::detail::select::warpsort::kMaxCapacity * 4, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, + 10000, + 16, + 10, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + cuvs::distance::DistanceType::CosineExpanded, + false}, // The following two test cases should show very similar recall. // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers diff --git a/cpp/test/neighbors/naive_knn.cuh b/cpp/test/neighbors/naive_knn.cuh index 90547150f..553e667aa 100644 --- a/cpp/test/neighbors/naive_knn.cuh +++ b/cpp/test/neighbors/naive_knn.cuh @@ -41,7 +41,9 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, if (midx >= m) return; IdxT grid_size = IdxT(blockDim.y) * IdxT(gridDim.y); for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; nidx += grid_size) { - EvalT acc = EvalT(0); + EvalT acc = EvalT(0); + EvalT normX = EvalT(0); + EvalT normY = EvalT(0); for (IdxT i = 0; i < k; ++i) { IdxT xidx = i + midx * k; IdxT yidx = i + nidx * k; @@ -51,6 +53,11 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, case cuvs::distance::DistanceType::InnerProduct: { acc += xv * yv; } break; + case cuvs::distance::DistanceType::CosineExpanded: { + acc += xv * yv; + normX += xv * xv; + normY += yv * yv; + } break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2SqrtUnexpanded: case cuvs::distance::DistanceType::L2Expanded: @@ -66,6 +73,9 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, case cuvs::distance::DistanceType::L2SqrtUnexpanded: { acc = raft::sqrt(acc); } break; + case cuvs::distance::DistanceType::CosineExpanded: { + acc = 1 - acc / (raft::sqrt(normX) * raft::sqrt(normY)); + } default: break; } dist[midx * n + nidx] = acc; @@ -118,7 +128,7 @@ void naive_knn(raft::resources const& handle, static_cast(k), dist_topk + offset * k, indices_topk + offset * k, - type != cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::is_min_close(type), mr); } RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx index a3799144c..25b9b2aee 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx @@ -58,12 +58,14 @@ cdef class IndexParams: metric : str, default = "sqeuclidean" String denoting the metric type. Valid values for metric: ["sqeuclidean", "inner_product", - "euclidean"], where + "euclidean", "cosine"], where - sqeuclidean is the euclidean distance without the square root operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, - euclidean is the euclidean distance - inner product distance is defined as distance(a, b) = \\sum_i a_i * b_i. + - cosine distance is defined as + distance(a, b) = 1 - \\sum_i a_i * b_i / ( ||a||_2 * ||b||_2). kmeans_n_iters : int, default = 20 The number of iterations searching for kmeans centers during index building. diff --git a/python/cuvs/cuvs/test/test_ivf_flat.py b/python/cuvs/cuvs/test/test_ivf_flat.py index bb50d3573..9dd4097dc 100644 --- a/python/cuvs/cuvs/test/test_ivf_flat.py +++ b/python/cuvs/cuvs/test/test_ivf_flat.py @@ -92,6 +92,7 @@ def run_ivf_flat_build_search_test( skl_metric = { "sqeuclidean": "sqeuclidean", "inner_product": "cosine", + "cosine": "cosine", "euclidean": "euclidean", }[metric] nn_skl = NearestNeighbors( @@ -107,7 +108,7 @@ def run_ivf_flat_build_search_test( @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("dtype", [np.float32]) @pytest.mark.parametrize( - "metric", ["sqeuclidean", "inner_product", "euclidean"] + "metric", ["sqeuclidean", "inner_product", "euclidean", "cosine"] ) def test_ivf_flat(inplace, dtype, metric): run_ivf_flat_build_search_test(