Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Feb 12, 2025
1 parent 368e43b commit a06b416
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/index/data_view_dense_index/data_view_index_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ AdaptToBaseIndexConfig(Config* cfg, PARAM_TYPE param_type, size_t dim) {
auto reorder_k = int(base_cfg->k.value() * base_cfg->refine_ratio.value());
base_cfg->k = reorder_k;
base_cfg->reorder_k = reorder_k;
base_cfg->ensure_topk_full = true;
break;
}
case PARAM_TYPE::RANGE_SEARCH: {
Expand Down
15 changes: 15 additions & 0 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,21 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, std::unique_
faiss::IVFSearchParameters base_search_params;
base_search_params.sel = id_selector;
base_search_params.nprobe = nprobe;
base_search_params.ensure_topk_full = ivf_cfg.ensure_topk_full.value();
if (base_search_params.ensure_topk_full) {
if (auto base_index_ptr = reinterpret_cast<faiss::IndexIVFPQFastScan*>(index_->base_index)) {
auto nlist = base_index_ptr->nlist;
base_search_params.nprobe = nlist;
// use max_codes to early termination
base_search_params.max_codes = (nprobe * 1.0 / nlist) * (index_->ntotal - bitset.count());
base_search_params.max_lists_num = nprobe;
} else {
throw std::runtime_error("invalid base index type of scann base index");
}
} else {
base_search_params.nprobe = nprobe;
base_search_params.max_codes = 0;
}

faiss::IndexScaNNSearchParameters scann_search_params;
scann_search_params.base_index_params = &base_search_params;
Expand Down
5 changes: 3 additions & 2 deletions tests/ut/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License

knowhere_file_glob(GLOB_RECURSE KNOWHERE_UT_SRCS *.cc)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_UT_SRCS test_perf.cc)

if(NOT WITH_DISKANN)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_DISKANN_TESTS test_diskann.cc)
Expand Down Expand Up @@ -46,4 +46,5 @@ target_link_libraries(knowhere_tests PRIVATE
Catch2::Catch2WithMain
atomic
stdc++fs
knowhere)
knowhere
profiler)
70 changes: 70 additions & 0 deletions tests/ut/test_data_view_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,76 @@ TEST_CASE("Test SCANN with data view refiner", "[float metrics]") {
}
}

TEST_CASE("Ensure topk test", "[float metrics]") {
using Catch::Approx;
auto version = GenTestVersionList();
if (!faiss::support_pq_fast_scan) {
SKIP("pass scann test");
}

const int64_t nb = 10000, nq = 10;
auto metric = GENERATE(as<std::string>{}, knowhere::metric::COSINE, knowhere::metric::IP, knowhere::metric::L2);
auto topk = nb;
auto dim = GENERATE(as<int64_t>{}, 120);

auto base_gen = [=]() {
knowhere::Json json;
json[knowhere::meta::DIM] = dim;
json[knowhere::meta::METRIC_TYPE] = metric;
json[knowhere::meta::TOPK] = topk;
json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.99;
json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01;
return json;
};

auto scann_gen = [base_gen, topk]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 512;
json[knowhere::indexparam::NPROBE] = 1;
json[knowhere::indexparam::REFINE_RATIO] = 1.0;
json[knowhere::indexparam::SUB_DIM] = 2;
json[knowhere::indexparam::WITH_RAW_DATA] = true;
json[knowhere::indexparam::ENSURE_TOPK_FULL] = true;
return json;
};

auto rand = GENERATE(1);
const auto train_ds = GenDataSet(nb, dim, rand);
const auto query_ds = GenDataSet(nq, dim, rand + 777);

const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric},
{knowhere::meta::TOPK, topk},
};
knowhere::ViewDataOp data_view = [&train_ds, data_size = sizeof(float) * dim](size_t id) {
auto data = train_ds->GetTensor();
return data + data_size * id;
};
auto data_view_pack = knowhere::Pack(data_view);
auto cfg_json = scann_gen().dump();
knowhere::Json json = knowhere::Json::parse(cfg_json);

auto scann_with_dv_refiner =
knowhere::IndexFactory::Instance()
.Create<knowhere::fp32>(knowhere::IndexEnum::INDEX_FAISS_SCANN_DVR, version, data_view_pack)
.value();

