Skip to content

Commit

Permalink
[Feat] Add support of logical merge in Cagra
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Feb 20, 2025
1 parent 3c9e745 commit cb8f16b
Show file tree
Hide file tree
Showing 12 changed files with 585 additions and 61 deletions.
204 changes: 200 additions & 4 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ struct extend_params {
/**
* @brief Determines the strategy for merging CAGRA graphs.
*
* @note Currently, only the PHYSICAL strategy is supported.
*/
enum MergeStrategy {
/**
Expand All @@ -286,9 +285,16 @@ enum MergeStrategy {
* This is expensive to build but does not impact search latency or quality.
* Preferred for many smaller CAGRA graphs.
*
* @note Currently, this is the only supported strategy.
*/
PHYSICAL
PHYSICAL,
/**
* @brief Logical merge: Wraps a new index structure around existing CAGRA graphs
* and broadcasts the query to each of them.
*
* This is a fast merge but incurs a small hit in search latency.
* Preferred for fewer larger CAGRA graphs.
*/
LOGICAL
};

/**
Expand Down Expand Up @@ -563,6 +569,82 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
};
/**
* @}
*/

/**
* @defgroup cagra_cpp_composite_index CAGRA composite index type
* @{
*/

/**
* @brief Lightweight composite kNN index for CAGRA.
*
* This class aggregates logically multiple CAGRA indices into a single composite index,
* providing a unified interface for kNN search. It is a lightweight structure
* that does not own or manage the lifecycle of the underlying indices; instead,
* it holds non-owning pointers to them.
*
* All sub-indices within the composite index **must share the same distance metric
* and dimensionality**.
*
* @tparam T Data element type.
* @tparam IdxT Index type representing dataset.extent(0), used for vector indices.
*/

template <typename T, typename IdxT>
struct composite_index {
template <typename Container>
explicit composite_index(Container&& indices) : sub_indices(std::forward<Container>(indices))
{
RAFT_EXPECTS(!sub_indices.empty(), "composite_index requires at least one sub-index.");

for (auto* idx : sub_indices) {
RAFT_EXPECTS(idx != nullptr, "sub_indices contains a null pointer.");
}

auto& first_index = *sub_indices.front();
metric_ = first_index.metric();
dim_ = first_index.dim();
size_ = 0;

for (auto* idx : sub_indices) {
RAFT_EXPECTS(idx->metric() == metric_, "All sub-indices must have the same metric.");
RAFT_EXPECTS(idx->dim() == dim_, "All sub-indices must have the same dim.");
size_ += idx->size();
}
}

public:
composite_index(const composite_index& other) = default;
composite_index& operator=(const composite_index& other) = default;

composite_index(composite_index&& other) noexcept = default;
composite_index& operator=(composite_index&& other) noexcept = default;

constexpr inline auto metric() const noexcept -> cuvs::distance::DistanceType { return metric_; }

constexpr inline auto size() const noexcept -> IdxT { return size_; }

constexpr inline auto dim() const noexcept -> uint32_t { return dim_; }

constexpr inline auto graph_degree() const noexcept -> uint32_t
{
return sub_indices.front()->graph_degree();
}

constexpr inline auto num_indices() const noexcept -> uint32_t { return sub_indices.size(); }

public:
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> sub_indices;

private:
cuvs::distance::DistanceType metric_;
IdxT size_;
uint32_t dim_;
};

/**
* @}
*/
Expand Down Expand Up @@ -1123,7 +1205,6 @@ void extend(
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/

void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
Expand Down Expand Up @@ -1207,7 +1288,105 @@ void search(raft::resources const& res,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<float, uint32_t>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<half, uint32_t>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<int8_t, uint32_t>& index,
raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the composite cagra index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index composite cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::composite_index<uint8_t, uint32_t>& index,
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @}
*/
Expand Down Expand Up @@ -1983,6 +2162,23 @@ auto merge(raft::resources const& res,
const cuvs::neighbors::cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<uint8_t, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::index<uint8_t, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<float, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<float, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<half, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<half, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<int8_t, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<int8_t, uint32_t>;

auto make_composite_index(const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<uint8_t, uint32_t>*>& indices)
-> cuvs::neighbors::cagra::composite_index<uint8_t, uint32_t>;

/**
* @}
*/
Expand Down
32 changes: 31 additions & 1 deletion cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ void search(raft::resources const& res,
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, none_filter_type>(
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
return;
} catch (const std::bad_cast&) {
}

Expand All @@ -369,6 +368,27 @@ void search(raft::resources const& res,
}
}

template <typename T, typename IdxT>
void search(raft::resources const& res,
const search_params& params,
const composite_index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
{
try {
using expected_filter_t = cuvs::neighbors::filtering::none_sample_filter;

auto& sample_filter = dynamic_cast<const expected_filter_t&>(sample_filter_ref);
auto sample_filter_copy = sample_filter;
return cagra::detail::search_on_composite_index<T, IdxT, expected_filter_t>(
res, params, idx, queries, neighbors, distances, sample_filter_copy);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type by composite_index");
}
}

template <class T, class IdxT, class Accessor>
void extend(
raft::resources const& handle,
Expand All @@ -389,6 +409,16 @@ index<T, IdxT> merge(raft::resources const& handle,
return cagra::detail::merge<T, IdxT>(handle, params, indices);
}

template <class T, class IdxT>
composite_index<T, IdxT> make_composite_index(const cagra::merge_params& params,
std::vector<index<T, IdxT>*>& indices)
{
if (params.strategy != cagra::MergeStrategy::LOGICAL) {
RAFT_LOG_WARN("Merge strategy should be LOGICAL.");
}
return composite_index<T, IdxT>(std::move(indices));
}

/** @} */ // end group cagra

} // namespace cuvs::neighbors::cagra
20 changes: 13 additions & 7 deletions cpp/src/neighbors/cagra_merge_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

namespace cuvs::neighbors::cagra {

#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
}; \
auto make_composite_index(const cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::composite_index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::make_composite_index<T, IdxT>(params, indices); \
}

RAFT_INST_CAGRA_MERGE(float, uint32_t);
Expand Down
20 changes: 13 additions & 7 deletions cpp/src/neighbors/cagra_merge_half.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

namespace cuvs::neighbors::cagra {

#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
}; \
auto make_composite_index(const cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::composite_index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::make_composite_index<T, IdxT>(params, indices); \
}

RAFT_INST_CAGRA_MERGE(half, uint32_t);
Expand Down
20 changes: 13 additions & 7 deletions cpp/src/neighbors/cagra_merge_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

namespace cuvs::neighbors::cagra {

#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
#define RAFT_INST_CAGRA_MERGE(T, IdxT) \
auto merge(raft::resources const& handle, \
const cuvs::neighbors::cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::merge<T, IdxT>(handle, params, indices); \
}; \
auto make_composite_index(const cagra::merge_params& params, \
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices) \
->cuvs::neighbors::cagra::composite_index<T, IdxT> \
{ \
return cuvs::neighbors::cagra::make_composite_index<T, IdxT>(params, indices); \
}

RAFT_INST_CAGRA_MERGE(int8_t, uint32_t);
Expand Down
Loading

0 comments on commit cb8f16b

Please sign in to comment.