Skip to content

Commit

Permalink
enhance: [bitset] extend op_find() to be able to search both 0 and 1 (m…
Browse files Browse the repository at this point in the history
…ilvus-io#39176)

issue: milvus-io#39124 

`bitset::find_first()` and `bitset::find_next()` now accept one more
parameter, which allows to search for `0` bit instead of `1` bit

Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva authored Jan 14, 2025
1 parent 702347b commit 3447ff7
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 39 deletions.
4 changes: 2 additions & 2 deletions internal/core/src/bitset/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
detail/platform/x86/instruction_set.cpp
)

set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")
set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512cd -mbmi")
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma -mbmi")

# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")
Expand Down
17 changes: 10 additions & 7 deletions internal/core/src/bitset/bitset.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,23 +546,26 @@ class BitsetBase {
return as_derived();
}

// Find the index of the first bit set to true.
// Find the index of the first bit set to either true (default), or false.
inline std::optional<size_t>
find_first() const {
find_first(const bool is_set = true) const {
return policy_type::op_find(
this->data(), this->offset(), this->size(), 0);
this->data(), this->offset(), this->size(), 0, is_set);
}

// Find the index of the first bit set to true, starting from a given bit index.
// Find the index of the first bit set to either true (default), or false, starting from a given bit index.
inline std::optional<size_t>
find_next(const size_t starting_bit_idx) const {
find_next(const size_t starting_bit_idx, const bool is_set = true) const {
const size_t size_v = this->size();
if (starting_bit_idx + 1 >= size_v) {
return std::nullopt;
}

return policy_type::op_find(
this->data(), this->offset(), this->size(), starting_bit_idx + 1);
return policy_type::op_find(this->data(),
this->offset(),
this->size(),
starting_bit_idx + 1,
is_set);
}

// Read multiple bits starting from a given bit index.
Expand Down
5 changes: 3 additions & 2 deletions internal/core/src/bitset/detail/bit_wise.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,11 @@ struct BitWiseBitsetPolicy {
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
const size_t starting_idx,
const bool is_set) {
for (size_t i = starting_idx; i < size; i++) {
const auto proxy = get_proxy(data, start + i);
if (proxy) {
if (proxy == is_set) {
return i;
}
}
Expand Down
5 changes: 3 additions & 2 deletions internal/core/src/bitset/detail/element_vectorized.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ struct VectorizedElementWiseBitsetPolicy {
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
const size_t starting_idx,
const bool is_set) {
return ElementWiseBitsetPolicy<ElementT>::op_find(
data, start, size, starting_idx);
data, start, size, starting_idx, is_set);
}

//
Expand Down
93 changes: 89 additions & 4 deletions internal/core/src/bitset/detail/element_wise.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,10 @@ struct ElementWiseBitsetPolicy {

//
static inline std::optional<size_t>
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
op_find_1(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
if (size == 0) {
return std::nullopt;
}
Expand Down Expand Up @@ -788,6 +788,91 @@ struct ElementWiseBitsetPolicy {
return std::nullopt;
}

static inline std::optional<size_t>
op_find_0(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx) {
if (size == 0) {
return std::nullopt;
}

//
auto start_element = get_element(start + starting_idx);
const auto end_element = get_element(start + size);

const auto start_shift = get_shift(start + starting_idx);
const auto end_shift = get_shift(start + size);

// same element?
if (start_element == end_element) {
const data_type existing_v = ~data[start_element];

const data_type existing_mask = get_shift_mask_end(start_shift) &
get_shift_mask_begin(end_shift);

const data_type value = existing_v & existing_mask;
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value);
return size_t(ctz) + start_element * data_bits - start;
} else {
return std::nullopt;
}
}

// process the first element
if (start_shift != 0) {
const data_type existing_v = ~data[start_element];
const data_type existing_mask = get_shift_mask_end(start_shift);

const data_type value = existing_v & existing_mask;
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value) +
start_element * data_bits - start;
return size_t(ctz);
}

start_element += 1;
}

// process the middle
for (size_t i = start_element; i < end_element; i++) {
const data_type value = ~data[i];
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value);
return size_t(ctz) + i * data_bits - start;
}
}

// process the last element
if (end_shift != 0) {
const data_type existing_v = ~data[end_element];
const data_type existing_mask = get_shift_mask_begin(end_shift);

const data_type value = existing_v & existing_mask;
if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value);
return size_t(ctz) + end_element * data_bits - start;
}
}

return std::nullopt;
}

//
static inline std::optional<size_t>
op_find(const data_type* const data,
const size_t start,
const size_t size,
const size_t starting_idx,
const bool is_set) {
if (is_set) {
return op_find_1(data, start, size, starting_idx);
} else {
return op_find_0(data, start, size, starting_idx);
}
}

//
template <typename T, typename U, CompareOpType Op>
static inline void
Expand Down
2 changes: 2 additions & 0 deletions internal/core/src/bitset/detail/proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <cstddef>

namespace milvus {
namespace bitset {
namespace detail {
Expand Down
56 changes: 34 additions & 22 deletions internal/core/unittest/test_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ from_i32(const int32_t i) {
//
template <typename BitsetT>
void
TestFindImpl(BitsetT& bitset, const size_t max_v) {
TestFindImpl(BitsetT& bitset, const size_t max_v, const bool is_set) {
const size_t n = bitset.size();

std::default_random_engine rng(123);
Expand All @@ -361,9 +361,13 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
}
}

if (!is_set) {
bitset.flip();
}

StopWatch sw;

auto bit_idx = bitset.find_first();
auto bit_idx = bitset.find_first(is_set);
if (!bit_idx.has_value()) {
ASSERT_EQ(one_pos.size(), 0);
return;
Expand All @@ -372,7 +376,7 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
for (size_t i = 0; i < one_pos.size(); i++) {
ASSERT_TRUE(bit_idx.has_value()) << n << ", " << max_v;
ASSERT_EQ(bit_idx.value(), one_pos[i]) << n << ", " << max_v;
bit_idx = bitset.find_next(bit_idx.value());
bit_idx = bitset.find_next(bit_idx.value(), is_set);
}

ASSERT_FALSE(bit_idx.has_value())
Expand All @@ -387,32 +391,40 @@ template <typename BitsetT>
void
TestFindImpl() {
for (const size_t n : typical_sizes) {
for (const size_t pr : {1, 100}) {
BitsetT bitset(n);
bitset.reset();

if (print_log) {
printf("Testing bitset, n=%zd, pr=%zd\n", n, pr);
}

TestFindImpl(bitset, pr);

for (const size_t offset : typical_offsets) {
if (offset >= n) {
continue;
}

for (const bool is_set : {true, false}) {
for (const size_t pr : {1, 100}) {
BitsetT bitset(n);
bitset.reset();
auto view = bitset.view(offset);

if (print_log) {
printf("Testing bitset view, n=%zd, offset=%zd, pr=%zd\n",
printf("Testing bitset, n=%zd, is_set=%d, pr=%zd\n",
n,
offset,
(is_set) ? 1 : 0,
pr);
}

TestFindImpl(view, pr);
TestFindImpl(bitset, pr, is_set);

for (const size_t offset : typical_offsets) {
if (offset >= n) {
continue;
}

bitset.reset();
auto view = bitset.view(offset);

if (print_log) {
printf(
"Testing bitset view, n=%zd, offset=%zd, "
"is_set=%d, pr=%zd\n",
n,
offset,
(is_set) ? 1 : 0,
pr);
}

TestFindImpl(view, pr, is_set);
}
}
}
}
Expand Down

0 comments on commit 3447ff7

Please sign in to comment.