From 8c3b0aad315717c957996103096727dce9931e63 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 21 Feb 2025 09:08:14 +0100 Subject: [PATCH] IVF-PQ: low-precision coarse search --- cpp/CMakeLists.txt | 2 +- .../src/cuvs/cuvs_ann_bench_param_parser.h | 19 ++ cpp/include/cuvs/neighbors/ivf_pq.hpp | 38 +++ cpp/src/neighbors/detail/ann_utils.cuh | 16 + .../neighbors/detail/cagra/cagra_build.cpp | 4 +- .../neighbors/detail/cagra/cagra_build.cuh | 24 +- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 4 + cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh | 302 +++++++++++++++--- .../{ivf_pq_index.cpp => ivf_pq_index.cu} | 78 +++++ 9 files changed, 429 insertions(+), 58 deletions(-) rename cpp/src/neighbors/{ivf_pq_index.cpp => ivf_pq_index.cu} (76%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 945a18157..d2fa30fbf 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -426,7 +426,7 @@ if(BUILD_SHARED_LIBS) src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu - src/neighbors/ivf_pq_index.cpp + src/neighbors/ivf_pq_index.cu src/neighbors/ivf_pq/ivf_pq_build_common.cu src/neighbors/ivf_pq/ivf_pq_serialize.cu src/neighbors/ivf_pq/ivf_pq_deserialize.cu diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h index 7617bfa66..48a0f39ab 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h @@ -154,6 +154,25 @@ void parse_search_param(const nlohmann::json& conf, // set half as default param.pq_param.lut_dtype = CUDA_R_16F; } + + if (conf.contains("coarse_search_dtype")) { + std::string type = conf.at("coarse_search_dtype"); + if (type == "float") { + param.pq_param.coarse_search_dtype = CUDA_R_32F; + } else if (type == "half") { + param.pq_param.coarse_search_dtype = CUDA_R_16F; + } else if (type == "int8") { + param.pq_param.coarse_search_dtype = CUDA_R_8I; + } else { + throw std::runtime_error("coarse_search_dtype: '" + type + + "', should be either 'float', 'half' or 'int8'"); + } + } + + if (conf.contains("max_internal_batch_size")) { + param.pq_param.max_internal_batch_size = conf.at("max_internal_batch_size"); + } + if (conf.contains("refine_ratio")) { param.refine_ratio = conf.at("refine_ratio"); if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); } diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index d85753b7f..05ba6e01d 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -27,6 +27,11 @@ #include #include +#include +#include +#include +#include + namespace cuvs::neighbors::ivf_pq { /** @@ -181,6 +186,21 @@ struct search_params : cuvs::neighbors::search_params { * performance if tweaked incorrectly. */ double preferred_shmem_carveout = 1.0; + /** + * [Experimental] The data type to use as the GEMM element type when searching the clusters to + * probe. + * + * Possible values: [CUDA_R_8I, CUDA_R_16F, CUDA_R_32F]. + * + * - Legacy default: CUDA_R_32F (float) + * - Recommended for performance: CUDA_R_16F (half) + * - Experimental/low-precision: CUDA_R_8I (int8_t) + */ + cudaDataType_t coarse_search_dtype = CUDA_R_32F; + /** + * Set the internal batch size to improve GPU utilization at the cost of larger memory footprint. + */ + uint32_t max_internal_batch_size = 4096; }; /** * @} @@ -427,6 +447,11 @@ struct index : cuvs::neighbors::index { raft::device_matrix_view rotation_matrix() noexcept; raft::device_matrix_view rotation_matrix() const noexcept; + raft::device_matrix_view rotation_matrix_int8( + const raft::resources& res) const; + raft::device_matrix_view rotation_matrix_half( + const raft::resources& res) const; + /** * Accumulated list sizes, sorted in descending order [n_lists + 1]. * The last value contains the total length of the index. @@ -447,6 +472,11 @@ struct index : cuvs::neighbors::index { raft::device_matrix_view centers() noexcept; raft::device_matrix_view centers() const noexcept; + raft::device_matrix_view centers_int8( + const raft::resources& res) const; + raft::device_matrix_view centers_half( + const raft::resources& res) const; + /** Cluster centers corresponding to the lists in the rotated space [n_lists, rot_dim] */ raft::device_matrix_view centers_rot() noexcept; raft::device_matrix_view centers_rot() const noexcept; @@ -485,6 +515,14 @@ struct index : cuvs::neighbors::index { raft::device_matrix centers_rot_; raft::device_matrix rotation_matrix_; + // Lazy-initialized low-precision variants of index members - for low-precision coarse search. + // These are never serialized and not touched during build/extend. + mutable std::optional> centers_int8_; + mutable std::optional> centers_half_; + mutable std::optional> + rotation_matrix_int8_; + mutable std::optional> rotation_matrix_half_; + // Computed members for accelerating search. raft::device_vector data_ptrs_; raft::device_vector inds_ptrs_; diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index 149eea3f1..4c62d5bac 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -195,6 +195,22 @@ struct mapping { /** @} */ }; +template <> +template <> +HDI constexpr auto mapping::operator()(const uint8_t& x) const -> int8_t +{ + // Avoid overflows when converting uint8_t -> int_8 + return static_cast(x >> 1); +} + +template <> +template <> +HDI constexpr auto mapping::operator()(const float& x) const -> int8_t +{ + // Carefully clamp floats if out-of-bounds. + return static_cast(std::clamp(x * 128.0f, -128.0f, 127.0f)); +} + /** * @brief Sets the first num bytes of the block of memory pointed by ptr to the specified value. * diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cpp b/cpp/src/neighbors/detail/cagra/cagra_build.cpp index 490dc0f30..0fdfd1bcf 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cpp +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cpp @@ -29,7 +29,9 @@ ivf_pq_params::ivf_pq_params(raft::matrix_extent dataset_extents, search_params.n_probes = std::max(10, build_params.n_lists * 0.01); search_params.lut_dtype = CUDA_R_16F; search_params.internal_distance_dtype = CUDA_R_16F; + search_params.coarse_search_dtype = CUDA_R_16F; + search_params.max_internal_batch_size = 128 * 1024; - refinement_rate = 2; + refinement_rate = 1; } } // namespace cuvs::neighbors::cagra::graph_build_params diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 4d559a662..352ba8d99 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -164,13 +164,13 @@ void build_knn_graph( const auto num_queries = dataset.extent(0); // Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed. - constexpr uint32_t kMaxQueries = 4096; + const uint32_t max_queries = pq.search_params.max_internal_batch_size; // Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the // rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good // multiple of what is required for the I/O batching below. constexpr size_t kMinWorkspaceRatio = 5; - auto desired_workspace_size = kMaxQueries * kMinWorkspaceRatio * + auto desired_workspace_size = max_queries * kMinWorkspaceRatio * (sizeof(DataT) * dataset.extent(1) // queries (dataset batch) + sizeof(float) * gpu_top_k // distances + sizeof(int64_t) * gpu_top_k // neighbors @@ -189,21 +189,21 @@ void build_knn_graph( node_degree, top_k, gpu_top_k, - kMaxQueries, + max_queries, pq.search_params.n_probes); auto distances = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, gpu_top_k)); + res, workspace_mr, raft::make_extents(max_queries, gpu_top_k)); auto neighbors = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, gpu_top_k)); + res, workspace_mr, raft::make_extents(max_queries, gpu_top_k)); auto refined_distances = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, top_k)); + res, workspace_mr, raft::make_extents(max_queries, top_k)); auto refined_neighbors = raft::make_device_mdarray( - res, workspace_mr, raft::make_extents(kMaxQueries, top_k)); - auto neighbors_host = raft::make_host_matrix(kMaxQueries, gpu_top_k); - auto queries_host = raft::make_host_matrix(kMaxQueries, dataset.extent(1)); - auto refined_neighbors_host = raft::make_host_matrix(kMaxQueries, top_k); - auto refined_distances_host = raft::make_host_matrix(kMaxQueries, top_k); + res, workspace_mr, raft::make_extents(max_queries, top_k)); + auto neighbors_host = raft::make_host_matrix(max_queries, gpu_top_k); + auto queries_host = raft::make_host_matrix(max_queries, dataset.extent(1)); + auto refined_neighbors_host = raft::make_host_matrix(max_queries, top_k); + auto refined_distances_host = raft::make_host_matrix(max_queries, top_k); // TODO(tfeher): batched search with multiple GPUs std::size_t num_self_included = 0; @@ -214,7 +214,7 @@ void build_knn_graph( dataset.data_handle(), dataset.extent(0), dataset.extent(1), - static_cast(kMaxQueries), + static_cast(max_queries), raft::resource::get_cuda_stream(res), workspace_mr); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 21a4b3185..5b6ff827e 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -239,6 +239,10 @@ void set_centers(raft::resources const& handle, index* index, const float* auto stream = raft::resource::get_cuda_stream(handle); auto* device_memory = raft::resource::get_workspace_resource(handle); + // Make sure to have trailing zeroes between dim and dim_ext; + // We rely on this to enable padded tensor gemm kernels during coarse search. + cuvs::spatial::knn::detail::utils::memzero( + index->centers().data_handle(), index->centers().size(), stream); // combine cluster_centers and their norms RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(), sizeof(float) * index->dim_ext(), diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index 05bb99353..5067f9878 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -131,19 +132,17 @@ void select_clusters(raft::resources const& handle, handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { uint32_t col = ix % dim_ext; uint32_t row = ix / dim_ext; - return col < dim ? utils::mapping{}(queries[col + dim * row]) : norm_factor; + if (col < dim) { return utils::mapping{}(queries[col + dim * row]); } + return col == dim ? norm_factor : 0.0f; }); float alpha; float beta; - uint32_t gemm_k = dim; switch (metric) { case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: { - alpha = -2.0; - beta = 0.0; - gemm_k = dim + 1; - RAFT_EXPECTS(gemm_k <= dim_ext, "unexpected gemm_k or dim_ext"); + alpha = -2.0; + beta = 0.0; } break; case cuvs::distance::DistanceType::CosineExpanded: case cuvs::distance::DistanceType::InnerProduct: { @@ -158,7 +157,7 @@ void select_clusters(raft::resources const& handle, false, n_lists, n_queries, - gemm_k, + dim_ext, &alpha, cluster_centers, dim_ext, @@ -180,6 +179,177 @@ void select_clusters(raft::resources const& handle, true); } +template +void select_clusters(raft::resources const& handle, + uint32_t* clusters_to_probe, // [n_queries, n_probes] + int8_t* float_queries, // [n_queries, dim_ext] + uint32_t n_queries, + uint32_t n_probes, + uint32_t n_lists, + uint32_t dim, + uint32_t dim_ext, + cuvs::distance::DistanceType metric, + const T* queries, // [n_queries, dim] + const int8_t* cluster_centers, // [n_lists, dim_ext] + rmm::mr::device_memory_resource* mr) +{ + raft::common::nvtx::range fun_scope( + "ivf_pq::search::select_clusters(n_probes = %u, n_queries = %u, n_lists = %u, dim = %u)", + n_probes, + n_queries, + n_lists, + dim); + auto stream = raft::resource::get_cuda_stream(handle); + int8_t norm_factor; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: norm_factor = -128; break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: norm_factor = 0; break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + auto float_queries_view = + raft::make_device_vector_view(float_queries, dim_ext * n_queries); + raft::linalg::map_offset( + handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { + uint32_t col = ix % dim_ext; + uint32_t row = ix / dim_ext; + if (col < dim) { return utils::mapping{}(queries[col + dim * row]); } + auto m = dim_ext - dim; + if (m == 1 || col > dim) { return norm_factor; } + return static_cast(1 - m); + }); + + using dist_type = int32_t; + dist_type alpha; + dist_type beta; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: { + alpha = -2; + beta = 0; + } break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: { + alpha = -1; + beta = 0; + } break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + rmm::device_uvector qc_distances(n_queries * n_lists, stream, mr); + raft::linalg::gemm(handle, + true, + false, + n_lists, + n_queries, + dim_ext, + &alpha, + cluster_centers, + dim_ext, + float_queries, + dim_ext, + &beta, + qc_distances.data(), + n_lists, + stream); + + // Select neighbor clusters for each query. + rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); + // cuvs::selection::select_k lacks uint32_t-as-a-value support at the moment + raft::matrix::select_k( + handle, + raft::make_device_matrix_view( + qc_distances.data(), n_queries, n_lists), + std::nullopt, + raft::make_device_matrix_view(cluster_dists.data(), n_queries, n_probes), + raft::make_device_matrix_view(clusters_to_probe, n_queries, n_probes), + true); +} + +template +void select_clusters(raft::resources const& handle, + uint32_t* clusters_to_probe, // [n_queries, n_probes] + half* float_queries, // [n_queries, dim_ext] + uint32_t n_queries, + uint32_t n_probes, + uint32_t n_lists, + uint32_t dim, + uint32_t dim_ext, + cuvs::distance::DistanceType metric, + const T* queries, // [n_queries, dim] + const half* cluster_centers, // [n_lists, dim_ext] + rmm::mr::device_memory_resource* mr) +{ + raft::common::nvtx::range fun_scope( + "ivf_pq::search::select_clusters(n_probes = %u, n_queries = %u, n_lists = %u, dim = %u)", + n_probes, + n_queries, + n_lists, + dim); + auto stream = raft::resource::get_cuda_stream(handle); + half norm_factor; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: norm_factor = 0; break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + auto float_queries_view = + raft::make_device_vector_view(float_queries, dim_ext * n_queries); + raft::linalg::map_offset( + handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { + uint32_t col = ix % dim_ext; + uint32_t row = ix / dim_ext; + if (col < dim) { return utils::mapping{}(queries[col + dim * row]); } + return col == dim ? norm_factor : half(0); + }); + + using dist_type = half; + dist_type alpha; + dist_type beta; + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: { + alpha = -2.0; + beta = 0.0; + } break; + case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::InnerProduct: { + alpha = -1.0; + beta = 0.0; + } break; + default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); + } + rmm::device_uvector qc_distances(n_queries * n_lists, stream, mr); + raft::linalg::gemm(handle, + true, + false, + n_lists, + n_queries, + dim_ext, + &alpha, + cluster_centers, + dim_ext, + float_queries, + dim_ext, + &beta, + qc_distances.data(), + n_lists, + stream); + + // Select neighbor clusters for each query. + rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); + cuvs::selection::select_k( + handle, + raft::make_device_matrix_view( + qc_distances.data(), n_queries, n_lists), + std::nullopt, + raft::make_device_matrix_view(cluster_dists.data(), n_queries, n_probes), + raft::make_device_matrix_view(clusters_to_probe, n_queries, n_probes), + true); +} + /** * An approximation to the number of times each cluster appears in a batched sample. * @@ -607,8 +777,23 @@ inline auto get_max_batch_size(raft::resources const& res, return max_batch_size; } -/** Maximum number of queries ivf_pq::search can process in one batch. */ -constexpr uint32_t kMaxQueries = 4096; +template +inline auto get_rotation_matrix(const raft::resources& res, const index& index) + -> raft::device_matrix_view +{ + if constexpr (std::is_same_v) { return index.rotation_matrix(); } + if constexpr (std::is_same_v) { return index.rotation_matrix_half(res); } + if constexpr (std::is_same_v) { return index.rotation_matrix_int8(res); } +} + +template +inline auto get_centers(const raft::resources& res, const index& index) + -> raft::device_matrix_view +{ + if constexpr (std::is_same_v) { return index.centers(); } + if constexpr (std::is_same_v) { return index.centers_half(res); } + if constexpr (std::is_same_v) { return index.centers_int8(res); } +} /** See raft::spatial::knn::ivf_pq::search docs */ template (handle, index).extent(1) + : index.dim_ext(); auto n_probes = std::min(params.n_probes, index.n_lists()); uint32_t max_samples = 0; @@ -678,10 +866,24 @@ inline void search(raft::resources const& handle, auto mr = raft::resource::get_workspace_resource(handle); // Maximum number of query vectors to search at the same time. - const auto max_queries = std::min(std::max(n_queries, 1), kMaxQueries); - auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples); - - rmm::device_uvector float_queries(max_queries * dim_ext, stream, mr); + const auto max_queries = + std::min(std::max(n_queries, 1), params.max_internal_batch_size); + auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples); + + using some_query_t = std:: + variant, rmm::device_uvector, rmm::device_uvector>; + some_query_t gemm_queries( + params.coarse_search_dtype == CUDA_R_32F + ? std::move(some_query_t{ + std::in_place_type_t>{}, max_queries * dim_ext, stream, mr}) + : params.coarse_search_dtype == CUDA_R_16F + ? std::move(some_query_t{ + std::in_place_type_t>{}, max_queries * dim_ext, stream, mr}) + : params.coarse_search_dtype == CUDA_R_8I + ? std::move(some_query_t{ + std::in_place_type_t>{}, max_queries * dim_ext, stream, mr}) + : throw raft::logic_error("Unsupported sparse coarse_search_dtype (only CUDA_R_32F, " + "CUDA_R_16F, and CUDA_R_8I are supported)")); rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); @@ -694,37 +896,49 @@ inline void search(raft::resources const& handle, raft::common::nvtx::range batch_scope( "ivf_pq::search-batch(queries: %u - %u)", offset_q, offset_q + queries_batch); - select_clusters(handle, - clusters_to_probe.data(), - float_queries.data(), - queries_batch, - n_probes, - index.n_lists(), - dim, - dim_ext, - index.metric(), - queries + static_cast(dim) * offset_q, - index.centers().data_handle(), - mr); + std::visit( + [&](auto&& gemm_qs) { + using gemm_type = std::remove_reference_t; + using value_type = std::remove_cv_t; + return select_clusters(handle, + clusters_to_probe.data(), + gemm_qs.data(), + queries_batch, + n_probes, + index.n_lists(), + dim, + dim_ext, + index.metric(), + queries + static_cast(dim) * offset_q, + get_centers(handle, index).data_handle(), + mr); + }, + gemm_queries); // Rotate queries - float alpha = 1.0; - float beta = 0.0; - raft::linalg::gemm(handle, - true, - false, - index.rot_dim(), - queries_batch, - dim, - &alpha, - index.rotation_matrix().data_handle(), - dim, - float_queries.data(), - dim_ext, - &beta, - rot_queries.data(), - index.rot_dim(), - stream); + std::visit( + [&](auto&& gemm_qs) { + using gemm_type = std::remove_reference_t; + using value_type = std::remove_cv_t; + float alpha = std::is_same_v ? 1.0 / 128.0 / 128.0 : 1.0; + float beta = 0.0; + raft::linalg::gemm(handle, + true, + false, + index.rot_dim(), + queries_batch, + dim, + &alpha, + get_rotation_matrix(handle, index).data_handle(), + dim, + gemm_qs.data(), + dim_ext, + &beta, + rot_queries.data(), + index.rot_dim(), + stream); + }, + gemm_queries); if (index.metric() == distance::DistanceType::CosineExpanded) { auto rot_queries_view = raft::make_device_matrix_view( rot_queries.data(), max_queries, index.rot_dim()); diff --git a/cpp/src/neighbors/ivf_pq_index.cpp b/cpp/src/neighbors/ivf_pq_index.cu similarity index 76% rename from cpp/src/neighbors/ivf_pq_index.cpp rename to cpp/src/neighbors/ivf_pq_index.cu index 8f4e5b331..3376f38b1 100644 --- a/cpp/src/neighbors/ivf_pq_index.cpp +++ b/cpp/src/neighbors/ivf_pq_index.cu @@ -16,6 +16,14 @@ #include +#include "detail/ann_utils.cuh" + +#include +#include +#include + +#include + namespace cuvs::neighbors::ivf_pq { index_params index_params::from_dataset(raft::matrix_extent dataset, cuvs::distance::DistanceType metric) @@ -339,6 +347,76 @@ uint32_t index::calculate_pq_dim(uint32_t dim) return r; } +template +raft::device_matrix_view index::rotation_matrix_int8( + const raft::resources& res) const +{ + if (!rotation_matrix_int8_.has_value()) { + rotation_matrix_int8_.emplace( + raft::make_device_mdarray(res, rotation_matrix().extents())); + raft::linalg::map(res, + rotation_matrix_int8_->view(), + cuvs::spatial::knn::detail::utils::mapping{}, + rotation_matrix()); + } + return rotation_matrix_int8_->view(); +} + +template +raft::device_matrix_view index::centers_int8( + const raft::resources& res) const +{ + if (!centers_int8_.has_value()) { + uint32_t n_lists = this->n_lists(); + uint32_t dim = this->dim(); + uint32_t dim_ext = this->dim_ext(); + uint32_t dim_ext_int8 = raft::round_up_safe(dim + 2, 16u); + centers_int8_.emplace(raft::make_device_matrix(res, n_lists, dim_ext_int8)); + auto* inputs = centers().data_handle(); + + // NB: we use all available spare slots between dim and dim_ext to improve precision + raft::linalg::map_offset( + res, centers_int8_->view(), [dim, dim_ext, dim_ext_int8, inputs] __device__(uint32_t ix) { + uint32_t col = ix % dim_ext_int8; + uint32_t row = ix / dim_ext_int8; + if (col < dim) { + return static_cast( + std::clamp(inputs[col + row * dim_ext] * 128.0f, -128.0f, 127.f)); + } + auto x = inputs[row * dim_ext + dim]; + auto c = 64.0f / static_cast(dim_ext_int8 - dim - 1); + auto y = std::clamp(x * c, -128.0f, 127.f); + auto z = std::clamp((y - std::round(y)) * 128.0f, -128.0f, 127.f); + if (col > dim) { return static_cast(std::round(y)); } + return static_cast(z); + }); + } + return centers_int8_->view(); +} + +template +raft::device_matrix_view index::rotation_matrix_half( + const raft::resources& res) const +{ + if (!rotation_matrix_half_.has_value()) { + rotation_matrix_half_.emplace( + raft::make_device_mdarray(res, rotation_matrix().extents())); + raft::linalg::map(res, rotation_matrix_half_->view(), raft::cast_op{}, rotation_matrix()); + } + return rotation_matrix_half_->view(); +} + +template +raft::device_matrix_view index::centers_half( + const raft::resources& res) const +{ + if (!centers_half_.has_value()) { + centers_half_.emplace(raft::make_device_mdarray(res, centers().extents())); + raft::linalg::map(res, centers_half_->view(), raft::cast_op{}, centers()); + } + return centers_half_->view(); +} + template struct index; } // namespace cuvs::neighbors::ivf_pq