Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-PQ: low-precision coarse search #715

Open
wants to merge 3 commits into
base: branch-25.04
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ if(BUILD_SHARED_LIBS)
src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu
src/neighbors/ivf_pq_index.cpp
src/neighbors/ivf_pq_index.cu
src/neighbors/ivf_pq/ivf_pq_build_common.cu
src/neighbors/ivf_pq/ivf_pq_serialize.cu
src/neighbors/ivf_pq/ivf_pq_deserialize.cu
Expand Down
19 changes: 19 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ void parse_search_param(const nlohmann::json& conf,
// set half as default
param.pq_param.lut_dtype = CUDA_R_16F;
}

if (conf.contains("coarse_search_dtype")) {
std::string type = conf.at("coarse_search_dtype");
if (type == "float") {
param.pq_param.coarse_search_dtype = CUDA_R_32F;
} else if (type == "half") {
param.pq_param.coarse_search_dtype = CUDA_R_16F;
} else if (type == "int8") {
param.pq_param.coarse_search_dtype = CUDA_R_8I;
} else {
throw std::runtime_error("coarse_search_dtype: '" + type +
"', should be either 'float', 'half' or 'int8'");
}
}

if (conf.contains("max_internal_batch_size")) {
param.pq_param.max_internal_batch_size = conf.at("max_internal_batch_size");
}

if (conf.contains("refine_ratio")) {
param.refine_ratio = conf.at("refine_ratio");
if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); }
Expand Down
38 changes: 38 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

#include <optional>
#include <tuple>
#include <variant>
#include <vector>

