Skip to content

Commit

Permalink
add range_search() to IndexRefine
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva committed Dec 20, 2024
1 parent 5637bb8 commit 9b9b023
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
39 changes: 39 additions & 0 deletions faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,45 @@ void IndexRefine::search(
}
}

void IndexRefine::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params_in) const {
const IndexRefineSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(
params, "IndexRefine params have incorrect type");
}

SearchParameters* base_index_params =
(params != nullptr) ? params->base_index_params : nullptr;

base_index->range_search(n, x, radius, result, base_index_params);

#pragma omp parallel if (n > 1)
{
std::unique_ptr<DistanceComputer> dc(
refine_index->get_distance_computer());

#pragma omp for
for (idx_t i = 0; i < n; i++) {
dc->set_query(x + i * d);

// reevaluate distances
const size_t idx_start = result->lims[i];
const size_t idx_end = result->lims[i + 1];

for (size_t j = idx_start; j < idx_end; j++) {
const auto label = result->labels[j];
result->distances[j] = (*dc)(label);
}
}
}
}

void IndexRefine::reconstruct(idx_t key, float* recons) const {
refine_index->reconstruct(key, recons);
}
Expand Down
7 changes: 7 additions & 0 deletions faiss/IndexRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ struct IndexRefine : Index {
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

// reconstruct is routed to the refine_index
void reconstruct(idx_t key, float* recons) const override;

Expand Down

0 comments on commit 9b9b023

Please sign in to comment.