REQUIRE(scann_with_dv_refiner.Type() == knowhere::IndexEnum::INDEX_FAISS_SCANN_DVR);
REQUIRE(scann_with_dv_refiner.Build(train_ds, json) == knowhere::Status::success);
REQUIRE(scann_with_dv_refiner.Count() == nb);
REQUIRE(scann_with_dv_refiner.Size() > 0);
REQUIRE(scann_with_dv_refiner.HasRawData(metric) == false);
REQUIRE(scann_with_dv_refiner.HasRawData(metric) ==
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(knowhere::IndexEnum::INDEX_FAISS_SCANN_DVR, version,
cfg_json));
auto scann_with_dv_refiner_results = scann_with_dv_refiner.Search(query_ds, json, nullptr);
auto res_ids = scann_with_dv_refiner_results.value()->GetIds();
// check we can get all vectors in (topk = nb, nprobe = )
for (auto i = 0; i < nq * topk; i++) {
REQUIRE(res_ids[i] != -1);
}
}

template <typename DataType>
void
BaseTest(const knowhere::DataSetPtr train_ds, const knowhere::DataSetPtr query_ds, const int64_t k,
Expand Down
1 change: 1 addition & 0 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
json[knowhere::indexparam::NPROBE] = 14;
json[knowhere::indexparam::REORDER_K] = 200;
json[knowhere::indexparam::WITH_RAW_DATA] = true;
json[knowhere::indexparam::ENSURE_TOPK_FULL] = false;
return json;
};

