Skip to content

Commit

Permalink
Avoiding extra copy, adding randomized testing for Brute Force with p…
Browse files Browse the repository at this point in the history
…refiltering
  • Loading branch information
Ishan Chattopadhyaya committed Feb 19, 2025
1 parent 9428394 commit 291f92e
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 49 deletions.
21 changes: 18 additions & 3 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class BruteForceQuery {
private List<Integer> mapping;
private float[][] queryVectors;
private BitSet[] prefilters;
private int numDocs = -1;
private int topK;

/**
Expand All @@ -41,12 +42,15 @@ public class BruteForceQuery {
* @param topK the top k results to return
* @param prefilter the prefilter data to use while searching the BRUTEFORCE
* index
* @param numDocs Maximum of bits in each prefilter, representing number of documents in this index.
* Used only when prefilter(s) is/are passed.
*/
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, BitSet[] prefilters) {
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, BitSet[] prefilters, int numDocs) {
this.queryVectors = queryVectors;
this.mapping = mapping;
this.topK = topK;
this.prefilters = prefilters;
this.numDocs = numDocs;
}

/**
Expand Down Expand Up @@ -85,6 +89,15 @@ public BitSet[] getPrefilters() {
return prefilters;
}

/**
* Gets the number of documents supposed to be in this index, as used for prefilters
*
* @return number of documents as an integer
*/
public int getNumDocs() {
return numDocs;
}

