Skip to content

Commit

Permalink
CAGRA tech debt: distance descriptor and workspace memory (#436)
Browse files Browse the repository at this point in the history
This PR introduces two changes:

1.  Refactor `dataset_descriptor_host` to pass and cache it by value while keeping the state in a thread-safe object in a shared pointers. Before this, the descriptor host itself was kept in shared pointer in LRU cache and was passed by reference; as a result, it could in theory die due to cache eviction while still being used via references to it.
2. Adjust the temporary buffers to always use the workspace resource in all CAGRA algo implementations (as of now, only SINGLE_CTA algo does this; the PR expands the change to MULTI_CTA and MULTI_KERNEL).

Both of the changes are required for effective use of stream-ordered dynamic batching #261 (1. fixes crashes and 2. fixes thread-blocking behavior).

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #436
  • Loading branch information
achirkin authored Nov 6, 2024
1 parent eff2cc5 commit 6b35b65
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 68 deletions.
4 changes: 2 additions & 2 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void search_main(raft::resources const& res,
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index.data());
strided_dset != nullptr) {
// Search using a plain (strided) row-major dataset
auto& desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
auto desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
res, params, *strided_dset, index.metric());
search_main_core<T, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, params, desc, graph_internal, queries, neighbors, distances, sample_filter);
Expand All @@ -161,7 +161,7 @@ void search_main(raft::resources const& res,
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
auto& desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
auto desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
res, params, *vpq_dset, index.metric());
search_main_core<T, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, params, desc, graph_internal, queries, neighbors, distances, sample_filter);
Expand Down
77 changes: 52 additions & 25 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
#include <raft/util/device_loads_stores.cuh>
#include <raft/util/vectorized.cuh>

#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
#include <type_traits>
#include <variant>

Expand Down Expand Up @@ -232,52 +234,77 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
*/
template <typename DataT, typename IndexT, typename DistanceT>
struct dataset_descriptor_host {
using dev_descriptor_t = dataset_descriptor_base_t<DataT, IndexT, DistanceT>;
using dd_ptr_t = std::shared_ptr<dev_descriptor_t>;
using init_f =
std::tuple<std::function<void(dev_descriptor_t*, rmm::cuda_stream_view stream)>, size_t>;
using dev_descriptor_t = dataset_descriptor_base_t<DataT, IndexT, DistanceT>;
uint32_t smem_ws_size_in_bytes = 0;
uint32_t team_size = 0;

struct state {
using ready_t = std::tuple<dev_descriptor_t*, rmm::cuda_stream_view>;
using init_f =
std::tuple<std::function<void(dev_descriptor_t*, rmm::cuda_stream_view)>, size_t>;

std::mutex mutex;
std::atomic<bool> ready; // Not sure if std::holds_alternative is thread-safe
std::variant<ready_t, init_f> value;

template <typename InitF>
state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)}
{
}

~state() noexcept
{
if (std::holds_alternative<ready_t>(value)) {
auto& [ptr, stream] = std::get<ready_t>(value);
RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream));
}
}

void eval(rmm::cuda_stream_view stream)
{
std::lock_guard<std::mutex> lock(mutex);
if (std::holds_alternative<init_f>(value)) {
auto& [fun, size] = std::get<init_f>(value);
dev_descriptor_t* ptr = nullptr;
RAFT_CUDA_TRY(cudaMallocAsync(&ptr, size, stream));
fun(ptr, stream);
value = std::make_tuple(ptr, stream);
ready.store(true, std::memory_order_release);
}
}

auto get(rmm::cuda_stream_view stream) -> dev_descriptor_t*
{
if (!ready.load(std::memory_order_acquire)) { eval(stream); }
return std::get<0>(std::get<ready_t>(value));
}
};

template <typename DescriptorImpl, typename InitF>
dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init)
: value_{std::make_tuple(init, sizeof(DescriptorImpl))},
: value_{std::make_shared<state>(init, sizeof(DescriptorImpl))},
smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()},
team_size{dd_host.team_size()}
{
}

dataset_descriptor_host() = default;

