Skip to content

Commit

Permalink
Improve multi-CTA algorithm (#492)
Browse files Browse the repository at this point in the history
It has been reported that when the number of search results is large, for example 100, using the multi-CTA algorithm can cause a decrease in recall. This PR is intended to alleviate this low recall issue.

close #208

Authors:
  - Akira Naruse (https://github.com/anaruse)
  - Tamas Bela Feher (https://github.com/tfeher)
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - tsuki (https://github.com/enp1s0)
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #492
  • Loading branch information
anaruse authored Jan 30, 2025
1 parent 9489d0c commit 836183e
Show file tree
Hide file tree
Showing 13 changed files with 532 additions and 257 deletions.
62 changes: 53 additions & 9 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,31 @@ void add_node_core(
raft::resource::get_cuda_stream(handle));
raft::resource::sync_stream(handle);

// Check search results
constexpr int max_warnings = 3;
int num_warnings = 0;
for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) {
std::uint32_t invalid_edges = 0;
for (std::uint32_t i = 0; i < base_degree; i++) {
if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; }
}
if (invalid_edges > 0) {
if (num_warnings < max_warnings) {
RAFT_LOG_WARN(
"Invalid edges found in search results "
"(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)",
(uint64_t)vec_i,
(uint64_t)invalid_edges,
(uint64_t)degree,
(uint64_t)base_degree);
}
num_warnings += 1;
}
}
if (num_warnings > max_warnings) {
RAFT_LOG_WARN("The number of queries that contain invalid search results: %d", num_warnings);
}

// Step 2: rank-based reordering
#pragma omp parallel
{
Expand All @@ -147,9 +172,16 @@ void add_node_core(
for (std::uint32_t i = 0; i < base_degree; i++) {
std::uint32_t detourable_node_count = 0;
const auto a_id = host_neighbor_indices(vec_i, i);
if (a_id >= idx.size()) {
// If the node ID is not valid, the number of detours is increased
// to a value greater than the maximum, so that the edge to that
// node is not selected as much as possible.
detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1);
continue;
}
for (std::uint32_t j = 0; j < i; j++) {
const auto b0_id = host_neighbor_indices(vec_i, j);
assert(b0_id < idx.size());
if (b0_id >= idx.size()) { continue; }
for (std::uint32_t k = 0; k < degree; k++) {
const auto b1_id = updated_graph(b0_id, k);
if (a_id == b1_id) {
Expand All @@ -160,6 +192,7 @@ void add_node_core(
}
detourable_node_count_list[i] = std::make_pair(a_id, detourable_node_count);
}

std::sort(detourable_node_count_list.begin(),
detourable_node_count_list.end(),
[&](const std::pair<IdxT, std::size_t> a, const std::pair<IdxT, std::size_t> b) {
Expand All @@ -181,13 +214,18 @@ void add_node_core(
const auto target_new_node_id = old_size + batch.offset() + vec_i;
for (std::size_t i = 0; i < num_rev_edges; i++) {
const auto target_node_id = updated_graph(old_size + batch.offset() + vec_i, i);

if (target_node_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", target_node_id);
}
IdxT replace_id = new_size;
IdxT replace_id_j = 0;
std::size_t replace_num_incoming_edges = 0;
for (std::int32_t j = degree - 1; j >= static_cast<std::int32_t>(rev_edge_search_range);
j--) {
const auto neighbor_id = updated_graph(target_node_id, j);
const auto neighbor_id = updated_graph(target_node_id, j);
if (neighbor_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", neighbor_id);
}
const std::size_t num_incoming_edges = host_num_incoming_edges(neighbor_id);
if (num_incoming_edges > replace_num_incoming_edges) {
// Check duplication
Expand All @@ -206,10 +244,6 @@ void add_node_core(
replace_id_j = j;
}
}
if (replace_id >= new_size) {
std::fprintf(stderr, "Invalid rev edge index (%u)\n", replace_id);
return;
}
updated_graph(target_node_id, replace_id_j) = target_new_node_id;
rev_edges[i] = replace_id;
}
Expand All @@ -221,13 +255,15 @@ void add_node_core(
const auto rank_based_list_ptr =
updated_graph.data_handle() + (old_size + batch.offset() + vec_i) * degree;
const auto rev_edges_return_list_ptr = rev_edges.data();
while (num_add < degree) {
while ((num_add < degree) &&
((rank_base_i < degree) || (rev_edges_return_i < num_rev_edges))) {
const auto node_list_ptr =
interleave_switch == 0 ? rank_based_list_ptr : rev_edges_return_list_ptr;
auto& node_list_index = interleave_switch == 0 ? rank_base_i : rev_edges_return_i;
const auto max_node_list_index = interleave_switch == 0 ? degree : num_rev_edges;
for (; node_list_index < max_node_list_index; node_list_index++) {
const auto candidate = node_list_ptr[node_list_index];
if (candidate >= new_size) { continue; }
// Check duplication
bool dup = false;
for (std::uint32_t j = 0; j < num_add; j++) {
Expand All @@ -244,6 +280,12 @@ void add_node_core(
}
interleave_switch = 1 - interleave_switch;
}
if (num_add < degree) {
RAFT_FAIL("Number of edges is not enough (target_new_node_id:%lu, num_add:%lu, degree:%lu)",
(uint64_t)target_new_node_id,
(uint64_t)num_add,
(uint64_t)degree);
}
for (std::uint32_t i = 0; i < degree; i++) {
updated_graph(target_new_node_id, i) = temp[i];
}
Expand All @@ -259,7 +301,9 @@ void add_graph_nodes(
raft::host_matrix_view<IdxT, std::int64_t> updated_graph_view,
const cagra::extend_params& params)
{
assert(input_updated_dataset_view.extent(0) >= index.size());
if (input_updated_dataset_view.extent(0) < index.size()) {
RAFT_FAIL("Updated dataset must be not smaller than the previous index state.");
}

const std::size_t initial_dataset_size = index.size();
const std::size_t new_dataset_size = input_updated_dataset_view.extent(0);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DataT, IndexT, DistanceT, CagraSampleFilterT_s>> plan =
factory<DataT, IndexT, DistanceT, CagraSampleFilterT_s>::create(
res, params, dataset_desc, queries.extent(1), graph.extent(1), topk);
res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk);

plan->check(topk);

Expand Down
64 changes: 47 additions & 17 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const IndexT* __restrict__ seed_ptr, // [num_seeds]
const uint32_t num_seeds,
IndexT* __restrict__ visited_hash_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -145,19 +147,29 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(

const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (best_index_team_local != raft::upper_bound<IndexT>() &&
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
} else {
result_distances_ptr[i] = raft::upper_bound<DistanceT>();
result_indices_ptr[i] = raft::upper_bound<IndexT>();
if (best_index_team_local != raft::upper_bound<IndexT>()) {
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
} else if ((traversed_hash_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) {
// Deactivate this entry as it has been already used by others.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
}
}
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
}
}
}

template <typename IndexT, typename DistanceT, typename DATASET_DESCRIPTOR_T>
template <typename IndexT,
typename DistanceT,
typename DATASET_DESCRIPTOR_T,
int STATIC_RESULT_POSITION = 1>
RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
IndexT* __restrict__ result_child_indices_ptr,
DistanceT* __restrict__ result_child_distances_ptr,
Expand All @@ -168,13 +180,17 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const uint32_t knn_k,
// hashmap
IndexT* __restrict__ visited_hashmap_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hashmap_ptr,
const uint32_t traversed_hash_bitlen,
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width)
const uint32_t search_width,
int* __restrict__ result_position = nullptr,
const int max_result_position = 0)
{
constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask<IndexT>::value;
constexpr IndexT invalid_index = raft::upper_bound<IndexT>();
constexpr IndexT invalid_index = ~static_cast<IndexT>(0);

// Read child indices of parents from knn graph and check if the distance
// computaiton is necessary.
Expand All @@ -186,11 +202,22 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
child_id = invalid_index;
} else if ((traversed_hashmap_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) {
// Deactivate this entry as this has been already used by others.
child_id = invalid_index;
}
}
result_child_indices_ptr[i] = child_id;
if (STATIC_RESULT_POSITION) {
result_child_indices_ptr[i] = child_id;
} else if (child_id != invalid_index) {
int j = atomicSub(result_position, 1) - 1;
result_child_indices_ptr[j] = child_id;
}
}
__syncthreads();

Expand All @@ -201,9 +228,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const auto compute_distance = dataset_desc.compute_distance_impl;
const auto args = dataset_desc.args.load();
const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0;
const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0];
for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) {
const bool valid_i = i < num_k;
const auto child_id = valid_i ? result_child_indices_ptr[i] : invalid_index;
const auto j = i + ofst;
const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position);
const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index;

// We should be calling `dataset_desc.compute_distance(..)` here as follows:
// > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index);
Expand All @@ -213,9 +242,10 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
(child_id != invalid_index) ? compute_distance(args, child_id)
: (lead_lane ? raft::upper_bound<DistanceT>() : 0),
team_size_bits);
__syncwarp();

// Store the distance
if (valid_i && lead_lane) { result_child_distances_ptr[i] = child_dist; }
if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; }
}
}

Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class factory {
search_params const& params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk);
return dispatch_kernel(res, plan, dataset_desc);
}

Expand All @@ -56,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
Loading

0 comments on commit 836183e

Please sign in to comment.