@Override
public String toString() {
return "BruteForceQuery [mapping=" + mapping + ", queryVectors=" + Arrays.toString(queryVectors) + ", prefilter="
Expand All @@ -98,6 +111,7 @@ public static class Builder {

private float[][] queryVectors;
private BitSet[] prefilters;
private int numDocs;
private List<Integer> mapping;
private int topK = 2;

Expand Down Expand Up @@ -141,8 +155,9 @@ public Builder withTopK(int topK) {
* many bits as there are vectors in the index
* @return an instance of this Builder
*/
public Builder withPrefilter(BitSet[] prefilters) {
public Builder withPrefilter(BitSet[] prefilters, int numDocs) {
this.prefilters = prefilters;
this.numDocs = numDocs;
return this;
}

Expand All @@ -152,7 +167,7 @@ public Builder withPrefilter(BitSet[] prefilters) {
* @return an instance of {@link BruteForceQuery}
*/
public BruteForceQuery build() {
return new BruteForceQuery(queryVectors, mapping, topK, prefilters);
return new BruteForceQuery(queryVectors, mapping, topK, prefilters, numDocs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
Expand Down Expand Up @@ -183,15 +185,10 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
MemorySegment prefilterDataMemorySegment = MemorySegment.NULL;
BitSet[] prefilters = cuvsQuery.getPrefilters();
if (prefilters != null && prefilters.length > 0) {
BitSet concatenatedFilters = Util.concatenate(prefilters);
byte[] paddedFilterBytes = new byte[Util.roundUp(concatenatedFilters.toByteArray().length, 4)];
int pad = paddedFilterBytes.length - concatenatedFilters.toByteArray().length;
for (int i=0; i<concatenatedFilters.toByteArray().length; i++) {
paddedFilterBytes[pad+i] = concatenatedFilters.toByteArray()[i];
}
for (int i=0; i<pad; i++) paddedFilterBytes[i] = 0;
prefilterDataMemorySegment = Util.buildMemorySegment(resources.getArena(), paddedFilterBytes);
prefilterDataLength = paddedFilterBytes.length;
BitSet concatenatedFilters = Util.concatenate(prefilters, cuvsQuery.getNumDocs());
long filters[] = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = Util.buildMemorySegment(resources.getArena(), filters);
prefilterDataLength = cuvsQuery.getNumDocs() * prefilters.length;
}

MemorySegment querySeg = Util.buildMemorySegment(resources.getArena(), cuvsQuery.getQueryVectors());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,30 +212,19 @@ public static MemorySegment buildMemorySegment(Arena arena, float[][] data) {
return dataMemorySegment;
}

public static BitSet concatenate(BitSet[] arr) {
// compute the size of the largest bitset
int size = -1;
for (BitSet b: arr) {
if (b == null) continue;
if (size == -1) size = Math.max(size, b.length());
}
if (size == -1) throw new RuntimeException("No valid bitset present in the array");
BitSet ret = new BitSet(size * arr.length);
public static BitSet concatenate(BitSet[] arr, int maxSizeOfEachBitSet) {
BitSet ret = new BitSet(maxSizeOfEachBitSet * arr.length);
for (int i=0; i<arr.length; i++) {
BitSet b = arr[i];
if (b == null || b.length() == 0) {
ret.set(i*size, (i+1)*size);
ret.set(i*maxSizeOfEachBitSet, (i+1)*maxSizeOfEachBitSet);
} else {
for (int j=0; j<size; j++) {
ret.set(i*size + j, b.get(j));
for (int j=0; j<maxSizeOfEachBitSet; j++) {
ret.set(i*maxSizeOfEachBitSet + j, b.get(j));
}
}
}
return ret;
}

public static int roundUp(int value, int multiplier) {
return (value + multiplier - 1) / multiplier * multiplier;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public void testIndexingAndSearchingFlow() throws Throwable {
BruteForceQuery cuvsQueryWithFiltering = new BruteForceQuery.Builder()
.withTopK(3)
.withQueryVectors(queries)
.withPrefilter(new BitSet[] {prefilter, prefilter, prefilter, prefilter})
.withPrefilter(new BitSet[] {prefilter, prefilter, prefilter, prefilter}, dataset.length)
.withMapping(map)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package com.nvidia.cuvs;

import java.lang.invoke.MethodHandles;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;

import org.junit.Before;
Expand Down Expand Up @@ -58,10 +60,23 @@ private void tmpResultsTopKWithRandomValues() throws Throwable {
int dimensions = random.nextInt(DIMENSIONS_LIMIT) + 1;
int numQueries = random.nextInt(NUM_QUERIES_LIMIT) + 1;
int topK = Math.min(random.nextInt(TOP_K_LIMIT) + 1, datasetSize);

boolean usePrefilter = random.nextBoolean();
if (datasetSize < topK)
datasetSize = topK;

BitSet[] prefilters = null;
if (usePrefilter) {
prefilters = new BitSet[numQueries];
for (int i=0; i<numQueries; i++) {
BitSet randomFilter = new BitSet(datasetSize);
for (int j=0; j<datasetSize; j++) {
randomFilter.set(j, random.nextBoolean());
}
prefilters[i] = randomFilter;
}
}


// Generate a random dataset
float[][] dataset = generateData(random, datasetSize, dimensions);

Expand Down Expand Up @@ -90,14 +105,15 @@ private void tmpResultsTopKWithRandomValues() throws Throwable {
assert topK > 0 && topK <= datasetSize : "Invalid topK value.";

// Generate expected results using brute force
List<List<Integer>> expected = generateExpectedResults(topK, dataset, queries, log);
List<List<Integer>> expected = generateExpectedResults(topK, dataset, queries, prefilters, log);

// Create CuVS index and query
try (CuVSResources resources = CuVSResources.create()) {

BruteForceQuery query = new BruteForceQuery.Builder()
.withTopK(topK)
.withQueryVectors(queries)
.withPrefilter(prefilters, dataset.length)
.build();

BruteForceIndexParams indexParams = new BruteForceIndexParams.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private void tmpResultsTopKWithRandomValues() throws Throwable {
assert topK > 0 && topK <= datasetSize : "Invalid topK value.";

// Generate expected results using brute force
List<List<Integer>> expected = generateExpectedResults(topK, dataset, queries, log);
List<List<Integer>> expected = generateExpectedResults(topK, dataset, queries, null, log);

// Create CuVS index and query
try (CuVSResources resources = CuVSResources.create()) {
Expand Down
15 changes: 10 additions & 5 deletions java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
Expand Down Expand Up @@ -50,16 +51,21 @@ protected float[][] generateData(Random random, int rows, int cols) {
return data;
}

protected List<List<Integer>> generateExpectedResults(int topK, float[][] dataset, float[][] queries, Logger log) {
protected List<List<Integer>> generateExpectedResults(int topK, float[][] dataset, float[][] queries, BitSet[] prefilters, Logger log) {
List<List<Integer>> neighborsResult = new ArrayList<>();
int dimensions = dataset[0].length;

for (float[] query : queries) {
for (int q=0; q<queries.length; q++) {
float[] query = queries[q];
Map<Integer, Double> distances = new TreeMap<>();
for (int j = 0; j < dataset.length; j++) {
double distance = 0;
for (int k = 0; k < dimensions; k++) {
distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]);
if (prefilters != null && prefilters[q].get(j) == false) {
distance = Double.POSITIVE_INFINITY;
} else {
for (int k = 0; k < dimensions; k++) {
distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]);
}
}
distances.put(j, Math.sqrt(distance));
}
Expand All @@ -85,7 +91,6 @@ protected void compareResults(SearchResults results, List<List<Integer>> expecte
// actual vs. expected results
for (int i = 0; i < results.getResults().size(); i++) {
Map<Integer, Float> result = results.getResults().get(i);
assertEquals("TopK mismatch for query.", Math.min(topK, datasetSize), result.size());

// Sort result by values (distances) and extract keys
List<Integer> sortedResultKeys = result.entrySet().stream().sorted(Map.Entry.comparingByValue())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private void tmpResultsTopKWithRandomValues() throws Throwable {
assert topK > 0 && topK <= datasetSize : "Invalid topK value.";

// Generate expected results using brute force
List<List<Integer>> expected = generateExpectedResults(topK, dataset, queries, log);
List<List<Integer>> expected = generateExpectedResults(topK, dataset, queries, null, log);

// Create CuVS index and query
try (CuVSResources resources = CuVSResources.create()) {
Expand Down
16 changes: 5 additions & 11 deletions java/internal/src/cuvs_java.c
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ cuvsBruteForceIndex_t build_brute_force_index(float *dataset, long rows, long di
* @param[in] n_rows number of rows in the dataset
*/
void search_brute_force_index(cuvsBruteForceIndex_t index, float *queries, int topk, long n_queries, int dimensions,
cuvsResources_t cuvs_resources, int64_t *neighbors_h, float *distances_h, int *return_value, unsigned char *prefilter_data,
cuvsResources_t cuvs_resources, int64_t *neighbors_h, float *distances_h, int *return_value, uint32_t *prefilter_data,
long prefilter_data_length) {

cudaStream_t stream;
Expand Down Expand Up @@ -296,16 +296,10 @@ void search_brute_force_index(cuvsBruteForceIndex_t index, float *queries, int t
prefilter.addr = (uintptr_t)NULL;
} else {
// Parse the filters data
uint32_t *prefilters = (uint32_t*) malloc(prefilter_data_length);
int num_integers = prefilter_data_length / sizeof(uint32_t);
for (int i=0; i<num_integers; i++) {
*(prefilters + i) = (uint32_t)prefilter_data[4*i + 1] << 24 |
(uint32_t)prefilter_data[4*i + 0] << 16 |
(uint32_t)prefilter_data[4*i + 3] << 8 |
(uint32_t)prefilter_data[4*i + 2];
}
int64_t prefilter_shape[1] = {prefilter_data_length};
DLManagedTensor prefilter_tensor = prepare_tensor(prefilters, prefilter_shape, kDLUInt, 32, 1, kDLCUDA);
int num_integers = (prefilter_data_length+63)/64 * 2;
int extraPaddingByteExists = prefilter_data_length % 64 > 32? 0: 1;
int64_t prefilter_shape[1] = {(prefilter_data_length + 31) / 32};
DLManagedTensor prefilter_tensor = prepare_tensor(prefilter_data, prefilter_shape, kDLUInt, 32, 1, kDLCUDA);
prefilter.type = BITMAP;
prefilter.addr = (uintptr_t)&prefilter_tensor;
}
Expand Down

0 comments on commit 291f92e

Please sign in to comment.