Skip to content

Commit

Permalink
Initial support of half in IVF-FLAT
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Feb 26, 2025
1 parent a2a6a67 commit 5e48f79
Show file tree
Hide file tree
Showing 13 changed files with 860 additions and 55 deletions.
3 changes: 3 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,16 @@ if(BUILD_SHARED_LIBS)
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
src/neighbors/ivf_flat_index.cpp
src/neighbors/ivf_flat/ivf_flat_build_extend_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_build_extend_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_build_extend_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_build_extend_uint8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_helpers.cu
src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_search_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu
src/neighbors/ivf_pq_index.cpp
Expand Down
565 changes: 562 additions & 3 deletions cpp/include/cuvs/neighbors/ivf_flat.hpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions cpp/src/neighbors/ivf_flat/generate_ivf_flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

types = dict(
float_int64_t=("float", "int64_t"),
half_int64_t=("half", "int64_t"),
int8_t_int64_t=("int8_t", "int64_t"),
uint8_t_int64_t=("uint8_t", "int64_t"),
)
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ inline auto build(raft::resources const& handle,
auto stream = raft::resource::get_cuda_stream(handle);
cuvs::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope(
"ivf_flat::build(%zu, %u)", size_t(n_rows), dim);
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, uint8_t> ||
std::is_same_v<T, int8_t>,
"unsupported data type");
RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");
RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists");
Expand Down
103 changes: 103 additions & 0 deletions cpp/src/neighbors/ivf_flat/ivf_flat_build_extend_half_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by generate_ivf_flat.py
*
* Make changes there and run in this directory:
*
* > python generate_ivf_flat.py
*
*/

#include <cuvs/neighbors/ivf_flat.hpp>

#include "ivf_flat_build.cuh"

