Skip to content

Commit

Permalink
Merge FloatSearch() and BinarySearch() into SearchOnGrowing() (milvus…
Browse files Browse the repository at this point in the history
…-io#18498)

Signed-off-by: yudong.cai <[email protected]>
  • Loading branch information
cydrain authored Aug 3, 2022
1 parent 4edc8d3 commit 7cd37fc
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 124 deletions.
159 changes: 51 additions & 108 deletions internal/core/src/query/SearchOnGrowing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,35 @@

namespace milvus::query {

Status
FloatSearch(const segcore::SegmentGrowingImpl& segment,
const query::SearchInfo& info,
const float* query_data,
int64_t num_queries,
int64_t ins_barrier,
const BitsetView& bitset,
SearchResult& results) {
// TODO: small index is disabled, however 3 unittests still call this API, consider to remove this API
// - Query::ExecWithPredicateLoader
// - Query::ExecWithPredicate
// - Query::ExecWithoutPredicate
int32_t
FloatIndexSearch(const segcore::SegmentGrowingImpl& segment,
const query::SearchInfo& info,
const void* query_data,
int64_t num_queries,
int64_t ins_barrier,
const BitsetView& bitset,
SubSearchResult& results) {
auto& schema = segment.get_schema();
auto& indexing_record = segment.get_indexing_record();
auto& record = segment.get_insert_record();

// step 1.1: get meta
// step 1.2: get which vector field to search
auto vecfield_id = info.field_id_;
auto& field = schema[vecfield_id];

AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT");
auto dim = field.get_dim();
auto topk = info.topk_;
auto total_count = topk * num_queries;
auto metric_type = info.metric_type_;
auto round_decimal = info.round_decimal_;
// step 2: small indexing search
// std::vector<int64_t> final_uids(total_count, -1);
// std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
dataset::SearchDataset search_dataset{info.metric_type_, num_queries, info.topk_,
info.round_decimal_, field.get_dim(), query_data};
auto vec_ptr = record.get_field_data<FloatVector>(vecfield_id);

int current_chunk_id = 0;

if (indexing_record.is_in(vecfield_id)) {
auto max_indexed_id = indexing_record.get_finished_ack();
const auto& field_indexing = indexing_record.get_vec_field_indexing(vecfield_id);
auto search_conf = field_indexing.get_search_params(topk);
auto search_conf = field_indexing.get_search_params(info.topk_);
AssertInfo(vec_ptr->get_size_per_chunk() == field_indexing.get_size_per_chunk(),
"[FloatSearch]Chunk size of vector not equal to chunk size of field index");

Expand All @@ -72,123 +65,73 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
}
}

final_qr.merge(sub_qr);
results.merge(sub_qr);
current_chunk_id++;
}
}

// step 3: brute force search where small indexing is unavailable
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
auto max_chunk = upper_div(ins_barrier, vec_size_per_chunk);

for (int chunk_id = current_chunk_id; chunk_id < max_chunk; ++chunk_id) {
auto& chunk = vec_ptr->get_chunk(chunk_id);

auto element_begin = chunk_id * vec_size_per_chunk;
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = element_end - element_begin;

auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_qr = BruteForceSearch(search_dataset, chunk.data(), size_per_chunk, sub_view);

// convert chunk uid to segment uid
for (auto& x : sub_qr.mutable_seg_offsets()) {
if (x != -1) {
x += chunk_id * vec_size_per_chunk;
}
}
final_qr.merge(sub_qr);
}
results.distances_ = std::move(final_qr.mutable_distances());
results.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
results.unity_topK_ = topk;
results.total_nq_ = num_queries;

return Status::OK();
return current_chunk_id;
}

Status
BinarySearch(const segcore::SegmentGrowingImpl& segment,
const query::SearchInfo& info,
const uint8_t* query_data,
int64_t num_queries,
int64_t ins_barrier,
const BitsetView& bitset,
SearchResult& results) {
void
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
int64_t ins_barrier,
const query::SearchInfo& info,
const void* query_data,
int64_t num_queries,
const BitsetView& bitset,
SearchResult& results) {
auto& schema = segment.get_schema();
auto& indexing_record = segment.get_indexing_record();
auto& record = segment.get_insert_record();
// step 1: binary search to find the barrier of the snapshot
// auto ins_barrier = get_barrier(record, timestamp);
auto metric_type = info.metric_type_;
// auto del_barrier = get_barrier(deleted_record_, timestamp);

// step 2.1: get meta
// step 2.2: get which vector field to search
// step 1.1: get meta
// step 1.2: get which vector field to search
auto vecfield_id = info.field_id_;
auto& field = schema[vecfield_id];
auto data_type = field.get_data_type();
AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type");

AssertInfo(field.get_data_type() == DataType::VECTOR_BINARY, "[BinarySearch]Field data type isn't VECTOR_BINARY");
auto dim = field.get_dim();
auto topk = info.topk_;
auto total_count = topk * num_queries;
auto metric_type = info.metric_type_;
auto round_decimal = info.round_decimal_;
// step 3: small indexing search
query::dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};