/**
* Return the device pointer, possibly evaluating it in the given thread.
*/
[[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) const -> const dev_descriptor_t*
{
if (std::holds_alternative<init_f>(value_)) { value_ = eval(std::get<init_f>(value_), stream); }
return std::get<dd_ptr_t>(value_).get();
return value_->get(stream);
}

[[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) -> dev_descriptor_t*
{
if (std::holds_alternative<init_f>(value_)) { value_ = eval(std::get<init_f>(value_), stream); }
return std::get<dd_ptr_t>(value_).get();
return value_->get(stream);
}

private:
mutable std::variant<dd_ptr_t, init_f> value_;

static auto eval(init_f init, rmm::cuda_stream_view stream) -> dd_ptr_t
{
using raft::RAFT_NAME;
auto& [fun, size] = init;
dd_ptr_t dev_ptr{
[stream, s = size]() {
dev_descriptor_t* p;
RAFT_CUDA_TRY(cudaMallocAsync(&p, s, stream));
return p;
}(),
[stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }};
fun(dev_ptr.get(), stream);
return dev_ptr;
}
mutable std::shared_ptr<state> value_;
};

/**
Expand Down
20 changes: 8 additions & 12 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,9 @@ template <typename DataT, typename IndexT, typename DistanceT>
struct store {
/** Number of descriptors to cache. */
static constexpr size_t kDefaultSize = 100;
raft::cache::lru<key,
key_hash,
std::equal_to<>,
std::shared_ptr<dataset_descriptor_host<DataT, IndexT, DistanceT>>>
value{kDefaultSize};
raft::cache::
lru<key, key_hash, std::equal_to<>, dataset_descriptor_host<DataT, IndexT, DistanceT>>
value{kDefaultSize};
};

} // namespace descriptor_cache
Expand All @@ -159,20 +157,18 @@ auto dataset_descriptor_init_with_cache(const raft::resources& res,
const cagra::search_params& params,
const DatasetT& dataset,
cuvs::distance::DistanceType metric)
-> const dataset_descriptor_host<DataT, IndexT, DistanceT>&
-> dataset_descriptor_host<DataT, IndexT, DistanceT>
{
using desc_t = dataset_descriptor_host<DataT, IndexT, DistanceT>;
auto key = descriptor_cache::make_key(params, dataset, metric);
auto key = descriptor_cache::make_key(params, dataset, metric);
auto& cache =
raft::resource::get_custom_resource<descriptor_cache::store<DataT, IndexT, DistanceT>>(res)
->value;
std::shared_ptr<desc_t> desc{nullptr};
dataset_descriptor_host<DataT, IndexT, DistanceT> desc;
if (!cache.get(key, &desc)) {
desc = std::make_shared<desc_t>(
std::move(dataset_descriptor_init<DataT, IndexT, DistanceT>(params, dataset, metric)));
desc = dataset_descriptor_init<DataT, IndexT, DistanceT>(params, dataset, metric);
cache.set(key, desc);
}
return *desc;
return desc;
}

}; // namespace cuvs::neighbors::cagra::detail
12 changes: 6 additions & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
using base_type::num_seeds;

uint32_t num_cta_per_query;
rmm::device_uvector<INDEX_T> intermediate_indices;
rmm::device_uvector<float> intermediate_distances;
lightweight_uvector<INDEX_T> intermediate_indices;
lightweight_uvector<float> intermediate_distances;
size_t topk_workspace_size;
rmm::device_uvector<uint32_t> topk_workspace;
lightweight_uvector<uint32_t> topk_workspace;

search(raft::resources const& res,
search_params params,
Expand All @@ -105,9 +105,9 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
int64_t graph_degree,
uint32_t topk)
: base_type(res, params, dataset_desc, dim, graph_degree, topk),
intermediate_indices(0, raft::resource::get_cuda_stream(res)),
intermediate_distances(0, raft::resource::get_cuda_stream(res)),
topk_workspace(0, raft::resource::get_cuda_stream(res))
intermediate_indices(res),
intermediate_distances(res),
topk_workspace(res)

