From 9b9b0230d15ad0a0424b16d052105f2c827b5092 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 7 Nov 2024 20:14:31 -0500 Subject: [PATCH] add range_search() to IndexRefine Signed-off-by: Alexandr Guzhva --- faiss/IndexRefine.cpp | 39 +++++++++++++++++++++++++++++++++++++++ faiss/IndexRefine.h | 7 +++++++ 2 files changed, 46 insertions(+) diff --git a/faiss/IndexRefine.cpp b/faiss/IndexRefine.cpp index 8bc429a5e9..6f1f588e2e 100644 --- a/faiss/IndexRefine.cpp +++ b/faiss/IndexRefine.cpp @@ -166,6 +166,45 @@ void IndexRefine::search( } } +void IndexRefine::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params_in) const { + const IndexRefineSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG( + params, "IndexRefine params have incorrect type"); + } + + SearchParameters* base_index_params = + (params != nullptr) ? params->base_index_params : nullptr; + + base_index->range_search(n, x, radius, result, base_index_params); + +#pragma omp parallel if (n > 1) + { + std::unique_ptr dc( + refine_index->get_distance_computer()); + +#pragma omp for + for (idx_t i = 0; i < n; i++) { + dc->set_query(x + i * d); + + // reevaluate distances + const size_t idx_start = result->lims[i]; + const size_t idx_end = result->lims[i + 1]; + + for (size_t j = idx_start; j < idx_end; j++) { + const auto label = result->labels[j]; + result->distances[j] = (*dc)(label); + } + } + } +} + void IndexRefine::reconstruct(idx_t key, float* recons) const { refine_index->reconstruct(key, recons); } diff --git a/faiss/IndexRefine.h b/faiss/IndexRefine.h index 9ad4e4be29..255271695f 100644 --- a/faiss/IndexRefine.h +++ b/faiss/IndexRefine.h @@ -54,6 +54,13 @@ struct IndexRefine : Index { idx_t* labels, const SearchParameters* params = nullptr) const override; + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params = nullptr) const override; + // reconstruct is routed to the refine_index void reconstruct(idx_t key, float* recons) const override;