diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 310d4e7a6..dffc94f06 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -1047,24 +1047,32 @@ void GnndGraph::init_random_graph() for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { // random sequence (range: 0~nrow) // segment_x stores neighbors which id % num_segments == x - std::vector rand_seq(nrow / num_segments); + std::vector rand_seq((nrow + num_segments - 1) / num_segments); std::iota(rand_seq.begin(), rand_seq.end(), 0); auto gen = std::default_random_engine{seg_idx}; std::shuffle(rand_seq.begin(), rand_seq.end(), gen); #pragma omp parallel for for (size_t i = 0; i < nrow; i++) { - size_t base_idx = i * node_degree + seg_idx * segment_size; - auto h_neighbor_list = h_graph + base_idx; - auto h_dist_list = h_dists.data_handle() + base_idx; + size_t base_idx = i * node_degree + seg_idx * segment_size; + auto h_neighbor_list = h_graph + base_idx; + auto h_dist_list = h_dists.data_handle() + base_idx; + size_t idx = base_idx; + size_t self_in_this_seg = 0; for (size_t j = 0; j < static_cast(segment_size); j++) { - size_t idx = base_idx + j; Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; if ((size_t)id == i) { - id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; + idx++; + id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; + self_in_this_seg = 1; } - h_neighbor_list[j].id_with_flag() = id; - h_dist_list[j] = std::numeric_limits::max(); + + h_neighbor_list[j].id_with_flag() = + j < (rand_seq.size() - self_in_this_seg) && size_t(id) < nrow + ? id + : std::numeric_limits::max(); + h_dist_list[j] = std::numeric_limits::max(); + idx++; } } } diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 1e695f9a8..e9408264e 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -952,6 +952,12 @@ class AnnCagraIndexMergeTest : public ::testing::TestWithParam { (ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows)) GTEST_SKIP(); + // Avoid splitting datasets with a size of 0 + if (ps.n_rows <= 3) GTEST_SKIP(); + + // IVF_PQ requires the `n_rows >= n_lists`. + if (ps.n_rows < 8 && ps.build_algo == graph_build_algo::IVF_PQ) GTEST_SKIP(); + size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size); @@ -1161,6 +1167,24 @@ inline std::vector generate_inputs() {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + // Corner cases for small datasets + inputs2 = raft::util::itertools::product( + {2}, + {3, 5, 31, 32, 64, 101}, + {1, 10}, + {2}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, + {0}, // query size + {0}, + {256}, + {1}, + {cuvs::distance::DistanceType::L2Expanded}, + {false}, + {true}, + {0.995}); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + // Varying dim and build algo. inputs2 = raft::util::itertools::product( {100},