Skip to content

Commit

Permalink
synchronize threading
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Feb 25, 2025
1 parent a2a6a67 commit 1e4d2e7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 27 deletions.
14 changes: 6 additions & 8 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ class hnsw_lib : public algo<T> {
struct build_param {
int m;
int ef_construction;
int num_threads = omp_get_num_procs();
int num_threads = omp_get_max_threads();
};

using search_param_base = typename algo<T>::search_param;
struct search_param : public search_param_base {
int ef;
int num_threads = 1;
int num_threads = omp_get_max_threads();
};

hnsw_lib(Metric metric, int dim, const build_param& param);
Expand Down Expand Up @@ -175,12 +175,7 @@ void hnsw_lib<T>::set_search_param(const search_param_base& param_, const void*
auto param = dynamic_cast<const search_param&>(param_);
appr_alg_->ef_ = param.ef;
num_threads_ = param.num_threads;
// bench_mode_ = param.metric_objective;
bench_mode_ = Mode::kLatency; // TODO(achirkin): pass the benchmark mode in the algo parameters

// Create a pool if multiple query threads have been set and the pool hasn't been created already
bool create_pool = (bench_mode_ == Mode::kLatency && num_threads_ > 1 && !thread_pool_);
if (create_pool) { thread_pool_ = std::make_shared<fixed_thread_pool>(num_threads_); }
}

template <typename T>
Expand All @@ -192,7 +187,10 @@ void hnsw_lib<T>::search(
get_search_knn_results(query + i * dim_, k, indices + i * k, distances + i * k);
};
if (bench_mode_ == Mode::kLatency && num_threads_ > 1) {
thread_pool_->submit(f, batch_size);
#pragma omp parallel for num_threads(num_threads_)
for (int i = 0; i < batch_size; i++) {
f(i);
}
} else {
for (int i = 0; i < batch_size; i++) {
f(i);
Expand Down
25 changes: 7 additions & 18 deletions cpp/src/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,26 +489,15 @@ void search(raft::resources const& res,
auto const* hnswlib_index =
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());
auto num_threads = params.num_threads == 0 ? omp_get_max_threads() : params.num_threads;

// when num_threads == 0, automatically maximize parallelism
if (params.num_threads) {
#pragma omp parallel for num_threads(params.num_threads)
for (int64_t i = 0; i < queries.extent(0); ++i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
}
} else {
#pragma omp parallel for
for (int64_t i = 0; i < queries.extent(0); ++i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
}
for (int64_t i = 0; i < queries.extent(0); ++i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ groups:
graph_degree: [32, 64, 96, 128]
intermediate_graph_degree: [32, 64, 96, 128]
graph_build_algo: ["NN_DESCENT"]
hierarchy: ["none", "cpu"]
hierarchy: ["none", "cpu", "gpu"]
ef_construction: [64, 128, 256, 512]
search:
ef: [10, 20, 40, 60, 80, 120, 200, 400, 600, 800]
Expand Down

0 comments on commit 1e4d2e7

Please sign in to comment.