From 291f92e18557e17f81b36fd0251fe84f47640766 Mon Sep 17 00:00:00 2001 From: Ishan Chattopadhyaya Date: Wed, 19 Feb 2025 22:53:27 +0530 Subject: [PATCH] Avoiding extra copy, adding randomized testing for Brute Force with prefiltering --- .../java/com/nvidia/cuvs/BruteForceQuery.java | 21 ++++++++++++++++--- .../cuvs/internal/BruteForceIndexImpl.java | 15 ++++++------- .../com/nvidia/cuvs/internal/common/Util.java | 21 +++++-------------- .../nvidia/cuvs/BruteForceAndSearchIT.java | 2 +- .../nvidia/cuvs/BruteForceRandomizedIT.java | 20 ++++++++++++++++-- .../com/nvidia/cuvs/CagraRandomizedIT.java | 2 +- .../java/com/nvidia/cuvs/CuVSTestCase.java | 15 ++++++++----- .../com/nvidia/cuvs/HnswRandomizedIT.java | 2 +- java/internal/src/cuvs_java.c | 16 +++++--------- 9 files changed, 65 insertions(+), 49 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java index 76c3c2502..019e27dcd 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java @@ -30,6 +30,7 @@ public class BruteForceQuery { private List mapping; private float[][] queryVectors; private BitSet[] prefilters; + private int numDocs = -1; private int topK; /** @@ -41,12 +42,15 @@ public class BruteForceQuery { * @param topK the top k results to return * @param prefilter the prefilter data to use while searching the BRUTEFORCE * index + * @param numDocs Maximum of bits in each prefilter, representing number of documents in this index. + * Used only when prefilter(s) is/are passed. */ - public BruteForceQuery(float[][] queryVectors, List mapping, int topK, BitSet[] prefilters) { + public BruteForceQuery(float[][] queryVectors, List mapping, int topK, BitSet[] prefilters, int numDocs) { this.queryVectors = queryVectors; this.mapping = mapping; this.topK = topK; this.prefilters = prefilters; + this.numDocs = numDocs; } /** @@ -85,6 +89,15 @@ public BitSet[] getPrefilters() { return prefilters; } + /** + * Gets the number of documents supposed to be in this index, as used for prefilters + * + * @return number of documents as an integer + */ + public int getNumDocs() { + return numDocs; + } + @Override public String toString() { return "BruteForceQuery [mapping=" + mapping + ", queryVectors=" + Arrays.toString(queryVectors) + ", prefilter=" @@ -98,6 +111,7 @@ public static class Builder { private float[][] queryVectors; private BitSet[] prefilters; + private int numDocs; private List mapping; private int topK = 2; @@ -141,8 +155,9 @@ public Builder withTopK(int topK) { * many bits as there are vectors in the index * @return an instance of this Builder */ - public Builder withPrefilter(BitSet[] prefilters) { + public Builder withPrefilter(BitSet[] prefilters, int numDocs) { this.prefilters = prefilters; + this.numDocs = numDocs; return this; } @@ -152,7 +167,7 @@ public Builder withPrefilter(BitSet[] prefilters) { * @return an instance of {@link BruteForceQuery} */ public BruteForceQuery build() { - return new BruteForceQuery(queryVectors, mapping, topK, prefilters); + return new BruteForceQuery(queryVectors, mapping, topK, prefilters, numDocs); } } } diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java index 7b39579ca..decb13133 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java @@ -26,6 +26,8 @@ import java.lang.foreign.MemorySegment; import java.lang.foreign.SequenceLayout; import java.lang.invoke.MethodHandle; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; @@ -183,15 +185,10 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable { MemorySegment prefilterDataMemorySegment = MemorySegment.NULL; BitSet[] prefilters = cuvsQuery.getPrefilters(); if (prefilters != null && prefilters.length > 0) { - BitSet concatenatedFilters = Util.concatenate(prefilters); - byte[] paddedFilterBytes = new byte[Util.roundUp(concatenatedFilters.toByteArray().length, 4)]; - int pad = paddedFilterBytes.length - concatenatedFilters.toByteArray().length; - for (int i=0; i 0 && topK <= datasetSize : "Invalid topK value."; // Generate expected results using brute force - List> expected = generateExpectedResults(topK, dataset, queries, log); + List> expected = generateExpectedResults(topK, dataset, queries, prefilters, log); // Create CuVS index and query try (CuVSResources resources = CuVSResources.create()) { @@ -98,6 +113,7 @@ private void tmpResultsTopKWithRandomValues() throws Throwable { BruteForceQuery query = new BruteForceQuery.Builder() .withTopK(topK) .withQueryVectors(queries) + .withPrefilter(prefilters, dataset.length) .build(); BruteForceIndexParams indexParams = new BruteForceIndexParams.Builder() diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedIT.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedIT.java index 811866dc6..34665d3eb 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedIT.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedIT.java @@ -91,7 +91,7 @@ private void tmpResultsTopKWithRandomValues() throws Throwable { assert topK > 0 && topK <= datasetSize : "Invalid topK value."; // Generate expected results using brute force - List> expected = generateExpectedResults(topK, dataset, queries, log); + List> expected = generateExpectedResults(topK, dataset, queries, null, log); // Create CuVS index and query try (CuVSResources resources = CuVSResources.create()) { diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java index 844be1985..e3db20ff0 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java @@ -21,6 +21,7 @@ import java.lang.invoke.MethodHandles; import java.util.ArrayList; +import java.util.BitSet; import java.util.List; import java.util.Map; import java.util.Random; @@ -50,16 +51,21 @@ protected float[][] generateData(Random random, int rows, int cols) { return data; } - protected List> generateExpectedResults(int topK, float[][] dataset, float[][] queries, Logger log) { + protected List> generateExpectedResults(int topK, float[][] dataset, float[][] queries, BitSet[] prefilters, Logger log) { List> neighborsResult = new ArrayList<>(); int dimensions = dataset[0].length; - for (float[] query : queries) { + for (int q=0; q distances = new TreeMap<>(); for (int j = 0; j < dataset.length; j++) { double distance = 0; - for (int k = 0; k < dimensions; k++) { - distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]); + if (prefilters != null && prefilters[q].get(j) == false) { + distance = Double.POSITIVE_INFINITY; + } else { + for (int k = 0; k < dimensions; k++) { + distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]); + } } distances.put(j, Math.sqrt(distance)); } @@ -85,7 +91,6 @@ protected void compareResults(SearchResults results, List> expecte // actual vs. expected results for (int i = 0; i < results.getResults().size(); i++) { Map result = results.getResults().get(i); - assertEquals("TopK mismatch for query.", Math.min(topK, datasetSize), result.size()); // Sort result by values (distances) and extract keys List sortedResultKeys = result.entrySet().stream().sorted(Map.Entry.comparingByValue()) diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/HnswRandomizedIT.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/HnswRandomizedIT.java index 6d367efc4..07033391b 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/HnswRandomizedIT.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/HnswRandomizedIT.java @@ -97,7 +97,7 @@ private void tmpResultsTopKWithRandomValues() throws Throwable { assert topK > 0 && topK <= datasetSize : "Invalid topK value."; // Generate expected results using brute force - List> expected = generateExpectedResults(topK, dataset, queries, log); + List> expected = generateExpectedResults(topK, dataset, queries, null, log); // Create CuVS index and query try (CuVSResources resources = CuVSResources.create()) { diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index 220169b8b..de29249ae 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -266,7 +266,7 @@ cuvsBruteForceIndex_t build_brute_force_index(float *dataset, long rows, long di * @param[in] n_rows number of rows in the dataset */ void search_brute_force_index(cuvsBruteForceIndex_t index, float *queries, int topk, long n_queries, int dimensions, - cuvsResources_t cuvs_resources, int64_t *neighbors_h, float *distances_h, int *return_value, unsigned char *prefilter_data, + cuvsResources_t cuvs_resources, int64_t *neighbors_h, float *distances_h, int *return_value, uint32_t *prefilter_data, long prefilter_data_length) { cudaStream_t stream; @@ -296,16 +296,10 @@ void search_brute_force_index(cuvsBruteForceIndex_t index, float *queries, int t prefilter.addr = (uintptr_t)NULL; } else { // Parse the filters data - uint32_t *prefilters = (uint32_t*) malloc(prefilter_data_length); - int num_integers = prefilter_data_length / sizeof(uint32_t); - for (int i=0; i 32? 0: 1; + int64_t prefilter_shape[1] = {(prefilter_data_length + 31) / 32}; + DLManagedTensor prefilter_tensor = prepare_tensor(prefilter_data, prefilter_shape, kDLUInt, 32, 1, kDLCUDA); prefilter.type = BITMAP; prefilter.addr = (uintptr_t)&prefilter_tensor; }