namespace cuvs::neighbors::ivf_pq {

/**
Expand Down Expand Up @@ -181,6 +186,21 @@ struct search_params : cuvs::neighbors::search_params {
* performance if tweaked incorrectly.
*/
double preferred_shmem_carveout = 1.0;
/**
* [Experimental] The data type to use as the GEMM element type when searching the clusters to
* probe.
*
* Possible values: [CUDA_R_8I, CUDA_R_16F, CUDA_R_32F].
*
* - Legacy default: CUDA_R_32F (float)
* - Recommended for performance: CUDA_R_16F (half)
* - Experimental/low-precision: CUDA_R_8I (int8_t)
*/
cudaDataType_t coarse_search_dtype = CUDA_R_32F;
/**
* Set the internal batch size to improve GPU utilization at the cost of larger memory footprint.
*/
uint32_t max_internal_batch_size = 4096;
};
/**
* @}
Expand Down Expand Up @@ -427,6 +447,11 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> rotation_matrix() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> rotation_matrix() const noexcept;

raft::device_matrix_view<const int8_t, uint32_t, raft::row_major> rotation_matrix_int8(
const raft::resources& res) const;
raft::device_matrix_view<const half, uint32_t, raft::row_major> rotation_matrix_half(
const raft::resources& res) const;

/**
* Accumulated list sizes, sorted in descending order [n_lists + 1].
* The last value contains the total length of the index.
Expand All @@ -447,6 +472,11 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> centers() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers() const noexcept;

raft::device_matrix_view<const int8_t, uint32_t, raft::row_major> centers_int8(
const raft::resources& res) const;
raft::device_matrix_view<const half, uint32_t, raft::row_major> centers_half(
const raft::resources& res) const;

/** Cluster centers corresponding to the lists in the rotated space [n_lists, rot_dim] */
raft::device_matrix_view<float, uint32_t, raft::row_major> centers_rot() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers_rot() const noexcept;
Expand Down Expand Up @@ -485,6 +515,14 @@ struct index : cuvs::neighbors::index {
raft::device_matrix<float, uint32_t, raft::row_major> centers_rot_;
raft::device_matrix<float, uint32_t, raft::row_major> rotation_matrix_;

// Lazy-initialized low-precision variants of index members - for low-precision coarse search.
// These are never serialized and not touched during build/extend.
mutable std::optional<raft::device_matrix<int8_t, uint32_t, raft::row_major>> centers_int8_;
mutable std::optional<raft::device_matrix<half, uint32_t, raft::row_major>> centers_half_;
mutable std::optional<raft::device_matrix<int8_t, uint32_t, raft::row_major>>
rotation_matrix_int8_;
mutable std::optional<raft::device_matrix<half, uint32_t, raft::row_major>> rotation_matrix_half_;

// Computed members for accelerating search.
raft::device_vector<uint8_t*, uint32_t, raft::row_major> data_ptrs_;
raft::device_vector<IdxT*, uint32_t, raft::row_major> inds_ptrs_;
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/neighbors/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,22 @@ struct mapping {
/** @} */
};

template <>
template <>
HDI constexpr auto mapping<int8_t>::operator()(const uint8_t& x) const -> int8_t
{
// Avoid overflows when converting uint8_t -> int_8
return static_cast<int8_t>(x >> 1);
}

template <>
template <>
HDI constexpr auto mapping<int8_t>::operator()(const float& x) const -> int8_t
{
// Carefully clamp floats if out-of-bounds.
return static_cast<int8_t>(std::clamp<float>(x * 128.0f, -128.0f, 127.0f));
}

/**
* @brief Sets the first num bytes of the block of memory pointed by ptr to the specified value.
*
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/neighbors/detail/cagra/cagra_build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ ivf_pq_params::ivf_pq_params(raft::matrix_extent<int64_t> dataset_extents,
search_params.n_probes = std::max<uint32_t>(10, build_params.n_lists * 0.01);
search_params.lut_dtype = CUDA_R_16F;
search_params.internal_distance_dtype = CUDA_R_16F;
search_params.coarse_search_dtype = CUDA_R_16F;
search_params.max_internal_batch_size = 128 * 1024;

refinement_rate = 2;
refinement_rate = 1;
}
} // namespace cuvs::neighbors::cagra::graph_build_params
24 changes: 12 additions & 12 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ void build_knn_graph(
const auto num_queries = dataset.extent(0);

// Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed.
constexpr uint32_t kMaxQueries = 4096;
const uint32_t max_queries = pq.search_params.max_internal_batch_size;

// Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the
// rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good
// multiple of what is required for the I/O batching below.
constexpr size_t kMinWorkspaceRatio = 5;
auto desired_workspace_size = kMaxQueries * kMinWorkspaceRatio *
auto desired_workspace_size = max_queries * kMinWorkspaceRatio *
(sizeof(DataT) * dataset.extent(1) // queries (dataset batch)
+ sizeof(float) * gpu_top_k // distances
+ sizeof(int64_t) * gpu_top_k // neighbors
Expand All @@ -189,21 +189,21 @@ void build_knn_graph(
node_degree,
top_k,
gpu_top_k,
kMaxQueries,
max_queries,
pq.search_params.n_probes);

auto distances = raft::make_device_mdarray<float>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, gpu_top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, gpu_top_k));
auto neighbors = raft::make_device_mdarray<int64_t>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, gpu_top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, gpu_top_k));
auto refined_distances = raft::make_device_mdarray<float>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, top_k));
res, workspace_mr, raft::make_extents<int64_t>(max_queries, top_k));
auto refined_neighbors = raft::make_device_mdarray<int64_t>(
res, workspace_mr, raft::make_extents<int64_t>(kMaxQueries, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(kMaxQueries, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(kMaxQueries, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(kMaxQueries, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(kMaxQueries, top_k);
res, workspace_mr, raft::make_extents<int64_t>(max_queries, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_queries, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(max_queries, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_queries, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(max_queries, top_k);

// TODO(tfeher): batched search with multiple GPUs
std::size_t num_self_included = 0;
Expand All @@ -214,7 +214,7 @@ void build_knn_graph(
dataset.data_handle(),
dataset.extent(0),
dataset.extent(1),
static_cast<int64_t>(kMaxQueries),
static_cast<int64_t>(max_queries),
raft::resource::get_cuda_stream(res),
workspace_mr);

Expand Down
4 changes: 4 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ void set_centers(raft::resources const& handle, index<IdxT>* index, const float*
auto stream = raft::resource::get_cuda_stream(handle);
auto* device_memory = raft::resource::get_workspace_resource(handle);

// Make sure to have trailing zeroes between dim and dim_ext;
// We rely on this to enable padded tensor gemm kernels during coarse search.
cuvs::spatial::knn::detail::utils::memzero(
index->centers().data_handle(), index->centers().size(), stream);
// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(),
sizeof(float) * index->dim_ext(),
Expand Down
Loading