diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 95c158675..5778d85a6 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -151,7 +151,7 @@ void search_main(raft::resources const& res, if (auto* strided_dset = dynamic_cast*>(&index.data()); strided_dset != nullptr) { // Search using a plain (strided) row-major dataset - auto& desc = dataset_descriptor_init_with_cache( + auto desc = dataset_descriptor_init_with_cache( res, params, *strided_dset, index.metric()); search_main_core( res, params, desc, graph_internal, queries, neighbors, distances, sample_filter); @@ -161,7 +161,7 @@ void search_main(raft::resources const& res, RAFT_FAIL("FP32 VPQ dataset support is coming soon"); } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { - auto& desc = dataset_descriptor_init_with_cache( + auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric()); search_main_core( res, params, desc, graph_internal, queries, neighbors, distances, sample_filter); diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 297eb1f55..7eb798459 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -31,8 +31,10 @@ #include #include +#include #include #include +#include #include #include @@ -232,52 +234,77 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t { */ template struct dataset_descriptor_host { - using dev_descriptor_t = dataset_descriptor_base_t; - using dd_ptr_t = std::shared_ptr; - using init_f = - std::tuple, size_t>; + using dev_descriptor_t = dataset_descriptor_base_t; uint32_t smem_ws_size_in_bytes = 0; uint32_t team_size = 0; + struct state { + using ready_t = std::tuple; + using init_f = + std::tuple, size_t>; + + std::mutex mutex; + std::atomic ready; // Not sure if std::holds_alternative is thread-safe + std::variant value; + + template + state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)} + { + } + + ~state() noexcept + { + if (std::holds_alternative(value)) { + auto& [ptr, stream] = std::get(value); + RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream)); + } + } + + void eval(rmm::cuda_stream_view stream) + { + std::lock_guard lock(mutex); + if (std::holds_alternative(value)) { + auto& [fun, size] = std::get(value); + dev_descriptor_t* ptr = nullptr; + RAFT_CUDA_TRY(cudaMallocAsync(&ptr, size, stream)); + fun(ptr, stream); + value = std::make_tuple(ptr, stream); + ready.store(true, std::memory_order_release); + } + } + + auto get(rmm::cuda_stream_view stream) -> dev_descriptor_t* + { + if (!ready.load(std::memory_order_acquire)) { eval(stream); } + return std::get<0>(std::get(value)); + } + }; + template dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init) - : value_{std::make_tuple(init, sizeof(DescriptorImpl))}, + : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, team_size{dd_host.team_size()} { } + dataset_descriptor_host() = default; + /** * Return the device pointer, possibly evaluating it in the given thread. */ [[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) const -> const dev_descriptor_t* { - if (std::holds_alternative(value_)) { value_ = eval(std::get(value_), stream); } - return std::get(value_).get(); + return value_->get(stream); } + [[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) -> dev_descriptor_t* { - if (std::holds_alternative(value_)) { value_ = eval(std::get(value_), stream); } - return std::get(value_).get(); + return value_->get(stream); } private: - mutable std::variant value_; - - static auto eval(init_f init, rmm::cuda_stream_view stream) -> dd_ptr_t - { - using raft::RAFT_NAME; - auto& [fun, size] = init; - dd_ptr_t dev_ptr{ - [stream, s = size]() { - dev_descriptor_t* p; - RAFT_CUDA_TRY(cudaMallocAsync(&p, s, stream)); - return p; - }(), - [stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }}; - fun(dev_ptr.get(), stream); - return dev_ptr; - } + mutable std::shared_ptr value_; }; /** diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index abc907da5..e6e7ff64f 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -135,11 +135,9 @@ template struct store { /** Number of descriptors to cache. */ static constexpr size_t kDefaultSize = 100; - raft::cache::lru, - std::shared_ptr>> - value{kDefaultSize}; + raft::cache:: + lru, dataset_descriptor_host> + value{kDefaultSize}; }; } // namespace descriptor_cache @@ -159,20 +157,18 @@ auto dataset_descriptor_init_with_cache(const raft::resources& res, const cagra::search_params& params, const DatasetT& dataset, cuvs::distance::DistanceType metric) - -> const dataset_descriptor_host& + -> dataset_descriptor_host { - using desc_t = dataset_descriptor_host; - auto key = descriptor_cache::make_key(params, dataset, metric); + auto key = descriptor_cache::make_key(params, dataset, metric); auto& cache = raft::resource::get_custom_resource>(res) ->value; - std::shared_ptr desc{nullptr}; + dataset_descriptor_host desc; if (!cache.get(key, &desc)) { - desc = std::make_shared( - std::move(dataset_descriptor_init(params, dataset, metric))); + desc = dataset_descriptor_init(params, dataset, metric); cache.set(key, desc); } - return *desc; + return desc; } }; // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 0003f2495..ecfd856f1 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -93,10 +93,10 @@ struct search : public search_plan_impl intermediate_indices; - rmm::device_uvector intermediate_distances; + lightweight_uvector intermediate_indices; + lightweight_uvector intermediate_distances; size_t topk_workspace_size; - rmm::device_uvector topk_workspace; + lightweight_uvector topk_workspace; search(raft::resources const& res, search_params params, @@ -105,9 +105,9 @@ struct search : public search_plan_impl<<<1, 1, 0, cuda_stream>>>(host_ptr, dev_ptr); } +template +auto get_value(const T* const dev_ptr, cudaStream_t stream) -> T +{ + T value; + RAFT_CUDA_TRY(cudaMemcpyAsync(&value, dev_ptr, sizeof(value), cudaMemcpyDefault, stream)); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + return value; +} + // MAX_DATASET_DIM : must equal to or greater than dataset_dim template RAFT_KERNEL random_pickup_kernel( @@ -609,18 +618,18 @@ struct search : search_plan_impl { using base_type::num_seeds; size_t result_buffer_allocation_size; - rmm::device_uvector result_indices; // results_indices_buffer - rmm::device_uvector result_distances; // result_distances_buffer - rmm::device_uvector parent_node_list; - rmm::device_uvector topk_hint; - rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; - rmm::device_uvector topk_workspace; + lightweight_uvector result_indices; // results_indices_buffer + lightweight_uvector result_distances; // result_distances_buffer + lightweight_uvector parent_node_list; + lightweight_uvector topk_hint; + lightweight_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; + lightweight_uvector topk_workspace; // temporary storage for _find_topk - rmm::device_uvector input_keys_storage; - rmm::device_uvector output_keys_storage; - rmm::device_uvector input_values_storage; - rmm::device_uvector output_values_storage; + lightweight_uvector input_keys_storage; + lightweight_uvector output_keys_storage; + lightweight_uvector input_values_storage; + lightweight_uvector output_values_storage; search(raft::resources const& res, search_params params, @@ -629,16 +638,16 @@ struct search : search_plan_impl { int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, graph_degree, topk), - result_indices(0, raft::resource::get_cuda_stream(res)), - result_distances(0, raft::resource::get_cuda_stream(res)), - parent_node_list(0, raft::resource::get_cuda_stream(res)), - topk_hint(0, raft::resource::get_cuda_stream(res)), - topk_workspace(0, raft::resource::get_cuda_stream(res)), - terminate_flag(raft::resource::get_cuda_stream(res)), - input_keys_storage(0, raft::resource::get_cuda_stream(res)), - output_keys_storage(0, raft::resource::get_cuda_stream(res)), - input_values_storage(0, raft::resource::get_cuda_stream(res)), - output_values_storage(0, raft::resource::get_cuda_stream(res)) + result_indices(res), + result_distances(res), + parent_node_list(res), + topk_hint(res), + topk_workspace(res), + terminate_flag(res), + input_keys_storage(res), + output_keys_storage(res), + input_values_storage(res), + output_values_storage(res) { set_params(res); } @@ -662,7 +671,7 @@ struct search : search_plan_impl { itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type()); RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); topk_workspace.resize(topk_workspace_size, raft::resource::get_cuda_stream(res)); - + terminate_flag.resize(1, raft::resource::get_cuda_stream(res)); hashmap.resize(hashmap_size, raft::resource::get_cuda_stream(res)); } @@ -847,7 +856,7 @@ struct search : search_plan_impl { stream); // termination (2) - if (iter + 1 >= min_iterations && terminate_flag.value(stream)) { + if (iter + 1 >= min_iterations && get_value(terminate_flag.data(), stream)) { iter++; break; } diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index f23b96631..99254aa50 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -151,7 +151,7 @@ struct search_plan_impl : public search_plan_impl_base { lightweight_uvector hashmap; lightweight_uvector num_executed_iterations; // device or managed? lightweight_uvector dev_seed; - const dataset_descriptor_host& dataset_desc; + dataset_descriptor_host dataset_desc; search_plan_impl(raft::resources const& res, search_params params,