{
set_params(res, params);
Expand Down
53 changes: 31 additions & 22 deletions cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ void get_value(T* const host_ptr, const T* const dev_ptr, cudaStream_t cuda_stre
get_value_kernel<T><<<1, 1, 0, cuda_stream>>>(host_ptr, dev_ptr);
}

template <class T>
auto get_value(const T* const dev_ptr, cudaStream_t stream) -> T
{
T value;
RAFT_CUDA_TRY(cudaMemcpyAsync(&value, dev_ptr, sizeof(value), cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
return value;
}

// MAX_DATASET_DIM : must equal to or greater than dataset_dim
template <class DATASET_DESCRIPTOR_T>
RAFT_KERNEL random_pickup_kernel(
Expand Down Expand Up @@ -609,18 +618,18 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
using base_type::num_seeds;

size_t result_buffer_allocation_size;
rmm::device_uvector<INDEX_T> result_indices; // results_indices_buffer
rmm::device_uvector<DISTANCE_T> result_distances; // result_distances_buffer
rmm::device_uvector<INDEX_T> parent_node_list;
rmm::device_uvector<uint32_t> topk_hint;
rmm::device_scalar<uint32_t> terminate_flag; // dev_terminate_flag, host_terminate_flag.;
rmm::device_uvector<uint32_t> topk_workspace;
lightweight_uvector<INDEX_T> result_indices; // results_indices_buffer
lightweight_uvector<DISTANCE_T> result_distances; // result_distances_buffer
lightweight_uvector<INDEX_T> parent_node_list;
lightweight_uvector<uint32_t> topk_hint;
lightweight_uvector<uint32_t> terminate_flag; // dev_terminate_flag, host_terminate_flag.;
lightweight_uvector<uint32_t> topk_workspace;

// temporary storage for _find_topk
rmm::device_uvector<float> input_keys_storage;
rmm::device_uvector<float> output_keys_storage;
rmm::device_uvector<INDEX_T> input_values_storage;
rmm::device_uvector<INDEX_T> output_values_storage;
lightweight_uvector<float> input_keys_storage;
lightweight_uvector<float> output_keys_storage;
lightweight_uvector<INDEX_T> input_values_storage;
lightweight_uvector<INDEX_T> output_values_storage;

search(raft::resources const& res,
search_params params,
Expand All @@ -629,16 +638,16 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
int64_t graph_degree,
uint32_t topk)
: base_type(res, params, dataset_desc, dim, graph_degree, topk),
result_indices(0, raft::resource::get_cuda_stream(res)),
result_distances(0, raft::resource::get_cuda_stream(res)),
parent_node_list(0, raft::resource::get_cuda_stream(res)),
topk_hint(0, raft::resource::get_cuda_stream(res)),
topk_workspace(0, raft::resource::get_cuda_stream(res)),
terminate_flag(raft::resource::get_cuda_stream(res)),
input_keys_storage(0, raft::resource::get_cuda_stream(res)),
output_keys_storage(0, raft::resource::get_cuda_stream(res)),
input_values_storage(0, raft::resource::get_cuda_stream(res)),
output_values_storage(0, raft::resource::get_cuda_stream(res))
result_indices(res),
result_distances(res),
parent_node_list(res),
topk_hint(res),
topk_workspace(res),
terminate_flag(res),
input_keys_storage(res),
output_keys_storage(res),
input_values_storage(res),
output_values_storage(res)
{
set_params(res);
}
Expand All @@ -662,7 +671,7 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type<DATA_T>());
RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size);
topk_workspace.resize(topk_workspace_size, raft::resource::get_cuda_stream(res));

terminate_flag.resize(1, raft::resource::get_cuda_stream(res));
hashmap.resize(hashmap_size, raft::resource::get_cuda_stream(res));
}

Expand Down Expand Up @@ -847,7 +856,7 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
stream);

// termination (2)
if (iter + 1 >= min_iterations && terminate_flag.value(stream)) {
if (iter + 1 >= min_iterations && get_value(terminate_flag.data(), stream)) {
iter++;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ struct search_plan_impl : public search_plan_impl_base {
lightweight_uvector<INDEX_T> hashmap;
lightweight_uvector<uint32_t> num_executed_iterations; // device or managed?
lightweight_uvector<INDEX_T> dev_seed;
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc;
dataset_descriptor_host<DataT, IndexT, DistanceT> dataset_desc;

search_plan_impl(raft::resources const& res,
search_params params,
Expand Down

0 comments on commit 6b35b65

Please sign in to comment.