Skip to content

Commit

Permalink
IVF-PQ: low-precision coarse search
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Feb 21, 2025
1 parent 6b5b472 commit 8c3b0aa
Show file tree
Hide file tree
Showing 9 changed files with 429 additions and 58 deletions.
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ if(BUILD_SHARED_LIBS)
src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu
src/neighbors/ivf_pq_index.cpp
src/neighbors/ivf_pq_index.cu
src/neighbors/ivf_pq/ivf_pq_build_common.cu
src/neighbors/ivf_pq/ivf_pq_serialize.cu
src/neighbors/ivf_pq/ivf_pq_deserialize.cu
Expand Down
19 changes: 19 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ void parse_search_param(const nlohmann::json& conf,
// set half as default
param.pq_param.lut_dtype = CUDA_R_16F;
}

if (conf.contains("coarse_search_dtype")) {
std::string type = conf.at("coarse_search_dtype");
if (type == "float") {
param.pq_param.coarse_search_dtype = CUDA_R_32F;
} else if (type == "half") {
param.pq_param.coarse_search_dtype = CUDA_R_16F;
} else if (type == "int8") {
param.pq_param.coarse_search_dtype = CUDA_R_8I;
} else {
throw std::runtime_error("coarse_search_dtype: '" + type +
"', should be either 'float', 'half' or 'int8'");
}
}

if (conf.contains("max_internal_batch_size")) {
param.pq_param.max_internal_batch_size = conf.at("max_internal_batch_size");
}

if (conf.contains("refine_ratio")) {
param.refine_ratio = conf.at("refine_ratio");
if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); }
Expand Down
38 changes: 38 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

#include <optional>
#include <tuple>
#include <variant>
#include <vector>

