Skip to content

Commit

Permalink
Merge pull request #719 from rapidsai/branch-25.02
Browse files Browse the repository at this point in the history
Forward-merge branch-25.02 into branch-25.04
  • Loading branch information
GPUtester authored Feb 24, 2025
2 parents a1e0cc0 + 1591029 commit a2a6a67
Show file tree
Hide file tree
Showing 16 changed files with 358 additions and 146 deletions.
1 change: 1 addition & 0 deletions conda/recipes/cuvs-bench-cpu/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ requirements:
- pyyaml
- python
- requests
- sklearn>=1.5
about:
home: https://rapids.ai/
license: Apache-2.0
Expand Down
1 change: 1 addition & 0 deletions conda/recipes/cuvs-bench/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ requirements:
- python
- requests
- rmm ={{ minor_version }}
- sklearn>=1.5
about:
home: https://rapids.ai/
license: Apache-2.0
Expand Down
8 changes: 5 additions & 3 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ struct index : cuvs::neighbors::index {
using search_params_type = cagra::search_params;
using index_type = IdxT;
using value_type = T;
using dataset_index_type = int64_t;

static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");

Expand Down Expand Up @@ -510,14 +512,14 @@ struct index : cuvs::neighbors::index {
*/
template <typename DatasetT>
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<int64_t>, DatasetT>>
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<dataset_index_type>, DatasetT>>
{
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<dataset_index_type>, DatasetT>>
{
dataset_ = std::move(dataset);
}
Expand Down Expand Up @@ -561,7 +563,7 @@ struct index : cuvs::neighbors::index {
cuvs::distance::DistanceType metric_;
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
std::unique_ptr<neighbors::dataset<dataset_index_type>> dataset_;
};
/**
* @}
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/cagra_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ index<T, IdxT> merge(raft::resources const& handle,
const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices)
{
using cagra_index_t = cuvs::neighbors::cagra::index<T, IdxT>;
using ds_idx_type = typename cagra_index_t::dataset_index_type;

std::size_t dim = 0;
std::size_t new_dataset_size = 0;
int64_t stride = -1;

for (auto index : indices) {
for (cagra_index_t* index : indices) {
RAFT_EXPECTS(index != nullptr,
"Null pointer detected in 'indices'. Ensure all elements are valid before usage.");
using ds_idx_type = decltype(index->data().n_rows());
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());
strided_dset != nullptr) {
if (dim == 0) {
Expand All @@ -74,8 +76,7 @@ index<T, IdxT> merge(raft::resources const& handle,
IdxT offset = 0;

auto merge_dataset = [&](T* dst) {
for (auto index : indices) {
using ds_idx_type = decltype(index->data().n_rows());
for (cagra_index_t* index : indices) {
auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());

RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst + offset * dim,
Expand Down
24 changes: 16 additions & 8 deletions cpp/src/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1047,24 +1047,32 @@ void GnndGraph<Index_t>::init_random_graph()
for (size_t seg_idx = 0; seg_idx < static_cast<size_t>(num_segments); seg_idx++) {
// random sequence (range: 0~nrow)
// segment_x stores neighbors which id % num_segments == x
std::vector<Index_t> rand_seq(nrow / num_segments);
std::vector<Index_t> rand_seq((nrow + num_segments - 1) / num_segments);
std::iota(rand_seq.begin(), rand_seq.end(), 0);
auto gen = std::default_random_engine{seg_idx};
std::shuffle(rand_seq.begin(), rand_seq.end(), gen);

#pragma omp parallel for
for (size_t i = 0; i < nrow; i++) {
size_t base_idx = i * node_degree + seg_idx * segment_size;
auto h_neighbor_list = h_graph + base_idx;
auto h_dist_list = h_dists.data_handle() + base_idx;
size_t base_idx = i * node_degree + seg_idx * segment_size;
auto h_neighbor_list = h_graph + base_idx;
auto h_dist_list = h_dists.data_handle() + base_idx;
size_t idx = base_idx;
size_t self_in_this_seg = 0;
for (size_t j = 0; j < static_cast<size_t>(segment_size); j++) {
size_t idx = base_idx + j;
Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx;
if ((size_t)id == i) {
id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx;
idx++;
id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx;
self_in_this_seg = 1;
}
h_neighbor_list[j].id_with_flag() = id;
h_dist_list[j] = std::numeric_limits<DistData_t>::max();

h_neighbor_list[j].id_with_flag() =
j < (rand_seq.size() - self_in_this_seg) && size_t(id) < nrow
? id
: std::numeric_limits<Index_t>::max();
h_dist_list[j] = std::numeric_limits<DistData_t>::max();
idx++;
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions cpp/tests/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,12 @@ class AnnCagraIndexMergeTest : public ::testing::TestWithParam<AnnCagraInputs> {
(ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows))
GTEST_SKIP();

// Avoid splitting datasets with a size of 0
if (ps.n_rows <= 3) GTEST_SKIP();

// IVF_PQ requires the `n_rows >= n_lists`.
if (ps.n_rows < 8 && ps.build_algo == graph_build_algo::IVF_PQ) GTEST_SKIP();

size_t queries_size = ps.n_queries * ps.k;
std::vector<IdxT> indices_Cagra(queries_size);
std::vector<IdxT> indices_naive(queries_size);
Expand Down Expand Up @@ -1161,6 +1167,24 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// Corner cases for small datasets
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{2},
{3, 5, 31, 32, 64, 101},
{1, 10},
{2}, // k
{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{0}, // query size
{0},
{256},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// Varying dim and build algo.
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
Expand Down
43 changes: 30 additions & 13 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.cuvs;

import java.util.Arrays;
import java.util.BitSet;
import java.util.List;

/**
Expand All @@ -28,7 +29,8 @@ public class BruteForceQuery {

private List<Integer> mapping;
private float[][] queryVectors;
private long[] prefilter;
private BitSet[] prefilters;
private int numDocs = -1;
private int topK;

/**
Expand All @@ -40,12 +42,15 @@ public class BruteForceQuery {
* @param topK the top k results to return
* @param prefilter the prefilter data to use while searching the BRUTEFORCE
* index
* @param numDocs Maximum of bits in each prefilter, representing number of documents in this index.
* Used only when prefilter(s) is/are passed.
*/
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, long[] prefilter) {
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, BitSet[] prefilters, int numDocs) {
this.queryVectors = queryVectors;
this.mapping = mapping;
this.topK = topK;
this.prefilter = prefilter;
this.prefilters = prefilters;
this.numDocs = numDocs;
}

/**
Expand Down Expand Up @@ -78,16 +83,25 @@ public int getTopK() {
/**
* Gets the prefilter long array
*
* @return a long array
* @return an array of bitsets
*/
public long[] getPrefilter() {
return prefilter;
public BitSet[] getPrefilters() {
return prefilters;
}

/**
* Gets the number of documents supposed to be in this index, as used for prefilters
*
* @return number of documents as an integer
*/
public int getNumDocs() {
return numDocs;
}

@Override
public String toString() {
return "BruteForceQuery [mapping=" + mapping + ", queryVectors=" + Arrays.toString(queryVectors) + ", prefilter="
+ Arrays.toString(prefilter) + ", topK=" + topK + "]";
+ Arrays.toString(prefilters) + ", topK=" + topK + "]";
}

/**
Expand All @@ -96,7 +110,8 @@ public String toString() {
public static class Builder {

private float[][] queryVectors;
private long[] prefilter;
private BitSet[] prefilters;
private int numDocs;
private List<Integer> mapping;
private int topK = 2;

Expand Down Expand Up @@ -134,13 +149,15 @@ public Builder withTopK(int topK) {
}

/**
* Sets the prefilter data for building the {@link BruteForceQuery}.
* Sets the prefilters data for building the {@link BruteForceQuery}.
*
* @param prefilter a one-dimensional long array
* @param prefilters array of bitsets, as many as queries, each containing as
* many bits as there are vectors in the index
* @return an instance of this Builder
*/
public Builder withPrefilter(long[] prefilter) {
this.prefilter = prefilter;
public Builder withPrefilter(BitSet[] prefilters, int numDocs) {
this.prefilters = prefilters;
this.numDocs = numDocs;
return this;
}

Expand All @@ -150,7 +167,7 @@ public Builder withPrefilter(long[] prefilter) {
* @return an instance of {@link BruteForceQuery}
*/
public BruteForceQuery build() {
return new BruteForceQuery(queryVectors, mapping, topK, prefilter);
return new BruteForceQuery(queryVectors, mapping, topK, prefilters, numDocs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Objects;
import java.util.UUID;

Expand Down Expand Up @@ -59,7 +63,7 @@ public class BruteForceIndexImpl implements BruteForceIndex{
FunctionDescriptor.of(ADDRESS, ADDRESS, C_LONG, C_LONG, ADDRESS, ADDRESS, C_INT));

private static final MethodHandle searchMethodHandle = downcallHandle("search_brute_force_index",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG, C_LONG));
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG));

private static final MethodHandle destroyIndexMethodHandle = downcallHandle("destroy_brute_force_index",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS));
Expand Down Expand Up @@ -169,16 +173,24 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
long numQueries = cuvsQuery.getQueryVectors().length;
long numBlocks = cuvsQuery.getTopK() * numQueries;
int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0;
long prefilterDataLength = cuvsQuery.getPrefilter() != null ? cuvsQuery.getPrefilter().length : 0;
long numRows = dataset != null ? dataset.length : 0;

SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_LONG);
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_FLOAT);
MemorySegment neighborsMemorySegment = resources.getArena().allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = resources.getArena().allocate(distancesSequenceLayout);
MemorySegment prefilterDataMemorySegment = cuvsQuery.getPrefilter() != null
? Util.buildMemorySegment(resources.getArena(), cuvsQuery.getPrefilter())
: MemorySegment.NULL;

// prepare the prefiltering data
long prefilterDataLength = 0;
MemorySegment prefilterDataMemorySegment = MemorySegment.NULL;
BitSet[] prefilters = cuvsQuery.getPrefilters();
if (prefilters != null && prefilters.length > 0) {
BitSet concatenatedFilters = Util.concatenate(prefilters, cuvsQuery.getNumDocs());
long filters[] = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = Util.buildMemorySegment(resources.getArena(), filters);
prefilterDataLength = cuvsQuery.getNumDocs() * prefilters.length;
}

MemorySegment querySeg = Util.buildMemorySegment(resources.getArena(), cuvsQuery.getQueryVectors());
try (var localArena = Arena.ofConfined()) {
MemorySegment returnValue = localArena.allocate(C_INT);
Expand All @@ -193,7 +205,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
distancesMemorySegment,
returnValue,
prefilterDataMemorySegment,
prefilterDataLength, numRows
prefilterDataLength
);
checkError(returnValue.get(C_INT, 0L), "searchMethodHandle");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;

import com.nvidia.cuvs.GPUInfo;
Expand Down Expand Up @@ -184,6 +186,14 @@ public static MemorySegment buildMemorySegment(Arena arena, long[] data) {
return dataMemorySegment;
}

public static MemorySegment buildMemorySegment(Arena arena, byte[] data) {
int cells = data.length;
MemoryLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, C_CHAR);
MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
MemorySegment.copy(data, 0, dataMemorySegment, C_CHAR, 0, cells);
return dataMemorySegment;
}

/**
* A utility method for building a {@link MemorySegment} for a 2D float array.
*
Expand All @@ -201,4 +211,20 @@ public static MemorySegment buildMemorySegment(Arena arena, float[][] data) {
}
return dataMemorySegment;
}

public static BitSet concatenate(BitSet[] arr, int maxSizeOfEachBitSet) {
BitSet ret = new BitSet(maxSizeOfEachBitSet * arr.length);
for (int i = 0; i < arr.length; i++) {
BitSet b = arr[i];
if (b == null || b.length() == 0) {
ret.set(i * maxSizeOfEachBitSet, (i + 1) * maxSizeOfEachBitSet);
} else {
for (int j = 0; j < maxSizeOfEachBitSet; j++) {
ret.set(i * maxSizeOfEachBitSet + j, b.get(j));
}
}
}
return ret;
}

}
Loading

0 comments on commit a2a6a67

Please sign in to comment.