Expand Down
1 change: 1 addition & 0 deletions thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct SearchParametersIVF : SearchParameters {
///< to minimize code change, when users only use nprobe to search, this config does not take affect since we will first retrieve the nearest nprobe buckets
///< it is a bit heavy to further retrieve more buckets
///< therefore to make sure we get topk results, use nprobe=nlist and use max_codes to narrow down the search range
size_t max_lists_num = 0; ///< select min{scanned number of (max_codes), scanned number of (max_lists_num) to return.}
bool ensure_topk_full = false;

///< during IVF range search, if reach 'max_empty_result_buckets' num of
Expand Down
26 changes: 22 additions & 4 deletions thirdparty/faiss/faiss/IndexIVFFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ void IndexIVFFastScan::search_preassigned(
IndexIVFStats* stats) const {
size_t nprobe = this->nprobe;
if (params) {
FAISS_THROW_IF_NOT(params->max_codes == 0);
nprobe = params->nprobe;
}

Expand Down Expand Up @@ -591,7 +590,7 @@ void IndexIVFFastScan::search_dispatch_implem(
int impl = implem;

if (impl == 0) {
if (bbs == 32) {
if (bbs == 32 && !params->ensure_topk_full) {
impl = 12;
} else {
impl = 10;
Expand Down Expand Up @@ -671,7 +670,7 @@ void IndexIVFFastScan::search_dispatch_implem(
)
);
search_implem_10(
n, x, *handler.get(), cq,
n, x, k, *handler.get(), cq,
&ndis, &nlist_visited, scaler, params);
}
// clang-format on
Expand Down Expand Up @@ -704,7 +703,7 @@ void IndexIVFFastScan::search_dispatch_implem(
cq_i, &ndis, &nlist_visited, scaler, params);
} else {
search_implem_10(
i1 - i0, x + i0 * d, *handler.get(),
i1 - i0, x + i0 * d,k, *handler.get(),
cq_i, &ndis, &nlist_visited, scaler, params);
}
// clang-format on
Expand Down Expand Up @@ -1021,12 +1020,27 @@ void IndexIVFFastScan::search_implem_2(
void IndexIVFFastScan::search_implem_10(
idx_t n,
const float* x,
idx_t k,
SIMDResultHandlerToFloat& handler,
const CoarseQuantized& cq,
size_t* ndis_out,
size_t* nlist_out,
const NormTableScaler* scaler,
const IVFSearchParameters* params) const {
// const size_t nprobe = params ? params->nprobe : this->nprobe;
const bool ensure_topk_full = params ? params->ensure_topk_full : false;
size_t max_codes = params ? params->max_codes : this->max_codes;
size_t max_lists_num = params ? params->max_lists_num : nlist;
FAISS_THROW_IF_NOT_MSG(
n == 1 || !ensure_topk_full,
"ensure_topk_full can't be true if queries number larger than 1.");
if (max_codes == 0) {
max_codes = std::numeric_limits<idx_t>::max();
}
if (max_lists_num == 0) {
max_lists_num = nlist;
}

size_t dim12 = ksub * M2;
AlignedTable<uint8_t> dis_tables;
AlignedTable<uint16_t> biases;
Expand Down Expand Up @@ -1085,6 +1099,10 @@ void IndexIVFFastScan::search_implem_10(
scaler);

ndis++;
auto nscan = handler.count_scanned_rows();
if ((nscan >= max_codes ||j >= max_lists_num)&& (!ensure_topk_full || nscan >= (size_t)k)) {
break;
}
}
}

Expand Down
1 change: 1 addition & 0 deletions thirdparty/faiss/faiss/IndexIVFFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ struct IndexIVFFastScan : IndexIVF {
void search_implem_10(
idx_t n,
const float* x,
idx_t k,
SIMDResultHandlerToFloat& handler,
const CoarseQuantized& cq,
size_t* ndis_out,
Expand Down
18 changes: 18 additions & 0 deletions thirdparty/faiss/faiss/impl/simd_result_handlers.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,25 @@ struct SIMDResultHandlerToFloat : SIMDResultHandler {
nullptr; // table of biases to add to each query (for IVF L2 search)
const float* normalizers = nullptr; // size 2 * nq, to convert

size_t scan_cnt = 0; // scanned vector number (except filtered)

SIMDResultHandlerToFloat(size_t nq, size_t ntotal) : nq(nq), ntotal(ntotal) {}

virtual void begin(const float* norms) {
normalizers = norms;
scan_cnt = 0;
}

// called at end of search to convert int16 distances to float, before
// normalizers are deallocated
virtual void end() {
normalizers = nullptr;
scan_cnt = 0;
}

// Get the number of scanned vectors
size_t count_scanned_rows() {
return scan_cnt;
}
};

Expand Down Expand Up @@ -293,6 +302,7 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
auto real_idx = this->adjust_id(b, j);
lt_mask -= 1 << j;
if (this->sel->is_member(real_idx)) {
this->scan_cnt++;
T d = d32tab[j];
if (C::cmp(idis[q], d)) {
idis[q] = d;
Expand All @@ -310,6 +320,7 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
lt_mask -= 1 << j;
T d = d32tab[j];
if (C::cmp(idis[q], d)) {
this->scan_cnt++;
idis[q] = d;
ids[q] = this->adjust_id(b, j);

Expand All @@ -329,6 +340,7 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
dis[q] = b + idis[q] * one_a;
}
}
this->scan_cnt = 0;
}
};

Expand Down Expand Up @@ -388,6 +400,7 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
auto real_idx = this->adjust_id(b, j);
lt_mask -= 1 << j;
if (this->sel->is_member(real_idx)) {
this->scan_cnt++;
T dis = d32tab[j];
if (C::cmp(heap_dis[0], dis)) {
heap_replace_top<C>(k, heap_dis, heap_ids, dis, real_idx);
Expand All @@ -404,6 +417,7 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
lt_mask -= 1 << j;
T dis = d32tab[j];
if (C::cmp(heap_dis[0], dis)) {
this->scan_cnt++;
int64_t idx = this->adjust_id(b, j);
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);

Expand Down Expand Up @@ -431,6 +445,7 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
heap_ids[j] = heap_ids_in[j];
}
}
this->scan_cnt = 0;
}
};

Expand Down Expand Up @@ -582,6 +597,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
auto real_idx = this->adjust_id(b, j);
lt_mask -= 1 << j;
if (this->sel->is_member(real_idx)) {
this->scan_cnt++;
T dis = d32tab[j];
res.add(dis, real_idx);

Expand All @@ -595,6 +611,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
int j = __builtin_ctz(lt_mask);
lt_mask -= 1 << j;
T dis = d32tab[j];
this->scan_cnt++;
res.add(dis, this->adjust_id(b, j));

this->in_range_num += 1;
Expand Down Expand Up @@ -639,6 +656,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
// possibly add empty results
heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
}
this->scan_cnt = 0;
}
};

Expand Down

0 comments on commit a06b416

Please sign in to comment.