namespace cuvs::neighbors::ivf_pq {

/**
Expand Down Expand Up @@ -181,6 +186,21 @@ struct search_params : cuvs::neighbors::search_params {
* performance if tweaked incorrectly.
*/
double preferred_shmem_carveout = 1.0;
/**
* [Experimental] The data type to use as the GEMM element type when searching the clusters to
* probe.
*
* Possible values: [CUDA_R_8I, CUDA_R_16F, CUDA_R_32F].
*
* - Legacy default: CUDA_R_32F (float)
* - Recommended for performance: CUDA_R_16F (half)
* - Experimental/low-precision: CUDA_R_8I (int8_t)
*/
cudaDataType_t coarse_search_dtype = CUDA_R_32F;
/**
* Set the internal batch size to improve GPU utilization at the cost of larger memory footprint.
*/
uint32_t max_internal_batch_size = 4096;
};
/**
* @}
Expand Down Expand Up @@ -427,6 +447,11 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> rotation_matrix() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> rotation_matrix() const noexcept;

raft::device_matrix_view<const int8_t, uint32_t, raft::row_major> rotation_matrix_int8(
const raft::resources& res) const;
raft::device_matrix_view<const half, uint32_t, raft::row_major> rotation_matrix_half(
const raft::resources& res) const;

/**
* Accumulated list sizes, sorted in descending order [n_lists + 1].
* The last value contains the total length of the index.
Expand All @@ -447,6 +472,11 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> centers() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers() const noexcept;

raft::device_matrix_view<const int8_t, uint32_t, raft::row_major> centers_int8(
const raft::resources& res) const;
raft::device_matrix_view<const half, uint32_t, raft::row_major> centers_half(
const raft::resources& res) const;

/** Cluster centers corresponding to the lists in the rotated space [n_lists, rot_dim] */
raft::device_matrix_view<float, uint32_t, raft::row_major> centers_rot() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers_rot() const noexcept;
Expand Down Expand Up @@ -485,6 +515,14 @@ struct index : cuvs::neighbors::index {
raft::device_matrix<float, uint32_t, raft::row_major> centers_rot_;
raft::device_matrix<float, uint32_t, raft::row_major> rotation_matrix_;

// Lazy-initialized low-precision variants of index members - for low-precision coarse search.
// These are never serialized and not touched during build/extend.
mutable std::optional<raft::device_matrix<int8_t, uint32_t, raft::row_major>> centers_int8_;
mutable std::optional<raft::device_matrix<half, uint32_t, raft::row_major>> centers_half_;
mutable std::optional<raft::device_matrix<int8_t, uint32_t, raft::row_major>>
rotation_matrix_int8_;
mutable std::optional<raft::device_matrix<half, uint32_t, raft::row_major>> rotation_matrix_half_;

// Computed members for accelerating search.
raft::device_vector<uint8_t*, uint32_t, raft::row_major> data_ptrs_;
raft::device_vector<IdxT*, uint32_t, raft::row_major> inds_ptrs_;
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/neighbors/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,22 @@ struct mapping {
/** @} */
};

template <>
template <>
HDI constexpr auto mapping<int8_t>::operator()(const uint8_t& x) const -> int8_t
{
// Avoid overflows when converting uint8_t -> int_8
return static_cast<int8_t>(x >> 1);
}

template <>
template <>
HDI constexpr auto mapping<int8_t>::operator()(const float& x) const -> int8_t
{
// Carefully clamp floats if out-of-bounds.
return static_cast<int8_t>(std::clamp<float>(x * 128.0f, -128.0f, 127.0f));
}

/**
* @brief Sets the first num bytes of the block of memory pointed by ptr to the specified value.
*
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/neighbors/detail/cagra/cagra_build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ ivf_pq_params::ivf_pq_params(raft::matrix_extent<int64_t> dataset_extents,
search_params.n_probes = std::max<uint32_t>(10, build_params.n_lists * 0.01);
search_params.lut_dtype = CUDA_R_16F;
search_params.internal_distance_dtype = CUDA_R_16F;
search_params.coarse_search_dtype = CUDA_R_16F;
search_params.max_internal_batch_size = 128 * 1024;

refinement_rate = 2;
refinement_rate = 1;
}
} // namespace cuvs::neighbors::cagra::graph_build_params
24 changes: 12 additions & 12 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ void build_knn_graph(
const auto num_queries = dataset.extent(0);

// Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed.
constexpr uint32_t kMaxQueries = 4096;
const uint32_t max_queries = pq.search_params.max_internal_batch_size;

// Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the
// rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good
// multiple of what is required for the I/O batching below.
constexpr size_t kMinWorkspaceRatio = 5;
auto desired_workspace_size = kMaxQueries * kMinWorkspaceRatio *
auto desired_workspace_size = max_queries * kMinWorkspaceRatio *
(sizeof(DataT) * dataset.extent(1) // queries (dataset batch)
+ sizeof(float) * gpu_top_k // distances
+ sizeof(int64_t) * gpu_top_k // neighbors
Expand All @@ -189,21 +189,21 @@ void build_knn_graph(
node_degree,
top_k,
gpu_top_k,
kMaxQueries,
max_queries,
pq.search_params.n_probes);

auto distances = raft::make_device_mdarray<float>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, gpu_top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, gpu_top_k));
auto neighbors = raft::make_device_mdarray<int64_t>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, gpu_top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, gpu_top_k));
auto refined_distances = raft::make_device_mdarray<float>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, top_k));
auto refined_neighbors = raft::make_device_mdarray<int64_t>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(kMaxQueries, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(kMaxQueries, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(kMaxQueries, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(kMaxQueries, top_k);
res, workspace_mr, raft::make_extents<int64_t>(max_queries, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_queries, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(max_queries, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_queries, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(max_queries, top_k);

// TODO(tfeher): batched search with multiple GPUs
std::size_t num_self_included = 0;
Expand All @@ -214,7 +214,7 @@ void build_knn_graph(
dataset.data_handle(),
dataset.extent(0),
dataset.extent(1),
static_cast<int64_t>(kMaxQueries),
static_cast<int64_t>(max_queries),
raft::resource::get_cuda_stream(res),
workspace_mr);

Expand Down
4 changes: 4 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ void set_centers(raft::resources const& handle, index<IdxT>* index, const float*
auto stream = raft::resource::get_cuda_stream(handle);
auto* device_memory = raft::resource::get_workspace_resource(handle);

// Make sure to have trailing zeroes between dim and dim_ext;
// We rely on this to enable padded tensor gemm kernels during coarse search.
cuvs::spatial::knn::detail::utils::memzero(
index->centers().data_handle(), index->centers().size(), stream);
// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(),
sizeof(float) * index->dim_ext(),
Expand Down
Loading

0 comments on commit 8c3b0aa

Please sign in to comment.