From 6b35b65923933e6396ae61322ce2e9b0772eea4a Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Wed, 6 Nov 2024 08:24:44 +0100 Subject: [PATCH] CAGRA tech debt: distance descriptor and workspace memory (#436) This PR introduces two changes: 1. Refactor `dataset_descriptor_host` to pass and cache it by value while keeping the state in a thread-safe object in a shared pointers. Before this, the descriptor host itself was kept in shared pointer in LRU cache and was passed by reference; as a result, it could in theory die due to cache eviction while still being used via references to it. 2. Adjust the temporary buffers to always use the workspace resource in all CAGRA algo implementations (as of now, only SINGLE_CTA algo does this; the PR expands the change to MULTI_CTA and MULTI_KERNEL). Both of the changes are required for effective use of stream-ordered dynamic batching https://github.com/rapidsai/cuvs/pull/261 (1. fixes crashes and 2. fixes thread-blocking behavior). Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/436 --- .../neighbors/detail/cagra/cagra_search.cuh | 4 +- .../detail/cagra/compute_distance.hpp | 77 +++++++++++++------ cpp/src/neighbors/detail/cagra/factory.cuh | 20 ++--- .../detail/cagra/search_multi_cta.cuh | 12 +-- .../detail/cagra/search_multi_kernel.cuh | 53 +++++++------ .../neighbors/detail/cagra/search_plan.cuh | 2 +- 6 files changed, 100 insertions(+), 68 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 95c158675..5778d85a6 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -151,7 +151,7 @@ void search_main(raft::resources const& res, if (auto* strided_dset = dynamic_cast*>(&index.data()); strided_dset != nullptr) { // Search using a plain (strided) row-major dataset - auto& desc = dataset_descriptor_init_with_cache( + auto desc = dataset_descriptor_init_with_cache( res, params, *strided_dset, index.metric()); search_main_core( res, params, desc, graph_internal, queries, neighbors, distances, sample_filter); @@ -161,7 +161,7 @@ void search_main(raft::resources const& res, RAFT_FAIL("FP32 VPQ dataset support is coming soon"); } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { - auto& desc = dataset_descriptor_init_with_cache( + auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric()); search_main_core( res, params, desc, graph_internal, queries, neighbors, distances, sample_filter); diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 297eb1f55..7eb798459 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -31,8 +31,10 @@ #include #include +#include #include #include +#include #include #include @@ -232,52 +234,77 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t { */ template struct dataset_descriptor_host { - using dev_descriptor_t = dataset_descriptor_base_t; - using dd_ptr_t = std::shared_ptr; - using init_f = - std::tuple, size_t>; + using dev_descriptor_t = dataset_descriptor_base_t; uint32_t smem_ws_size_in_bytes = 0; uint32_t team_size = 0; + struct state { + using ready_t = std::tuple; + using init_f = + std::tuple, size_t>; + + std::mutex mutex; + std::atomic ready; // Not sure if std::holds_alternative is thread-safe + std::variant value; + + template + state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)} + { + } + + ~state() noexcept + { + if (std::holds_alternative(value)) { + auto& [ptr, stream] = std::get(value); + RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream)); + } + } + + void eval(rmm::cuda_stream_view stream) + { + std::lock_guard lock(mutex); + if (std::holds_alternative(value)) { + auto& [fun, size] = std::get(value); + dev_descriptor_t* ptr = nullptr; + RAFT_CUDA_TRY(cudaMallocAsync(&ptr, size, stream)); + fun(ptr, stream); + value = std::make_tuple(ptr, stream); + ready.store(true, std::memory_order_release); + } + } + + auto get(rmm::cuda_stream_view stream) -> dev_descriptor_t* + { + if (!ready.load(std::memory_order_acquire)) { eval(stream); } + return std::get<0>(std::get(value)); + } + }; + template dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init) - : value_{std::make_tuple(init, sizeof(DescriptorImpl))}, + : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, team_size{dd_host.team_size()} { } + dataset_descriptor_host() = default; + /** * Return the device pointer, possibly evaluating it in the given thread. */ [[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) const -> const dev_descriptor_t* { - if (std::holds_alternative(value_)) { value_ = eval(std::get(value_), stream); } - return std::get(value_).get(); + return value_->get(stream); } + [[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) -> dev_descriptor_t* { - if (std::holds_alternative(value_)) { value_ = eval(std::get(value_), stream); } - return std::get(value_).get(); + return value_->get(stream); } private: - mutable std::variant value_; - - static auto eval(init_f init, rmm::cuda_stream_view stream) -> dd_ptr_t - { - using raft::RAFT_NAME; - auto& [fun, size] = init; - dd_ptr_t dev_ptr{ - [stream, s = size]() { - dev_descriptor_t* p; - RAFT_CUDA_TRY(cudaMallocAsync(&p, s, stream)); - return p; - }(), - [stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }}; - fun(dev_ptr.get(), stream); - return dev_ptr; - } + mutable std::shared_ptr value_; }; /** diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index abc907da5..e6e7ff64f 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -135,11 +135,9 @@ template struct store { /** Number of descriptors to cache. */ static constexpr size_t kDefaultSize = 100; - raft::cache::lru, - std::shared_ptr>> - value{kDefaultSize}; + raft::cache:: + lru, dataset_descriptor_host> + value{kDefaultSize}; }; } // namespace descriptor_cache @@ -159,20 +157,18 @@ auto dataset_descriptor_init_with_cache(const raft::resources& res, const cagra::search_params& params, const DatasetT& dataset, cuvs::distance::DistanceType metric) - -> const dataset_descriptor_host& + -> dataset_descriptor_host { - using desc_t = dataset_descriptor_host; - auto key = descriptor_cache::make_key(params, dataset, metric); + auto key = descriptor_cache::make_key(params, dataset, metric); auto& cache = raft::resource::get_custom_resource>(res) ->value; - std::shared_ptr desc{nullptr}; + dataset_descriptor_host desc; if (!cache.get(key, &desc)) { - desc = std::make_shared( - std::move(dataset_descriptor_init(params, dataset, metric))); + desc = dataset_descriptor_init(params, dataset, metric); cache.set(key, desc); } - return *desc; + return desc; } }; // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 0003f2495..ecfd856f1 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -93,10 +93,10 @@ struct search : public search_plan_impl intermediate_indices; - rmm::device_uvector intermediate_distances; + lightweight_uvector intermediate_indices; + lightweight_uvector intermediate_distances; size_t topk_workspace_size; - rmm::device_uvector topk_workspace; + lightweight_uvector topk_workspace; search(raft::resources const& res, search_params params, @@ -105,9 +105,9 @@ struct search : public search_plan_impl<<<1, 1, 0, cuda_stream>>>(host_ptr, dev_ptr); } +template +auto get_value(const T* const dev_ptr, cudaStream_t stream) -> T +{ + T value; + RAFT_CUDA_TRY(cudaMemcpyAsync(&value, dev_ptr, sizeof(value), cudaMemcpyDefault, stream)); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + return value; +} + // MAX_DATASET_DIM : must equal to or greater than dataset_dim template RAFT_KERNEL random_pickup_kernel( @@ -609,18 +618,18 @@ struct search : search_plan_impl { using base_type::num_seeds; size_t result_buffer_allocation_size; - rmm::device_uvector result_indices; // results_indices_buffer - rmm::device_uvector result_distances; // result_distances_buffer - rmm::device_uvector parent_node_list; - rmm::device_uvector topk_hint; - rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; - rmm::device_uvector topk_workspace; + lightweight_uvector result_indices; // results_indices_buffer + lightweight_uvector result_distances; // result_distances_buffer + lightweight_uvector parent_node_list; + lightweight_uvector topk_hint; + lightweight_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; + lightweight_uvector topk_workspace; // temporary storage for _find_topk - rmm::device_uvector input_keys_storage; - rmm::device_uvector output_keys_storage; - rmm::device_uvector input_values_storage; - rmm::device_uvector output_values_storage; + lightweight_uvector input_keys_storage; + lightweight_uvector output_keys_storage; + lightweight_uvector input_values_storage; + lightweight_uvector output_values_storage; search(raft::resources const& res, search_params params, @@ -629,16 +638,16 @@ struct search : search_plan_impl { int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, graph_degree, topk), - result_indices(0, raft::resource::get_cuda_stream(res)), - result_distances(0, raft::resource::get_cuda_stream(res)), - parent_node_list(0, raft::resource::get_cuda_stream(res)), - topk_hint(0, raft::resource::get_cuda_stream(res)), - topk_workspace(0, raft::resource::get_cuda_stream(res)), - terminate_flag(raft::resource::get_cuda_stream(res)), - input_keys_storage(0, raft::resource::get_cuda_stream(res)), - output_keys_storage(0, raft::resource::get_cuda_stream(res)), - input_values_storage(0, raft::resource::get_cuda_stream(res)), - output_values_storage(0, raft::resource::get_cuda_stream(res)) + result_indices(res), + result_distances(res), + parent_node_list(res), + topk_hint(res), + topk_workspace(res), + terminate_flag(res), + input_keys_storage(res), + output_keys_storage(res), + input_values_storage(res), + output_values_storage(res) { set_params(res); } @@ -662,7 +671,7 @@ struct search : search_plan_impl { itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type()); RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); topk_workspace.resize(topk_workspace_size, raft::resource::get_cuda_stream(res)); - + terminate_flag.resize(1, raft::resource::get_cuda_stream(res)); hashmap.resize(hashmap_size, raft::resource::get_cuda_stream(res)); } @@ -847,7 +856,7 @@ struct search : search_plan_impl { stream); // termination (2) - if (iter + 1 >= min_iterations && terminate_flag.value(stream)) { + if (iter + 1 >= min_iterations && get_value(terminate_flag.data(), stream)) { iter++; break; } diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index f23b96631..99254aa50 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -151,7 +151,7 @@ struct search_plan_impl : public search_plan_impl_base { lightweight_uvector hashmap; lightweight_uvector num_executed_iterations; // device or managed? lightweight_uvector dev_seed; - const dataset_descriptor_host& dataset_desc; + dataset_descriptor_host dataset_desc; search_plan_impl(raft::resources const& res, search_params params,