namespace cuvs::neighbors::ivf_flat {

#define CUVS_INST_IVF_FLAT_BUILD_EXTEND(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset) \
->cuvs::neighbors::ivf_flat::index<T, IdxT> \
{ \
return cuvs::neighbors::ivf_flat::index<T, IdxT>( \
std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \
} \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
cuvs::neighbors::ivf_flat::index<T, IdxT>& idx) \
{ \
cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \
} \
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::index_params& params, \
raft::host_matrix_view<const T, IdxT, raft::row_major> dataset) \
->cuvs::neighbors::ivf_flat::index<T, IdxT> \
{ \
return cuvs::neighbors::ivf_flat::index<T, IdxT>( \
std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \
} \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::index_params& params, \
raft::host_matrix_view<const T, IdxT, raft::row_major> dataset, \
cuvs::neighbors::ivf_flat::index<T, IdxT>& idx) \
{ \
cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \
} \
auto extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const cuvs::neighbors::ivf_flat::index<T, IdxT>& orig_index) \
->cuvs::neighbors::ivf_flat::index<T, IdxT> \
{ \
return cuvs::neighbors::ivf_flat::index<T, IdxT>(std::move( \
cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \
} \
\
void extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
cuvs::neighbors::ivf_flat::index<T, IdxT>* idx) \
{ \
cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \
} \
auto extend(raft::resources const& handle, \
raft::host_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
const cuvs::neighbors::ivf_flat::index<T, IdxT>& orig_index) \
->cuvs::neighbors::ivf_flat::index<T, IdxT> \
{ \
return cuvs::neighbors::ivf_flat::index<T, IdxT>(std::move( \
cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \
} \
\
void extend(raft::resources const& handle, \
raft::host_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
cuvs::neighbors::ivf_flat::index<T, IdxT>* idx) \
{ \
cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \
}
CUVS_INST_IVF_FLAT_BUILD_EXTEND(half, int64_t);

#undef CUVS_INST_IVF_FLAT_BUILD_EXTEND

} // namespace cuvs::neighbors::ivf_flat
82 changes: 42 additions & 40 deletions cpp/src/neighbors/ivf_flat/ivf_flat_helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ void pack(raft::resources const& res,
detail::pack<float, int64_t>(res, codes, veclen, offset, list_data);
}

void pack(raft::resources const& res,
raft::device_matrix_view<const half, uint32_t, raft::row_major> codes,
uint32_t veclen,
uint32_t offset,
raft::device_mdspan<half,
typename list_spec<uint32_t, half, int64_t>::list_extents,
raft::row_major> list_data)
{
detail::pack<half, int64_t>(res, codes, veclen, offset, list_data);
}

void pack(raft::resources const& res,
raft::device_matrix_view<const int8_t, uint32_t, raft::row_major> codes,
uint32_t veclen,
Expand Down Expand Up @@ -68,6 +79,17 @@ void unpack(raft::resources const& res,
detail::unpack<float, int64_t>(res, list_data, veclen, offset, codes);
}

void unpack(raft::resources const& res,
raft::device_mdspan<const half,
typename list_spec<uint32_t, half, int64_t>::list_extents,
raft::row_major> list_data,
uint32_t veclen,
uint32_t offset,
raft::device_matrix_view<half, uint32_t, raft::row_major> codes)
{
detail::unpack<half, int64_t>(res, list_data, veclen, offset, codes);
}

void unpack(raft::resources const& res,
raft::device_mdspan<const int8_t,
typename list_spec<uint32_t, int8_t, int64_t>::list_extents,
Expand Down Expand Up @@ -95,6 +117,11 @@ void pack_1(const float* flat_code, float* block, uint32_t dim, uint32_t veclen,
detail::pack_1<float>(flat_code, block, dim, veclen, offset);
}

void pack_1(const half* flat_code, half* block, uint32_t dim, uint32_t veclen, uint32_t offset)
{
detail::pack_1<half>(flat_code, block, dim, veclen, offset);
}

void pack_1(const int8_t* flat_code, int8_t* block, uint32_t dim, uint32_t veclen, uint32_t offset)
{
detail::pack_1<int8_t>(flat_code, block, dim, veclen, offset);
Expand All @@ -111,6 +138,11 @@ void unpack_1(const float* block, float* flat_code, uint32_t dim, uint32_t vecle
detail::unpack_1<float>(block, flat_code, dim, veclen, offset);
}

void unpack_1(const half* block, half* flat_code, uint32_t dim, uint32_t veclen, uint32_t offset)
{
detail::unpack_1<half>(block, flat_code, dim, veclen, offset);
}

void unpack_1(
const int8_t* block, int8_t* flat_code, uint32_t dim, uint32_t veclen, uint32_t offset)
{
Expand Down Expand Up @@ -149,51 +181,16 @@ void reset_index(const raft::resources& res, index<float, int64_t>* index)
detail::reset_index<float, int64_t>(res, index);
}

/**
* @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for
* externally modifying the index without going through the build stage. The data and indices of the
* IVF lists will be lost.
*
* Usage example:
* @code{.cpp}
* raft::resources res;
* using namespace cuvs::neighbors;
* // use default index parameters
* ivf_flat::index_params index_params;
* // initialize an empty index
* ivf_flat::index<int8_t, int64_t> index(res, index_params, D);
* // reset the index's state and list sizes
* ivf_flat::helpers::reset_index(res, &index);
* @endcode
*
* @param[in] res raft resource
* @param[inout] index pointer to IVF-Flat index
*/
void reset_index(const raft::resources& res, index<half, int64_t>* index)
{
detail::reset_index<half, int64_t>(res, index);
}

void reset_index(const raft::resources& res, index<int8_t, int64_t>* index)
{
detail::reset_index<int8_t, int64_t>(res, index);
}

/**
* @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for
* externally modifying the index without going through the build stage. The data and indices of the
* IVF lists will be lost.
*
* Usage example:
* @code{.cpp}
* raft::resources res;
* using namespace cuvs::neighbors;
* // use default index parameters
* ivf_flat::index_params index_params;
* // initialize an empty index
* ivf_flat::index<uint8_t, int64_t> index(res, index_params, D);
* // reset the index's state and list sizes
* ivf_flat::helpers::reset_index(res, &index);
* @endcode
*
* @param[in] res raft resource
* @param[inout] index pointer to IVF-Flat index
*/
void reset_index(const raft::resources& res, index<uint8_t, int64_t>* index)
{
detail::reset_index<uint8_t, int64_t>(res, index);
Expand All @@ -204,6 +201,11 @@ void recompute_internal_state(const raft::resources& res, index<float, int64_t>*
ivf::detail::recompute_internal_state(res, *index);
}

void recompute_internal_state(const raft::resources& res, index<half, int64_t>* index)
{
ivf::detail::recompute_internal_state(res, *index);
}

void recompute_internal_state(const raft::resources& res, index<int8_t, int64_t>* index)
{
ivf::detail::recompute_internal_state(res, *index);
Expand Down
48 changes: 48 additions & 0 deletions cpp/src/neighbors/ivf_flat/ivf_flat_search_half_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by generate_ivf_flat.py
*
* Make changes there and run in this directory:
*
* > python generate_ivf_flat.py
*
*/

#include <cuvs/neighbors/ivf_flat.hpp>

#include "ivf_flat_search.cuh"

namespace cuvs::neighbors::ivf_flat {

#define CUVS_INST_IVF_FLAT_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::search_params& params, \
const cuvs::neighbors::ivf_flat::index<T, IdxT>& index, \
raft::device_matrix_view<const T, IdxT, raft::row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors, \
raft::device_matrix_view<float, IdxT, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::ivf_flat::detail::search( \
handle, params, index, queries, neighbors, distances, sample_filter); \
}
CUVS_INST_IVF_FLAT_SEARCH(half, int64_t);

#undef CUVS_INST_IVF_FLAT_SEARCH

} // namespace cuvs::neighbors::ivf_flat
35 changes: 35 additions & 0 deletions cpp/src/neighbors/ivf_flat/ivf_flat_serialize_half_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by generate_ivf_flat.py
*
* Make changes there and run in this directory:
*
* > python generate_ivf_flat.py
*
*/

#include <cuvs/neighbors/ivf_flat.hpp>

#include "ivf_flat_serialize.cuh"

namespace cuvs::neighbors::ivf_flat {
CUVS_INST_IVF_FLAT_SERIALIZE(half, int64_t);

#undef CUVS_INST_IVF_FLAT_SERIALIZE

} // namespace cuvs::neighbors::ivf_flat
Loading

0 comments on commit 5e48f79

Please sign in to comment.