auto vec_ptr = record.get_field_data<BinaryVector>(vecfield_id);
auto max_indexed_id = 0;
// step 2: small indexing search
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};

int32_t current_chunk_id = 0;
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
current_chunk_id = FloatIndexSearch(segment, info, query_data, num_queries, ins_barrier, bitset, final_qr);
}

// step 4: brute force search where small indexing is unavailable
// step 3: brute force search where small indexing is unavailable
auto vec_ptr = record.get_field_data_base(vecfield_id);
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
auto max_chunk = upper_div(ins_barrier, vec_size_per_chunk);
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
auto& chunk = vec_ptr->get_chunk(chunk_id);

for (int chunk_id = current_chunk_id; chunk_id < max_chunk; ++chunk_id) {
auto chunk_data = vec_ptr->get_chunk_data(chunk_id);

auto element_begin = chunk_id * vec_size_per_chunk;
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_size_per_chunk);
auto nsize = element_end - element_begin;
auto size_per_chunk = element_end - element_begin;

auto sub_view = bitset.subview(element_begin, nsize);
auto sub_result = BruteForceSearch(search_dataset, chunk.data(), nsize, sub_view);
auto sub_view = bitset.subview(element_begin, size_per_chunk);
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, sub_view);

// convert chunk uid to segment uid
for (auto& x : sub_result.mutable_seg_offsets()) {
for (auto& x : sub_qr.mutable_seg_offsets()) {
if (x != -1) {
x += chunk_id * vec_size_per_chunk;
}
}
final_result.merge(sub_result);
final_qr.merge(sub_qr);
}

final_result.round_values();
results.distances_ = std::move(final_result.mutable_distances());
results.seg_offsets_ = std::move(final_result.mutable_seg_offsets());
results.distances_ = std::move(final_qr.mutable_distances());
results.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
results.unity_topK_ = topk;
results.total_nq_ = num_queries;

return Status::OK();
}

// TODO: refactor and merge this into one
void
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
int64_t ins_barrier,
const query::SearchInfo& info,
const void* query_data,
int64_t num_queries,
const BitsetView& bitset,
SearchResult& results) {
// TODO: add data_type to info
auto data_type = segment.get_schema()[info.field_id_].get_data_type();
AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type");
if (data_type == DataType::VECTOR_FLOAT) {
auto typed_data = reinterpret_cast<const float*>(query_data);
FloatSearch(segment, info, typed_data, num_queries, ins_barrier, bitset, results);
} else {
auto typed_data = reinterpret_cast<const uint8_t*>(query_data);
BinarySearch(segment, info, typed_data, num_queries, ins_barrier, bitset, results);
}
}

} // namespace milvus::query
22 changes: 6 additions & 16 deletions internal/core/src/segcore/SegmentSealedImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,8 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
PanicInfo("Field Data is not loaded");
}

query::dataset::SearchDataset dataset;
dataset.query_data = query_data;
dataset.num_queries = query_count;
// if(field_meta.is)
dataset.metric_type = search_info.metric_type_;
dataset.topk = search_info.topk_;
dataset.dim = field_meta.get_dim();
dataset.round_decimal = search_info.round_decimal_;

query::dataset::SearchDataset dataset{search_info.metric_type_, query_count, search_info.topk_,
search_info.round_decimal_, field_meta.get_dim(), query_data};
AssertInfo(get_bit(field_data_ready_bitset_, field_id),
"Can't get bitset element at " + std::to_string(field_id.get()));
AssertInfo(row_count_opt_.has_value(), "Can't get row count value");
Expand All @@ -388,13 +381,10 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
auto chunk_data = vec_data->get_chunk_data(0);
auto sub_qr = query::BruteForceSearch(dataset, chunk_data, row_count, bitset);

SearchResult results;
results.distances_ = std::move(sub_qr.mutable_distances());
results.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
results.unity_topK_ = dataset.topk;
results.total_nq_ = dataset.num_queries;

output = std::move(results);
output.distances_ = std::move(sub_qr.mutable_distances());
output.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
output.unity_topK_ = dataset.topk;
output.total_nq_ = dataset.num_queries;
}

void
Expand Down

0 comments on commit 7cd37fc

Please sign in to comment.