From 6762fe540b19ea58904786c4cacf57ca5f0d9695 Mon Sep 17 00:00:00 2001 From: Vivek Narang <123010842+narangvivek10@users.noreply.github.com> Date: Thu, 4 Jan 2024 14:17:00 -0500 Subject: [PATCH 01/10] Remove hardcoded limit in `print_results` function (#2080) The `print_results` function here is currently hardcoded to print only 2 results irrespective of the number of queries. A better way here could be to replace the hardcoded limit and allow printing results for the actual number of queries. Authors: - Vivek Narang (https://github.com/narangvivek10) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2080 --- cpp/include/raft/neighbors/detail/refine_device.cuh | 2 +- cpp/template/src/common.cuh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index 337318f791..5c9f1459e7 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/template/src/common.cuh b/cpp/template/src/common.cuh index 0b72d3bf3b..c2cb15bcf3 100644 --- a/cpp/template/src/common.cuh +++ b/cpp/template/src/common.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ void generate_dataset(raft::device_resources const& dev_resources, 1.0f); } -// Copy the results to host and print a few samples +// Copy the results to host and print them template void print_results(raft::device_resources const& dev_resources, raft::device_matrix_view neighbors, @@ -61,7 +61,7 @@ void print_results(raft::device_resources const& dev_resources, // We need to sync the stream before accessing the data. raft::resource::sync_stream(dev_resources, stream); - for (int query_id = 0; query_id < 2; query_id++) { + for (int query_id = 0; query_id < neighbors.extent(0); query_id++) { std::cout << "Query " << query_id << " neighbor indices: "; raft::print_host_vector("", &neighbors_host(query_id, 0), topk, std::cout); std::cout << "Query " << query_id << " neighbor distances: "; From 3b88d170e0402901a836960ff534faf45d1828fa Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Tue, 9 Jan 2024 15:21:01 +0900 Subject: [PATCH 02/10] Improve parallelism of refine host (#2059) This PR addresses https://github.com/rapidsai/raft/issues/2058 by changing the thread parallelism method. In the first half of the `refine` process, the distance calculation is performed on all candidate vectors, i.e., the number of queries * the original top-k vectors. Since the distance calculations for each vector can be performed independently, this part is thread-parallelized assuming that maximum parallelism is the number of queries * original top-k. This means that even if the number of queries is 1, this part can be executed in thread parallel. On the other hand, the second half of the `refine` process, the so-called top-k calculation, can be performed independently for each query, but it is difficult to thread parallelize the calculation for a given query, Therefore, this part is parallelized assuming the maximum parallelism is the number of queries, as in the current implementation. Authors: - Akira Naruse (https://github.com/anaruse) - Corey J. Nolet (https://github.com/cjnolet) - William Hicks (https://github.com/wphicks) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2059 --- .../raft/neighbors/detail/refine_host-inl.hpp | 55 ++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp index 14c53a4699..a54525f3e6 100644 --- a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp +++ b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -43,6 +44,58 @@ template %zu)", n_queries, orig_k, refined_k); auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads())); + + // If the number of queries is small, separate the distance calculation and + // the top-k calculation into separate loops, and apply finer-grained thread + // parallelism to the distance calculation loop. + if (n_queries < size_t(suggested_n_threads)) { + std::vector>> refined_pairs( + n_queries, std::vector>(orig_k)); + + // For efficiency, each thread should read a certain amount of array + // elements. The number of threads for distance computation is determined + // taking this into account. + auto n_elements = std::max(size_t(512), dim); + auto max_n_threads = raft::div_rounding_up_safe(n_queries * orig_k * dim, n_elements); + auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads); + + // The max number of threads for topk computation is the number of queries. + auto suggested_n_threads_for_topk = std::min(size_t(suggested_n_threads), n_queries); + + // Compute the refined distance using original dataset vectors +#pragma omp parallel for collapse(2) num_threads(suggested_n_threads_for_distance) + for (size_t i = 0; i < n_queries; i++) { + for (size_t j = 0; j < orig_k; j++) { + const DataT* query = queries.data_handle() + dim * i; + IdxT id = neighbor_candidates(i, j); + DistanceT distance = 0.0; + if (static_cast(id) >= n_rows) { + distance = std::numeric_limits::max(); + } else { + const DataT* row = dataset.data_handle() + dim * id; + for (size_t k = 0; k < dim; k++) { + distance += DC::template eval(query[k], row[k]); + } + } + refined_pairs[i][j] = std::make_tuple(distance, id); + } + } + + // Sort the query neighbors by their refined distances +#pragma omp parallel for num_threads(suggested_n_threads_for_topk) + for (size_t i = 0; i < n_queries; i++) { + std::sort(refined_pairs[i].begin(), refined_pairs[i].end()); + // Store first refined_k neighbors + for (size_t j = 0; j < refined_k; j++) { + indices(i, j) = std::get<1>(refined_pairs[i][j]); + if (distances.data_handle() != nullptr) { + distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j])); + } + } + } + return; + } + if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; } #pragma omp parallel num_threads(suggested_n_threads) From 1484a03fc3cef56d88f0491778b3afdda9a9cc8e Mon Sep 17 00:00:00 2001 From: Micka Date: Tue, 9 Jan 2024 19:30:47 +0100 Subject: [PATCH 03/10] Fix `max_queries` for CAGRA (#2081) Fix for #2072: CAGRA search is launching a thread per query in single-CTA. The maximum number of thread is 65535 so the `max_queries` auto selection should be bounded to this number. Authors: - Micka (https://github.com/lowener) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2081 --- .../neighbors/detail/cagra/cagra_search.cuh | 7 ++-- .../neighbors/detail/cagra/search_plan.cuh | 4 +-- notebooks/utils.py | 5 ++- .../pylibraft/neighbors/cagra/cagra.pyx | 34 +++++++++---------- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 23a966d41f..41a43c9bce 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,7 +118,10 @@ void search_main(raft::resources const& res, RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match"); const uint32_t topk = neighbors.extent(1); - if (params.max_queries == 0) { params.max_queries = queries.extent(0); } + cudaDeviceProp deviceProp = resource::get_device_properties(res); + if (params.max_queries == 0) { + params.max_queries = std::min(queries.extent(0), deviceProp.maxGridSize[1]); + } common::nvtx::range fun_scope( "cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim()); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index f57b776ccf..f2f51617f4 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -147,7 +147,7 @@ struct search_plan_impl : public search_plan_impl_base { // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size inline void calc_hashmap_params(raft::resources const& res) { - // for multipel CTA search + // for multiple CTA search uint32_t mc_num_cta_per_query = 0; uint32_t mc_search_width = 0; uint32_t mc_itopk_size = 0; diff --git a/notebooks/utils.py b/notebooks/utils.py index 1c2e44a6ae..311efc98bc 100644 --- a/notebooks/utils.py +++ b/notebooks/utils.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,7 +73,7 @@ def benchmark_runs(self): self.timings.append(t1 - t0) -def load_dataset(dataset_url, work_folder=None): +def load_dataset(dataset_url="http://ann-benchmarks.com/sift-128-euclidean.hdf5", work_folder=None): """Download dataset from url. It is expected that the dataset contains a hdf5 file in ann-benchmarks format Parameters @@ -82,7 +82,6 @@ def load_dataset(dataset_url, work_folder=None): work_folder name of the local folder to store the dataset """ - dataset_url = "http://ann-benchmarks.com/sift-128-euclidean.hdf5" dataset_filename = dataset_url.split("/")[-1] # We'll need to load store some data in this tutorial diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index c11d933b27..c19faa826d 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -91,7 +91,7 @@ from pylibraft.neighbors.common cimport _get_metric_string cdef class IndexParams: - """" + """ Parameters to build index for CAGRA nearest neighbor search Parameters @@ -104,13 +104,13 @@ cdef class IndexParams: graph_degree : int, default = 64 - build_algo: string denoting the graph building algorithm to use, + build_algo: string denoting the graph building algorithm to use, \ default = "ivf_pq" Valid values for algo: ["ivf_pq", "nn_descent"], where - - ivf_pq will use the IVF-PQ algorithm for building the knn graph - - nn_descent (experimental) will use the NN-Descent algorithm for - building the knn graph. It is expected to be generally - faster than ivf_pq. + - ivf_pq will use the IVF-PQ algorithm for building the knn graph + - nn_descent (experimental) will use the NN-Descent algorithm for + building the knn graph. It is expected to be generally + faster than ivf_pq. """ cdef c_cagra.index_params params @@ -501,10 +501,10 @@ cdef class SearchParams: Upper limit of search iterations. Auto select when 0. algo: string denoting the search algorithm to use, default = "auto" Valid values for algo: ["auto", "single_cta", "multi_cta"], where - - auto will automatically select the best value based on query size - - single_cta is better when query contains larger number of - vectors (e.g >10) - - multi_cta is better when query contains only a few vectors + - auto will automatically select the best value based on query size + - single_cta is better when query contains larger number of + vectors (e.g >10) + - multi_cta is better when query contains only a few vectors team_size: int, default = 0 Number of threads used to calculate a single distance. 4, 8, 16, or 32. @@ -516,13 +516,13 @@ cdef class SearchParams: thread_block_size: int, default = 0 Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. - hashmap_mode: string denoting the type of hash map to use. It's - usually better to allow the algorithm to select this value., - default = "auto" + hashmap_mode: string denoting the type of hash map to use. + It's usually better to allow the algorithm to select this value, + default = "auto". Valid values for hashmap_mode: ["auto", "small", "hash"], where - - auto will automatically select the best value based on algo - - small will use the small shared memory hash table with resetting. - - hash will use a single hash table in global memory. + - auto will automatically select the best value based on algo + - small will use the small shared memory hash table with resetting. + - hash will use a single hash table in global memory. hashmap_min_bitlen: int, default = 0 Upper limit of hashmap fill rate. More than 0.1, less than 0.9. hashmap_max_fill_rate: float, default = 0.5 From 26d310b6111febe2f40ec3622014021d8b6d660a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 Jan 2024 19:09:51 -0800 Subject: [PATCH 04/10] Add public enum for select-k algorithm selection (#2046) Add an enum that controls which select-k algorithm is used. This takes the enum that was in the raft_internal and exposes in the public api. This lets users pick which select algorithm they want to use directly Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2046 --- cpp/bench/prims/matrix/select_k.cu | 44 ++++--- .../raft/matrix/detail/select_k-ext.cuh | 9 +- .../raft/matrix/detail/select_k-inl.cuh | 90 ++++++++----- cpp/include/raft/matrix/select_k.cuh | 11 +- cpp/include/raft/matrix/select_k_types.hpp | 101 +++++++++++++++ .../raft_internal/matrix/select_k.cuh | 121 +----------------- .../select_k/generate_heuristic.ipynb | 20 +-- .../matrix/detail/select_k_double_int64_t.cu | 5 +- .../matrix/detail/select_k_double_uint32_t.cu | 5 +- cpp/src/matrix/detail/select_k_float_int32.cu | 5 +- .../matrix/detail/select_k_float_int64_t.cu | 5 +- .../matrix/detail/select_k_float_uint32_t.cu | 5 +- .../matrix/detail/select_k_half_int64_t.cu | 5 +- .../matrix/detail/select_k_half_uint32_t.cu | 5 +- cpp/test/matrix/select_k.cu | 52 ++++---- cpp/test/matrix/select_k.cuh | 84 ++++++------ cpp/test/matrix/select_large_k.cu | 8 +- 17 files changed, 306 insertions(+), 269 deletions(-) create mode 100644 cpp/include/raft/matrix/select_k_types.hpp diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 324d3aef84..6364ab17da 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ struct replace_with_mask { } }; -template +template struct selection : public fixture { explicit selection(const select::params& p) : fixture(p.use_memory_pool), @@ -110,16 +110,24 @@ struct selection : public fixture { int iter = 0; loop_on_state(state, [&iter, this]() { common::nvtx::range lap_scope("lap-", iter++); - select::select_k_impl(handle, - Algo, - in_dists_.data(), - params_.use_index_input ? in_ids_.data() : NULL, - params_.batch_size, - params_.len, - params_.k, - out_dists_.data(), - out_ids_.data(), - params_.select_min); + + std::optional> in_ids_view; + if (params_.use_index_input) { + in_ids_view = raft::make_device_matrix_view( + in_ids_.data(), params_.batch_size, params_.len); + } + + matrix::select_k(handle, + raft::make_device_matrix_view( + in_dists_.data(), params_.batch_size, params_.len), + in_ids_view, + raft::make_device_matrix_view( + out_dists_.data(), params_.batch_size, params_.k), + raft::make_device_matrix_view( + out_ids_.data(), params_.batch_size, params_.k), + params_.select_min, + false, + Algo); }); } catch (raft::exception& e) { state.SkipWithError(e.what()); @@ -213,13 +221,13 @@ const std::vector kInputs{ {1000, 10000, 256, true, false, false, true, 0.999}, }; -#define SELECTION_REGISTER(KeyT, IdxT, A) \ - namespace BENCHMARK_PRIVATE_NAME(selection) { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT +SELECTION_REGISTER(float, uint32_t, kAuto); // NOLINT SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT @@ -252,7 +260,7 @@ SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT // register other benchmarks #define SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, A, input) \ { \ - using SelectK = selection; \ + using SelectK = selection; \ std::stringstream name; \ name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \ << input.len << "/" << input.k << "/" << input.use_index_input << "/" \ diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 870f0c3240..dfdbfa2d07 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include // uint32_t #include // __half #include +#include #include // RAFT_EXPLICIT #include // rmm:cuda_stream_view #include // rmm::mr::device_memory_resource @@ -38,7 +39,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, bool select_min, rmm::mr::device_memory_resource* mr = nullptr, - bool sorted = false) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -54,7 +56,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 63aeff2f1c..0a6f292e68 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -1,5 +1,6 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +24,7 @@ #include #include #include +#include #include #include @@ -31,10 +33,6 @@ namespace raft::matrix::detail { -// this is a subset of algorithms, chosen by running the algorithm_selection -// notebook in cpp/scripts/heuristics/select_k -enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bitsExtraPass }; - /** * Predict the fastest select_k algorithm based on the number of rows/cols/k * @@ -47,31 +45,31 @@ enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bit * 'generate_heuristic' notebook there will replace the body of this function * with the latest learned heuristic */ -inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k) +inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k) { if (k > 256) { if (cols > 16862) { if (rows > 1020) { - return Algo::kRadix11bitsExtraPass; + return SelectAlgo::kRadix11bitsExtraPass; } else { - return Algo::kRadix11bits; + return SelectAlgo::kRadix11bits; } } else { - return Algo::kRadix11bitsExtraPass; + return SelectAlgo::kRadix11bitsExtraPass; } } else { if (k > 2) { if (cols > 22061) { - return Algo::kWarpDistributedShm; + return SelectAlgo::kWarpDistributedShm; } else { if (rows > 198) { - return Algo::kWarpDistributedShm; + return SelectAlgo::kWarpDistributedShm; } else { - return Algo::kWarpImmediate; + return SelectAlgo::kWarpImmediate; } } } else { - return Algo::kWarpImmediate; + return SelectAlgo::kWarpImmediate; } } } @@ -239,31 +237,48 @@ void select_k(raft::resources const& handle, IdxT* out_idx, bool select_min, rmm::mr::device_memory_resource* mr = nullptr, - bool sorted = false) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - auto stream = raft::resource::get_cuda_stream(handle); - auto algo = choose_select_k_algorithm(batch_size, len, k); + if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); } + + auto stream = raft::resource::get_cuda_stream(handle); switch (algo) { - case Algo::kRadix11bits: - case Algo::kRadix11bitsExtraPass: { - bool fused_last_filter = algo == Algo::kRadix11bits; - detail::select::radix::select_k(in_val, - in_idx, - batch_size, - len, - k, - out_val, - out_idx, - select_min, - fused_last_filter, - stream, - mr); + case SelectAlgo::kRadix8bits: + case SelectAlgo::kRadix11bits: + case SelectAlgo::kRadix11bitsExtraPass: { + if (algo == SelectAlgo::kRadix8bits) { + detail::select::radix::select_k(in_val, + in_idx, + batch_size, + len, + k, + out_val, + out_idx, + select_min, + true, // fused_last_filter + stream, + mr); + } else { + bool fused_last_filter = algo == SelectAlgo::kRadix11bits; + detail::select::radix::select_k(in_val, + in_idx, + batch_size, + len, + k, + out_val, + out_idx, + select_min, + fused_last_filter, + stream, + mr); + } if (sorted) { auto offsets = raft::make_device_vector(handle, (IdxT)(batch_size + 1)); @@ -283,14 +298,25 @@ void select_k(raft::resources const& handle, } return; } - case Algo::kWarpDistributedShm: + case SelectAlgo::kWarpDistributed: + return detail::select::warpsort:: + select_k_impl( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + case SelectAlgo::kWarpDistributedShm: return detail::select::warpsort:: select_k_impl( in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); - case Algo::kWarpImmediate: + case SelectAlgo::kWarpAuto: + return detail::select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + case SelectAlgo::kWarpImmediate: return detail::select::warpsort:: select_k_impl( in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + case SelectAlgo::kWarpFiltered: + return detail::select::warpsort:: + select_k_impl( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); default: RAFT_FAIL("K-selection Algorithm not supported."); } } diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 37a36cbf6b..92d7db006d 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -76,6 +77,8 @@ namespace raft::matrix { * whether to select k smallest (true) or largest (false) keys. * @param[in] sorted * whether to make sure selected pairs are sorted by value + * @param[in] algo + * the selection algorithm to use */ template void select_k(raft::resources const& handle, @@ -84,7 +87,8 @@ void select_k(raft::resources const& handle, raft::device_matrix_view out_val, raft::device_matrix_view out_idx, bool select_min, - bool sorted = false) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) { RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), "output k must fit the int type."); @@ -109,7 +113,8 @@ void select_k(raft::resources const& handle, out_idx.data_handle(), select_min, nullptr, - sorted); + sorted, + algo); } /** @} */ // end of group select_k diff --git a/cpp/include/raft/matrix/select_k_types.hpp b/cpp/include/raft/matrix/select_k_types.hpp new file mode 100644 index 0000000000..f001f91770 --- /dev/null +++ b/cpp/include/raft/matrix/select_k_types.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft::matrix { + +/** + * @defgroup select_k Batched-select k smallest or largest key/values + * @{ + */ + +/** + * @brief Algorithm used to select the k largest neighbors + * + * Details about how the the select-k algorithms in RAFT work can be found in the + * paper "Parallel Top-K Algorithms on GPU: A Comprehensive Study and New Methods" + * https://doi.org/10.1145/3581784.3607062. The kRadix* variants below correspond + * to the 'Air Top-k' algorithm described in the paper, and the kWarp* variants + * correspond to the 'GridSelect' algorithm. + */ +enum class SelectAlgo : uint8_t { + /** Automatically pick the select-k algorithm based off the input dimensions and k value */ + kAuto = 0, + /** Radix Select using 8 bits per pass */ + kRadix8bits = 1, + /** Radix Select using 11 bits per pass, fusing the last filter step */ + kRadix11bits = 2, + /** Radix Select using 11 bits per pass, without fusing the last filter step */ + kRadix11bitsExtraPass = 3, + /** + * Automatically switches between the kWarpImmediate and kWarpFiltered algorithms + * based off of input size + */ + kWarpAuto = 4, + /** + * This version of warp_sort adds every input element into the intermediate sorting + * buffer, and thus does the sorting step every `Capacity` input elements. + * + * This implementation is preferred for very small len values. + */ + kWarpImmediate = 5, + /** + * This version of warp_sort compares each input element against the current + * estimate of k-th value before adding it to the intermediate sorting buffer. + * This makes the algorithm do less sorting steps for long input sequences + * at the cost of extra checks on each step. + * + * This implementation is preferred for large len values. + */ + kWarpFiltered = 6, + /** + * This version of warp_sort compares each input element against the current + * estimate of k-th value before adding it to the intermediate sorting buffer. + * In contrast to `warp_sort_filtered`, it keeps one distributed buffer for + * all threads in a warp (independently of the subwarp size), which makes its flushing less often. + */ + kWarpDistributed = 7, + /** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ + kWarpDistributedShm = 8, +}; + +inline auto operator<<(std::ostream& os, const SelectAlgo& algo) -> std::ostream& +{ + auto underlying_value = static_cast::type>(algo); + + switch (algo) { + case SelectAlgo::kAuto: return os << "kAuto=" << underlying_value; + case SelectAlgo::kRadix8bits: return os << "kRadix8bits=" << underlying_value; + case SelectAlgo::kRadix11bits: return os << "kRadix11bits=" << underlying_value; + case SelectAlgo::kRadix11bitsExtraPass: + return os << "kRadix11bitsExtraPass=" << underlying_value; + case SelectAlgo::kWarpAuto: return os << "kWarpAuto=" << underlying_value; + case SelectAlgo::kWarpImmediate: return os << "kWarpImmediate=" << underlying_value; + case SelectAlgo::kWarpFiltered: return os << "kWarpFiltered=" << underlying_value; + case SelectAlgo::kWarpDistributed: return os << "kWarpDistributed=" << underlying_value; + case SelectAlgo::kWarpDistributedShm: return os << "kWarpDistributedShm=" << underlying_value; + default: throw std::invalid_argument("invalid value for SelectAlgo"); + } +} + +/** @} */ // end of group select_k + +} // namespace raft::matrix diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index 93095ff82e..b899978f1c 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,123 +47,4 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << "}"; return os; } - -enum class Algo { - kPublicApi, - kRadix8bits, - kRadix11bits, - kRadix11bitsExtraPass, - kWarpAuto, - kWarpImmediate, - kWarpFiltered, - kWarpDistributed, - kWarpDistributedShm, -}; - -inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& -{ - switch (algo) { - case Algo::kPublicApi: return os << "kPublicApi"; - case Algo::kRadix8bits: return os << "kRadix8bits"; - case Algo::kRadix11bits: return os << "kRadix11bits"; - case Algo::kRadix11bitsExtraPass: return os << "kRadix11bitsExtraPass"; - case Algo::kWarpAuto: return os << "kWarpAuto"; - case Algo::kWarpImmediate: return os << "kWarpImmediate"; - case Algo::kWarpFiltered: return os << "kWarpFiltered"; - case Algo::kWarpDistributed: return os << "kWarpDistributed"; - case Algo::kWarpDistributedShm: return os << "kWarpDistributedShm"; - default: return os << "unknown enum value"; - } -} - -template -void select_k_impl(const resources& handle, - const Algo& algo, - const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min) -{ - auto stream = resource::get_cuda_stream(handle); - switch (algo) { - case Algo::kPublicApi: { - auto in_extent = make_extents(batch_size, len); - auto out_extent = make_extents(batch_size, k); - auto in_span = make_mdspan(in, in_extent); - auto in_idx_span = - make_mdspan(in_idx, in_extent); - auto out_span = make_mdspan(out, out_extent); - auto out_idx_span = make_mdspan(out_idx, out_extent); - if (in_idx == nullptr) { - // NB: std::nullopt prevents automatic inference of the template parameters. - return matrix::select_k( - handle, in_span, std::nullopt, out_span, out_idx_span, select_min, true); - } else { - return matrix::select_k(handle, - in_span, - std::make_optional(in_idx_span), - out_span, - out_idx_span, - select_min, - true); - } - } - case Algo::kRadix8bits: - return detail::select::radix::select_k(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - true, // fused_last_filter - stream); - case Algo::kRadix11bits: - return detail::select::radix::select_k(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - true, // fused_last_filter - stream); - case Algo::kRadix11bitsExtraPass: - return detail::select::radix::select_k(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - false, // fused_last_filter - stream); - case Algo::kWarpAuto: - return detail::select::warpsort::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpImmediate: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpFiltered: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpDistributed: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpDistributedShm: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - } -} } // namespace raft::matrix::select diff --git a/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb b/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb index 50bc12556a..f764d2f88f 100644 --- a/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb +++ b/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb @@ -405,31 +405,31 @@ "name": "stdout", "output_type": "stream", "text": [ - "inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)\n", + "inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k)\n", "{\n", " if (k > 256) {\n", " if (cols > 16862) {\n", " if (rows > 1020) {\n", - " return Algo::kRadix11bitsExtraPass;\n", + " return SelectAlgo::kRadix11bitsExtraPass;\n", " } else {\n", - " return Algo::kRadix11bits;\n", + " return SelectAlgo::kRadix11bits;\n", " }\n", " } else {\n", - " return Algo::kRadix11bitsExtraPass;\n", + " return SelectAlgo::kRadix11bitsExtraPass;\n", " }\n", " } else {\n", " if (k > 2) {\n", " if (cols > 22061) {\n", - " return Algo::kWarpDistributedShm;\n", + " return SelectAlgo::kWarpDistributedShm;\n", " } else {\n", " if (rows > 198) {\n", - " return Algo::kWarpDistributedShm;\n", + " return SelectAlgo::kWarpDistributedShm;\n", " } else {\n", - " return Algo::kWarpImmediate;\n", + " return SelectAlgo::kWarpImmediate;\n", " }\n", " }\n", " } else {\n", - " return Algo::kWarpImmediate;\n", + " return SelectAlgo::kWarpImmediate;\n", " }\n", " }\n", "}\n" @@ -466,7 +466,7 @@ " if _is_leaf_node(nodeid):\n", " # we're a leaf node, just output the label of the most frequent algorithm\n", " class_name = _get_label(nodeid)\n", - " code.append(\" \" * indent + f\"return Algo::{class_name};\")\n", + " code.append(\" \" * indent + f\"return SelectAlgo::{class_name};\")\n", " else: \n", " feature = feature_names[tree.feature[nodeid]]\n", " threshold = int(np.floor(tree.threshold[nodeid]))\n", @@ -476,7 +476,7 @@ " _convert_node(tree.children_left[nodeid], indent + 2)\n", " code.append(\" \" * indent + \"}\")\n", " \n", - " code.append(\"inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)\")\n", + " code.append(\"inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k)\")\n", " code.append(\"{\")\n", " _convert_node(0, indent=2)\n", " code.append(\"}\")\n", diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index c75a5b5261..87e5d49d29 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 171c8a1ae7..67dce0e166 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index a21444dc0c..4be7c54839 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 9542874ec0..6337994e86 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index fbf311d9bd..ad26547812 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index fdbfd66c46..e3c29a2033 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index 48a3e91f9d..3e3a738915 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, uint32_t); diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index ce4e3e867e..f3eb32b2e1 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -93,56 +93,56 @@ auto inputs_random_many_infs = select::params{1000, 10000, 256, true, false, false, true, 0.999}); using ReferencedRandomFloatInt = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatInt, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomFloatInt, testing::Combine(inputs_random_longlist, - testing::Values(select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass, - select::Algo::kWarpImmediate, - select::Algo::kWarpFiltered, - select::Algo::kWarpDistributed, - select::Algo::kWarpDistributedShm))); + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed, + SelectAlgo::kWarpDistributedShm))); using ReferencedRandomDoubleSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleSizeT, testing::Combine(inputs_random_longlist, - testing::Values(select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass, - select::Algo::kWarpImmediate, - select::Algo::kWarpFiltered, - select::Algo::kWarpDistributed, - select::Algo::kWarpDistributedShm))); + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed, + SelectAlgo::kWarpDistributedShm))); using ReferencedRandomDoubleInt = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleInt, testing::Combine(inputs_random_largesize, - testing::Values(select::Algo::kWarpAuto, - select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass))); + testing::Values(SelectAlgo::kWarpAuto, + SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); using ReferencedRandomFloatIntkWarpsortAsGT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatIntkWarpsortAsGT, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomFloatIntkWarpsortAsGT, testing::Combine(inputs_random_many_infs, - testing::Values(select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass))); + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); } // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh index fdea982d6c..412a9ae5a2 100644 --- a/cpp/test/matrix/select_k.cuh +++ b/cpp/test/matrix/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,8 +49,8 @@ auto gen_simple_ids(uint32_t batch_size, uint32_t len) -> std::vector template struct io_simple { public: - bool not_supported = false; - std::optional algo = std::nullopt; + bool not_supported = false; + std::optional algo = std::nullopt; io_simple(const select::params& spec, const std::vector& in_dists, @@ -80,10 +80,10 @@ template struct io_computed { public: bool not_supported = false; - select::Algo algo; + SelectAlgo algo; io_computed(const select::params& spec, - const select::Algo& algo, + const SelectAlgo& algo, const std::vector& in_dists, const std::optional>& in_ids = std::nullopt) : algo(algo), @@ -94,11 +94,11 @@ struct io_computed { { // check if the size is supported by the algorithm switch (algo) { - case select::Algo::kWarpAuto: - case select::Algo::kWarpImmediate: - case select::Algo::kWarpFiltered: - case select::Algo::kWarpDistributed: - case select::Algo::kWarpDistributedShm: { + case SelectAlgo::kWarpAuto: + case SelectAlgo::kWarpImmediate: + case SelectAlgo::kWarpFiltered: + case SelectAlgo::kWarpDistributed: + case SelectAlgo::kWarpDistributedShm: { if (spec.k > raft::matrix::detail::select::warpsort::kMaxCapacity) { not_supported = true; return; @@ -118,16 +118,22 @@ struct io_computed { update_device(in_dists_d.data(), in_dists_.data(), in_dists_.size(), stream); update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream); - select::select_k_impl(handle, - algo, - in_dists_d.data(), - spec.use_index_input ? in_ids_d.data() : nullptr, - spec.batch_size, - spec.len, - spec.k, - out_dists_d.data(), - out_ids_d.data(), - spec.select_min); + std::optional> in_ids_view; + if (spec.use_index_input) { + in_ids_view = raft::make_device_matrix_view( + in_ids_d.data(), spec.batch_size, spec.len); + } + + matrix::select_k( + handle, + raft::make_device_matrix_view( + in_dists_d.data(), spec.batch_size, spec.len), + in_ids_view, + raft::make_device_matrix_view(out_dists_d.data(), spec.batch_size, spec.k), + raft::make_device_matrix_view(out_ids_d.data(), spec.batch_size, spec.k), + spec.select_min, + false, + algo); update_host(out_dists_.data(), out_dists_d.data(), out_dists_.size(), stream); update_host(out_ids_.data(), out_ids_d.data(), out_ids_.size(), stream); @@ -194,13 +200,13 @@ struct io_computed { }; template -using Params = std::tuple; +using Params = std::tuple; template typename ParamsReader> struct SelectK // NOLINT : public testing::TestWithParam::params_t> { const select::params spec; - const select::Algo algo; + const SelectAlgo algo; typename ParamsReader::io_t ref; io_computed res; @@ -255,18 +261,18 @@ struct SelectK // NOLINT ASSERT_TRUE(hostVecMatch(ref.get_out_ids(), res.get_out_ids(), compare_ids)); } - auto forgive_algo(const std::optional& algo, IdxT ix) const -> bool + auto forgive_algo(const std::optional& algo, IdxT ix) const -> bool { if (!algo.has_value()) { return false; } switch (algo.value()) { // not sure which algo this is. - case select::Algo::kPublicApi: return true; + case SelectAlgo::kAuto: return true; // warp-sort-based algos currently return zero index for inf distances. - case select::Algo::kWarpAuto: - case select::Algo::kWarpImmediate: - case select::Algo::kWarpFiltered: - case select::Algo::kWarpDistributed: - case select::Algo::kWarpDistributedShm: return ix == 0; + case SelectAlgo::kWarpAuto: + case SelectAlgo::kWarpImmediate: + case SelectAlgo::kWarpFiltered: + case SelectAlgo::kWarpDistributed: + case SelectAlgo::kWarpDistributedShm: return ix == 0; // Do not forgive by default default: return false; } @@ -281,7 +287,7 @@ struct params_simple { std::optional>, std::vector, std::vector>; - using params_t = std::tuple; + using params_t = std::tuple; static auto read(params_t ps) -> Params { @@ -387,13 +393,13 @@ INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, SimpleFloatInt, testing::Combine(inputs_simple_f, - testing::Values(select::Algo::kPublicApi, - select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass, - select::Algo::kWarpImmediate, - select::Algo::kWarpFiltered, - select::Algo::kWarpDistributed))); + testing::Values(SelectAlgo::kAuto, + SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed))); template struct replace_with_mask { @@ -401,12 +407,12 @@ struct replace_with_mask { constexpr auto inline operator()(KeyT x, uint8_t mask) -> KeyT { return mask ? replacement : x; } }; -template +template struct with_ref { template struct params_random { using io_t = io_computed; - using params_t = std::tuple; + using params_t = std::tuple; static auto read(params_t ps) -> Params { diff --git a/cpp/test/matrix/select_large_k.cu b/cpp/test/matrix/select_large_k.cu index 2772e84eb3..baa07f5e87 100644 --- a/cpp/test/matrix/select_large_k.cu +++ b/cpp/test/matrix/select_large_k.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,12 +25,12 @@ auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, tr select::params{100, 100000, 1237, true}); using ReferencedRandomFloatSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, testing::Combine(inputs_random_largek, - testing::Values(select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass))); + testing::Values(SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); } // namespace raft::matrix From 93a504e00229c89c5b61814bdc24de09afe26534 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 11 Jan 2024 11:21:21 -0600 Subject: [PATCH 05/10] refactor CUDA versions in dependencies.yaml (#2086) Contributes to https://github.com/rapidsai/build-planning/issues/7. Proposes splitting the `cuda-version` dependency in `dependencies.yaml` out to its own thing, separate from the bits of the CUDA Toolkit this project needs. ### Benefits of this change * prevents accidental inclusion of multiple `cuda-version` version in environments * reduces update effort (via enabling more use of globs like `"12.*"`) * improves the chance that errors like "`conda` recipe is missing a dependency" are caught in CI Authors: - James Lamb (https://github.com/jameslamb) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Bradley Dice (https://github.com/bdice) - Ray Douglass (https://github.com/raydouglass) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2086 --- .pre-commit-config.yaml | 2 +- dependencies.yaml | 48 +++++++++++++++++++++++++++++------------ 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80ad3614bc..c2e6d9fce4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -101,7 +101,7 @@ repos: args: ["--toml", "pyproject.toml"] exclude: (?x)^(^CHANGELOG.md$) - repo: https://github.com/rapidsai/dependency-file-generator - rev: v1.5.1 + rev: v1.8.0 hooks: - id: rapids-dependency-file-generator args: ["--clean"] diff --git a/dependencies.yaml b/dependencies.yaml index f049c75511..0e4d6d4693 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -8,7 +8,8 @@ files: includes: - build - build_pylibraft - - cudatoolkit + - cuda + - cuda_version - develop - checks - build_wheels @@ -26,19 +27,20 @@ files: arch: [x86_64, aarch64] includes: - build + - cuda + - cuda_version - develop - - cudatoolkit - nn_bench - nn_bench_python test_cpp: output: none includes: - - cudatoolkit + - cuda_version - test_libraft test_python: output: none includes: - - cudatoolkit + - cuda_version - py_version - test_python_common - test_pylibraft @@ -51,11 +53,11 @@ files: docs: output: none includes: - - test_pylibraft + - cuda_version - cupy - - cudatoolkit - docs - py_version + - test_pylibraft py_build_pylibraft: output: pyproject pyproject_dir: python/pylibraft @@ -155,8 +157,8 @@ dependencies: - sysroot_linux-aarch64==2.17 - output_types: conda matrices: - - matrix: {cuda: "12.0"} - packages: [cuda-version=12.0, cuda-nvcc] + - matrix: {cuda: "12.*"} + packages: [cuda-nvcc] - matrix: {cuda: "11.8", arch: x86_64} packages: [nvcc_linux-64=11.8] - matrix: {cuda: "11.8", arch: aarch64} @@ -239,15 +241,37 @@ dependencies: - pandas - pyyaml - pandas - - cudatoolkit: + cuda_version: specific: - output_types: conda matrices: + - matrix: + cuda: "11.2" + packages: + - cuda-version=11.2 + - matrix: + cuda: "11.4" + packages: + - cuda-version=11.4 + - matrix: + cuda: "11.5" + packages: + - cuda-version=11.5 + - matrix: + cuda: "11.8" + packages: + - cuda-version=11.8 - matrix: cuda: "12.0" packages: - cuda-version=12.0 + cuda: + specific: + - output_types: conda + matrices: + - matrix: + cuda: "12.*" + packages: - cuda-nvtx-dev - cuda-cudart-dev - cuda-profiler-api @@ -258,7 +282,6 @@ dependencies: - matrix: cuda: "11.8" packages: - - cuda-version=11.8 - cudatoolkit - cuda-nvtx=11.8 - cuda-profiler-api=11.8.86 @@ -273,7 +296,6 @@ dependencies: - matrix: cuda: "11.5" packages: - - cuda-version=11.5 - cudatoolkit - cuda-nvtx=11.5 - cuda-profiler-api>=11.4.240,<=11.8.86 # use any `11.x` version since pkg is missing several CUDA/arch packages @@ -288,7 +310,6 @@ dependencies: - matrix: cuda: "11.4" packages: - - cuda-version=11.4 - cudatoolkit - &cudanvtx114 cuda-nvtx=11.4 - cuda-profiler-api>=11.4.240,<=11.8.86 # use any `11.x` version since pkg is missing several CUDA/arch packages @@ -303,7 +324,6 @@ dependencies: - matrix: cuda: "11.2" packages: - - cuda-version=11.2 - cudatoolkit - *cudanvtx114 - cuda-profiler-api>=11.4.240,<=11.8.86 # use any `11.x` version since pkg is missing several CUDA/arch packages From 856288a2b4c4d9a74b5cbf4d0f5f2a64978072ba Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 11 Jan 2024 20:26:52 +0100 Subject: [PATCH 06/10] Add IVF-PQ example into the template project (#2091) A simple example with search and refinement. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2091 --- cpp/template/CMakeLists.txt | 5 +- cpp/template/src/common.cuh | 1 + cpp/template/src/ivf_pq_example.cu | 116 +++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 cpp/template/src/ivf_pq_example.cu diff --git a/cpp/template/CMakeLists.txt b/cpp/template/CMakeLists.txt index 538eac07ef..40a3795ed1 100644 --- a/cpp/template/CMakeLists.txt +++ b/cpp/template/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -39,3 +39,6 @@ target_link_libraries(CAGRA_EXAMPLE PRIVATE raft::raft raft::compiled) add_executable(IVF_FLAT_EXAMPLE src/ivf_flat_example.cu) target_link_libraries(IVF_FLAT_EXAMPLE PRIVATE raft::raft raft::compiled) + +add_executable(IVF_PQ_EXAMPLE src/ivf_pq_example.cu) +target_link_libraries(IVF_PQ_EXAMPLE PRIVATE raft::raft raft::compiled) diff --git a/cpp/template/src/common.cuh b/cpp/template/src/common.cuh index c2cb15bcf3..193abc747d 100644 --- a/cpp/template/src/common.cuh +++ b/cpp/template/src/common.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/template/src/ivf_pq_example.cu b/cpp/template/src/ivf_pq_example.cu new file mode 100644 index 0000000000..4bc0ba4348 --- /dev/null +++ b/cpp/template/src/ivf_pq_example.cu @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.cuh" + +#include +#include +#include +#include + +#include +#include + +#include + +void ivf_pq_build_search(raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) +{ + using namespace raft::neighbors; // NOLINT + + ivf_pq::index_params index_params; + index_params.n_lists = 1024; + index_params.kmeans_trainset_fraction = 0.1; + index_params.metric = raft::distance::DistanceType::L2Expanded; + index_params.pq_bits = 8; + index_params.pq_dim = 2; + + std::cout << "Building IVF-PQ index" << std::endl; + auto index = ivf_pq::build(dev_resources, index_params, dataset); + + std::cout << "Number of clusters " << index.n_lists() << ", number of vectors added to index " + << index.size() << std::endl; + + // Set search parameters. + ivf_pq::search_params search_params; + search_params.n_probes = 50; + // Set the internal search precision to 16-bit floats; + // usually, this improves the performance at a slight cost to the recall. + search_params.internal_distance_dtype = CUDA_R_16F; + search_params.lut_dtype = CUDA_R_16F; + + // Create output arrays. + int64_t topk = 10; + int64_t n_queries = queries.extent(0); + auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); + auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); + + // Search K nearest neighbors for each of the queries. + ivf_pq::search( + dev_resources, search_params, index, queries, neighbors.view(), distances.view()); + + // Re-ranking operation: refine the initial search results by computing exact distances + int64_t topk_refined = 7; + auto neighbors_refined = + raft::make_device_matrix(dev_resources, n_queries, topk_refined); + auto distances_refined = raft::make_device_matrix(dev_resources, n_queries, topk_refined); + + // Note, refinement requires the original dataset and the queries. + // Don't forget to specify the same distance metric as used by the index. + raft::neighbors::refine(dev_resources, + dataset, + queries, + raft::make_const_mdspan(neighbors.view()), + neighbors_refined.view(), + distances_refined.view(), + index.metric()); + + // Show both the original and the refined results + std::cout << std::endl << "Original results:" << std::endl; + print_results(dev_resources, neighbors.view(), distances.view()); + std::cout << std::endl << "Refined results:" << std::endl; + print_results(dev_resources, neighbors_refined.view(), distances_refined.view()); +} + +int main() +{ + raft::device_resources dev_resources; + + // Set pool memory resource with 1 GiB initial pool size. All allocations use the same pool. + rmm::mr::pool_memory_resource pool_mr( + rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(&pool_mr); + + // Alternatively, one could define a pool allocator for temporary arrays (used within RAFT + // algorithms). In that case only the internal arrays would use the pool, any other allocation + // uses the default RMM memory resource. Here is how to change the workspace memory resource to + // a pool with 2 GiB upper limit. + // raft::resource::set_workspace_to_pool_resource(dev_resources, 2 * 1024 * 1024 * 1024ull); + + // Create input arrays. + int64_t n_samples = 10000; + int64_t n_dim = 3; + int64_t n_queries = 10; + auto dataset = raft::make_device_matrix(dev_resources, n_samples, n_dim); + auto queries = raft::make_device_matrix(dev_resources, n_queries, n_dim); + generate_dataset(dev_resources, dataset.view(), queries.view()); + + // Simple build and search example. + ivf_pq_build_search(dev_resources, + raft::make_const_mdspan(dataset.view()), + raft::make_const_mdspan(queries.view())); +} From 7d5bb3c90d2f3338444f68bac24336bdbb9cc465 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 12 Jan 2024 12:35:10 -0500 Subject: [PATCH 07/10] Properly taking ownership of nccl subcomm (and destroying it) (#2094) Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/raft/pull/2094 --- cpp/include/raft/comms/detail/std_comms.hpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index de2a7d3415..323e408cab 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,6 +81,7 @@ class std_comms : public comms_iface { num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(subcomms_ucp), + own_nccl_comm_(false), ucp_worker_(ucp_worker), ucp_eps_(eps), next_request_id_(0) @@ -95,13 +96,18 @@ class std_comms : public comms_iface { * @param rank rank of the current worker * @param stream stream for ordering collective operations */ - std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream) + std_comms(const ncclComm_t nccl_comm, + int num_ranks, + int rank, + rmm::cuda_stream_view stream, + bool own_nccl_comm = false) : nccl_comm_(nccl_comm), stream_(stream), status_(stream), num_ranks_(num_ranks), rank_(rank), - subcomms_ucp_(false) + subcomms_ucp_(false), + own_nccl_comm_(own_nccl_comm) { initialize(); }; @@ -116,6 +122,11 @@ class std_comms : public comms_iface { { requests_in_flight_.clear(); free_requests_.clear(); + + if (own_nccl_comm_) { + RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_)); + nccl_comm_ = nullptr; + } } int get_size() const { return num_ranks_; } @@ -172,7 +183,7 @@ class std_comms : public comms_iface { RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key)); - return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_)); + return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_, true)); } void barrier() const @@ -515,6 +526,7 @@ class std_comms : public comms_iface { int rank_; bool subcomms_ucp_; + bool own_nccl_comm_; comms_ucp_handler ucp_handler_; ucp_worker_h ucp_worker_; From 0c75bb1985c1fe9177cf18e9c42d071d7115af7b Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 12 Jan 2024 12:54:18 -0500 Subject: [PATCH 08/10] Remove usages of rapids-env-update (#2095) Reference: https://github.com/rapidsai/ops/issues/2766 Replace rapids-env-update with rapids-configure-conda-channels, rapids-configure-sccache, and rapids-date-string. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/raft/pull/2095 --- ci/build_cpp.sh | 8 ++++++-- ci/build_python.sh | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh index 178ce723a5..2778c2a7d7 100755 --- a/ci/build_cpp.sh +++ b/ci/build_cpp.sh @@ -1,9 +1,13 @@ #!/bin/bash -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. set -euo pipefail -source rapids-env-update +rapids-configure-conda-channels + +source rapids-configure-sccache + +source rapids-date-string export CMAKE_GENERATOR=Ninja diff --git a/ci/build_python.sh b/ci/build_python.sh index 3e67edd5db..a8b76269ae 100755 --- a/ci/build_python.sh +++ b/ci/build_python.sh @@ -1,9 +1,13 @@ #!/bin/bash -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. set -euo pipefail -source rapids-env-update +rapids-configure-conda-channels + +source rapids-configure-sccache + +source rapids-date-string export CMAKE_GENERATOR=Ninja From 1d9adab59d6eb273b5244b232813d8f7c86d74a9 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 16 Jan 2024 20:10:18 +0100 Subject: [PATCH 09/10] Add AIR-Top-k reference (#2031) Add reference to AIR top-k paper. Authors: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2031 --- README.md | 35 +++++++++++++++++++ .../raft/matrix/detail/select_radix.cuh | 8 ++++- .../raft/neighbors/nn_descent_types.hpp | 6 ++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9ab1168bdb..26ddc30ed4 100755 --- a/README.md +++ b/README.md @@ -354,3 +354,38 @@ If citing CAGRA, please consider the following bibtex: primaryClass={cs.DS} } ``` + +If citing the k-selection routines, please consider the following bibtex: + +```bibtex +@proceedings{10.1145/3581784, + title = {SC '23: Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis}, + year = {2023}, + isbn = {9798400701092}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + abstract = {Started in 1988, the SC Conference has become the annual nexus for researchers and practitioners from academia, industry and government to share information and foster collaborations to advance the state of the art in High Performance Computing (HPC), Networking, Storage, and Analysis.}, + location = {, Denver, CO, USA, } +} +``` + +If citing the nearest neighbors descent API, please consider the following bibtex: +```bibtex +@inproceedings{10.1145/3459637.3482344, + author = {Wang, Hui and Zhao, Wan-Lei and Zeng, Xiangxiang and Yang, Jianye}, + title = {Fast K-NN Graph Construction by GPU Based NN-Descent}, + year = {2021}, + isbn = {9781450384469}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + url = {https://doi.org/10.1145/3459637.3482344}, + doi = {10.1145/3459637.3482344}, + abstract = {NN-Descent is a classic k-NN graph construction approach. It is still widely employed in machine learning, computer vision, and information retrieval tasks due to its efficiency and genericness. However, the current design only works well on CPU. In this paper, NN-Descent has been redesigned to adapt to the GPU architecture. A new graph update strategy called selective update is proposed. It reduces the data exchange between GPU cores and GPU global memory significantly, which is the processing bottleneck under GPU computation architecture. This redesign leads to full exploitation of the parallelism of the GPU hardware. In the meantime, the genericness, as well as the simplicity of NN-Descent, are well-preserved. Moreover, a procedure that allows to k-NN graph to be merged efficiently on GPU is proposed. It makes the construction of high-quality k-NN graphs for out-of-GPU-memory datasets tractable. Our approach is 100-250\texttimes{} faster than the single-thread NN-Descent and is 2.5-5\texttimes{} faster than the existing GPU-based approaches as we tested on million as well as billion scale datasets.}, + booktitle = {Proceedings of the 30th ACM International Conference on Information \& Knowledge Management}, + pages = {1929–1938}, + numpages = {10}, + keywords = {high-dimensional, nn-descent, gpu, k-nearest neighbor graph}, + location = {Virtual Event, Queensland, Australia}, + series = {CIKM '21} +} +``` \ No newline at end of file diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 4245be42d6..b6ed03b93d 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1141,6 +1141,12 @@ void radix_topk_one_block(const T* in, * * Note, the output is NOT sorted within the groups of `k` selected elements. * + * Reference: + * Jingrong Zhang, Akira Naruse, Xipeng Li, and Yong Wang. 2023. Parallel Top-K Algorithms on GPU: + * A Comprehensive Study and New Methods. In The International Conference for High Performance + * Computing, Networking, Storage and Analysis (SC ’23), November 12–17, 2023, Denver, CO, USA. + * ACM, New York, NY, USA. https://doi.org/10.1145/3581784.3607062 + * * @tparam T * the type of the keys (what is being compared). * @tparam IdxT diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 7d4f3d615b..fd1df2965e 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -58,6 +58,12 @@ struct index_params : ann::index_params { * The index contains an all-neighbors graph of the input dataset * stored in host memory of dimensions (n_rows, n_cols) * + * Reference: + * Hui Wang, Wan-Lei Zhao, Xiangxiang Zeng, and Jianye Yang. 2021. + * Fast k-NN Graph Construction by GPU based NN-Descent. In Proceedings of the 30th ACM + * International Conference on Information & Knowledge Management (CIKM '21). Association for + * Computing Machinery, New York, NY, USA, 1929–1938. https://doi.org/10.1145/3459637.3482344 + * * @tparam IdxT dtype to be used for constructing knn-graph */ template From 3c7586f813973c5489df70f25c2e221343b65853 Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 16 Jan 2024 13:00:42 -0800 Subject: [PATCH 10/10] [FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067) (#2067) - Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - rhdong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2067 --- cpp/bench/prims/CMakeLists.txt | 3 +- cpp/bench/prims/linalg/sddmm.cu | 275 +++++++++++++ .../distance/detail/kernels/gram_matrix.cuh | 4 +- cpp/include/raft/linalg/linalg_types.hpp | 9 +- .../raft/sparse/detail/cusparse_wrappers.h | 114 +++++- .../sparse/linalg/detail/cusparse_utils.hpp | 103 +++++ .../raft/sparse/linalg/detail/sddmm.hpp | 99 +++++ .../raft/sparse/linalg/detail/spmm.hpp | 54 +-- cpp/include/raft/sparse/linalg/sddmm.hpp | 83 ++++ cpp/include/raft/sparse/linalg/spmm.cuh | 66 +--- cpp/include/raft/sparse/linalg/spmm.hpp | 79 ++++ cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/sddmm.cu | 365 ++++++++++++++++++ 13 files changed, 1136 insertions(+), 119 deletions(-) create mode 100644 cpp/bench/prims/linalg/sddmm.cu create mode 100644 cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp create mode 100644 cpp/include/raft/sparse/linalg/detail/sddmm.hpp create mode 100644 cpp/include/raft/sparse/linalg/sddmm.hpp create mode 100644 cpp/include/raft/sparse/linalg/spmm.hpp create mode 100644 cpp/test/sparse/sddmm.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index fe58453d0d..3a2431cd34 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -117,6 +117,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/linalg/reduce_cols_by_key.cu bench/prims/linalg/reduce_rows_by_key.cu bench/prims/linalg/reduce.cu + bench/prims/linalg/sddmm.cu bench/prims/main.cpp ) diff --git a/cpp/bench/prims/linalg/sddmm.cu b/cpp/bench/prims/linalg/sddmm.cu new file mode 100644 index 0000000000..139a2b838d --- /dev/null +++ b/cpp/bench/prims/linalg/sddmm.cu @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace raft::bench::linalg { + +template +struct SDDMMBenchParams { + size_t m; + size_t k; + size_t n; + float sparsity; + bool transpose_a; + bool transpose_b; + ValueType alpha = 1.0; + ValueType beta = 0.0; +}; + +enum Alg { SDDMM, Inner }; + +template +inline auto operator<<(std::ostream& os, const SDDMMBenchParams& params) -> std::ostream& +{ + os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n + << "\tsparsity=" << params.sparsity << "\ttrans_a=" << (params.transpose_a ? "T" : "F") + << " trans_b=" << (params.transpose_b ? "T" : "F"); + return os; +} + +template +struct SDDMMBench : public fixture { + SDDMMBench(const SDDMMBenchParams& p) + : fixture(true), + params(p), + handle(stream), + a_data_d(0, stream), + b_data_d(0, stream), + c_indptr_d(0, stream), + c_indices_d(0, stream), + c_data_d(0, stream), + c_dense_data_d(0, stream) + { + a_data_d.resize(params.m * params.k, stream); + b_data_d.resize(params.k * params.n, stream); + + raft::random::RngState rng(2024ULL); + raft::random::uniform( + handle, rng, a_data_d.data(), params.m * params.k, ValueType(-1.0), ValueType(1.0)); + raft::random::uniform( + handle, rng, b_data_d.data(), params.k * params.n, ValueType(-1.0), ValueType(1.0)); + + std::vector c_dense_data_h(params.m * params.n); + + c_true_nnz = create_sparse_matrix(c_dense_data_h); + std::vector values(c_true_nnz); + std::vector indices(c_true_nnz); + std::vector indptr(params.m + 1); + + c_data_d.resize(c_true_nnz, stream); + c_indptr_d.resize(params.m + 1, stream); + c_indices_d.resize(c_true_nnz, stream); + + if (SDDMMorInner == Alg::Inner) { c_dense_data_d.resize(params.m * params.n, stream); } + + convert_to_csr(c_dense_data_h, params.m, params.n, values, indices, indptr); + RAFT_EXPECTS(c_true_nnz == c_indices_d.size(), + "Something wrong. The c_true_nnz != c_indices_d.size()!"); + + update_device(c_data_d.data(), values.data(), c_true_nnz, stream); + update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream); + update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream); + } + + void convert_to_csr(std::vector& matrix, + IndexType rows, + IndexType cols, + std::vector& values, + std::vector& indices, + std::vector& indptr) + { + IndexType offset_indptr = 0; + IndexType offset_values = 0; + indptr[offset_indptr++] = 0; + + for (IndexType i = 0; i < rows; ++i) { + for (IndexType j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + values[offset_values] = static_cast(1.0); + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + size_t create_sparse_matrix(std::vector& matrix) + { + size_t total_elements = static_cast(params.m * params.n); + size_t num_ones = static_cast((total_elements * 1.0f) * params.sparsity); + size_t res = num_ones; + + for (size_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis(gen); + + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + ~SDDMMBench() {} + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto a = raft::make_device_matrix_view( + a_data_d.data(), + (!params.transpose_a ? params.m : params.k), + (!params.transpose_a ? params.k : params.m)); + + auto b = raft::make_device_matrix_view( + b_data_d.data(), + (!params.transpose_b ? params.k : params.n), + (!params.transpose_b ? params.n : params.k)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + raft::resource::get_cusparse_handle(handle); + + resource::sync_stream(handle); + + auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + + loop_on_state(state, [this, &a, &b, &c, &op_a, &op_b]() { + if (SDDMMorInner == Alg::SDDMM) { + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + } else { + raft::distance::pairwise_distance(handle, + a_data_d.data(), + b_data_d.data(), + c_dense_data_d.data(), + static_cast(params.m), + static_cast(params.n), + static_cast(params.k), + raft::distance::DistanceType::InnerProduct, + std::is_same_v); + resource::sync_stream(handle); + } + }); + } + + private: + const raft::device_resources handle; + SDDMMBenchParams params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + rmm::device_uvector c_dense_data_d; + + size_t c_true_nnz = 0; + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; +}; + +template +static std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + bool transpose_a; + bool transpose_b; + size_t m; + size_t k; + size_t n; + float sparsity; + }; + + const std::vector params_group = + raft::util::itertools::product({false, true}, + {false, true}, + {size_t(10), size_t(1024)}, + {size_t(128), size_t(1024)}, + {size_t(1024 * 1024)}, + {0.01f, 0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(SDDMMBenchParams( + {params.m, params.k, params.n, params.sparsity, params.transpose_a, params.transpose_b})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index e121c1be9c..14b4ba12c6 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include // #include #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/linalg_types.hpp b/cpp/include/raft/linalg/linalg_types.hpp index e50d3a8e79..9c81fbc177 100644 --- a/cpp/include/raft/linalg/linalg_types.hpp +++ b/cpp/include/raft/linalg/linalg_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS }; */ enum class FillMode { UPPER, LOWER }; +/** + * @brief Enum for this type indicates which operation is applied to the related input (e.g. sparse + * matrix, or vector). + * + */ +enum class Operation { NON_TRANSPOSE, TRANSPOSE }; + } // end namespace raft::linalg \ No newline at end of file diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index e8bf9c6de5..cc3ae3ab87 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, alg, static_cast(externalBuffer)); } + +template +cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize); +} +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize); +} +template +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + T* externalBuffer, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + float* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_32F, + alg, + static_cast(externalBuffer)); +} +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + double* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_64F, + alg, + static_cast(externalBuffer)); +} + /** @} */ #else /** diff --git a/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp new file mode 100644 index 0000000000..b15614905b --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief create a cuSparse dense descriptor + * @tparam ValueType Data type of dense_view (float/double) + * @tparam IndexType Type of dense_view + * @tparam LayoutPolicy layout of dense_view + * @param[in] dense_view input raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +cusparseDnMatDescr_t create_descriptor( + raft::device_matrix_view dense_view) +{ + bool is_row_major = raft::is_row_major(dense_view); + auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); + cusparseDnMatDescr_t descr; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descr, + dense_view.extent(0), + dense_view.extent(1), + ld, + const_cast*>(dense_view.data_handle()), + order)); + return descr; +} + +/** + * @brief create a cuSparse sparse descriptor + * @tparam ValueType Data type of sparse_view (float/double) + * @tparam IndptrType Data type of csr_matrix_view index pointers + * @tparam IndicesType Data type of csr_matrix_view indices + * @tparam NZType Type of sparse_view + * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns + * @returns sparse matrix descriptor to be used by cuSparse API + */ +template +cusparseSpMatDescr_t create_descriptor( + raft::device_csr_matrix_view sparse_view) +{ + cusparseSpMatDescr_t descr; + auto csr_structure = sparse_view.structure_view(); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descr, + static_cast(csr_structure.get_n_rows()), + static_cast(csr_structure.get_n_cols()), + static_cast(csr_structure.get_nnz()), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), + const_cast*>(sparse_view.get_elements().data()))); + return descr; +} + +/** + * @brief convert the operation to cusparseOperation_t type + * @param param[in] op type of operation + */ +inline cusparseOperation_t convert_operation(const raft::linalg::Operation op) +{ + if (op == raft::linalg::Operation::TRANSPOSE) { + return CUSPARSE_OPERATION_TRANSPOSE; + } else if (op == raft::linalg::Operation::NON_TRANSPOSE) { + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } else { + RAFT_EXPECTS(false, "The operation type is not allowed."); + } + return CUSPARSE_OPERATION_NON_TRANSPOSE; +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/sddmm.hpp b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp new file mode 100644 index 0000000000..5088a20f46 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (op_a(A) * op_b(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @tparam NZType Type of C + * + * @param[in] handle raft resource handle + * @param[in] descr_a input dense descriptor + * @param[in] descr_b input dense descriptor + * @param[in/out] descr_c output sparse descriptor + * @param[in] op_a input Operation op(A) + * @param[in] op_b input Operation op(B) + * @param[in] alpha scalar pointer + * @param[in] beta scalar pointer + */ +template +void sddmm(raft::resources const& handle, + cusparseDnMatDescr_t& descr_a, + cusparseDnMatDescr_t& descr_b, + cusparseSpMatDescr_t& descr_c, + cusparseOperation_t op_a, + cusparseOperation_t op_b, + const ValueType* alpha, + const ValueType* beta) +{ + auto alg = CUSPARSE_SDDMM_ALG_DEFAULT; + size_t bufferSize; + + RAFT_CUSPARSE_TRY( + raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + &bufferSize, + resource::get_cuda_stream(handle))); + + resource::sync_stream(handle); + + rmm::device_uvector tmp(bufferSize, resource::get_cuda_stream(handle)); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + tmp.data(), + resource::get_cuda_stream(handle))); +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index d8d73ee83f..6206348b02 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,58 +48,6 @@ bool is_row_major(raft::device_matrix_view -cusparseDnMatDescr_t create_descriptor( - raft::device_matrix_view& dense_view, const bool is_row_major) -{ - auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; - IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); - cusparseDnMatDescr_t descr; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( - &descr, - dense_view.extent(0), - dense_view.extent(1), - ld, - const_cast*>(dense_view.data_handle()), - order)); - return descr; -} - -/** - * @brief create a cuSparse sparse descriptor - * @tparam ValueType Data type of sparse_view (float/double) - * @tparam IndptrType Data type of csr_matrix_view index pointers - * @tparam IndicesType Data type of csr_matrix_view indices - * @tparam NZType Type of sparse_view - * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns - * @returns sparse matrix descriptor to be used by cuSparse API - */ -template -cusparseSpMatDescr_t create_descriptor( - raft::device_csr_matrix_view& sparse_view) -{ - cusparseSpMatDescr_t descr; - auto csr_structure = sparse_view.structure_view(); - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( - &descr, - static_cast(csr_structure.get_n_rows()), - static_cast(csr_structure.get_n_cols()), - static_cast(csr_structure.get_nnz()), - const_cast(csr_structure.get_indptr().data()), - const_cast(csr_structure.get_indices().data()), - const_cast*>(sparse_view.get_elements().data()))); - return descr; -} - /** * @brief SPMM function designed for handling all CSR * DENSE * combinations of operand layouts for cuSparse. diff --git a/cpp/include/raft/sparse/linalg/sddmm.hpp b/cpp/include/raft/sparse/linalg/sddmm.hpp new file mode 100644 index 0000000000..c19f1d9081 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/sddmm.hpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (opA(A) * opB(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam NZType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @param[in] handle raft handle + * @param[in] A input raft::device_matrix_view + * @param[in] B input raft::device_matrix_view + * @param[inout] C output raft::device_csr_matrix_view + * @param[in] opA input Operation op(A) + * @param[in] opB input Operation op(B) + * @param[in] alpha input raft::host_scalar_view + * @param[in] beta input raft::host_scalar_view + */ +template +void sddmm(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::device_csr_matrix_view C, + const raft::linalg::Operation opA, + const raft::linalg::Operation opB, + raft::host_scalar_view alpha, + raft::host_scalar_view beta) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(A), "A is not contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(B), "B is not contiguous"); + + static_assert(std::is_same_v || std::is_same_v, + "The `ValueType` of sddmm only supports float/double."); + + auto descrA = detail::create_descriptor(A); + auto descrB = detail::create_descriptor(B); + auto descrC = detail::create_descriptor(C); + auto op_A = detail::convert_operation(opA); + auto op_B = detail::convert_operation(opB); + + detail::sddmm( + handle, descrA, descrB, descrC, op_A, op_B, alpha.data_handle(), beta.data_handle()); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrB)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descrC)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/spmm.cuh b/cpp/include/raft/sparse/linalg/spmm.cuh index 064da4d8fb..439ed8c341 100644 --- a/cpp/include/raft/sparse/linalg/spmm.cuh +++ b/cpp/include/raft/sparse/linalg/spmm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,66 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __SPMM_H -#define __SPMM_H - #pragma once -#include "detail/spmm.hpp" - -namespace raft { -namespace sparse { -namespace linalg { - -/** - * @brief SPMM function designed for handling all CSR * DENSE - * combinations of operand layouts for cuSparse. - * It computes the following equation: Z = alpha . X * Y + beta . Z - * where X is a CSR device matrix view and Y,Z are device matrix views - * @tparam ValueType Data type of input/output matrices (float/double) - * @tparam IndexType Type of Y and Z - * @tparam NZType Type of X - * @tparam LayoutPolicyY layout of Y - * @tparam LayoutPolicyZ layout of Z - * @param[in] handle raft handle - * @param[in] trans_x transpose operation for X - * @param[in] trans_y transpose operation for Y - * @param[in] alpha scalar - * @param[in] x input raft::device_csr_matrix_view - * @param[in] y input raft::device_matrix_view - * @param[in] beta scalar - * @param[out] z output raft::device_matrix_view - */ -template -void spmm(raft::resources const& handle, - const bool trans_x, - const bool trans_y, - const ValueType* alpha, - raft::device_csr_matrix_view x, - raft::device_matrix_view y, - const ValueType* beta, - raft::device_matrix_view z) -{ - bool is_row_major = detail::is_row_major(y, z); - - auto descr_x = detail::create_descriptor(x); - auto descr_y = detail::create_descriptor(y, is_row_major); - auto descr_z = detail::create_descriptor(z, is_row_major); - - detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); - - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -} // end namespace linalg -} // end namespace sparse -} // end namespace raft +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the spmm.hpp at the same path instead.") -#endif +#include diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp new file mode 100644 index 0000000000..c2fdd64574 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __SPMM_H +#define __SPMM_H + +#pragma once + +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief SPMM function designed for handling all CSR * DENSE + * combinations of operand layouts for cuSparse. + * It computes the following equation: Z = alpha . X * Y + beta . Z + * where X is a CSR device matrix view and Y,Z are device matrix views + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of Y and Z + * @tparam NZType Type of X + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] handle raft handle + * @param[in] trans_x transpose operation for X + * @param[in] trans_y transpose operation for Y + * @param[in] alpha scalar + * @param[in] x input raft::device_csr_matrix_view + * @param[in] y input raft::device_matrix_view + * @param[in] beta scalar + * @param[out] z output raft::device_matrix_view + */ +template +void spmm(raft::resources const& handle, + const bool trans_x, + const bool trans_y, + const ValueType* alpha, + raft::device_csr_matrix_view x, + raft::device_matrix_view y, + const ValueType* beta, + raft::device_matrix_view z) +{ + bool is_row_major = detail::is_row_major(y, z); + + auto descr_x = detail::create_descriptor(x); + auto descr_y = detail::create_descriptor(y); + auto descr_z = detail::create_descriptor(z); + + detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft + +#endif diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6e32281ec0..931530b66a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -315,6 +315,7 @@ if(BUILD_TESTS) test/sparse/normalize.cu test/sparse/reduce.cu test/sparse/row_op.cu + test/sparse/sddmm.cu test/sparse/sort.cu test/sparse/spgemmi.cu test/sparse/symmetrize.cu diff --git a/cpp/test/sparse/sddmm.cu b/cpp/test/sparse/sddmm.cu new file mode 100644 index 0000000000..9323ee8c2b --- /dev/null +++ b/cpp/test/sparse/sddmm.cu @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../test_utils.cuh" + +namespace raft { +namespace sparse { + +template +struct SDDMMInputs { + ValueType tolerance; + + IndexType m; + IndexType k; + IndexType n; + + ValueType alpha; + ValueType beta; + + bool transpose_a; + bool transpose_b; + + ValueType sparsity; + + unsigned long long int seed; +}; + +template +struct sum_abs_op { + __host__ __device__ ValueType operator()(const ValueType& x, const ValueType& y) const + { + return y >= ValueType(0.0) ? (x + y) : (x - y); + } +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SDDMMInputs& params) +{ + os << " m: " << params.m << "\tk: " << params.k << "\tn: " << params.n + << "\talpha: " << params.alpha << "\tbeta: " << params.beta + << "\tsparsity: " << params.sparsity; + + return os; +} + +template +class SDDMMTest : public ::testing::TestWithParam> { + public: + SDDMMTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + a_data_d(0, resource::get_cuda_stream(handle)), + b_data_d(0, resource::get_cuda_stream(handle)), + c_indptr_d(0, resource::get_cuda_stream(handle)), + c_indices_d(0, resource::get_cuda_stream(handle)), + c_data_d(0, resource::get_cuda_stream(handle)), + c_expected_data_d(0, resource::get_cuda_stream(handle)) + { + } + + protected: + IndexType create_sparse_matrix(IndexType m, + IndexType n, + ValueType sparsity, + std::vector& matrix) + { + IndexType total_elements = static_cast(m * n); + IndexType num_ones = static_cast((total_elements * 1.0f) * sparsity); + IndexType res = num_ones; + + for (IndexType i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis(gen); + + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + IndexType rows, + IndexType cols, + std::vector& values, + std::vector& indices, + std::vector& indptr) + { + IndexType offset_indptr = 0; + IndexType offset_values = 0; + indptr[offset_indptr++] = 0; + + for (IndexType i = 0; i < rows; ++i) { + for (IndexType j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + values[offset_values] = static_cast(1.0); + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, + const std::vector& cols, + const std::vector& row_ptrs, + bool is_row_major_A, + bool is_row_major_B) + { + if (params.m * params.k != static_cast(A.size()) || + params.k * params.n != static_cast(B.size())) { + std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; + return; + } + + bool trans_a = params.transpose_a ? !is_row_major_A : is_row_major_A; + bool trans_b = params.transpose_b ? !is_row_major_B : is_row_major_B; + + for (IndexType i = 0; i < params.m; ++i) { + for (IndexType j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { + ValueType sum = 0; + for (IndexType l = 0; l < params.k; ++l) { + IndexType a_index = trans_a ? i * params.k + l : l * params.m + i; + IndexType b_index = trans_b ? l * params.n + cols[j] : cols[j] * params.k + l; + sum += A[a_index] * B[b_index]; + } + vals[j] = params.alpha * sum + params.beta * vals[j]; + } + } + } + + void make_data() + { + IndexType a_size = params.m * params.k; + IndexType b_size = params.k * params.n; + IndexType c_size = params.m * params.n; + + std::vector a_data_h(a_size); + std::vector b_data_h(b_size); + + a_data_d.resize(a_size, stream); + b_data_d.resize(b_size, stream); + + auto blobs_a_b = raft::make_device_matrix(handle, 1, a_size + b_size); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_a_b.data_handle(), + labels.data_handle(), + 1, + a_size + b_size, + 1, + stream, + false, + nullptr, + nullptr, + ValueType(1.0), + false, + ValueType(-1.0f), + ValueType(1.0f), + uint64_t(2024)); + + raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + resource::sync_stream(handle); + + std::vector c_dense_data_h(c_size); + IndexType c_true_nnz = + create_sparse_matrix(params.m, params.n, params.sparsity, c_dense_data_h); + + std::vector c_indptr_h(params.m + 1); + std::vector c_indices_h(c_true_nnz); + std::vector c_data_h(c_true_nnz); + + convert_to_csr(c_dense_data_h, params.m, params.n, c_data_h, c_indices_h, c_indptr_h); + + bool is_row_major_A = (std::is_same_v); + bool is_row_major_B = (std::is_same_v); + + c_data_d.resize(c_data_h.size(), stream); + update_device(c_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + resource::sync_stream(handle); + + cpu_sddmm( + a_data_h, b_data_h, c_data_h, c_indices_h, c_indptr_h, is_row_major_A, is_row_major_B); + + c_indptr_d.resize(c_indptr_h.size(), stream); + c_indices_d.resize(c_indices_h.size(), stream); + c_expected_data_d.resize(c_data_h.size(), stream); + + update_device(c_indptr_d.data(), c_indptr_h.data(), c_indptr_h.size(), stream); + update_device(c_indices_d.data(), c_indices_h.data(), c_indices_h.size(), stream); + update_device(c_expected_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + + resource::sync_stream(handle); + } + + void SetUp() override { make_data(); } + + void Run() + { + auto a = raft::make_device_matrix_view( + a_data_d.data(), + (!params.transpose_a ? params.m : params.k), + (!params.transpose_a ? params.k : params.m)); + auto b = raft::make_device_matrix_view( + b_data_d.data(), + (!params.transpose_b ? params.k : params.n), + (!params.transpose_b ? params.n : params.k)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + + auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + + resource::sync_stream(handle); + + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + c.get_elements().data(), + c_expected_data_d.size(), + raft::CompareApprox(params.tolerance), + stream)); + + thrust::device_ptr expected_data_ptr = + thrust::device_pointer_cast(c_expected_data_d.data()); + ValueType sum_abs = thrust::reduce(thrust::cuda::par.on(stream), + expected_data_ptr, + expected_data_ptr + c_expected_data_d.size(), + ValueType(0.0f), + sum_abs_op()); + ValueType avg = sum_abs / (1.0f * c_expected_data_d.size()); + + ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); + } + + raft::resources handle; + cudaStream_t stream; + SDDMMInputs params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; + + rmm::device_uvector c_expected_data_d; +}; + +using SDDMMTestF_Row_Col = SDDMMTest; +TEST_P(SDDMMTestF_Row_Col, Result) { Run(); } + +using SDDMMTestF_Col_Row = SDDMMTest; +TEST_P(SDDMMTestF_Col_Row, Result) { Run(); } + +using SDDMMTestF_Row_Row = SDDMMTest; +TEST_P(SDDMMTestF_Row_Row, Result) { Run(); } + +using SDDMMTestF_Col_Col = SDDMMTest; +TEST_P(SDDMMTestF_Col_Col, Result) { Run(); } + +using SDDMMTestD_Row_Col = SDDMMTest; +TEST_P(SDDMMTestD_Row_Col, Result) { Run(); } + +using SDDMMTestD_Col_Row = SDDMMTest; +TEST_P(SDDMMTestD_Col_Row, Result) { Run(); } + +using SDDMMTestD_Row_Row = SDDMMTest; +TEST_P(SDDMMTestD_Row_Row, Result) { Run(); } + +using SDDMMTestD_Col_Col = SDDMMTest; +TEST_P(SDDMMTestD_Col_Col, Result) { Run(); } + +const std::vector> sddmm_inputs_f = { + {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, + {0.0003f, 32, 1024, 1024, 1.0, 0.3, false, true, 0.2, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.2, 0.2, true, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.1, 0.2, false, false, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 1.0, 0.3, true, false, 0.4, 1234ULL}, + {0.0003f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; + +const std::vector> sddmm_inputs_d = { + {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, + {0.0001f, 32, 1024, 1024, 1.0, 0.3, false, true, 0.2, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.2, 0.2, true, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.1, 0.2, false, false, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 1.0, 0.3, true, false, 0.4, 1234ULL}, + {0.0001f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Col, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Row, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Row, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Col, ::testing::ValuesIn(sddmm_inputs_f)); + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Col, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Row, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Row, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Col, ::testing::ValuesIn(sddmm_inputs_d)); + +} // namespace sparse +} // namespace raft