Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable refine for FAISS HNSW indices by default, if available #932

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,24 +551,32 @@ class FaissHnswIterator : public IndexIterator {
// wrap a sign, if needed
workspace.qdis = std::unique_ptr<faiss::DistanceComputer>(storage_distance_computer(index_hnsw));

// a tricky point here.
// Basically, if out hnsw index's storage is HasInverseL2Norms, then
// this is a cosine index. But because refine always keeps original
// data, then we need to use a wrapper over a distance computer
const faiss::HasInverseL2Norms* has_l2_norms =
dynamic_cast<const faiss::HasInverseL2Norms*>(index_hnsw->storage);
if (has_l2_norms != nullptr) {
// add a cosine wrapper over it
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(new faiss::WithCosineNormDistanceComputer(
has_l2_norms->get_inverse_l2_norms(), index->d,
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer())));
if (refine_ratio != 0) {
// the refine is needed

// a tricky point here.
// Basically, if out hnsw index's storage is HasInverseL2Norms, then
// this is a cosine index. But because refine always keeps original
// data, then we need to use a wrapper over a distance computer
const faiss::HasInverseL2Norms* has_l2_norms =
dynamic_cast<const faiss::HasInverseL2Norms*>(index_hnsw->storage);
if (has_l2_norms != nullptr) {
// add a cosine wrapper over it
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(new faiss::WithCosineNormDistanceComputer(
has_l2_norms->get_inverse_l2_norms(), index->d,
std::unique_ptr<faiss::DistanceComputer>(
index_refine->refine_index->get_distance_computer())));
} else {
// use it as is
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer());
}
} else {
// use it as is
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer());
// the refine is not needed
workspace.qdis_refine = nullptr;
}
} else {
const faiss::IndexHNSW* index_hnsw = dynamic_cast<const faiss::IndexHNSW*>(index.get());
Expand Down Expand Up @@ -882,9 +890,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
return expected<DataSetPtr>::Err(Status::invalid_args, "k parameter is missing");
}

// whether a user wants a refine
const bool whether_to_enable_refine = hnsw_cfg.refine_k.has_value();

// set up an index wrapper
auto [index_wrapper, is_refined] =
create_conditional_hnsw_wrapper(index.get(), hnsw_cfg, whether_bf_search.value_or(false));
auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper(
index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine);

if (index_wrapper == nullptr) {
return expected<DataSetPtr>::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW");
Expand Down Expand Up @@ -1011,9 +1022,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
return expected<DataSetPtr>::Err(Status::invalid_args, "ef parameter is missing");
}

// whether a user wants a refine
const bool whether_to_enable_refine = true;

// set up an index wrapper
auto [index_wrapper, is_refined] =
create_conditional_hnsw_wrapper(index.get(), hnsw_cfg, whether_bf_search.value_or(false));
auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper(
index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine);

if (index_wrapper == nullptr) {
return expected<DataSetPtr>::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW");
Expand Down Expand Up @@ -1229,11 +1243,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
throw;
}

//
const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);

const float iterator_refine_ratio =
(dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr)
? hnsw_cfg.iterator_refine_ratio.value_or(0.5)
: 0;
should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0;

// create an iterator and initialize it
auto it =
Expand Down
8 changes: 6 additions & 2 deletions src/index/hnsw/faiss_hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class FaissHnswConfig : public BaseHnswConfig {
.for_static();
KNOWHERE_CONFIG_DECLARE_FIELD(refine_k)
.description("refine k")
.allow_empty_without_default()
.set_default(1)
.set_range(1, std::numeric_limits<CFG_FLOAT::value_type>::max())
.for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(refine_type)
Expand Down Expand Up @@ -83,7 +83,7 @@ class FaissHnswFlatConfig : public FaissHnswConfig {
// check our parameters
if (param_type == PARAM_TYPE::TRAIN) {
// prohibit refine
if (refine.value_or(false) || refine_type.has_value() || refine_k.has_value()) {
if (refine.value_or(false) || refine_type.has_value()) {
if (err_msg) {
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
Expand Down Expand Up @@ -189,6 +189,8 @@ class FaissHnswPqConfig : public FaissHnswConfig {
}
}
}
default:
break;
}
return Status::success;
}
Expand Down Expand Up @@ -232,6 +234,8 @@ class FaissHnswPrqConfig : public FaissHnswConfig {
}
}
}
default:
break;
}
return Status::success;
}
Expand Down
10 changes: 7 additions & 3 deletions src/index/hnsw/impl/IndexConditionalWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,13 @@ WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswCo
return false;
}

// returns nullptr in case of invalid index
// returns nullptr in case of invalid index.
//
// `whether_to_enable_refine` allows to enable the refine for the search if the
// index was trained with the refine.
std::tuple<std::unique_ptr<faiss::Index>, bool>
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search) {
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search,
const bool whether_to_enable_refine) {
const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::COSINE);

// check if we have a refine available.
Expand Down Expand Up @@ -126,7 +130,7 @@ create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw
}

// check if a user wants a refined result
if (hnsw_cfg.refine_k.has_value()) {
if (whether_to_enable_refine) {
// yes, a user wants to perform a refine

// thus, we need to define a new refine index and pass
Expand Down
8 changes: 6 additions & 2 deletions src/index/hnsw/impl/IndexConditionalWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ std::optional<bool>
WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswConfig& cfg, const BitsetView& bitset);

// first return arg: returns nullptr in case of invalid index
// second return arg: returns whether an index does refine
// second return arg: returns whether an index does the refine
//
// `whether_to_enable_refine` allows to enable the refine for the search if the
// index was trained with the refine.
std::tuple<std::unique_ptr<faiss::Index>, bool>
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search);
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search,
const bool whether_to_enable_refine);

} // namespace knowhere
29 changes: 11 additions & 18 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,18 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -286,20 +283,18 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -334,20 +329,18 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator insufficient results") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down
41 changes: 41 additions & 0 deletions thirdparty/faiss/faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,47 @@ void IndexRefine::search(
}
}

void IndexRefine::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params_in) const
{
const IndexRefineSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(
params, "IndexRefine params have incorrect type");
}

SearchParameters* base_index_params =
(params != nullptr) ? params->base_index_params : nullptr;

base_index->range_search(
n, x, radius, result, base_index_params);

#pragma omp parallel if (n > 1)
{
std::unique_ptr<DistanceComputer> dc(
refine_index->get_distance_computer());

#pragma omp for
for (idx_t i = 0; i < n; i++) {
dc->set_query(x + i * d);

// reevaluate distances
const size_t idx_start = result->lims[i];
const size_t idx_end = result->lims[i + 1];

for (size_t j = idx_start; j < idx_end; j++) {
const auto label = result->labels[j];
result->distances[j] = (*dc)(label);
}
}
}
}

void IndexRefine::reconstruct(idx_t key, float* recons) const {
refine_index->reconstruct(key, recons);
}
Expand Down
7 changes: 7 additions & 0 deletions thirdparty/faiss/faiss/IndexRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ struct IndexRefine : Index {
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

// reconstruct is routed to the refine_index
void reconstruct(idx_t key, float* recons) const override;

Expand Down
Loading