Skip to content

Commit

Permalink
Fix for a crash situation when using different threads (#680)
Browse files Browse the repository at this point in the history
Using a pool memory manager was causing crash with different threads. Modified a test to run parallely sometimes.

Co-authored-by: Vivek Narang <[email protected]>
  • Loading branch information
2 people authored and rhdong committed Feb 19, 2025
1 parent d4366ce commit 88b018d
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 64 deletions.
179 changes: 117 additions & 62 deletions java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,38 @@

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.lang.invoke.MethodHandles;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.carrotsearch.randomizedtesting.RandomizedRunner;
import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo;
import com.nvidia.cuvs.CagraIndexParams.CuvsDistanceType;

@RunWith(RandomizedRunner.class)
public class CagraBuildAndSearchIT extends CuVSTestCase {

private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

@Before
public void setup() {
assumeTrue("not supported on " + System.getProperty("os.name"), isLinuxAmd64());
initializeRandom();
log.info("Random context initialized for test.");
}

/**
Expand All @@ -60,14 +69,14 @@ public void testIndexingAndSearchingFlow() throws Throwable {
{ 0.03902049f, 0.9689629f },
{ 0.92514056f, 0.4463501f },
{ 0.6673192f, 0.10993068f }
};
};
List<Integer> map = List.of(0, 1, 2, 3);
float[][] queries = {
{ 0.48216683f, 0.0428398f },
{ 0.5084142f, 0.6545497f },
{ 0.51260436f, 0.2643005f },
{ 0.05198065f, 0.5789965f }
};
};

// Expected search results
List<Map<Integer, Float>> expectedResults = Arrays.asList(
Expand All @@ -76,68 +85,114 @@ public void testIndexingAndSearchingFlow() throws Throwable {
Map.of(3, 0.047766715f, 2, 0.20332818f, 0, 0.48305473f),
Map.of(1, 0.15224178f, 0, 0.59063464f, 3, 0.5986642f));

for (int j = 0; j < 10; j++) {

try (CuVSResources resources = CuVSResources.create()) {

// Configure index parameters
CagraIndexParams indexParams = new CagraIndexParams.Builder()
.withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT)
.withGraphDegree(1)
.withIntermediateGraphDegree(2)
.withNumWriterThreads(32)
.withMetric(CuvsDistanceType.L2Expanded)
.build();

// Create the index with the dataset
CagraIndex index = CagraIndex.newBuilder(resources)
.withDataset(dataset)
.withIndexParams(indexParams)
.build();

// Saving the index on to the disk.
String indexFileName = UUID.randomUUID().toString() + ".cag";
index.serialize(new FileOutputStream(indexFileName));

// Loading a CAGRA index from disk.
File indexFile = new File(indexFileName);
InputStream inputStream = new FileInputStream(indexFile);
CagraIndex loadedIndex = CagraIndex.newBuilder(resources)
.from(inputStream)
.build();

// Configure search parameters
CagraSearchParams searchParams = new CagraSearchParams.Builder(resources)
.build();

// Create a query object with the query vectors
CagraQuery cuvsQuery = new CagraQuery.Builder()
.withTopK(3)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(map)
.build();

// Perform the search
SearchResults results = index.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Search from deserialized index
results = loadedIndex.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Cleanup
if (indexFile.exists()) {
indexFile.delete();
int numTestsRuns = 10;

try (CuVSResources resources = CuVSResources.create()) {
// sometimes run this test using different threads?
boolean runTestInDifferentThreads = random.nextBoolean();
// if running in different threads, run concurrently or one after the other?
boolean runConcurrently = runTestInDifferentThreads ? random.nextBoolean(): false;

log.info("Running in different threads? " + runTestInDifferentThreads);
log.info("Running concurrently? " + runConcurrently);

ExecutorService parallelExecutor = runConcurrently ? Executors.newFixedThreadPool(numTestsRuns): null;

for (int j = 0; j < numTestsRuns; j++) {
Runnable testLogic = indexAndQueryOnce(dataset, map, queries, expectedResults, resources);
if (runTestInDifferentThreads) {
if (runConcurrently) {
parallelExecutor.submit(testLogic);
} else {
ExecutorService singleExecutor = Executors.newSingleThreadExecutor();
singleExecutor.submit(testLogic);
singleExecutor.shutdown();
singleExecutor.awaitTermination(2000, TimeUnit.SECONDS);
}
} else {
// run the test logic in the main thread
testLogic.run();
}
index.destroyIndex();
}
if (parallelExecutor != null) {
parallelExecutor.shutdown();
parallelExecutor.awaitTermination(2000, TimeUnit.SECONDS);
}

}
}

private Runnable indexAndQueryOnce(float[][] dataset, List<Integer> map, float[][] queries,
List<Map<Integer, Float>> expectedResults, CuVSResources resources) throws Throwable, FileNotFoundException {

Runnable thread = new Runnable() {

@Override
public void run() {
try {

// Configure index parameters
CagraIndexParams indexParams = new CagraIndexParams.Builder()
.withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT)
.withGraphDegree(1)
.withIntermediateGraphDegree(2)
.withNumWriterThreads(32)
.withMetric(CuvsDistanceType.L2Expanded)
.build();

// Create the index with the dataset
CagraIndex index = CagraIndex.newBuilder(resources)
.withDataset(dataset)
.withIndexParams(indexParams)
.build();

// Saving the index on to the disk.
String indexFileName = UUID.randomUUID().toString() + ".cag";
index.serialize(new FileOutputStream(indexFileName));

// Loading a CAGRA index from disk.
File indexFile = new File(indexFileName);
InputStream inputStream = new FileInputStream(indexFile);
CagraIndex loadedIndex = CagraIndex.newBuilder(resources)
.from(inputStream)
.build();

// Configure search parameters
CagraSearchParams searchParams = new CagraSearchParams.Builder(resources)
.build();

// Create a query object with the query vectors
CagraQuery cuvsQuery = new CagraQuery.Builder()
.withTopK(3)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(map)
.build();

// Perform the search
SearchResults results = index.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Search from deserialized index
results = loadedIndex.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Cleanup
if (indexFile.exists()) {
indexFile.delete();
}
index.destroyIndex();
} catch (Throwable ex) {
throw new RuntimeException("Exception during indexing/querying", ex);
}
}
};
return thread;
}
}
2 changes: 0 additions & 2 deletions java/internal/src/cuvs_java.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ cuvsCagraIndex_t build_cagra_index(float *dataset, long rows, long dimensions, c
cuvsStreamGet(cuvs_resources, &stream);

omp_set_num_threads(n_writer_threads);
cuvsRMMPoolMemoryResourceEnable(95, 95, false);

int64_t dataset_shape[2] = {rows, dimensions};
DLManagedTensor dataset_tensor = prepare_tensor(dataset, dataset_shape, kDLFloat, 32, 2, kDLCUDA);
Expand Down Expand Up @@ -226,7 +225,6 @@ cuvsBruteForceIndex_t build_brute_force_index(float *dataset, long rows, long di
int *return_value, int n_writer_threads) {

omp_set_num_threads(n_writer_threads);
cuvsRMMPoolMemoryResourceEnable(95, 95, false);

cudaStream_t stream;
cuvsStreamGet(cuvs_resources, &stream);
Expand Down

0 comments on commit 88b018d

Please sign in to comment.