From 934645cc7c6a6c19b535d770163ac1f776b95eea Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 20 Aug 2024 10:21:11 -0700 Subject: [PATCH] [FEA] Support for half-float mixed precise in brute-force (#225) - distance supports half-float mixed precision - prefiltered_brute_force supports half - migrate the ann brute force test cases and support half Authors: - rhdong (https://github.com/rhdong) - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/225 --- cpp/CMakeLists.txt | 15 + cpp/bench/ann/CMakeLists.txt | 18 + .../ann/src/cuvs/cuvs_brute_force_knn.cu | 333 +++++ cpp/cmake/thirdparty/get_raft.cmake | 4 +- cpp/include/cuvs/distance/distance.hpp | 87 ++ cpp/include/cuvs/neighbors/brute_force.hpp | 139 +- cpp/src/distance/detail/distance.cuh | 84 +- .../distance/detail/distance_ops/canberra.cuh | 26 +- .../detail/distance_ops/correlation.cuh | 32 +- .../distance/detail/distance_ops/cosine.cuh | 27 +- .../distance/detail/distance_ops/hamming.cuh | 6 +- .../detail/distance_ops/hellinger.cuh | 6 +- .../detail/distance_ops/jensen_shannon.cuh | 14 +- .../detail/distance_ops/kl_divergence.cuh | 26 +- cpp/src/distance/detail/distance_ops/l1.cuh | 9 +- .../distance/detail/distance_ops/l2_exp.cuh | 28 +- .../distance/detail/distance_ops/l2_unexp.cuh | 15 +- .../distance/detail/distance_ops/l_inf.cuh | 4 +- .../distance/detail/distance_ops/lp_unexp.cuh | 10 +- .../detail/distance_ops/russel_rao.cuh | 9 +- .../distance/detail/distance_ops/template.cuh | 4 +- .../distance/detail/masked_distance_base.cuh | 2 +- .../detail/pairwise_distance_base.cuh | 26 +- .../detail/pairwise_distance_cutlass_base.cuh | 30 +- .../distance/detail/pairwise_distance_gemm.h | 103 +- .../detail/pairwise_matrix/dispatch-ext.cuh | 174 +-- .../detail/pairwise_matrix/dispatch-inl.cuh | 12 +- .../detail/pairwise_matrix/dispatch.cuh | 2 - .../pairwise_matrix/dispatch_00_generate.py | 10 +- ...patch_canberra_double_double_double_int.cu | 4 +- ...dispatch_canberra_float_float_float_int.cu | 4 +- .../dispatch_canberra_half_float_float_int.cu | 50 + ...ch_correlation_double_double_double_int.cu | 4 +- ...patch_correlation_float_float_float_int.cu | 4 +- ...spatch_correlation_half_float_float_int.cu | 50 + ...ispatch_cosine_double_double_double_int.cu | 4 +- .../dispatch_cosine_float_float_float_int.cu | 4 +- .../dispatch_cosine_half_float_float_int.cu | 51 + ...ing_unexpanded_double_double_double_int.cu | 4 +- ...amming_unexpanded_float_float_float_int.cu | 4 +- ...hamming_unexpanded_half_float_float_int.cu | 50 + ...inger_expanded_double_double_double_int.cu | 4 +- ...ellinger_expanded_float_float_float_int.cu | 4 +- ...hellinger_expanded_half_float_float_int.cu | 50 + ...jensen_shannon_double_double_double_int.cu | 4 +- ...ch_jensen_shannon_float_float_float_int.cu | 4 +- ...tch_jensen_shannon_half_float_float_int.cu | 55 + ..._kl_divergence_double_double_double_int.cu | 4 +- ...tch_kl_divergence_float_float_float_int.cu | 4 +- ...atch_kl_divergence_half_float_float_int.cu | 50 + .../dispatch_l1_double_double_double_int.cu | 4 +- .../dispatch_l1_float_float_float_int.cu | 4 +- .../dispatch_l1_half_float_float_int.cu | 50 + ...ch_l2_expanded_double_double_double_int.cu | 4 +- ...patch_l2_expanded_float_float_float_int.cu | 4 +- ...spatch_l2_expanded_half_float_float_int.cu | 51 + ..._l2_unexpanded_double_double_double_int.cu | 4 +- ...tch_l2_unexpanded_float_float_float_int.cu | 4 +- ...atch_l2_unexpanded_half_float_float_int.cu | 50 + ...dispatch_l_inf_double_double_double_int.cu | 4 +- .../dispatch_l_inf_float_float_float_int.cu | 4 +- .../dispatch_l_inf_half_float_float_int.cu | 50 + ..._lp_unexpanded_double_double_double_int.cu | 4 +- ...tch_lp_unexpanded_float_float_float_int.cu | 4 +- ...atch_lp_unexpanded_half_float_float_int.cu | 50 + .../detail/pairwise_matrix/dispatch_rbf.cu | 12 +- ...tch_russel_rao_double_double_double_int.cu | 4 +- ...spatch_russel_rao_float_float_float_int.cu | 4 +- ...ispatch_russel_rao_half_float_float_int.cu | 50 + .../detail/pairwise_matrix/params.cuh | 4 +- cpp/src/distance/distance-ext.cuh | 1122 ++++------------- cpp/src/distance/distance-inl.cuh | 48 +- cpp/src/distance/distance.cu | 1073 +++------------- cpp/src/distance/distance.cuh | 2 - cpp/src/distance/pairwise_distance.cu | 36 + cpp/src/neighbors/brute_force.cu | 105 +- .../cagra/search_multi_cta_kernel-ext.cuh | 4 +- .../detail/cagra/search_multi_cta_kernel.cuh | 2 +- .../cagra/search_single_cta_kernel-ext.cuh | 4 +- .../detail/cagra/search_single_cta_kernel.cuh | 2 +- cpp/src/neighbors/detail/fused_l2_knn.cuh | 179 +-- .../neighbors/detail/haversine_distance.cuh | 52 +- cpp/src/neighbors/detail/knn_brute_force.cuh | 214 ++-- cpp/src/neighbors/detail/knn_utils.cuh | 37 +- cpp/test/CMakeLists.txt | 15 + cpp/test/distance/dist_canberra.cu | 25 +- cpp/test/distance/dist_correlation.cu | 49 +- cpp/test/distance/dist_cos.cu | 69 +- cpp/test/distance/dist_hamming.cu | 24 +- cpp/test/distance/dist_hellinger.cu | 24 +- cpp/test/distance/dist_inner_product.cu | 25 +- cpp/test/distance/dist_jensen_shannon.cu | 24 +- cpp/test/distance/dist_kl_divergence.cu | 24 +- cpp/test/distance/dist_l1.cu | 25 +- cpp/test/distance/dist_l2_exp.cu | 72 +- cpp/test/distance/dist_l2_sqrt_exp.cu | 27 +- cpp/test/distance/dist_l2_unexp.cu | 24 +- cpp/test/distance/dist_l_inf.cu | 25 +- cpp/test/distance/dist_lp_unexp.cu | 26 +- cpp/test/distance/dist_russell_rao.cu | 24 +- cpp/test/distance/distance_base.cuh | 347 ++--- cpp/test/neighbors/ann_brute_force.cuh | 200 +++ .../neighbors/ann_brute_force/test_float.cu | 28 + .../neighbors/ann_brute_force/test_half.cu | 30 + cpp/test/neighbors/brute_force.cu | 231 +++- cpp/test/neighbors/brute_force_prefiltered.cu | 202 ++- 106 files changed, 3692 insertions(+), 2760 deletions(-) create mode 100644 cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l1_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_half_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu create mode 100644 cpp/test/neighbors/ann_brute_force.cuh create mode 100644 cpp/test/neighbors/ann_brute_force/test_float.cu create mode 100644 cpp/test/neighbors/ann_brute_force/test_half.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3b483538a..76fd09b58 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -216,30 +216,43 @@ add_library( src/cluster/kmeans_transform_float.cu src/cluster/single_linkage_float.cu src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_rbf.cu src/distance/detail/fused_distance_nn.cu @@ -413,6 +426,8 @@ add_library( src/selection/select_k_half_uint32_t.cu ) +target_compile_definitions(cuvs PRIVATE "CUVS_EXPLICIT_INSTANTIATE_ONLY") + target_compile_options( cuvs INTERFACE $<$:--expt-extended-lambda --expt-relaxed-constexpr> diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 80c1f3530..6fe23483e 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -35,6 +35,7 @@ option(CUVS_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" OFF) option(CUVS_ANN_BENCH_SINGLE_EXE "Make a single executable with benchmark as shared library modules" OFF ) +option(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE "Include cuVS brute force knn in benchmark" ON) # ################################################################################################## # * Process options ---------------------------------------------------------- @@ -53,6 +54,9 @@ if(BUILD_CPU_ONLY) set(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE OFF) set(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB OFF) set(CUVS_ANN_BENCH_USE_GGNN OFF) + set(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE OFF) +else() + set(CUVS_FAISS_ENABLE_GPU ON) endif() set(CUVS_ANN_BENCH_USE_CUVS OFF) @@ -61,6 +65,7 @@ if(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ OR CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT OR CUVS_ANN_BENCH_USE_CUVS_CAGRA OR CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB + OR CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE ) set(CUVS_ANN_BENCH_USE_CUVS ON) endif() @@ -169,6 +174,8 @@ function(ConfigureAnnBench) ) endif() + target_compile_definitions(${BENCH_NAME} PRIVATE "CUVS_EXPLICIT_INSTANTIATE_ONLY") + target_include_directories( ${BENCH_NAME} PUBLIC "$" @@ -223,6 +230,17 @@ if(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE) ConfigureAnnBench(NAME CUVS_BRUTE_FORCE PATH src/cuvs/cuvs_benchmark.cu LINKS cuvs) endif() +if(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE) + ConfigureAnnBench( + NAME + CUVS_KNN_BRUTE_FORCE + PATH + $<$:src/cuvs/cuvs_brute_force_knn.cu> + LINKS + cuvs + ) +endif() + if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) ConfigureAnnBench( NAME diff --git a/cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu b/cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu new file mode 100644 index 000000000..4c38b3420 --- /dev/null +++ b/cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu @@ -0,0 +1,333 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace cuvs::neighbors::brute_force { + +struct print_metric { + cuvs::distance::DistanceType value; +}; + +struct RandomKNNInputs { + int num_queries; + int num_db_vecs; + int dim; + int k; + cuvs::distance::DistanceType metric; + bool row_major; +}; + +inline auto operator<<(std::ostream& os, const print_metric& p) -> std::ostream& +{ + switch (p.value) { + case cuvs::distance::DistanceType::L2Expanded: os << "L2Expanded"; break; + case cuvs::distance::DistanceType::L2SqrtExpanded: os << "L2SqrtExpanded"; break; + case cuvs::distance::DistanceType::CosineExpanded: os << "CosineExpanded"; break; + case cuvs::distance::DistanceType::L1: os << "L1"; break; + case cuvs::distance::DistanceType::L2Unexpanded: os << "L2Unexpanded"; break; + case cuvs::distance::DistanceType::L2SqrtUnexpanded: os << "L2SqrtUnexpanded"; break; + case cuvs::distance::DistanceType::InnerProduct: os << "InnerProduct"; break; + case cuvs::distance::DistanceType::Linf: os << "Linf"; break; + case cuvs::distance::DistanceType::Canberra: os << "Canberra"; break; + case cuvs::distance::DistanceType::LpUnexpanded: os << "LpUnexpanded"; break; + case cuvs::distance::DistanceType::CorrelationExpanded: os << "CorrelationExpanded"; break; + case cuvs::distance::DistanceType::JaccardExpanded: os << "JaccardExpanded"; break; + case cuvs::distance::DistanceType::HellingerExpanded: os << "HellingerExpanded"; break; + case cuvs::distance::DistanceType::Haversine: os << "Haversine"; break; + case cuvs::distance::DistanceType::BrayCurtis: os << "BrayCurtis"; break; + case cuvs::distance::DistanceType::JensenShannon: os << "JensenShannon"; break; + case cuvs::distance::DistanceType::HammingUnexpanded: os << "HammingUnexpanded"; break; + case cuvs::distance::DistanceType::KLDivergence: os << "KLDivergence"; break; + case cuvs::distance::DistanceType::RusselRaoExpanded: os << "RusselRaoExpanded"; break; + case cuvs::distance::DistanceType::DiceExpanded: os << "DiceExpanded"; break; + case cuvs::distance::DistanceType::Precomputed: os << "Precomputed"; break; + default: RAFT_FAIL("unreachable code"); + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const RandomKNNInputs& input) +{ + return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs + << " dim:" << input.dim << " k:" << input.k << " metric:" << print_metric{input.metric} + << " row_major:" << input.row_major; +} + +template +class BruteForceKNNBenchmark { + public: + BruteForceKNNBenchmark(const RandomKNNInputs& params, const std::string& type_str) + : stream_(raft::resource::get_cuda_stream(handle_)), + params_(params), + type_str_(type_str), + database(params_.num_db_vecs * params_.dim, stream_), + search_queries(params_.num_queries * params_.dim, stream_), + cuvs_indices_(params_.num_queries * params_.k, stream_), + cuvs_distances_(params_.num_queries * params_.k, stream_) + { + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(database.data(), params_.num_db_vecs, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(search_queries.data(), params_.num_queries, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(cuvs_distances_.data(), params_.num_queries, params_.k), + DistT{0.0}); + } + + void runBenchmark() + { + DistT metric_arg = 3.0; + rmm::device_uvector workspace(0, stream_); + + std::chrono::duration build_dur; + std::chrono::duration search_dur; + + auto indices = raft::make_device_matrix_view( + cuvs_indices_.data(), params_.num_queries, params_.k); + auto distances = raft::make_device_matrix_view( + cuvs_distances_.data(), params_.num_queries, params_.k); + raft::resource::sync_stream(handle_, stream_); + + if (params_.row_major) { + { + auto idx_warm = + cuvs::neighbors::brute_force::build(handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + params_.metric, + metric_arg); + cuvs::neighbors::brute_force::search( + handle_, + idx_warm, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + indices, + distances, + std::nullopt); + flush_l2_cache(); + raft::resource::sync_stream(handle_, stream_); + } + + auto start = std::chrono::high_resolution_clock::now(); + auto idx = + cuvs::neighbors::brute_force::build(handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + params_.metric, + metric_arg); + raft::resource::sync_stream(handle_, stream_); + auto end = std::chrono::high_resolution_clock::now(); + build_dur = end - start; + + start = std::chrono::high_resolution_clock::now(); + cuvs::neighbors::brute_force::search( + handle_, + idx, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + indices, + distances, + std::nullopt); + raft::resource::sync_stream(handle_, stream_); + end = std::chrono::high_resolution_clock::now(); + search_dur = end - start; + + } else { + { + auto idx_warm = + cuvs::neighbors::brute_force::build(handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + params_.metric, + metric_arg); + cuvs::neighbors::brute_force::search( + handle_, + idx_warm, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + indices, + distances, + std::nullopt); + flush_l2_cache(); + raft::resource::sync_stream(handle_, stream_); + } + + auto start = std::chrono::high_resolution_clock::now(); + auto idx = cuvs::neighbors::brute_force::build( + handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + params_.metric, + metric_arg); + raft::resource::sync_stream(handle_, stream_); + auto end = std::chrono::high_resolution_clock::now(); + build_dur = end - start; + + start = std::chrono::high_resolution_clock::now(); + cuvs::neighbors::brute_force::search( + handle_, + idx, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + indices, + distances, + std::nullopt); + raft::resource::sync_stream(handle_, stream_); + end = std::chrono::high_resolution_clock::now(); + search_dur = end - start; + } + + double total_dur = build_dur.count() + search_dur.count(); + double throughput = static_cast(params_.num_queries) / (total_dur / 1000.0); + ; + printResult(params_, build_dur.count(), search_dur.count(), total_dur, throughput); + } + + void setUp() + { + unsigned long long int seed = 1234ULL; + raft::random::RngState r(seed); + + // JensenShannon distance requires positive values + T min_val = params_.metric == cuvs::distance::DistanceType::JensenShannon ? T(0.0) : T(-1.0); + uniform(handle_, r, database.data(), params_.num_db_vecs * params_.dim, min_val, T(1.0)); + uniform(handle_, r, search_queries.data(), params_.num_queries * params_.dim, min_val, T(1.0)); + } + + private: + void flush_l2_cache() + { + int l2_cache_size = 0; + int device_id = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device_id)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&l2_cache_size, cudaDevAttrL2CacheSize, device_id)); + scratch_buf_ = rmm::device_buffer(l2_cache_size * 3, stream_); + RAFT_CUDA_TRY(cudaMemsetAsync(scratch_buf_.data(), 0, scratch_buf_.size(), stream_)); + }; + + void printResult(const RandomKNNInputs& params, + double build_time, + double search_time, + double total_time, + double throughput) + { + std::cout << std::left << std::setw(15) << type_str_ << std::setw(10) << params.num_queries + << std::setw(10) << params.num_db_vecs << std::setw(10) << params.dim << std::setw(10) + << params.k << std::setw(20) << print_metric{params.metric} << std::setw(15) + << (params.row_major ? "row" : "col") << std::right << std::setw(20) << std::fixed + << std::setprecision(3) << build_time << std::right << std::setw(20) << std::fixed + << std::setprecision(3) << search_time << std::right << std::setw(20) << std::fixed + << std::setprecision(3) << total_time << std::right << std::setw(20) << std::fixed + << std::setprecision(3) << throughput << "\n"; + } + raft::resources handle_; + cudaStream_t stream_ = 0; + RandomKNNInputs params_; + rmm::device_uvector database; + rmm::device_uvector search_queries; + rmm::device_uvector cuvs_indices_; + rmm::device_uvector cuvs_distances_; + rmm::device_buffer scratch_buf_; + std::string type_str_; +}; + +static std::vector getInputs() +{ + std::vector param_vec; + struct TestParams { + int num_queries; + int num_db_vecs; + int dim; + int k; + cuvs::distance::DistanceType metric; + bool row_major; + }; + + const std::vector params_group = raft::util::itertools::product( + {int(10), int(100), int(1024)}, + {int(1000000)}, + {int(32), int(256), int(1024)}, + {int(128), int(1024)}, + {cuvs::distance::DistanceType::InnerProduct, cuvs::distance::DistanceType::L2SqrtExpanded}, + {true, false}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(RandomKNNInputs({params.num_queries, + params.num_db_vecs, + params.dim, + params.k, + params.metric, + params.row_major})); + } + return param_vec; +} + +void printHeader() +{ + std::cout << std::left << std::setw(15) << "Type" << std::setw(10) << "Queries" << std::setw(10) + << "Vectors" << std::setw(10) << "Dim" << std::setw(10) << "K" << std::setw(20) + << "Metric" << std::setw(15) << "Layout" << std::right << std::setw(20) + << "Build Time (ms)" << std::right << std::setw(20) << "Search Time (ms)" << std::right + << std::setw(20) << "Total Time (ms)" << std::right << std::setw(20) + << "Throughput (q/s)" + << "\n"; + std::cout << std::string(165, '-') << "\n"; +} + +void runBenchmarkForType() +{ + auto selected_inputs = getInputs(); + for (const auto& input : selected_inputs) { + { + BruteForceKNNBenchmark benchmark(input, "float"); + benchmark.setUp(); + benchmark.runBenchmark(); + } + { + BruteForceKNNBenchmark benchmark(input, "half"); + benchmark.setUp(); + benchmark.runBenchmark(); + } + } +} + +} // namespace cuvs::neighbors::brute_force + +int main() +{ + cuvs::neighbors::brute_force::printHeader(); + cuvs::neighbors::brute_force::runBenchmarkForType(); + return 0; +} diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 7640fbfa6..fb25623ce 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -61,8 +61,8 @@ endfunction() # To use a different RAFT locally, set the CMake variable # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${RAFT_VERSION}.00 - FORK ${RAFT_FORK} - PINNED_TAG ${RAFT_PINNED_TAG} + FORK rhdong + PINNED_TAG half_knn ENABLE_MNMG_DEPENDENCIES OFF ENABLE_NVTX OFF USE_RAFT_STATIC ${CUVS_USE_RAFT_STATIC} diff --git a/cpp/include/cuvs/distance/distance.hpp b/cpp/include/cuvs/distance/distance.hpp index 5786b0a32..def72641e 100644 --- a/cpp/include/cuvs/distance/distance.hpp +++ b/cpp/include/cuvs/distance/distance.hpp @@ -19,6 +19,7 @@ #include "distance.h" #include +#include #include #include @@ -156,6 +157,49 @@ void pairwise_distance( raft::device_matrix_view dist, cuvs::distance::DistanceType metric, double metric_arg = 2.0f); +/** + * @brief Compute pairwise distances for two matrices + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * + * raft::resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * + * // ... fill input with data ... + * + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; + * cuvs::distance::pairwise_distance(handle, + * raft::make_const(input.view()), + * raft::make_const(input.view()), + * output.view(), + * metric); + * @endcode + * + * @param[in] handle raft handle for managing expensive resources + * @param[in] x first set of points (size n*k) + * @param[in] y second set of points (size m*k) + * @param[out] dist output distance matrix (size n*m) + * @param[in] metric distance to evaluate + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f); /** * @brief Compute pairwise distances for two matrices @@ -243,6 +287,49 @@ void pairwise_distance( raft::device_matrix_view dist, cuvs::distance::DistanceType metric, double metric_arg = 2.0f); +/** + * @brief Compute pairwise distances for two matrices + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * + * raft::resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * + * // ... fill input with data ... + * + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; + * cuvs::distance::pairwise_distance(handle, + * raft::make_const(input.view()), + * raft::make_const(input.view()), + * output.view(), + * metric); + * @endcode + * + * @param[in] handle raft handle for managing expensive resources + * @param[in] x first set of points (size n*k) + * @param[in] y second set of points (size m*k) + * @param[out] dist output distance matrix (size n*m) + * @param[in] metric distance to evaluate + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f); /** @} */ // end group pairwise_distance_runtime diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index db70a7fa6..5408eb1a0 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -23,6 +23,8 @@ #include #include +#include + namespace cuvs::neighbors::brute_force { /** @@ -36,7 +38,7 @@ namespace cuvs::neighbors::brute_force { * * @tparam T data element type */ -template +template struct index : cuvs::neighbors::index { public: index(const index&) = delete; @@ -54,9 +56,9 @@ struct index : cuvs::neighbors::index { */ index(raft::resources const& res, raft::host_matrix_view dataset_view, - std::optional>&& norms, + std::optional>&& norms, cuvs::distance::DistanceType metric, - T metric_arg = 0.0); + DistT metric_arg = 0.0); /** Construct a brute force index from dataset * @@ -67,9 +69,9 @@ struct index : cuvs::neighbors::index { */ index(raft::resources const& res, raft::device_matrix_view dataset_view, - std::optional>&& norms, + std::optional>&& norms, cuvs::distance::DistanceType metric, - T metric_arg = 0.0); + DistT metric_arg = 0.0); /** Construct a brute force index from dataset * @@ -78,9 +80,9 @@ struct index : cuvs::neighbors::index { */ index(raft::resources const& res, raft::device_matrix_view dataset_view, - std::optional> norms_view, + std::optional> norms_view, cuvs::distance::DistanceType metric, - T metric_arg = 0.0); + DistT metric_arg = 0.0); /** Construct a brute force index from dataset * @@ -91,9 +93,9 @@ struct index : cuvs::neighbors::index { */ index(raft::resources const& res, raft::device_matrix_view dataset_view, - std::optional>&& norms, + std::optional>&& norms, cuvs::distance::DistanceType metric, - T metric_arg = 0.0); + DistT metric_arg = 0.0); /** Construct a brute force index from dataset * @@ -102,9 +104,9 @@ struct index : cuvs::neighbors::index { */ index(raft::resources const& res, raft::device_matrix_view dataset_view, - std::optional> norms_view, + std::optional> norms_view, cuvs::distance::DistanceType metric, - T metric_arg = 0.0); + DistT metric_arg = 0.0); /** * Replace the dataset with a new dataset. @@ -124,7 +126,7 @@ struct index : cuvs::neighbors::index { cuvs::distance::DistanceType metric() const noexcept { return metric_; } /** Metric argument */ - T metric_arg() const noexcept { return metric_arg_; } + DistT metric_arg() const noexcept { return metric_arg_; } /** Total length of the index (number of vectors). */ size_t size() const noexcept { return dataset_view_.extent(0); } @@ -139,7 +141,7 @@ struct index : cuvs::neighbors::index { } /** Dataset norms */ - raft::device_vector_view norms() const + raft::device_vector_view norms() const { return norms_view_.value(); } @@ -150,10 +152,10 @@ struct index : cuvs::neighbors::index { private: cuvs::distance::DistanceType metric_; raft::device_matrix dataset_; - std::optional> norms_; - std::optional> norms_view_; + std::optional> norms_; + std::optional> norms_view_; raft::device_matrix_view dataset_view_; - T metric_arg_; + DistT metric_arg_; }; /** * @} @@ -183,8 +185,28 @@ struct index : cuvs::neighbors::index { auto build(raft::resources const& handle, raft::device_matrix_view dataset, cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, - float metric_arg = 0) -> cuvs::neighbors::brute_force::index; - + float metric_arg = 0) -> cuvs::neighbors::brute_force::index; +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // create and fill the index from a [N, D] dataset + * auto index = brute_force::build(handle, dataset, metric); + * @endcode + * + * @param[in] handle + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] metric cuvs::distance::DistanceType + * @param[in] metric_arg metric argument + * + * @return the constructed ivf-flat index + */ +auto build(raft::resources const& handle, + raft::device_matrix_view dataset, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, + float metric_arg = 0) -> cuvs::neighbors::brute_force::index; /** * @brief Build the index from the dataset for efficient search. * @@ -205,7 +227,28 @@ auto build(raft::resources const& handle, auto build(raft::resources const& handle, raft::device_matrix_view dataset, cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, - float metric_arg = 0) -> cuvs::neighbors::brute_force::index; + float metric_arg = 0) -> cuvs::neighbors::brute_force::index; +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // create and fill the index from a [N, D] dataset + * auto index = brute_force::build(handle, dataset, metric); + * @endcode + * + * @param[in] handle + * @param[in] dataset a device pointer to a col-major matrix [n_rows, dim] + * @param[in] metric cuvs::distance::DistanceType + * @param[in] metric_arg metric argument + * + * @return the constructed bruteforce index + */ +auto build(raft::resources const& handle, + raft::device_matrix_view dataset, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, + float metric_arg = 0) -> cuvs::neighbors::brute_force::index; /** * @} */ @@ -244,12 +287,46 @@ auto build(raft::resources const& handle, * `index->size()` bits to indicate whether queries[0] should compute the distance with dataset. */ void search(raft::resources const& handle, - const cuvs::neighbors::brute_force::index& index, + const cuvs::neighbors::brute_force::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, std::optional> sample_filter); +/** + * @brief Search ANN using the constructed index. + * + * See the [brute_force::build](#brute_force::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * brute_force::search(handle, index, queries1, out_inds1, out_dists1); + * brute_force::search(handle, index, queries2, out_inds2, out_dists2); + * brute_force::search(handle, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a + * given + */ +void search(raft::resources const& handle, + const cuvs::neighbors::brute_force::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> sample_filter); /** * @brief Search ANN using the constructed index. * @@ -265,11 +342,31 @@ void search(raft::resources const& handle, * given query */ void search(raft::resources const& handle, - const cuvs::neighbors::brute_force::index& index, + const cuvs::neighbors::brute_force::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, std::optional> sample_filter); +/** + * @brief Search ANN using the constructed index. + * + * See the [brute_force::build](#brute_force::build) documentation for a usage example. + * + * @param[in] handle + * @param[in] index bruteforce constructed index + * @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a + * given query + */ +void search(raft::resources const& handle, + const cuvs::neighbors::brute_force::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> sample_filter); /** * @} */ diff --git a/cpp/src/distance/detail/distance.cuh b/cpp/src/distance/detail/distance.cuh index c8dde4ea4..d6fd04646 100644 --- a/cpp/src/distance/detail/distance.cuh +++ b/cpp/src/distance/detail/distance.cuh @@ -27,6 +27,7 @@ #include #include #include +#include // to_float #include @@ -104,8 +105,8 @@ void distance_impl(raft::resources const& handle, { ops::canberra_distance_op distance_op{}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( @@ -217,8 +218,8 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - DataT* x_norm = workspace; - DataT* y_norm = workspace; + OutT* x_norm = reinterpret_cast(workspace); + OutT* y_norm = reinterpret_cast(workspace); // TODO: Column major case looks to have lower accuracy for X == Y, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. @@ -255,8 +256,8 @@ void distance_impl(raft::resources const& handle, { ops::hamming_distance_op distance_op{k}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -319,8 +320,8 @@ void distance_impl(raft::resources const& handle, // Then calculate Hellinger distance ops::hellinger_distance_op distance_op{}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); @@ -349,8 +350,8 @@ void distance_impl(raft::resources const& handle, { ops::jensen_shannon_distance_op distance_op{}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -376,14 +377,24 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto unaryOp_lambda = [] __device__(DataT input) { - const bool x_zero = (input == 0); - return (!x_zero) * raft::log(input + x_zero); + auto input_ = raft::to_float(input); + const bool x_zero = (input_ == 0); + if constexpr (std::is_same_v) { + return __float2half((!x_zero) * raft::log(input_ + x_zero)); + } else { + return (!x_zero) * raft::log(input_ + x_zero); + } }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { // reverse previous log (x) back to x using (e ^ log(x)) - const bool x_zero = (input == 0); - return (!x_zero) * raft::exp(input); + auto input_ = raft::to_float(input); + const bool x_zero = (input_ == 0); + if constexpr (std::is_same_v) { + return __float2half((!x_zero) * raft::exp(input_)); + } else { + return (!x_zero) * raft::exp(input_); + } }; if (x != y) { @@ -391,8 +402,8 @@ void distance_impl(raft::resources const& handle, (DataT*)y, y, n * k, unaryOp_lambda, stream); } - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; // This op takes some shortcuts when x equals y. So its behavior changes based // on this. @@ -425,8 +436,8 @@ void distance_impl(raft::resources const& handle, { ops::l1_distance_op distance_op{}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( @@ -460,8 +471,13 @@ void distance_impl_l2_expanded( // NOTE: different name ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - DataT* x_norm = workspace; - DataT* y_norm = workspace; + // TODO: May we have a better method to avoid misalignment? + uintptr_t offset = alignof(OutT) - (reinterpret_cast(workspace) % alignof(OutT)); + if (offset == alignof(OutT)) { offset = 0; } + OutT* x_norm = reinterpret_cast(reinterpret_cast(workspace) + offset); + + offset = (reinterpret_cast(x_norm) % alignof(OutT)); + OutT* y_norm = x_norm; // TODO: Column major case looks to have lower accuracy for X == Y, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. @@ -548,8 +564,8 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -576,8 +592,8 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -602,8 +618,8 @@ void distance_impl(raft::resources const& handle, { ops::l_inf_distance_op distance_op{}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -628,8 +644,8 @@ void distance_impl(raft::resources const& handle, { ops::lp_unexp_distance_op distance_op{metric_arg}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -654,8 +670,8 @@ void distance_impl(raft::resources const& handle, { ops::russel_rao_distance_op distance_op{k}; - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; + const OutT* x_norm = nullptr; + const OutT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -705,8 +721,8 @@ void distance(raft::resources const& handle, void* workspace, size_t worksize, FinalLambda fin_op, - bool isRowMajor = true, - InType metric_arg = 2.0f) + bool isRowMajor = true, + OutType metric_arg = 2.0f) { // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), @@ -762,8 +778,8 @@ void distance(raft::resources const& handle, Index_ k, void* workspace, size_t worksize, - bool isRowMajor = true, - InType metric_arg = 2.0f) + bool isRowMajor = true, + OutType metric_arg = 2.0f) { auto fin_op = raft::identity_op(); diff --git a/cpp/src/distance/detail/distance_ops/canberra.cuh b/cpp/src/distance/detail/distance_ops/canberra.cuh index 8bbdc9945..bf01caf98 100644 --- a/cpp/src/distance/detail/distance_ops/canberra.cuh +++ b/cpp/src/distance/detail/distance_ops/canberra.cuh @@ -19,6 +19,8 @@ #include // raft::abs #include // DI +#include + namespace cuvs::distance::detail::ops { /** @@ -50,17 +52,27 @@ struct canberra_distance_op { DI void core(AccT& acc, DataT& x, DataT& y) const { - const auto diff = raft::abs(x - y); - const auto add = raft::abs(x) + raft::abs(y); - // deal with potential for 0 in denominator by - // forcing 0/1 instead - acc += ((add != 0) * diff / (add + (add == 0))); + if constexpr ((std::is_same_v && std::is_same_v)) { + AccT _x = __half2float(x); + AccT _y = __half2float(y); + const auto diff = raft::abs(_x - _y); + const auto add = raft::abs(_x) + raft::abs(_y); + // deal with potential for 0 in denominator by + // forcing 0/1 instead + acc += ((add != 0) * diff / (add + (add == 0))); + } else { + const auto diff = raft::abs(x - y); + const auto add = raft::abs(x) + raft::abs(y); + // deal with potential for 0 in denominator by + // forcing 0/1 instead + acc += ((add != 0) * diff / (add + (add == 0))); + } }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/correlation.cuh b/cpp/src/distance/detail/distance_ops/correlation.cuh index f033f3dfa..810d3f90c 100644 --- a/cpp/src/distance/detail/distance_ops/correlation.cuh +++ b/cpp/src/distance/detail/distance_ops/correlation.cuh @@ -18,6 +18,8 @@ #include // DI +#include + namespace cuvs::distance::detail::ops { /** @brief The correlation distance @@ -34,20 +36,20 @@ struct correlation_distance_op { using AccT = AccType; using IdxT = IdxType; - const DataT* x2n; - const DataT* y2n; + const AccT* x2n; + const AccT* y2n; IdxT m; IdxT n; IdxT k; correlation_distance_op( - bool is_row_major, const DataT* x2n_, const DataT* y2n_, IdxT m_, IdxT n_, IdxT k_) noexcept + bool is_row_major, const AccT* x2n_, const AccT* y2n_, IdxT m_, IdxT n_, IdxT k_) noexcept : x2n(x2n_), y2n(y2n_), m(m_), n(n_), k(k_) { // The distance op is typically created before the row-major/col-major // swapping has been done. So we do it here. if (!is_row_major) { - std::swap(x2n, y2n); + std::swap(x2n, y2n); std::swap(m, n); } } @@ -63,15 +65,18 @@ struct correlation_distance_op { template static constexpr size_t shared_mem_size() { - return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(AccT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += raft::to_float(x) * raft::to_float(y); + }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { @@ -80,23 +85,22 @@ struct correlation_distance_op { // changes, this will be where we find the bugs. extern __shared__ char smem[]; - DataT regx2n[Policy::AccRowsPerTh], regy2n[Policy::AccColsPerTh]; + AccT regx2n[Policy::AccRowsPerTh], regy2n[Policy::AccColsPerTh]; - DataT* sx2Norm = - (DataT*)(&smem[Policy::SmemSize + (Policy::Mblk + Policy::Nblk) * sizeof(DataT)]); - DataT* sy2Norm = (&sx2Norm[Policy::Mblk]); + AccT* sx2Norm = (AccT*)(&smem[Policy::SmemSize + (Policy::Mblk + Policy::Nblk) * sizeof(AccT)]); + AccT* sy2Norm = (&sx2Norm[Policy::Mblk]); // Load x & y norms required by this threadblock in shmem buffer if (gridStrideX == blockIdx.x * Policy::Nblk) { for (int i = threadIdx.x; i < Policy::Mblk; i += Policy::Nthreads) { auto idx = gridStrideY + i; - sx2Norm[i] = idx < m ? x2n[idx] : 0; + sx2Norm[i] = idx < m ? raft::to_float(x2n[idx]) : 0; } } for (int i = threadIdx.x; i < Policy::Nblk; i += Policy::Nthreads) { auto idx = gridStrideX + i; - sy2Norm[i] = idx < n ? y2n[idx] : 0; + sy2Norm[i] = idx < n ? raft::to_float(y2n[idx]) : 0; } __syncthreads(); diff --git a/cpp/src/distance/detail/distance_ops/cosine.cuh b/cpp/src/distance/detail/distance_ops/cosine.cuh index d48731651..b0a8b867c 100644 --- a/cpp/src/distance/detail/distance_ops/cosine.cuh +++ b/cpp/src/distance/detail/distance_ops/cosine.cuh @@ -18,17 +18,19 @@ #include // DI +#include + namespace cuvs::distance::detail::ops { // Epilogue operator for CUTLASS based kernel template struct cosine_cutlass_op { __device__ cosine_cutlass_op() noexcept {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + __device__ AccT operator()(AccT& aNorm, const AccT& bNorm, AccT& accVal) const noexcept { return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + __device__ AccT operator()(DataT aData) const noexcept { return raft::to_float(aData); } }; /** @@ -55,15 +57,22 @@ struct cosine_distance_op { template static constexpr size_t shared_mem_size() { - return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(AccT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + DI void core(AccT& acc, DataT& x, DataT& y) const + { + if constexpr ((std::is_same_v && std::is_same_v)) { + acc += __half2float(x) * __half2float(y); + } else { + acc += x * y; + } + }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { @@ -71,7 +80,11 @@ struct cosine_distance_op { for (int i = 0; i < Policy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); + if constexpr ((std::is_same_v && std::is_same_v)) { + acc[i][j] = 1.0 - (acc[i][j] / (__half2float(regxn[i]) * __half2float(regyn[j]))); + } else { + acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); + } } } } diff --git a/cpp/src/distance/detail/distance_ops/hamming.cuh b/cpp/src/distance/detail/distance_ops/hamming.cuh index 7c6553f38..8548df752 100644 --- a/cpp/src/distance/detail/distance_ops/hamming.cuh +++ b/cpp/src/distance/detail/distance_ops/hamming.cuh @@ -54,12 +54,12 @@ struct hamming_distance_op { template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { - const DataT one_over_k = DataT(1.0) / k; + const AccT one_over_k = AccT(1.0) / k; #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { #pragma unroll diff --git a/cpp/src/distance/detail/distance_ops/hellinger.cuh b/cpp/src/distance/detail/distance_ops/hellinger.cuh index ad5ca3156..5d9dd2259 100644 --- a/cpp/src/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/src/distance/detail/distance_ops/hellinger.cuh @@ -50,14 +50,14 @@ struct hellinger_distance_op { DI void core(AccT& acc, DataT& x, DataT& y) const { // This is sqrt(x) * sqrt(y). - const auto product = x * y; + const AccT product = raft::to_float(x) * raft::to_float(y); acc += product; }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/src/distance/detail/distance_ops/jensen_shannon.cuh index 216639494..8f8324ed0 100644 --- a/cpp/src/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/src/distance/detail/distance_ops/jensen_shannon.cuh @@ -52,19 +52,21 @@ struct jensen_shannon_distance_op { DI void core(AccT& acc, DataT& x, DataT& y) const { - const DataT m = 0.5f * (x + y); + AccT x_ = raft::to_float(x); + AccT y_ = raft::to_float(y); + const AccT m = 0.5f * (x_ + y_); const bool m_zero = (m == 0); const auto logM = (!m_zero) * raft::log(m + m_zero); - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); + const bool x_zero = (x_ == 0); + const bool y_zero = (y_ == 0); + acc += (-x_ * (logM - raft::log(x_ + x_zero))) + (-y_ * (logM - raft::log(y_ + y_zero))); }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/kl_divergence.cuh b/cpp/src/distance/detail/distance_ops/kl_divergence.cuh index 929c3a559..7f183b159 100644 --- a/cpp/src/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/src/distance/detail/distance_ops/kl_divergence.cuh @@ -59,31 +59,33 @@ struct kl_divergence_op { { // TODO: make sure that these branches get hoisted out of main loop.. Could // be quite expensive otherwise. + AccT x_ = raft::to_float(x); + AccT y_ = raft::to_float(y); if (x_equal_y) { if (is_row_major) { - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); + const bool x_zero = (x_ == 0); + const bool y_zero = (y_ == 0); + acc += x_ * (raft::log(x_ + x_zero) - (!y_zero) * raft::log(y_ + y_zero)); } else { - const bool y_zero = (y == 0); - const bool x_zero = (x == 0); - acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); + const bool y_zero = (y_ == 0); + const bool x_zero = (x_ == 0); + acc += y_ * (raft::log(y_ + y_zero) - (!x_zero) * raft::log(x_ + x_zero)); } } else { if (is_row_major) { - const bool x_zero = (x == 0); - acc += x * (raft::log(x + x_zero) - y); + const bool x_zero = (x_ == 0); + acc += x_ * (raft::log(x_ + x_zero) - y_); } else { - const bool y_zero = (y == 0); - acc += y * (raft::log(y + y_zero) - x); + const bool y_zero = (y_ == 0); + acc += y_ * (raft::log(y_ + y_zero) - x_); } } }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/l1.cuh b/cpp/src/distance/detail/distance_ops/l1.cuh index 76eaffaf3..278702ea3 100644 --- a/cpp/src/distance/detail/distance_ops/l1.cuh +++ b/cpp/src/distance/detail/distance_ops/l1.cuh @@ -46,12 +46,15 @@ struct l1_distance_op { return Policy::SmemSize; } - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += raft::abs(x - y); }; + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += raft::abs(raft::to_float(x) - raft::to_float(y)); + }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/l2_exp.cuh b/cpp/src/distance/detail/distance_ops/l2_exp.cuh index f45c41206..04817aa8b 100644 --- a/cpp/src/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/src/distance/detail/distance_ops/l2_exp.cuh @@ -19,6 +19,8 @@ #include #include // DI +#include + namespace cuvs::distance::detail::ops { /** @@ -52,8 +54,8 @@ struct l2_exp_cutlass_op { * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal) * can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead. */ - outVal = outVal * !((outVal * outVal < get_clamp_precision()) * (aNorm == bNorm)); - return sqrt ? raft::sqrt(outVal * (outVal > 0)) : outVal; + outVal = outVal * AccT(!((outVal * outVal < get_clamp_precision()) * (aNorm == bNorm))); + return sqrt ? raft::sqrt(outVal * static_cast(outVal > AccT(0))) : outVal; } __device__ AccT operator()(DataT aData) const noexcept { return aData; } @@ -88,15 +90,22 @@ struct l2_exp_distance_op { template static constexpr size_t shared_mem_size() { - return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(AccT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + DI void core(AccT& acc, DataT& x, DataT& y) const + { + if constexpr ((std::is_same_v && std::is_same_v)) { + acc += __half2float(x) * __half2float(y); + } else { + acc += x * y; + } + }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { @@ -104,8 +113,8 @@ struct l2_exp_distance_op { for (int i = 0; i < Policy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < Policy::AccColsPerTh; ++j) { - DataT accVal = acc[i][j]; - DataT val = regxn[i] + regyn[j] - (DataT)2.0 * accVal; + AccT accVal = acc[i][j]; + AccT val = regxn[i] + regyn[j] - (AccT)2.0 * accVal; /** * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product @@ -113,7 +122,8 @@ struct l2_exp_distance_op { * instead. */ acc[i][j] = - val * (val > 0) * !((val * val < get_clamp_precision()) * (regxn[i] == regyn[j])); + val * static_cast((val > AccT(0))) * + static_cast(!((val * val < get_clamp_precision()) * (regxn[i] == regyn[j]))); } } if (sqrt) { diff --git a/cpp/src/distance/detail/distance_ops/l2_unexp.cuh b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh index aa6cc27f3..f12820d8e 100644 --- a/cpp/src/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh @@ -18,6 +18,8 @@ #include // DI +#include + namespace cuvs::distance::detail::ops { /** @@ -53,14 +55,19 @@ struct l2_unexp_distance_op { DI void core(AccT& acc, DataT& x, DataT& y) const { - const auto diff = x - y; - acc += diff * diff; + if constexpr ((std::is_same_v && std::is_same_v)) { + const auto diff = __half2float(x) - __half2float(y); + acc += diff * diff; + } else { + const auto diff = x - y; + acc += diff * diff; + } }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/l_inf.cuh b/cpp/src/distance/detail/distance_ops/l_inf.cuh index d8f9384d7..d8559a7d1 100644 --- a/cpp/src/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/src/distance/detail/distance_ops/l_inf.cuh @@ -55,8 +55,8 @@ struct l_inf_distance_op { template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/lp_unexp.cuh b/cpp/src/distance/detail/distance_ops/lp_unexp.cuh index 6136f9f3e..2adf0befa 100644 --- a/cpp/src/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/src/distance/detail/distance_ops/lp_unexp.cuh @@ -53,18 +53,18 @@ struct lp_unexp_distance_op { DI void core(AccT& acc, DataT& x, DataT& y) const { - const auto diff = raft::abs(x - y); - acc += raft::pow(diff, p); + const AccT diff = raft::abs(raft::to_float(x) - raft::to_float(y)); + acc += raft::pow(diff, raft::to_float(p)); }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { - const auto one_over_p = 1.0f / p; + const AccT one_over_p = 1.0f / static_cast(raft::to_float(p)); #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { #pragma unroll diff --git a/cpp/src/distance/detail/distance_ops/russel_rao.cuh b/cpp/src/distance/detail/distance_ops/russel_rao.cuh index 5dffdcdb8..4988c7353 100644 --- a/cpp/src/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/src/distance/detail/distance_ops/russel_rao.cuh @@ -52,12 +52,15 @@ struct russel_rao_distance_op { return Policy::SmemSize; } - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += raft::to_float(x) * raft::to_float(y); + }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/distance_ops/template.cuh b/cpp/src/distance/detail/distance_ops/template.cuh index bdb933237..cb26e210d 100644 --- a/cpp/src/distance/detail/distance_ops/template.cuh +++ b/cpp/src/distance/detail/distance_ops/template.cuh @@ -52,8 +52,8 @@ struct template_distance_op { template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, + AccT* regxn, + AccT* regyn, IdxT gridStrideX, IdxT gridStrideY) const { diff --git a/cpp/src/distance/detail/masked_distance_base.cuh b/cpp/src/distance/detail/masked_distance_base.cuh index 2c41ee3be..ec7270baa 100644 --- a/cpp/src/distance/detail/masked_distance_base.cuh +++ b/cpp/src/distance/detail/masked_distance_base.cuh @@ -266,7 +266,7 @@ struct MaskedDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; + acc[i][j] = BaseClass::Zero(); } } } diff --git a/cpp/src/distance/detail/pairwise_distance_base.cuh b/cpp/src/distance/detail/pairwise_distance_base.cuh index 990f845fd..72d75ec12 100644 --- a/cpp/src/distance/detail/pairwise_distance_base.cuh +++ b/cpp/src/distance/detail/pairwise_distance_base.cuh @@ -72,8 +72,8 @@ struct PairwiseDistances : public BaseClass { private: typedef Policy P; - const DataT* xn; - const DataT* yn; + const OutT* xn; + const OutT* yn; const DataT* const yBase; OutT* dOutput; char* smem; @@ -99,8 +99,8 @@ struct PairwiseDistances : public BaseClass { IdxT _lda, IdxT _ldb, IdxT _ldd, - const DataT* _xn, - const DataT* _yn, + const OutT* _xn, + const OutT* _yn, OutT* _dOutput, char* _smem, OpT _distance_op, @@ -154,7 +154,7 @@ struct PairwiseDistances : public BaseClass { // Epilog: if (distance_op.use_norms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + OutT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); @@ -200,7 +200,7 @@ struct PairwiseDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; + acc[i][j] = BaseClass::Zero(); } } } @@ -242,23 +242,23 @@ struct PairwiseDistances : public BaseClass { DI void load_norms(IdxT tile_idx_m, IdxT tile_idx_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) + OutT (®xn)[P::AccRowsPerTh], + OutT (®yn)[P::AccColsPerTh]) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); + OutT* sxNorm = (OutT*)(&smem[P::SmemSize]); + OutT* syNorm = (&sxNorm[P::Mblk]); // Load x & y norms required by this threadblock in shmem buffer if (tile_idx_n == blockIdx.x * P::Nblk) { for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; + sxNorm[i] = idx < this->m ? xn[idx] : OutT(0); } } for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { auto idx = tile_idx_n + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + syNorm[i] = idx < this->n ? yn[idx] : OutT(0); } __syncthreads(); @@ -285,7 +285,7 @@ struct PairwiseDistances : public BaseClass { auto colId = startx + j * P::AccThCols; if (rowId < this->m && colId < this->n) { // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], AccT(0)); } } } diff --git a/cpp/src/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/src/distance/detail/pairwise_distance_cutlass_base.cuh index d41c5d30c..d4d86d7f4 100644 --- a/cpp/src/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/src/distance/detail/pairwise_distance_cutlass_base.cuh @@ -58,8 +58,8 @@ template std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, const DataT* y, - const DataT* xn, - const DataT* yn, + const OutT* xn, + const OutT* yn, IdxT m, IdxT n, IdxT k, @@ -77,12 +77,12 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da auto dist_op = distance_op.get_cutlass_op(); using DistanceFn = decltype(dist_op); using EpilogueOutputOp = - epilogue::thread::PairwiseDistanceEpilogueElementwise; constexpr int batch_count = 1; @@ -143,13 +143,13 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da epilog_op_param, a, b, - xn, // C matrix eq vector param, which here is A norm - nullptr, // tensor_Z, - (DataT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param - dOutput + offsetN, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t)0, // batch stride B - (int64_t)0, // batch stride Norm A + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (OutT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param + dOutput + offsetN, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A (int64_t)0, (int64_t)0, // batch stride Norm B (int64_t)0, // batch stride Output diff --git a/cpp/src/distance/detail/pairwise_distance_gemm.h b/cpp/src/distance/detail/pairwise_distance_gemm.h index 6ac13f27b..a74613812 100644 --- a/cpp/src/distance/detail/pairwise_distance_gemm.h +++ b/cpp/src/distance/detail/pairwise_distance_gemm.h @@ -19,11 +19,15 @@ #include "./pairwise_distance_epilogue.h" #include +#include +#include #include #include #include #include +#include + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cuvs { @@ -235,8 +239,105 @@ struct PairwiseDistanceGemm; }; +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct PairwiseDistanceGemm { + // using Transform = cutlass::ComplexTransform::kNone; + // Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 32, N = 32, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename cutlass::gemm::device::GemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cuvs::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = cutlass::gemm::kernel:: + GemmWithFusedEpilogue; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel } // namespace gemm -} // namespace cuvs \ No newline at end of file +} // namespace cuvs diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh index bc8189c70..3107f0fa4 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -38,8 +38,8 @@ void pairwise_matrix_dispatch(OpT distance_op, IdxT k, const DataT* x, const DataT* y, - const DataT* x_norm, - const DataT* y_norm, + const OutT* x_norm, + const OutT* y_norm, OutT* out, FinOpT fin_op, cudaStream_t stream, @@ -47,9 +47,9 @@ void pairwise_matrix_dispatch(OpT distance_op, }; // namespace cuvs::distance::detail -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY +#endif // CUVS_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ +#define instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ OpT, DataT, AccT, OutT, FinOpT, IdxT) \ extern template void cuvs::distance::detail:: \ pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ @@ -59,136 +59,70 @@ void pairwise_matrix_dispatch(OpT distance_op, IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ bool is_row_major) +#define instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default(OpT, IdxT) \ + instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ + OpT, float, float, float, raft::identity_op, IdxT); \ + instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ + OpT, double, double, double, raft::identity_op, IdxT); \ + instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ + OpT, half, float, float, raft::identity_op, IdxT); + +#define instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo(OpT, IdxT, FinOpT) \ + instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ + OpT, float, float, float, FinOpT, IdxT); \ + instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ + OpT, double, double, double, FinOpT, IdxT); \ + instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ + OpT, half, float, float, FinOpT, IdxT); + /* * Hierarchy of instantiations: * * This file defines extern template instantiations of the distance kernels. The - * instantiation of the public API is handled in raft/distance/distance-ext.cuh. + * instantiation of the public API is handled in cuvs/distance/distance-ext.cuh. * * After adding an instance here, make sure to also add the instance there. */ // The following two instances are used in the RBF kernel object. Note the use of int64_t for the // index type. -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_unexp_distance_op, - float, - float, - float, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_unexp_distance_op, - double, - double, - double, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); - -// Rest of instances -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::canberra_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::correlation_distance_op, - float, - float, - float, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::correlation_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hellinger_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::jensen_shannon_distance_op, - float, - float, - float, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::jensen_shannon_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::canberra_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::correlation_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::cosine_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::hamming_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::hellinger_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::jensen_shannon_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::kl_divergence_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::l1_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::l2_exp_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::l2_unexp_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::l_inf_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::lp_unexp_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::russel_rao_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo( cuvs::distance::detail::ops::l2_unexp_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::lp_unexp_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::russel_rao_distance_op, - double, - double, - double, - raft::identity_op, - int); + int64_t, + cuvs::distance::kernels::detail::rbf_fin_op); -#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch +#undef instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo +#undef instantiate_cuvs_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh index fc3849feb..96d7c265e 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -72,8 +72,8 @@ void pairwise_matrix_dispatch(OpT distance_op, IdxT k, const DataT* x, const DataT* y, - const DataT* x_norm, - const DataT* y_norm, + const OutT* x_norm, + const OutT* y_norm, OutT* out, FinOpT fin_op, cudaStream_t stream, @@ -113,7 +113,13 @@ void pairwise_matrix_dispatch(OpT distance_op, void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - if (cutlass_range.contains(runtime_arch)) { + // TODO: the cutlass doesn't support the odd `k` on half DataT. + bool if_unsupported_on_half = (sizeof(DataT) == 2) && ((k % 2) != 0); + + if (if_unsupported_on_half) { + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); + } else if (cutlass_range.contains(runtime_arch) && !if_unsupported_on_half) { // If device is SM_80 or later, use CUTLASS-based kernel. pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); } else { diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch.cuh index 06b039c3a..0521a5713 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch.cuh @@ -19,6 +19,4 @@ #include "dispatch-inl.cuh" #endif -#ifdef RAFT_COMPILED #include "dispatch-ext.cuh" -#endif diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py index e013db1e1..1bd51aef9 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -56,8 +56,8 @@ IdxT k, \\ const DataT* x, \\ const DataT* y, \\ - const DataT* x_norm, \\ - const DataT* y_norm, \\ + const OutT* x_norm, \\ + const OutT* y_norm, \\ OutT* out, \\ FinOpT fin_op, \\ cudaStream_t stream, \\ @@ -77,6 +77,12 @@ OutT="double", IdxT="int", ), + dict( + DataT="half", + AccT="float", + OutT="float", + IdxT="int", + ), ] op_instances = [ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu index f82df6cc0..c2c44dc53 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu index a20ca5f47..00099dcae 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu new file mode 100644 index 000000000..0b70f2341 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::canberra_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu index 7bb7e4a96..d9b6feb9c 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu index 34fcc4be4..dfb6f62a8 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_half_float_float_int.cu new file mode 100644 index 000000000..b2c959b55 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::correlation_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu index cb23743c1..d7046d4e2 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu @@ -38,8 +38,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu index ad71ff295..215805dde 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu @@ -38,8 +38,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_half_float_float_int.cu new file mode 100644 index 000000000..5f9928958 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_half_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::cosine_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu index e81d54411..d558604d8 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu index ddbdab602..632523194 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_half_float_float_int.cu new file mode 100644 index 000000000..76360ebd8 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hamming_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu index d2acecaf0..707d9c08d 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu index 034d76679..7dceab56c 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_half_float_float_int.cu new file mode 100644 index 000000000..35e7adf06 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hellinger_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu index 030faeecd..e3e074479 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu index f7551a566..6eff40550 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_half_float_float_int.cu new file mode 100644 index 000000000..24302c6e6 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_half_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::jensen_shannon_distance_op, + half, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu index 6640d3949..4f45adf27 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu index 60cafa474..b2cac754f 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_half_float_float_int.cu new file mode 100644 index 000000000..9347a026c --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::kl_divergence_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu index 8f6e3a35d..82d9d1abe 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu index 73868a486..ad5f06048 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_half_float_float_int.cu new file mode 100644 index 000000000..99043958f --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l1_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu index 8ac80b77d..b2911f16b 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu @@ -38,8 +38,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu index abebb9121..93a416643 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu @@ -38,8 +38,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_half_float_float_int.cu new file mode 100644 index 000000000..e30499ae7 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_half_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_exp_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu index ffa6bf02b..eecab9ec4 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu index acef42a4e..9f58f5f85 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_half_float_float_int.cu new file mode 100644 index 000000000..73e9352ee --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu index c2bbbf06b..812dda450 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu index 163b9f37b..f95dd7a87 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_half_float_float_int.cu new file mode 100644 index 000000000..ba9c976c4 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l_inf_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu index d13532ac6..a88875f3a 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu index 65e0163d7..b8b3775d1 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_half_float_float_int.cu new file mode 100644 index 000000000..ef323f57d --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::lp_unexp_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu index 23f2b34e8..1cb0ed8ae 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu @@ -38,8 +38,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ @@ -61,4 +61,12 @@ instantiate_raft_distance_detail_pairwise_matrix_dispatch( cuvs::distance::kernels::detail::rbf_fin_op, int64_t); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, + half, + float, + float, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + #undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu index 1a5e5cf98..1afbd690e 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu index a9b1f6bb4..217f61b84 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu @@ -37,8 +37,8 @@ IdxT k, \ const DataT* x, \ const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ + const OutT* x_norm, \ + const OutT* y_norm, \ OutT* out, \ FinOpT fin_op, \ cudaStream_t stream, \ diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu new file mode 100644 index 000000000..c65fb24f6 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::russel_rao_distance_op, half, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/params.cuh b/cpp/src/distance/detail/pairwise_matrix/params.cuh index aa419aca0..739c4a9f6 100644 --- a/cpp/src/distance/detail/pairwise_matrix/params.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/params.cuh @@ -27,8 +27,8 @@ struct pairwise_matrix_params { IdxT ld_out; const DataT* x; const DataT* y; - const DataT* x_norm; - const DataT* y_norm; + const OutT* x_norm; + const OutT* y_norm; OutT* out; FinOpT fin_op; bool is_row_major; diff --git a/cpp/src/distance/distance-ext.cuh b/cpp/src/distance/distance-ext.cuh index 148951afa..e7fa30f03 100644 --- a/cpp/src/distance/distance-ext.cuh +++ b/cpp/src/distance/distance-ext.cuh @@ -24,6 +24,8 @@ #include // rmm::device_uvector +#include + #ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY namespace cuvs { @@ -45,8 +47,8 @@ void distance(raft::resources const& handle, void* workspace, size_t worksize, FinalLambda fin_op, - bool isRowMajor = true, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; + bool isRowMajor = true, + OutT metric_arg = 2.0f) RAFT_EXPLICIT; template +template void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, - Type* dist, + DistT* dist, IdxT m, IdxT n, IdxT k, rmm::device_uvector& workspace, cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) RAFT_EXPLICIT; + bool isRowMajor = true, + DistT metric_arg = DistT(2.0f)) RAFT_EXPLICIT; -template +template void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, - Type* dist, + DistT* dist, IdxT m, IdxT n, IdxT k, cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) RAFT_EXPLICIT; + bool isRowMajor = true, + DistT metric_arg = DistT(2.0f)) RAFT_EXPLICIT; template void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, + raft::device_matrix_view const x, + raft::device_matrix_view const y, raft::device_matrix_view dist, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; + OutT metric_arg = 2.0f) RAFT_EXPLICIT; -template +template void pairwise_distance(raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, cuvs::distance::DistanceType metric, - Type metric_arg = 2.0f) RAFT_EXPLICIT; + DistT metric_arg = DistT(2.0f)) RAFT_EXPLICIT; }; // namespace distance }; // namespace cuvs -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY +#endif // CUVS_EXPLICIT_INSTANTIATE_ONLY /* * Hierarchy of instantiations: @@ -158,909 +163,220 @@ void pairwise_distance(raft::resources const& handle, * dispatch-ext.cuh and the corresponding .cu files. */ -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - FinalLambda fin_op, \ - bool isRowMajor, \ - DataT metric_arg) +#define instantiate_cuvs_distance_distance(DistT, DataT, AccT, OutT, IdxT) \ + extern template void \ + cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + raft::identity_op fin_op, \ + bool isRowMajor, \ + OutT metric_arg); \ + \ + extern template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + OutT metric_arg); \ + \ + extern template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + OutT metric_arg); \ + \ + extern template void \ + cuvs::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + OutT metric_arg); \ + \ + extern template void \ + cuvs::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + OutT metric_arg) + +#define instantiate_cuvs_distance_distance_by_algo(DistT) \ + instantiate_cuvs_distance_distance(DistT, float, float, float, int); \ + instantiate_cuvs_distance_distance(DistT, double, double, double, int); \ + instantiate_cuvs_distance_distance(DistT, half, float, float, int) + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::Canberra); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::CorrelationExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::CosineExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::HammingUnexpanded); + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::HellingerExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::InnerProduct); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::JensenShannon); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::KLDivergence); + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L1); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2Expanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2SqrtExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2SqrtUnexpanded); + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2Unexpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::Linf); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::LpUnexpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::RusselRaoExpanded); + +#undef instantiate_cuvs_distance_distance_by_algo +#undef instantiate_cuvs_distance_distance // The following two instances are used in test/distance/gram.cu. Note the use // of int64_t for the index type. -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - float, - float, - float, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); - -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); - -#undef instantiate_raft_distance_distance - -// Same, but without raft::identity_op -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - bool isRowMajor, \ +#define instantiate_cuvs_distance_distance_extra(DistT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ DataT metric_arg) -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -// Same, but without workspace -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ - extern template size_t cuvs::distance::getWorkspaceSize( \ - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) - -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ - extern template size_t cuvs::distance::getWorkspaceSize( \ - raft::device_matrix_view const& x, \ - raft::device_matrix_view const& y) - -// We could consider not taking template parameters for this function. The -// number of instantiations seems a bit excessive.. -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ +instantiate_cuvs_distance_distance_extra(cuvs::distance::DistanceType::L2Unexpanded, + float, + float, + float, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_cuvs_distance_distance_extra(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + +#undef instantiate_cuvs_distance_distance_extra + +#define instantiate_cuvs_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t cuvs::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k); \ + \ + extern template size_t \ + cuvs::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y); \ + \ + extern template size_t \ + cuvs::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +#define instantiate_cuvs_distance_getWorkspaceSize_by_algo(DistT) \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, float, float, float, int); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, double, double, double, int); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, half, float, float, int); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, float, float, float, int64_t); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, double, double, double, int64_t); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, half, float, float, int64_t) + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::Canberra); +instantiate_cuvs_distance_getWorkspaceSize_by_algo( + cuvs::distance::DistanceType::CorrelationExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::CosineExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::HammingUnexpanded); + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::HellingerExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::InnerProduct); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::JensenShannon); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::KLDivergence); + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L1); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2Expanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2SqrtExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2SqrtUnexpanded); + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2Unexpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::Linf); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::LpUnexpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::RusselRaoExpanded); + +#undef instantiate_cuvs_distance_getWorkspaceSize_by_algo +#undef instantiate_cuvs_distance_getWorkspaceSize + +#define instantiate_cuvs_distance_pairwise_distance(DataT, IdxT, DistT) \ extern template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ const DataT* x, \ const DataT* y, \ - DataT* dist, \ + DistT* dist, \ IdxT m, \ IdxT n, \ IdxT k, \ rmm::device_uvector& workspace, \ cuvs::distance::DistanceType metric, \ bool isRowMajor, \ - DataT metric_arg) + DistT metric_arg) -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); +instantiate_cuvs_distance_pairwise_distance(float, int, float); +instantiate_cuvs_distance_pairwise_distance(double, int, double); +instantiate_cuvs_distance_pairwise_distance(half, int, float); -#undef instantiate_raft_distance_pairwise_distance +#undef instantiate_cuvs_distance_pairwise_distance // Same, but without workspace -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ +#define instantiate_cuvs_distance_pairwise_distance(DataT, IdxT, DistT) \ extern template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ const DataT* x, \ const DataT* y, \ - DataT* dist, \ + DistT* dist, \ IdxT m, \ IdxT n, \ IdxT k, \ cuvs::distance::DistanceType metric, \ bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); - -#undef instantiate_raft_distance_pairwise_distance - -// Version with mdspan -#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - DataT metric_arg) - -// Again, we might want to consider reigning in the number of instantiations... -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ - extern template void cuvs::distance::pairwise_distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - cuvs::distance::DistanceType metric, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); - -#undef instantiate_raft_distance_pairwise_distance + DistT metric_arg) + +instantiate_cuvs_distance_pairwise_distance(float, int, float); +instantiate_cuvs_distance_pairwise_distance(double, int, double); +instantiate_cuvs_distance_pairwise_distance(half, int, float); + +#undef instantiate_cuvs_distance_pairwise_distance + +#define instantiate_cuvs_distance_pairwise_distance(DataT, layout, IdxT, DistT) \ + extern template void cuvs::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + cuvs::distance::DistanceType metric, \ + DistT metric_arg) + +instantiate_cuvs_distance_pairwise_distance(float, raft::layout_c_contiguous, int, float); +instantiate_cuvs_distance_pairwise_distance(float, raft::layout_f_contiguous, int, float); +instantiate_cuvs_distance_pairwise_distance(double, raft::layout_c_contiguous, int, double); +instantiate_cuvs_distance_pairwise_distance(double, raft::layout_f_contiguous, int, double); +instantiate_cuvs_distance_pairwise_distance(half, raft::layout_c_contiguous, int, float); +instantiate_cuvs_distance_pairwise_distance(half, raft::layout_f_contiguous, int, float); + +#undef instantiate_cuvs_distance_pairwise_distance diff --git a/cpp/src/distance/distance-inl.cuh b/cpp/src/distance/distance-inl.cuh index 6236901c3..e047d3144 100644 --- a/cpp/src/distance/distance-inl.cuh +++ b/cpp/src/distance/distance-inl.cuh @@ -75,8 +75,8 @@ void distance(raft::resources const& handle, void* workspace, size_t worksize, FinalLambda fin_op, - bool isRowMajor = true, - DataT metric_arg = 2.0f) + bool isRowMajor = true, + OutT metric_arg = 2.0f) { detail::distance( handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); @@ -115,8 +115,8 @@ void distance(raft::resources const& handle, IdxT k, void* workspace, size_t worksize, - bool isRowMajor = true, - DataT metric_arg = 2.0f) + bool isRowMajor = true, + OutT metric_arg = 2.0f) { detail::distance( handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); @@ -206,8 +206,8 @@ void distance(raft::resources const& handle, IdxT m, IdxT n, IdxT k, - bool isRowMajor = true, - DataT metric_arg = 2.0f) + bool isRowMajor = true, + OutT metric_arg = 2.0f) { auto stream = raft::resource::get_cuda_stream(handle); rmm::device_uvector workspace(0, stream); @@ -222,6 +222,7 @@ void distance(raft::resources const& handle, * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type * @tparam IdxT indexing type + * @tparam DistT output type, equal to Type by default * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points @@ -235,25 +236,25 @@ void distance(raft::resources const& handle, * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, - Type* dist, + DistT* dist, IdxT m, IdxT n, IdxT k, rmm::device_uvector& workspace, cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) + bool isRowMajor = true, + DistT metric_arg = 2.0f) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto dispatch = [&](auto distance_type) { - auto worksize = getWorkspaceSize(x, y, m, n, k); + auto worksize = getWorkspaceSize(x, y, m, n, k); workspace.resize(worksize, stream); - detail::distance( + detail::distance( handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); }; @@ -315,6 +316,7 @@ void pairwise_distance(raft::resources const& handle, * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type * @tparam IdxT indexing type + * @tparam DistT output type, equal to Type by default * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points @@ -326,21 +328,21 @@ void pairwise_distance(raft::resources const& handle, * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, - Type* dist, + DistT* dist, IdxT m, IdxT n, IdxT k, cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) + bool isRowMajor = true, + DistT metric_arg = 2.0f) { auto stream = raft::resource::get_cuda_stream(handle); rmm::device_uvector workspace(0, stream); - pairwise_distance( + pairwise_distance( handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); } @@ -397,7 +399,7 @@ void distance(raft::resources const& handle, raft::device_matrix_view const x, raft::device_matrix_view const y, raft::device_matrix_view dist, - DataT metric_arg = 2.0f) + OutT metric_arg = 2.0f) { RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); RAFT_EXPECTS(dist.extent(0) == x.extent(0), @@ -428,6 +430,7 @@ void distance(raft::resources const& handle, * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type * @tparam IdxT indexing type + * @tparam DistT output type, equal to Type by default * @param handle raft handle for managing expensive resources * @param x first matrix of points (size mxk) * @param y second matrix of points (size nxk) @@ -435,13 +438,16 @@ void distance(raft::resources const& handle, * @param metric distance metric * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void pairwise_distance(raft::resources const& handle, raft::device_matrix_view const x, raft::device_matrix_view const y, - raft::device_matrix_view dist, + raft::device_matrix_view dist, cuvs::distance::DistanceType metric, - Type metric_arg = 2.0f) + DistT metric_arg = DistT(2.0f)) { RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); RAFT_EXPECTS(dist.extent(0) == x.extent(0), diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu index 02c071d13..72be93f10 100644 --- a/cpp/src/distance/distance.cu +++ b/cpp/src/distance/distance.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,910 +25,219 @@ * kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu. * */ - -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ - template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - FinalLambda fin_op, \ - bool isRowMajor, \ - DataT metric_arg) +#define instantiate_cuvs_distance_distance(DistT, DataT, AccT, OutT, IdxT) \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + raft::identity_op fin_op, \ + bool isRowMajor, \ + OutT metric_arg); \ + \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + OutT metric_arg); \ + \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + OutT metric_arg); \ + \ + template void \ + cuvs::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + OutT metric_arg); \ + \ + template void \ + cuvs::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + OutT metric_arg) + +#define instantiate_cuvs_distance_distance_by_algo(DistT) \ + instantiate_cuvs_distance_distance(DistT, float, float, float, int); \ + instantiate_cuvs_distance_distance(DistT, double, double, double, int); \ + instantiate_cuvs_distance_distance(DistT, half, float, float, int) + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::Canberra); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::CorrelationExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::CosineExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::HammingUnexpanded); + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::HellingerExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::InnerProduct); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::JensenShannon); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::KLDivergence); + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L1); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2Expanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2SqrtExpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2SqrtUnexpanded); + +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::L2Unexpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::Linf); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::LpUnexpanded); +instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::RusselRaoExpanded); + +#undef instantiate_cuvs_distance_distance_by_algo +#undef instantiate_cuvs_distance_distance // The following two instances are used in test/distance/gram.cu. Note the use // of int64_t for the index type. -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - float, - float, - float, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); - -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); - -#undef instantiate_raft_distance_distance - -// Same, but without raft::identity_op -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -// Same, but without workspace -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - bool isRowMajor, \ +#define instantiate_cuvs_distance_distance_extra(DistT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ DataT metric_arg) -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ - template size_t cuvs::distance::getWorkspaceSize( \ - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) - -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ - template size_t cuvs::distance::getWorkspaceSize( \ - raft::device_matrix_view const& x, \ - raft::device_matrix_view const& y) - -// We could consider not taking template parameters for this function. The -// number of instantiations seems a bit excessive.. -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ +instantiate_cuvs_distance_distance_extra(cuvs::distance::DistanceType::L2Unexpanded, + float, + float, + float, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_cuvs_distance_distance_extra(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + +#undef instantiate_cuvs_distance_distance_extra + +#define instantiate_cuvs_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t cuvs::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k); \ + \ + template size_t \ + cuvs::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y); \ + \ + template size_t \ + cuvs::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +#define instantiate_cuvs_distance_getWorkspaceSize_by_algo(DistT) \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, float, float, float, int); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, double, double, double, int); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, half, float, float, int); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, float, float, float, int64_t); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, double, double, double, int64_t); \ + instantiate_cuvs_distance_getWorkspaceSize(DistT, half, float, float, int64_t) + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::Canberra); +instantiate_cuvs_distance_getWorkspaceSize_by_algo( + cuvs::distance::DistanceType::CorrelationExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::CosineExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::HammingUnexpanded); + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::HellingerExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::InnerProduct); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::JensenShannon); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::KLDivergence); + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L1); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2Expanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2SqrtExpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2SqrtUnexpanded); + +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::L2Unexpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::Linf); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::LpUnexpanded); +instantiate_cuvs_distance_getWorkspaceSize_by_algo(cuvs::distance::DistanceType::RusselRaoExpanded); + +#undef instantiate_cuvs_distance_getWorkspaceSize_by_algo +#undef instantiate_cuvs_distance_getWorkspaceSize + +#define instantiate_cuvs_distance_pairwise_distance(DataT, IdxT, DistT) \ template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ const DataT* x, \ const DataT* y, \ - DataT* dist, \ + DistT* dist, \ IdxT m, \ IdxT n, \ IdxT k, \ rmm::device_uvector& workspace, \ cuvs::distance::DistanceType metric, \ bool isRowMajor, \ - DataT metric_arg) + DistT metric_arg) -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); +instantiate_cuvs_distance_pairwise_distance(float, int, float); +instantiate_cuvs_distance_pairwise_distance(double, int, double); +instantiate_cuvs_distance_pairwise_distance(half, int, float); -#undef instantiate_raft_distance_pairwise_distance +#undef instantiate_cuvs_distance_pairwise_distance // Same, but without workspace -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ +#define instantiate_cuvs_distance_pairwise_distance(DataT, IdxT, DistT) \ template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ const DataT* x, \ const DataT* y, \ - DataT* dist, \ + DistT* dist, \ IdxT m, \ IdxT n, \ IdxT k, \ cuvs::distance::DistanceType metric, \ bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); - -#undef instantiate_raft_distance_pairwise_distance - -// Version with mdspan -#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ - template void cuvs::distance::distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - DataT metric_arg) - -// Again, we might want to consider reigning in the number of instantiations... -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ - template void cuvs::distance::pairwise_distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - cuvs::distance::DistanceType metric, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); - -#undef instantiate_raft_distance_pairwise_distance + DistT metric_arg) + +instantiate_cuvs_distance_pairwise_distance(float, int, float); +instantiate_cuvs_distance_pairwise_distance(double, int, double); +instantiate_cuvs_distance_pairwise_distance(half, int, float); + +#undef instantiate_cuvs_distance_pairwise_distance + +#define instantiate_cuvs_distance_pairwise_distance(DataT, layout, IdxT, DistT) \ + template void cuvs::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + cuvs::distance::DistanceType metric, \ + DistT metric_arg) + +instantiate_cuvs_distance_pairwise_distance(float, raft::layout_c_contiguous, int, float); +instantiate_cuvs_distance_pairwise_distance(float, raft::layout_f_contiguous, int, float); +instantiate_cuvs_distance_pairwise_distance(double, raft::layout_c_contiguous, int, double); +instantiate_cuvs_distance_pairwise_distance(double, raft::layout_f_contiguous, int, double); +instantiate_cuvs_distance_pairwise_distance(half, raft::layout_c_contiguous, int, float); +instantiate_cuvs_distance_pairwise_distance(half, raft::layout_f_contiguous, int, float); + +#undef instantiate_cuvs_distance_pairwise_distance diff --git a/cpp/src/distance/distance.cuh b/cpp/src/distance/distance.cuh index b5bfc07cb..d1bfc8212 100644 --- a/cpp/src/distance/distance.cuh +++ b/cpp/src/distance/distance.cuh @@ -19,6 +19,4 @@ #include "distance-inl.cuh" #endif -#ifdef RAFT_COMPILED #include "distance-ext.cuh" -#endif diff --git a/cpp/src/distance/pairwise_distance.cu b/cpp/src/distance/pairwise_distance.cu index 10f096b96..a802ce91c 100644 --- a/cpp/src/distance/pairwise_distance.cu +++ b/cpp/src/distance/pairwise_distance.cu @@ -61,6 +61,24 @@ void pairwise_distance( handle, x_v, y_v, d_v, metric, metric_arg); } +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + auto x_v = raft::make_device_matrix_view( + x.data_handle(), x.extent(0), x.extent(1)); + auto y_v = raft::make_device_matrix_view( + y.data_handle(), y.extent(0), y.extent(1)); + auto d_v = raft::make_device_matrix_view( + dist.data_handle(), dist.extent(0), dist.extent(1)); + pairwise_distance( + handle, x_v, y_v, d_v, metric, metric_arg); +} + void pairwise_distance( raft::resources const& handle, raft::device_matrix_view const x, @@ -97,6 +115,24 @@ void pairwise_distance( handle, x_v, y_v, d_v, metric, metric_arg); } +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + auto x_v = raft::make_device_matrix_view( + x.data_handle(), x.extent(0), x.extent(1)); + auto y_v = raft::make_device_matrix_view( + y.data_handle(), y.extent(0), y.extent(1)); + auto d_v = raft::make_device_matrix_view( + dist.data_handle(), dist.extent(0), dist.extent(1)); + pairwise_distance( + handle, x_v, y_v, d_v, metric, metric_arg); +} + /** @} */ // end group pairwise_distance_runtime } // namespace cuvs::distance diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index ce21e2d39..c76feb015 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -21,12 +21,12 @@ #include namespace cuvs::neighbors::brute_force { -template -index::index(raft::resources const& res, - raft::host_matrix_view dataset, - std::optional>&& norms, - cuvs::distance::DistanceType metric, - T metric_arg) +template +index::index(raft::resources const& res, + raft::host_matrix_view dataset, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + DistT metric_arg) : cuvs::neighbors::index(), metric_(metric), dataset_(raft::make_device_matrix(res, 0, 0)), @@ -38,12 +38,12 @@ index::index(raft::resources const& res, raft::resource::sync_stream(res); } -template -index::index(raft::resources const& res, - raft::device_matrix_view dataset, - std::optional>&& norms, - cuvs::distance::DistanceType metric, - T metric_arg) +template +index::index(raft::resources const& res, + raft::device_matrix_view dataset, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + DistT metric_arg) : cuvs::neighbors::index(), metric_(metric), dataset_(raft::make_device_matrix(res, 0, 0)), @@ -54,12 +54,12 @@ index::index(raft::resources const& res, update_dataset(res, dataset); } -template -index::index(raft::resources const& res, - raft::device_matrix_view dataset_view, - std::optional> norms_view, - cuvs::distance::DistanceType metric, - T metric_arg) +template +index::index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional> norms_view, + cuvs::distance::DistanceType metric, + DistT metric_arg) : cuvs::neighbors::index(), metric_(metric), dataset_(raft::make_device_matrix(res, 0, 0)), @@ -69,12 +69,12 @@ index::index(raft::resources const& res, { } -template -index::index(raft::resources const& res, - raft::device_matrix_view dataset_view, - std::optional>&& norms, - cuvs::distance::DistanceType metric, - T metric_arg) +template +index::index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + DistT metric_arg) : cuvs::neighbors::index(), metric_(metric), dataset_( @@ -99,12 +99,12 @@ index::index(raft::resources const& res, dataset_view_ = raft::make_const_mdspan(dataset_.view()); } -template -index::index(raft::resources const& res, - raft::device_matrix_view dataset_view, - std::optional> norms_view, - cuvs::distance::DistanceType metric, - T metric_arg) +template +index::index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional> norms_view, + cuvs::distance::DistanceType metric, + DistT metric_arg) : cuvs::neighbors::index(), metric_(metric), dataset_( @@ -129,73 +129,74 @@ index::index(raft::resources const& res, dataset_view_ = raft::make_const_mdspan(dataset_.view()); } -template -void index::update_dataset(raft::resources const& res, - raft::device_matrix_view dataset) +template +void index::update_dataset( + raft::resources const& res, raft::device_matrix_view dataset) { dataset_view_ = dataset; } -template -void index::update_dataset(raft::resources const& res, - raft::host_matrix_view dataset) +template +void index::update_dataset( + raft::resources const& res, raft::host_matrix_view dataset) { dataset_ = raft::make_device_matrix(res, dataset.extent(0), dataset.extent(1)); raft::copy(res, dataset_.view(), dataset); dataset_view_ = raft::make_const_mdspan(dataset_.view()); } -#define CUVS_INST_BFKNN(T) \ +#define CUVS_INST_BFKNN(T, DistT) \ auto build(raft::resources const& res, \ raft::device_matrix_view dataset, \ cuvs::distance::DistanceType metric, \ - T metric_arg) \ - ->cuvs::neighbors::brute_force::index \ + DistT metric_arg) \ + ->cuvs::neighbors::brute_force::index \ { \ - return detail::build(res, dataset, metric, metric_arg); \ + return detail::build(res, dataset, metric, metric_arg); \ } \ auto build(raft::resources const& res, \ raft::device_matrix_view dataset, \ cuvs::distance::DistanceType metric, \ - T metric_arg) \ - ->cuvs::neighbors::brute_force::index \ + DistT metric_arg) \ + ->cuvs::neighbors::brute_force::index \ { \ - return detail::build(res, dataset, metric, metric_arg); \ + return detail::build(res, dataset, metric, metric_arg); \ } \ \ void search( \ raft::resources const& res, \ - const cuvs::neighbors::brute_force::index& idx, \ + const cuvs::neighbors::brute_force::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ + raft::device_matrix_view distances, \ std::optional> sample_filter = std::nullopt) \ { \ if (!sample_filter.has_value()) { \ - detail::brute_force_search(res, idx, queries, neighbors, distances); \ + detail::brute_force_search(res, idx, queries, neighbors, distances); \ } else { \ - detail::brute_force_search_filtered( \ + detail::brute_force_search_filtered( \ res, idx, queries, *sample_filter, neighbors, distances); \ } \ } \ void search( \ raft::resources const& res, \ - const cuvs::neighbors::brute_force::index& idx, \ + const cuvs::neighbors::brute_force::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ + raft::device_matrix_view distances, \ std::optional> sample_filter = std::nullopt) \ { \ if (!sample_filter.has_value()) { \ - detail::brute_force_search(res, idx, queries, neighbors, distances); \ + detail::brute_force_search(res, idx, queries, neighbors, distances); \ } else { \ RAFT_FAIL("filtered search isn't available with col_major queries yet"); \ } \ } \ \ - template struct cuvs::neighbors::brute_force::index; + template struct cuvs::neighbors::brute_force::index; -CUVS_INST_BFKNN(float); +CUVS_INST_BFKNN(float, float); +CUVS_INST_BFKNN(half, float); #undef CUVS_INST_BFKNN diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index b00d6617c..e907568f5 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -24,7 +24,7 @@ namespace cuvs::neighbors::cagra::detail { namespace multi_cta_search { -#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY +#ifdef _CUVS_EXPLICIT_INSTANTIATE_ONLY template __launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL fusedL2kNN(const DataT* x, const DataT* y, - const DataT* _xn, - const DataT* _yn, + const OutT* _xn, + const OutT* _yn, const IdxT m, const IdxT n, const IdxT k, @@ -342,8 +342,8 @@ __launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL fusedL2kNN(const DataT* x, auto epilog_lambda = [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, + OutT * regxn, + OutT * regyn, IdxT gridStrideX, IdxT gridStrideY) { // Use ::template to disambiguate (See: @@ -536,8 +536,8 @@ void fusedL2UnexpKnnImpl(const DataT* x, void* workspace, size_t& worksize) { - typedef typename raft::linalg::Policy2x8::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; typedef typename std::conditional::type KPolicy; @@ -705,8 +705,8 @@ template void fusedL2ExpKnnImpl(const DataT* x, const DataT* y, - const DataT* xn, - const DataT* yn, + const AccT* xn, + const AccT* yn, IdxT m, IdxT n, IdxT k, @@ -721,8 +721,8 @@ void fusedL2ExpKnnImpl(const DataT* x, void* workspace, size_t& worksize) { - typedef typename raft::linalg::Policy2x8::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; typedef typename std::conditional::type KPolicy; @@ -777,7 +777,7 @@ void fusedL2ExpKnnImpl(const DataT* x, int32_t* mutexes = nullptr; if (grid.x > 1) { const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); - const auto normsSize = (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); + const auto normsSize = (x != y) ? (m + n) * sizeof(AccT) : n * sizeof(AccT); const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; if (worksize < requiredSize) { worksize = requiredSize; @@ -790,8 +790,8 @@ void fusedL2ExpKnnImpl(const DataT* x, // calculate norms if they haven't been passed in if (!xn) { - DataT* xn_ = (DataT*)workspace; - workspace = xn_ + m; + AccT* xn_ = (AccT*)workspace; + workspace = xn_ + m; raft::linalg::rowNorm( xn_, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); xn = xn_; @@ -800,7 +800,7 @@ void fusedL2ExpKnnImpl(const DataT* x, if (x == y) { yn = xn; } else { - DataT* yn_ = (DataT*)(workspace); + AccT* yn_ = (AccT*)(workspace); raft::linalg::rowNorm( yn_, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); yn = yn_; @@ -843,8 +843,8 @@ void fusedL2ExpKnn(IdxT m, IdxT ldd, const DataT* x, const DataT* y, - const DataT* xn, - const DataT* yn, + const AccT* xn, + const AccT* yn, bool sqrt, OutT* out_dists, IdxT* out_inds, @@ -930,10 +930,13 @@ void fusedL2ExpKnn(IdxT m, * @param[in] rowMajorQuery are the query array in row-major layout? * @param[in] stream stream to order kernel launch */ -template +template void fusedL2Knn(size_t D, value_idx* out_inds, - value_t* out_dists, + distance_t* out_dists, const value_t* index, const value_t* query, size_t n_index_rows, @@ -943,8 +946,8 @@ void fusedL2Knn(size_t D, bool rowMajorQuery, cudaStream_t stream, cuvs::distance::DistanceType metric, - const value_t* index_norms = NULL, - const value_t* query_norms = NULL) + const distance_t* index_norms = NULL, + const distance_t* query_norms = NULL) { // Validate the input data ASSERT(k > 0, "l2Knn: k must be > 0"); @@ -975,83 +978,87 @@ void fusedL2Knn(size_t D, tempWorksize = cuvs::distance::getWorkspaceSize(query, index, n_query_rows, n_index_rows, D); worksize = tempWorksize; workspace.resize(worksize, stream); - fusedL2ExpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - query_norms, - index_norms, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); + fusedL2ExpKnn( + n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + query_norms, + index_norms, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); if (worksize > tempWorksize) { workspace.resize(worksize, stream); - fusedL2ExpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - query_norms, - index_norms, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); + fusedL2ExpKnn( + n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + query_norms, + index_norms, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); } break; case cuvs::distance::DistanceType::L2Unexpanded: case cuvs::distance::DistanceType::L2SqrtUnexpanded: - fusedL2UnexpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); + fusedL2UnexpKnn( + n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); if (worksize) { workspace.resize(worksize, stream); - fusedL2UnexpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); + fusedL2UnexpKnn( + n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); } break; default: printf("only L2 distance metric is supported\n"); break; diff --git a/cpp/src/neighbors/detail/haversine_distance.cuh b/cpp/src/neighbors/detail/haversine_distance.cuh index fc6aa477d..ee972cf35 100644 --- a/cpp/src/neighbors/detail/haversine_distance.cuh +++ b/cpp/src/neighbors/detail/haversine_distance.cuh @@ -22,15 +22,30 @@ #include #include +#include + namespace cuvs::neighbors::detail { -template -DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) +template +DI distance_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) { - value_t sin_0 = raft::sin(0.5 * (x1 - y1)); - value_t sin_1 = raft::sin(0.5 * (x2 - y2)); - value_t rdist = sin_0 * sin_0 + raft::cos(x1) * raft::cos(y1) * sin_1 * sin_1; - - return 2 * raft::asin(raft::sqrt(rdist)); + if constexpr ((std::is_same_v && std::is_same_v)) { + distance_t _x1 = __half2float(x1); + distance_t _y1 = __half2float(y1); + distance_t _x2 = __half2float(x2); + distance_t _y2 = __half2float(y2); + + distance_t sin_0 = raft::sin(distance_t(0.5) * (_x1 - _y1)); + distance_t sin_1 = raft::sin(distance_t(0.5) * (_x2 - _y2)); + distance_t rdist = sin_0 * sin_0 + raft::cos(_x1) * raft::cos(_y1) * sin_1 * sin_1; + + return static_cast(2) * raft::asin(raft::sqrt(rdist)); + } else { + distance_t sin_0 = raft::sin(distance_t(0.5) * (x1 - y1)); + distance_t sin_1 = raft::sin(distance_t(0.5) * (x2 - y2)); + distance_t rdist = sin_0 * sin_0 + raft::cos(x1) * raft::cos(y1) * sin_1 * sin_1; + + return static_cast(2) * raft::asin(raft::sqrt(rdist)); + } } /** @@ -46,9 +61,14 @@ DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) * @param[in] n_index_rows number of rows in index array * @param[in] k number of closest neighbors to return */ -template +template RAFT_KERNEL haversine_knn_kernel(value_idx* out_inds, - value_t* out_dists, + distance_t* out_dists, const value_t* index, const value_t* query, size_t n_index_rows, @@ -56,12 +76,12 @@ RAFT_KERNEL haversine_knn_kernel(value_idx* out_inds, { constexpr int kNumWarps = tpb / raft::WarpSize; - __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ distance_t smemK[kNumWarps * warp_q]; __shared__ value_idx smemV[kNumWarps * warp_q]; using namespace raft::neighbors::detail::faiss_select; - BlockSelect, warp_q, thread_q, tpb> heap( - std::numeric_limits::max(), std::numeric_limits::max(), smemK, smemV, k); + BlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), std::numeric_limits::max(), smemK, smemV, k); // Grid is exactly sized to rows available int limit = raft::Pow2::roundDown(n_index_rows); @@ -77,7 +97,7 @@ RAFT_KERNEL haversine_knn_kernel(value_idx* out_inds, value_t y1 = idx_ptr[0]; value_t y2 = idx_ptr[1]; - value_t dist = compute_haversine(x1, y1, x2, y2); + distance_t dist = compute_haversine(x1, y1, x2, y2); heap.add(dist, i); } @@ -88,7 +108,7 @@ RAFT_KERNEL haversine_knn_kernel(value_idx* out_inds, value_t y1 = idx_ptr[0]; value_t y2 = idx_ptr[1]; - value_t dist = compute_haversine(x1, y1, x2, y2); + distance_t dist = compute_haversine(x1, y1, x2, y2); heap.addThreadQ(dist, i); } @@ -117,9 +137,9 @@ RAFT_KERNEL haversine_knn_kernel(value_idx* out_inds, * @param[in] k number of closest neighbors to return * @param[in] stream stream to order kernel launch */ -template +template void haversine_knn(value_idx* out_inds, - value_t* out_dists, + distance_t* out_dists, const value_t* index, const value_t* query, size_t n_index_rows, diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 559d33cc2..e3f7acc96 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -48,9 +48,9 @@ #include #include +#include #include #include - #include #include @@ -62,7 +62,7 @@ namespace cuvs::neighbors::detail { * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -70,15 +70,15 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t n, size_t d, size_t k, - ElementType* distances, // size (m, k) - IndexType* indices, // size (m, k) + DistanceT* distances, // size (m, k) + IndexType* indices, // size (m, k) cuvs::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - const ElementType* precomputed_index_norms = nullptr, - const ElementType* precomputed_search_norms = nullptr, - const uint32_t* filter_bitmap = nullptr) + DistanceT metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + const DistanceT* precomputed_index_norms = nullptr, + const DistanceT* precomputed_search_norms = nullptr, + const uint32_t* filter_bitmap = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -88,7 +88,7 @@ void tiled_brute_force_knn(const raft::resources& handle, auto total_mem = rmm::available_device_memory().second; cuvs::neighbors::detail::faiss_select::chooseTileSize( - m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); + m, n, d, sizeof(DistanceT), total_mem, tile_rows, tile_cols); // for unittesting, its convenient to be able to put a max size on the tiles // so we can test the tiling logic without having to use huge inputs. @@ -99,13 +99,13 @@ void tiled_brute_force_knn(const raft::resources& handle, tile_cols = std::max(tile_cols, k); // stores pairwise distances for the current tile - rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); + rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); // calculate norms for L2 expanded distances - this lets us avoid calculating // norms repeatedly per-tile, and just do once for the entire input auto pairwise_metric = metric; - rmm::device_uvector search_norms(0, stream); - rmm::device_uvector index_norms(0, stream); + rmm::device_uvector search_norms(0, stream); + rmm::device_uvector index_norms(0, stream); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded) { @@ -162,14 +162,14 @@ void tiled_brute_force_knn(const raft::resources& handle, if (n < k) { raft::matrix::fill(handle, raft::make_device_matrix_view(distances, m, k), - std::numeric_limits::lowest()); + std::numeric_limits::lowest()); if constexpr (std::is_signed_v) { raft::matrix::fill(handle, raft::make_device_matrix_view(indices, m, k), IndexType{-1}); } } - rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); + rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); bool select_min = cuvs::distance::is_min_close(metric); @@ -189,7 +189,7 @@ void tiled_brute_force_knn(const raft::resources& handle, search + i * d, current_query_size, d), raft::make_device_matrix_view( index + j * d, current_centroid_size, d), - raft::make_device_matrix_view( + raft::make_device_matrix_view( temp_distances.data(), current_query_size, current_centroid_size), pairwise_metric, metric_arg); @@ -208,7 +208,7 @@ void tiled_brute_force_knn(const raft::resources& handle, IndexType row = i + (idx / current_centroid_size); IndexType col = j + (idx % current_centroid_size); - cuvs::distance::detail::ops::l2_exp_cutlass_op l2_op(sqrt); + cuvs::distance::detail::ops::l2_exp_cutlass_op l2_op(sqrt); return l2_op(row_norms[row], col_norms[col], dist[idx]); }); } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { @@ -222,16 +222,16 @@ void tiled_brute_force_knn(const raft::resources& handle, [=] __device__(IndexType idx) { IndexType row = i + (idx / current_centroid_size); IndexType col = j + (idx % current_centroid_size); - auto val = 1.0 - dist[idx] / (row_norms[row] * col_norms[col]); + auto val = DistanceT(1.0) - dist[idx] / DistanceT(row_norms[row] * col_norms[col]); return val; }); } if (filter_bitmap != nullptr) { - auto distances_ptr = temp_distances.data(); - auto count = thrust::make_counting_iterator(0); - ElementType masked_distance = select_min ? std::numeric_limits::infinity() - : std::numeric_limits::lowest(); + auto distances_ptr = temp_distances.data(); + auto count = thrust::make_counting_iterator(0); + DistanceT masked_distance = select_min ? std::numeric_limits::infinity() + : std::numeric_limits::lowest(); thrust::for_each(raft::resource::get_thrust_policy(handle), count, count + current_query_size * current_centroid_size, @@ -250,10 +250,10 @@ void tiled_brute_force_knn(const raft::resources& handle, cuvs::selection::select_k( handle, - raft::make_device_matrix_view( + raft::make_device_matrix_view( temp_distances.data(), current_query_size, current_centroid_size), std::nullopt, - raft::make_device_matrix_view( + raft::make_device_matrix_view( distances + i * k, current_query_size, current_k), raft::make_device_matrix_view( indices + i * k, current_query_size, current_k), @@ -269,10 +269,10 @@ void tiled_brute_force_knn(const raft::resources& handle, // concatenation. // Fix both of these problems in a single pass here if (tile_cols != n) { - const ElementType* in_distances = distances + i * k; - const IndexType* in_indices = indices + i * k; - ElementType* out_distances = temp_out_distances.data(); - IndexType* out_indices = temp_out_indices.data(); + const DistanceT* in_distances = distances + i * k; + const IndexType* in_indices = indices + i * k; + DistanceT* out_distances = temp_out_distances.data(); + IndexType* out_indices = temp_out_indices.data(); auto count = thrust::make_counting_iterator(0); thrust::for_each(raft::resource::get_thrust_policy(handle), @@ -292,11 +292,11 @@ void tiled_brute_force_knn(const raft::resources& handle, // select the actual top-k items here from the temporary output cuvs::selection::select_k( handle, - raft::make_device_matrix_view( + raft::make_device_matrix_view( temp_out_distances.data(), current_query_size, temp_out_cols), raft::make_device_matrix_view( temp_out_indices.data(), current_query_size, temp_out_cols), - raft::make_device_matrix_view( + raft::make_device_matrix_view( distances + i * k, current_query_size, k), raft::make_device_matrix_view( indices + i * k, current_query_size, k), @@ -330,7 +330,10 @@ void tiled_brute_force_knn(const raft::resources& handle, * @param[in] metric corresponds to the cuvs::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm */ -template +template void brute_force_knn_impl( raft::resources const& handle, std::vector& input, @@ -339,15 +342,15 @@ void brute_force_knn_impl( value_t* search_items, IntType n, IdxType* res_I, - value_t* res_D, + DistType* res_D, IntType k, bool rowMajorIndex = true, bool rowMajorQuery = true, std::vector* translations = nullptr, cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded, - float metricArg = 0, - std::vector* input_norms = nullptr, - const value_t* search_norms = nullptr) + DistType metricArg = 0, + std::vector* input_norms = nullptr, + const DistType* search_norms = nullptr) { auto userStream = raft::resource::get_cuda_stream(handle); @@ -375,11 +378,11 @@ void brute_force_knn_impl( rmm::device_uvector trans(id_ranges->size(), userStream); raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); - rmm::device_uvector all_D(0, userStream); + rmm::device_uvector all_D(0, userStream); rmm::device_uvector all_I(0, userStream); - value_t* out_D = res_D; - IdxType* out_I = res_I; + DistType* out_D = res_D; + IdxType* out_I = res_I; if (input.size() > 1) { all_D.resize(input.size() * k * n, userStream); @@ -419,8 +422,8 @@ void brute_force_knn_impl( size_t total_rows_processed = 0; for (size_t i = 0; i < input.size(); i++) { - value_t* out_d_ptr = out_D + (i * k * n); - IdxType* out_i_ptr = out_I + (i * k * n); + DistType* out_d_ptr = out_D + (i * k * n); + IdxType* out_i_ptr = out_I + (i * k * n); auto stream = raft::resource::get_next_usable_stream(handle, i); @@ -448,13 +451,13 @@ void brute_force_knn_impl( if (metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::L2SqrtUnexpanded || metric == cuvs::distance::DistanceType::LpUnexpanded) { - value_t p = 0.5; // standard l2 + DistType p = 0.5; // standard l2 if (metric == cuvs::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; - raft::linalg::unaryOp( + raft::linalg::unaryOp( res_D, res_D, n * k, - [p] __device__(value_t input) { return powf(fabsf(input), p); }, + [p] __device__(DistType input) { return powf(fabsf(input), p); }, stream); } } else { @@ -514,14 +517,17 @@ void brute_force_knn_impl( if (translations == nullptr) delete id_ranges; }; -template +template void brute_force_search( raft::resources const& res, - const cuvs::neighbors::brute_force::index& idx, + const cuvs::neighbors::brute_force::index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - std::optional> query_norms = std::nullopt) + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) { RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), @@ -532,36 +538,37 @@ void brute_force_search( std::vector dataset = {const_cast(idx.dataset().data_handle())}; std::vector sizes = {idx.dataset().extent(0)}; - std::vector norms; - if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } - - brute_force_knn_impl(res, - dataset, - sizes, - d, - const_cast(queries.data_handle()), - queries.extent(0), - neighbors.data_handle(), - distances.data_handle(), - k, - true, - std::is_same_v, - nullptr, - idx.metric(), - idx.metric_arg(), - norms.size() ? &norms : nullptr, - query_norms ? query_norms->data_handle() : nullptr); + std::vector norms; + if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } + + brute_force_knn_impl( + res, + dataset, + sizes, + d, + const_cast(queries.data_handle()), + queries.extent(0), + neighbors.data_handle(), + distances.data_handle(), + k, + true, + std::is_same_v, + nullptr, + idx.metric(), + idx.metric_arg(), + norms.size() ? &norms : nullptr, + query_norms ? query_norms->data_handle() : nullptr); } -template +template void brute_force_search_filtered( raft::resources const& res, - const cuvs::neighbors::brute_force::index& idx, + const cuvs::neighbors::brute_force::index& idx, raft::device_matrix_view queries, cuvs::core::bitmap_view filter, raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - std::optional> query_norms = std::nullopt) + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) { auto metric = idx.metric(); @@ -604,26 +611,26 @@ void brute_force_search_filtered( if (sparsity > 0.01f) { raft::resources stream_pool_handle(res); raft::resource::set_cuda_stream(stream_pool_handle, stream); - auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; - - tiled_brute_force_knn(stream_pool_handle, - queries.data_handle(), - idx.dataset().data_handle(), - n_queries, - n_dataset, - dim, - k, - distances.data_handle(), - neighbors.data_handle(), - metric, - 2.0, - 0, - 0, - idx_norm, - nullptr, - filter.data()); + auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; + + tiled_brute_force_knn(stream_pool_handle, + queries.data_handle(), + idx.dataset().data_handle(), + n_queries, + n_dataset, + dim, + k, + distances.data_handle(), + neighbors.data_handle(), + metric, + DistanceT{2.0}, + 0, + 0, + idx_norm, + nullptr, + filter.data()); } else { - auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); + auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); // fill csr raft::sparse::convert::bitmap_to_csr(res, filter, csr); @@ -639,20 +646,20 @@ void brute_force_search_filtered( auto dataset_view = raft::make_device_matrix_view( idx.dataset().data_handle(), n_dataset, dim); - auto csr_view = raft::make_device_csr_matrix_view( + auto csr_view = raft::make_device_csr_matrix_view( csr.get_elements().data(), compressed_csr_view); raft::sparse::linalg::masked_matmul(res, queries, dataset_view, filter, csr_view); // post process - std::optional> query_norms_; + std::optional> query_norms_; if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded) { if (metric == cuvs::distance::DistanceType::CosineExpanded) { if (!query_norms) { - query_norms_ = raft::make_device_vector(res, n_queries); - raft::linalg::rowNorm((T*)(query_norms_->data_handle()), + query_norms_ = raft::make_device_vector(res, n_queries); + raft::linalg::rowNorm((DistanceT*)(query_norms_->data_handle()), queries.data_handle(), dim, n_queries, @@ -663,8 +670,8 @@ void brute_force_search_filtered( } } else { if (!query_norms) { - query_norms_ = raft::make_device_vector(res, n_queries); - raft::linalg::rowNorm((T*)(query_norms_->data_handle()), + query_norms_ = raft::make_device_vector(res, n_queries); + raft::linalg::rowNorm((DistanceT*)(query_norms_->data_handle()), queries.data_handle(), dim, n_queries, @@ -686,7 +693,7 @@ void brute_force_search_filtered( } // select k - auto const_csr_view = raft::make_device_csr_matrix_view( + auto const_csr_view = raft::make_device_csr_matrix_view( csr.get_elements().data(), compressed_csr_view); std::optional> no_opt = std::nullopt; bool select_min = cuvs::distance::is_min_close(metric); @@ -697,21 +704,21 @@ void brute_force_search_filtered( return; } -template -cuvs::neighbors::brute_force::index build( +template +cuvs::neighbors::brute_force::index build( raft::resources const& res, raft::device_matrix_view dataset, cuvs::distance::DistanceType metric, - T metric_arg) + DistT metric_arg) { // certain distance metrics can benefit by pre-calculating the norms for the index dataset // which lets us avoid calculating these at query time - std::optional> norms; + std::optional> norms; if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded) { - norms = raft::make_device_vector(res, dataset.extent(0)); + norms = raft::make_device_vector(res, dataset.extent(0)); // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == cuvs::distance::DistanceType::CosineExpanded) { raft::linalg::norm(res, @@ -729,6 +736,7 @@ cuvs::neighbors::brute_force::index build( } } - return cuvs::neighbors::brute_force::index(res, dataset, std::move(norms), metric, metric_arg); + return cuvs::neighbors::brute_force::index( + res, dataset, std::move(norms), metric, metric_arg); } } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/knn_utils.cuh b/cpp/src/neighbors/detail/knn_utils.cuh index 1cc709fa4..60d5f6e30 100644 --- a/cpp/src/neighbors/detail/knn_utils.cuh +++ b/cpp/src/neighbors/detail/knn_utils.cuh @@ -21,15 +21,16 @@ #include #include +#include #include namespace cuvs::neighbors::detail { -template -RAFT_KERNEL epilogue_on_csr_kernel(value_t* __restrict__ compressed_C, +template +RAFT_KERNEL epilogue_on_csr_kernel(output_t* __restrict__ compressed_C, const value_idx* __restrict__ rows, const value_idx* __restrict__ cols, - const value_t* __restrict__ Q_sq_norms, + const output_t* __restrict__ Q_sq_norms, const value_t* __restrict__ R_sq_norms, value_idx nnz, expansion_f expansion_func) @@ -43,13 +44,13 @@ RAFT_KERNEL epilogue_on_csr_kernel(value_t* __restrict__ compressed_C, compressed_C[tid] = expansion_func(compressed_C[tid], Q_sq_norms[i], R_sq_norms[j]); } -template +template void epilogue_on_csr(raft::resources const& handle, - value_t* compressed_C, + output_t* compressed_C, const value_idx nnz, const value_idx* rows, const value_idx* cols, - const value_t* Q_sq_norms, + const output_t* Q_sq_norms, const value_t* R_sq_norms, cuvs::distance::DistanceType metric) { @@ -65,8 +66,12 @@ void epilogue_on_csr(raft::resources const& handle, Q_sq_norms, R_sq_norms, nnz, - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { - return value_t(-2.0) * dot + q_norm + r_norm; + [] __device__ __host__(output_t dot, output_t q_norm, value_t r_norm) -> output_t { + if constexpr (std::is_same_v) { + return output_t(-2.0) * dot + q_norm + __half2float(r_norm); + } else { + return output_t(-2.0) * dot + q_norm + r_norm; + } }); } else if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { epilogue_on_csr_kernel<<>>( @@ -76,8 +81,12 @@ void epilogue_on_csr(raft::resources const& handle, Q_sq_norms, R_sq_norms, nnz, - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { - return raft::sqrt(value_t(-2.0) * dot + q_norm + r_norm); + [] __device__ __host__(output_t dot, output_t q_norm, value_t r_norm) -> output_t { + if constexpr (std::is_same_v) { + return raft::sqrt(output_t(-2.0) * dot + q_norm + __half2float(r_norm)); + } else { + return raft::sqrt(output_t(-2.0) * dot + q_norm + r_norm); + } }); } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { epilogue_on_csr_kernel<<>>( @@ -87,8 +96,12 @@ void epilogue_on_csr(raft::resources const& handle, Q_sq_norms, R_sq_norms, nnz, - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) -> value_t { - return value_t(1.0) - dot / (q_norm * r_norm); + [] __device__ __host__(output_t dot, output_t q_norm, value_t r_norm) -> output_t { + if constexpr (std::is_same_v) { + return output_t(1.0) - dot / (q_norm * __half2float(r_norm)); + } else { + return output_t(1.0) - dot / (q_norm * r_norm); + } }); } RAFT_CUDA_TRY(cudaPeekAtLastError()); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3495b2344..a30e2dec7 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -70,6 +70,9 @@ function(ConfigureTest) ${TEST_NAME} PRIVATE "$<$:${CUVS_CXX_FLAGS}>" "$<$:${CUVS_CUDA_FLAGS}>" ) + + target_compile_definitions(${TEST_NAME} PRIVATE "CUVS_EXPLICIT_INSTANTIATE_ONLY") + if(_CUVS_TEST_NOCUDA) target_compile_definitions(${TEST_NAME} PRIVATE "CUVS_DISABLE_CUDA") endif() @@ -126,6 +129,18 @@ if(BUILD_TESTS) 100 ) + ConfigureTest( + NAME + NEIGHBORS_ANN_BRUTE_FORCE_TEST + PATH + neighbors/ann_brute_force/test_float.cu + neighbors/ann_brute_force/test_half.cu + GPUS + 1 + PERCENT + 100 + ) + ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_TEST diff --git a/cpp/test/distance/dist_canberra.cu b/cpp/test/distance/dist_canberra.cu index 2bf590601..e7ffa9d0f 100644 --- a/cpp/test/distance/dist_canberra.cu +++ b/cpp/test/distance/dist_canberra.cu @@ -20,8 +20,9 @@ namespace cuvs { namespace distance { -template -class DistanceCanberra : public DistanceTest {}; +template +class DistanceCanberra + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -63,6 +64,26 @@ TEST_P(DistanceCanberraD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCanberra DistanceCanberraH; +TEST_P(DistanceCanberraH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraH, ::testing::ValuesIn(inputsh)); + class BigMatrixCanberra : public BigMatrixDistanceTest {}; TEST_F(BigMatrixCanberra, Result) {} diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu index 9e061bebc..70f3a9adb 100644 --- a/cpp/test/distance/dist_correlation.cu +++ b/cpp/test/distance/dist_correlation.cu @@ -20,13 +20,15 @@ namespace cuvs { namespace distance { -template +template class DistanceCorrelation - : public DistanceTest {}; + : public DistanceTest {}; -template +template class DistanceCorrelationXequalY - : public DistanceTestSameBuffer {}; + : public DistanceTestSameBuffer {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -48,6 +50,26 @@ TEST_P(DistanceCorrelationF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCorrelation DistanceCorrelationH; +TEST_P(DistanceCorrelationH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationH, ::testing::ValuesIn(inputsh)); + typedef DistanceCorrelationXequalY DistanceCorrelationXequalYF; TEST_P(DistanceCorrelationXequalYF, Result) { @@ -87,6 +109,25 @@ TEST_P(DistanceCorrelationD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationD, ::testing::ValuesIn(inputsd)); +typedef DistanceCorrelationXequalY DistanceCorrelationXequalYH; +TEST_P(DistanceCorrelationXequalYH, Result) +{ + int m = params.m; + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + cuvs::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + cuvs::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationXequalYH, ::testing::ValuesIn(inputsh)); + class BigMatrixCorrelation : public BigMatrixDistanceTest {}; TEST_F(BigMatrixCorrelation, Result) {} diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index e134f045f..78e2c745f 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -20,13 +20,15 @@ namespace cuvs { namespace distance { -template -class DistanceExpCos : public DistanceTest { -}; +template +class DistanceExpCos + : public DistanceTest {}; -template +template class DistanceExpCosXequalY - : public DistanceTestSameBuffer {}; + : public DistanceTestSameBuffer {}; const std::vector> inputsf = { {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, @@ -52,6 +54,30 @@ const std::vector> inputsXeqYf = { {0.03f, 1024, 1024, 1024, false, 1234ULL}, }; +const std::vector> inputsh = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; + +const std::vector> inputsXeqYh = { + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, +}; + typedef DistanceExpCos DistanceExpCosF; TEST_P(DistanceExpCosF, Result) { @@ -62,6 +88,16 @@ TEST_P(DistanceExpCosF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf)); +typedef DistanceExpCos DistanceExpCosH; +TEST_P(DistanceExpCosH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosH, ::testing::ValuesIn(inputsh)); + typedef DistanceExpCosXequalY DistanceExpCosXequalYF; TEST_P(DistanceExpCosXequalYF, Result) { @@ -85,6 +121,29 @@ TEST_P(DistanceExpCosXequalYF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYF, ::testing::ValuesIn(inputsXeqYf)); +typedef DistanceExpCosXequalY DistanceExpCosXequalYH; +TEST_P(DistanceExpCosXequalYH, Result) +{ + int m = params.m; + int n = params.m; + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + n, + cuvs::CompareApprox(params.tolerance), + stream)); + n = params.isRowMajor ? m : m / 2; + m = params.isRowMajor ? m / 2 : m; + + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m, + n, + cuvs::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYH, ::testing::ValuesIn(inputsXeqYh)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/dist_hamming.cu b/cpp/test/distance/dist_hamming.cu index 0cf753eca..3073ed939 100644 --- a/cpp/test/distance/dist_hamming.cu +++ b/cpp/test/distance/dist_hamming.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceHamming - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -64,6 +64,26 @@ TEST_P(DistanceHammingD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHamming DistanceHammingH; +TEST_P(DistanceHammingH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingH, ::testing::ValuesIn(inputsh)); + class BigMatrixHamming : public BigMatrixDistanceTest {}; TEST_F(BigMatrixHamming, Result) {} diff --git a/cpp/test/distance/dist_hellinger.cu b/cpp/test/distance/dist_hellinger.cu index 3998a60ab..692bfeeff 100644 --- a/cpp/test/distance/dist_hellinger.cu +++ b/cpp/test/distance/dist_hellinger.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceHellingerExp - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -64,6 +64,26 @@ TEST_P(DistanceHellingerExpD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHellingerExp DistanceHellingerExpH; +TEST_P(DistanceHellingerExpH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpH, ::testing::ValuesIn(inputsh)); + class BigMatrixHellingerExp : public BigMatrixDistanceTest {}; TEST_F(BigMatrixHellingerExp, Result) {} diff --git a/cpp/test/distance/dist_inner_product.cu b/cpp/test/distance/dist_inner_product.cu index 1d6709d52..aaedb5bf1 100644 --- a/cpp/test/distance/dist_inner_product.cu +++ b/cpp/test/distance/dist_inner_product.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceInnerProduct - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 10, 5, 32, true, 1234ULL}, @@ -66,6 +66,27 @@ TEST_P(DistanceInnerProductD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 10, 5, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceInnerProduct DistanceInnerProductH; +TEST_P(DistanceInnerProductH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductH, ::testing::ValuesIn(inputsh)); + class BigMatrixInnerProduct : public BigMatrixDistanceTest {}; TEST_F(BigMatrixInnerProduct, Result) {} diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu index 43b7b361d..f50830d4e 100644 --- a/cpp/test/distance/dist_jensen_shannon.cu +++ b/cpp/test/distance/dist_jensen_shannon.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceJensenShannon - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -64,6 +64,26 @@ TEST_P(DistanceJensenShannonD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceJensenShannon DistanceJensenShannonH; +TEST_P(DistanceJensenShannonH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonH, ::testing::ValuesIn(inputsh)); + class BigMatrixJensenShannon : public BigMatrixDistanceTest {}; TEST_F(BigMatrixJensenShannon, Result) {} diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu index 5e5692841..3d2373bf7 100644 --- a/cpp/test/distance/dist_kl_divergence.cu +++ b/cpp/test/distance/dist_kl_divergence.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceKLDivergence - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -64,6 +64,26 @@ TEST_P(DistanceKLDivergenceD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceKLDivergence DistanceKLDivergenceH; +TEST_P(DistanceKLDivergenceH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceH, ::testing::ValuesIn(inputsh)); + class BigMatrixKLDivergence : public BigMatrixDistanceTest {}; TEST_F(BigMatrixKLDivergence, Result) {} diff --git a/cpp/test/distance/dist_l1.cu b/cpp/test/distance/dist_l1.cu index a3ecd21fe..cd9a9219b 100644 --- a/cpp/test/distance/dist_l1.cu +++ b/cpp/test/distance/dist_l1.cu @@ -20,8 +20,9 @@ namespace cuvs { namespace distance { -template -class DistanceUnexpL1 : public DistanceTest {}; +template +class DistanceUnexpL1 + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -63,6 +64,26 @@ TEST_P(DistanceUnexpL1D, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1D, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceUnexpL1 DistanceUnexpL1H; +TEST_P(DistanceUnexpL1H, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1H, ::testing::ValuesIn(inputsh)); + class BigMatrixUnexpL1 : public BigMatrixDistanceTest {}; TEST_F(BigMatrixUnexpL1, Result) {} diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu index f3d038cbc..5804bb6c5 100644 --- a/cpp/test/distance/dist_l2_exp.cu +++ b/cpp/test/distance/dist_l2_exp.cu @@ -20,13 +20,14 @@ namespace cuvs { namespace distance { -template -class DistanceEucExpTest : public DistanceTest { -}; +template +class DistanceEucExpTest + : public DistanceTest {}; -template +template class DistanceEucExpTestXequalY - : public DistanceTestSameBuffer {}; + : public DistanceTestSameBuffer { +}; const std::vector> inputsf = { {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, @@ -58,6 +59,36 @@ const std::vector> inputsXeqYf = { {0.03f, 1021, 1021, 1021, false, 1234ULL}, }; +const std::vector> inputsh = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, + {0.001f, 2048, 4096, 128, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, +}; + +const std::vector> inputsXeqYh = { + {0.01f, 2048, 4096, 128, true, 1234ULL}, + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.03f, 1021, 1021, 1021, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, + {0.03f, 1021, 1021, 1021, false, 1234ULL}, +}; + typedef DistanceEucExpTest DistanceEucExpTestF; TEST_P(DistanceEucExpTestF, Result) { @@ -68,6 +99,16 @@ TEST_P(DistanceEucExpTestF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf)); +typedef DistanceEucExpTest DistanceEucExpTestH; +TEST_P(DistanceEucExpTestH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestH, ::testing::ValuesIn(inputsh)); + typedef DistanceEucExpTestXequalY DistanceEucExpTestXequalYF; TEST_P(DistanceEucExpTestXequalYF, Result) { @@ -89,6 +130,27 @@ INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestXequalYF, ::testing::ValuesIn(inputsXeqYf)); +typedef DistanceEucExpTestXequalY DistanceEucExpTestXequalYH; +TEST_P(DistanceEucExpTestXequalYH, Result) +{ + int m = params.m; + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + cuvs::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + cuvs::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, + DistanceEucExpTestXequalYH, + ::testing::ValuesIn(inputsXeqYh)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/dist_l2_sqrt_exp.cu b/cpp/test/distance/dist_l2_sqrt_exp.cu index b24384be8..a2e09fdb0 100644 --- a/cpp/test/distance/dist_l2_sqrt_exp.cu +++ b/cpp/test/distance/dist_l2_sqrt_exp.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceEucSqrtExpTest - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 2048, 4096, 128, true, 1234ULL}, @@ -67,6 +67,29 @@ TEST_P(DistanceEucSqrtExpTestD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 2048, 4096, 128, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestH; +TEST_P(DistanceEucSqrtExpTestH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestH, ::testing::ValuesIn(inputsh)); + class BigMatrixEucSqrtExp : public BigMatrixDistanceTest {}; TEST_F(BigMatrixEucSqrtExp, Result) {} diff --git a/cpp/test/distance/dist_l2_unexp.cu b/cpp/test/distance/dist_l2_unexp.cu index c057434fa..3f9e6458f 100644 --- a/cpp/test/distance/dist_l2_unexp.cu +++ b/cpp/test/distance/dist_l2_unexp.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceEucUnexpTest - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -64,6 +64,26 @@ TEST_P(DistanceEucUnexpTestD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucUnexpTest DistanceEucUnexpTestH; +TEST_P(DistanceEucUnexpTestH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestH, ::testing::ValuesIn(inputsh)); + class BigMatrixEucUnexp : public BigMatrixDistanceTest { }; TEST_F(BigMatrixEucUnexp, Result) {} diff --git a/cpp/test/distance/dist_l_inf.cu b/cpp/test/distance/dist_l_inf.cu index b9ced68f3..21e9a6c87 100644 --- a/cpp/test/distance/dist_l_inf.cu +++ b/cpp/test/distance/dist_l_inf.cu @@ -20,8 +20,9 @@ namespace cuvs { namespace distance { -template -class DistanceLinf : public DistanceTest {}; +template +class DistanceLinf : public DistanceTest { +}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -63,6 +64,26 @@ TEST_P(DistanceLinfD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceLinf DistanceLinfH; +TEST_P(DistanceLinfH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfH, ::testing::ValuesIn(inputsh)); + class BigMatrixLinf : public BigMatrixDistanceTest {}; TEST_F(BigMatrixLinf, Result) {} diff --git a/cpp/test/distance/dist_lp_unexp.cu b/cpp/test/distance/dist_lp_unexp.cu index 26620b44b..95e521fb3 100644 --- a/cpp/test/distance/dist_lp_unexp.cu +++ b/cpp/test/distance/dist_lp_unexp.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template -class DistanceLpUnexp : public DistanceTest { -}; +template +class DistanceLpUnexp + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL, 4.0f}, @@ -64,6 +64,26 @@ TEST_P(DistanceLpUnexpD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL, 4.0f}, + {0.001f, 1024, 32, 1024, true, 1234ULL, 3.0f}, + {0.001f, 32, 1024, 1024, true, 1234ULL, 4.0f}, + {0.003f, 1024, 1024, 1024, true, 1234ULL, 3.0f}, + {0.001f, 1024, 1024, 32, false, 1234ULL, 4.0f}, + {0.001f, 1024, 32, 1024, false, 1234ULL, 3.0f}, + {0.001f, 32, 1024, 1024, false, 1234ULL, 4.0f}, + {0.003f, 1024, 1024, 1024, false, 1234ULL, 3.0f}, +}; +typedef DistanceLpUnexp DistanceLpUnexpH; +TEST_P(DistanceLpUnexpH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpH, ::testing::ValuesIn(inputsh)); + class BigMatrixLpUnexp : public BigMatrixDistanceTest { }; TEST_F(BigMatrixLpUnexp, Result) {} diff --git a/cpp/test/distance/dist_russell_rao.cu b/cpp/test/distance/dist_russell_rao.cu index 46da7f9cd..814a0503f 100644 --- a/cpp/test/distance/dist_russell_rao.cu +++ b/cpp/test/distance/dist_russell_rao.cu @@ -20,9 +20,9 @@ namespace cuvs { namespace distance { -template +template class DistanceRussellRao - : public DistanceTest {}; + : public DistanceTest {}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -64,6 +64,26 @@ TEST_P(DistanceRussellRaoD, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoD, ::testing::ValuesIn(inputsd)); +const std::vector> inputsh = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceRussellRao DistanceRussellRaoH; +TEST_P(DistanceRussellRaoH, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoH, ::testing::ValuesIn(inputsh)); + class BigMatrixRussellRao : public BigMatrixDistanceTest {}; TEST_F(BigMatrixRussellRao, Result) {} diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 2213db87e..8a431f49a 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -18,7 +18,6 @@ #include "../test_utils.cuh" -#include #include // cuvs::distance::DistanceType #include // raft::common::nvtx::range #include //raft::make_device_matrix_view @@ -34,8 +33,18 @@ namespace cuvs { namespace distance { -template -RAFT_KERNEL naiveDistanceKernel(DataType* dist, +template +_RAFT_DEVICE inline auto half2float(T& a) +{ + if constexpr (std::is_same_v::type, half>) { + return __half2float(a); + } else { + return a; + } +} + +template +RAFT_KERNEL naiveDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -47,11 +56,11 @@ RAFT_KERNEL naiveDistanceKernel(DataType* dist, std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); + OutputType acc = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto diff = x[xidx] - y[yidx]; + auto diff = half2float(x[xidx]) - half2float(y[yidx]); acc += diff * diff; } if (type == cuvs::distance::DistanceType::L2SqrtExpanded || @@ -61,8 +70,8 @@ RAFT_KERNEL naiveDistanceKernel(DataType* dist, dist[outidx] = acc; } -template -RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(DataType* dist, +template +RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -75,12 +84,12 @@ RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(DataType* dist, std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) { return; } - DataType acc = DataType(0); + OutputType acc = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); auto diff = (a > b) ? (a - b) : (b - a); if (type == cuvs::distance::DistanceType::Linf) { acc = raft::max(acc, diff); @@ -98,8 +107,8 @@ RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(DataType* dist, dist[outidx] = acc; } -template -RAFT_KERNEL naiveCosineDistanceKernel(DataType* dist, +template +RAFT_KERNEL naiveCosineDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -111,15 +120,15 @@ RAFT_KERNEL naiveCosineDistanceKernel(DataType* dist, std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) { return; } - DataType acc_a = DataType(0); - DataType acc_b = DataType(0); - DataType acc_ab = DataType(0); + OutputType acc_a = OutputType(0); + OutputType acc_b = OutputType(0); + OutputType acc_ab = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); acc_a += a * a; acc_b += b * b; acc_ab += a * b; @@ -128,11 +137,11 @@ RAFT_KERNEL naiveCosineDistanceKernel(DataType* dist, std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; // Use 1.0 - (cosine similarity) to calc the distance - dist[outidx] = (DataType)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); + dist[outidx] = (OutputType)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } -template -RAFT_KERNEL naiveInnerProductKernel(DataType* dist, +template +RAFT_KERNEL naiveInnerProductKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -144,13 +153,13 @@ RAFT_KERNEL naiveInnerProductKernel(DataType* dist, std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) { return; } - DataType acc_ab = DataType(0); + OutputType acc_ab = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); acc_ab += a * b; } @@ -158,8 +167,8 @@ RAFT_KERNEL naiveInnerProductKernel(DataType* dist, dist[outidx] = acc_ab; } -template -RAFT_KERNEL naiveHellingerDistanceKernel(DataType* dist, +template +RAFT_KERNEL naiveHellingerDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -171,13 +180,13 @@ RAFT_KERNEL naiveHellingerDistanceKernel(DataType* dist, std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) { return; } - DataType acc_ab = DataType(0); + OutputType acc_ab = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); acc_ab += raft::sqrt(a) * raft::sqrt(b); } @@ -189,25 +198,25 @@ RAFT_KERNEL naiveHellingerDistanceKernel(DataType* dist, dist[outidx] = raft::sqrt(rectifier * acc_ab); } -template -RAFT_KERNEL naiveLpUnexpDistanceKernel(DataType* dist, +template +RAFT_KERNEL naiveLpUnexpDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, std::int64_t n, std::int64_t k, bool isRowMajor, - DataType p) + OutputType p) { std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); + OutputType acc = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); auto diff = raft::abs(a - b); acc += raft::pow(diff, p); } @@ -217,8 +226,8 @@ RAFT_KERNEL naiveLpUnexpDistanceKernel(DataType* dist, dist[outidx] = acc; } -template -RAFT_KERNEL naiveHammingDistanceKernel(DataType* dist, +template +RAFT_KERNEL naiveHammingDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -229,12 +238,12 @@ RAFT_KERNEL naiveHammingDistanceKernel(DataType* dist, std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); + OutputType acc = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); acc += (a != b); } acc = acc / k; @@ -242,8 +251,8 @@ RAFT_KERNEL naiveHammingDistanceKernel(DataType* dist, dist[outidx] = acc; } -template -RAFT_KERNEL naiveJensenShannonDistanceKernel(DataType* dist, +template +RAFT_KERNEL naiveJensenShannonDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -254,19 +263,19 @@ RAFT_KERNEL naiveJensenShannonDistanceKernel(DataType* dist, std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); + OutputType acc = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); - DataType m = 0.5f * (a + b); - bool a_zero = a == 0; - bool b_zero = b == 0; + OutputType m = 0.5f * (a + b); + bool a_zero = a == 0; + bool b_zero = b == 0; - DataType p = (!a_zero * m) / (a_zero + a); - DataType q = (!b_zero * m) / (b_zero + b); + OutputType p = (!a_zero * m) / (a_zero + a); + OutputType q = (!b_zero * m) / (b_zero + b); bool p_zero = p == 0; bool q_zero = q == 0; @@ -278,8 +287,8 @@ RAFT_KERNEL naiveJensenShannonDistanceKernel(DataType* dist, dist[outidx] = acc; } -template -RAFT_KERNEL naiveRussellRaoDistanceKernel(OutType* dist, +template +RAFT_KERNEL naiveRussellRaoDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -290,12 +299,12 @@ RAFT_KERNEL naiveRussellRaoDistanceKernel(OutType* dist, std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; - OutType acc = OutType(0); + OutputType acc = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); acc += (a * b); } acc = (k - acc) / k; @@ -303,8 +312,8 @@ RAFT_KERNEL naiveRussellRaoDistanceKernel(OutType* dist, dist[outidx] = acc; } -template -RAFT_KERNEL naiveKLDivergenceDistanceKernel(OutType* dist, +template +RAFT_KERNEL naiveKLDivergenceDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -315,12 +324,12 @@ RAFT_KERNEL naiveKLDivergenceDistanceKernel(OutType* dist, std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; - OutType acc = OutType(0); + OutputType acc = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); bool b_zero = (b == 0); bool a_zero = (a == 0); acc += a * (log(a + a_zero) - log(b + b_zero)); @@ -330,8 +339,8 @@ RAFT_KERNEL naiveKLDivergenceDistanceKernel(OutType* dist, dist[outidx] = acc; } -template -RAFT_KERNEL naiveCorrelationDistanceKernel(OutType* dist, +template +RAFT_KERNEL naiveCorrelationDistanceKernel(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -342,16 +351,16 @@ RAFT_KERNEL naiveCorrelationDistanceKernel(OutType* dist, std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; - OutType acc = OutType(0); - auto a_norm = DataType(0); - auto b_norm = DataType(0); - auto a_sq_norm = DataType(0); - auto b_sq_norm = DataType(0); + OutputType acc = OutputType(0); + auto a_norm = OutputType(0); + auto b_norm = OutputType(0); + auto a_sq_norm = OutputType(0); + auto b_sq_norm = OutputType(0); for (std::int64_t i = 0; i < k; ++i) { std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; + auto a = half2float(x[xidx]); + auto b = half2float(y[yidx]); a_norm += a; b_norm += b; a_sq_norm += (a * a); @@ -369,8 +378,8 @@ RAFT_KERNEL naiveCorrelationDistanceKernel(OutType* dist, dist[outidx] = acc; } -template -void naiveDistance(DataType* dist, +template +void naiveDistance(OutputType* dist, const DataType* x, const DataType* y, std::int64_t m, @@ -378,8 +387,8 @@ void naiveDistance(DataType* dist, std::int64_t k, cuvs::distance::DistanceType type, bool isRowMajor, - DataType metric_arg = 2.0f, - cudaStream_t stream = 0) + OutputType metric_arg = 2.0f, + cudaStream_t stream = 0) { static const dim3 TPB(4, 256, 1); dim3 nblks(raft::ceildiv(m, (std::int64_t)TPB.x), raft::ceildiv(n, (std::int64_t)TPB.y), 1); @@ -388,49 +397,50 @@ void naiveDistance(DataType* dist, case cuvs::distance::DistanceType::Canberra: case cuvs::distance::DistanceType::Linf: case cuvs::distance::DistanceType::L1: - naiveL1_Linf_CanberraDistanceKernel + naiveL1_Linf_CanberraDistanceKernel <<>>(dist, x, y, m, n, k, type, isRowMajor); break; case cuvs::distance::DistanceType::L2SqrtUnexpanded: case cuvs::distance::DistanceType::L2Unexpanded: case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: - naiveDistanceKernel + naiveDistanceKernel <<>>(dist, x, y, m, n, k, type, isRowMajor); break; case cuvs::distance::DistanceType::CosineExpanded: - naiveCosineDistanceKernel + naiveCosineDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; case cuvs::distance::DistanceType::HellingerExpanded: - naiveHellingerDistanceKernel + naiveHellingerDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; case cuvs::distance::DistanceType::LpUnexpanded: - naiveLpUnexpDistanceKernel + naiveLpUnexpDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor, metric_arg); break; case cuvs::distance::DistanceType::HammingUnexpanded: - naiveHammingDistanceKernel + naiveHammingDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; case cuvs::distance::DistanceType::InnerProduct: - naiveInnerProductKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveInnerProductKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; case cuvs::distance::DistanceType::JensenShannon: - naiveJensenShannonDistanceKernel + naiveJensenShannonDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; case cuvs::distance::DistanceType::RusselRaoExpanded: - naiveRussellRaoDistanceKernel + naiveRussellRaoDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; case cuvs::distance::DistanceType::KLDivergence: - naiveKLDivergenceDistanceKernel + naiveKLDivergenceDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; case cuvs::distance::DistanceType::CorrelationExpanded: - naiveCorrelationDistanceKernel + naiveCorrelationDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; default: FAIL() << "should be here\n"; @@ -438,13 +448,13 @@ void naiveDistance(DataType* dist, RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template +template struct DistanceInputs { - DataType tolerance; + OutputType tolerance; std::int64_t m, n, k; bool isRowMajor; unsigned long long int seed; - DataType metric_arg = 2.0f; + OutputType metric_arg = 2.0f; }; template @@ -472,31 +482,38 @@ constexpr bool layout_to_row_major() return false; } -template +template void distanceLauncher(raft::resources const& handle, DataType* x, DataType* y, - DataType* dist, - DataType* dist2, + OutputType* dist, + OutputType* dist2, std::int64_t m, std::int64_t n, std::int64_t k, - DistanceInputs& params, - DataType threshold, - DataType metric_arg = 2.0f) + DistanceInputs& params, + OutputType threshold, + OutputType metric_arg = 2.0f) { - auto x_v = raft::make_device_matrix_view(x, m, k); - auto y_v = raft::make_device_matrix_view(y, n, k); - auto dist_v = raft::make_device_matrix_view(dist, m, n); + // Create device matrix views for the input and output data + auto x_v = raft::make_device_matrix_view(x, m, k); + auto y_v = raft::make_device_matrix_view(y, n, k); + auto dist_v = raft::make_device_matrix_view(dist, m, n); + // Explicitly instantiate the template function cuvs::distance::pairwise_distance(handle, x_v, y_v, dist_v, distanceType, metric_arg); } -template -class DistanceTest : public ::testing::TestWithParam> { +template +class DistanceTest : public ::testing::TestWithParam> { public: DistanceTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(raft::resource::get_cuda_stream(handle)), x(params.m * params.k, stream), y(params.n * params.k, stream), @@ -513,11 +530,11 @@ class DistanceTest : public ::testing::TestWithParam> { "test::%s/%s", testInfo->test_suite_name(), testInfo->name()); raft::random::RngState r(params.seed); - std::int64_t m = params.m; - std::int64_t n = params.n; - std::int64_t k = params.k; - DataType metric_arg = params.metric_arg; - bool isRowMajor = params.isRowMajor; + std::int64_t m = params.m; + std::int64_t n = params.n; + std::int64_t k = params.k; + OutputType metric_arg = params.metric_arg; + bool isRowMajor = params.isRowMajor; if (distanceType == cuvs::distance::DistanceType::HellingerExpanded || distanceType == cuvs::distance::DistanceType::JensenShannon || distanceType == cuvs::distance::DistanceType::KLDivergence) { @@ -537,33 +554,33 @@ class DistanceTest : public ::testing::TestWithParam> { naiveDistance( dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); - DataType threshold = -10000.f; + OutputType threshold = -10000.f; if (isRowMajor) { - distanceLauncher(handle, - x.data(), - y.data(), - dist.data(), - dist2.data(), - m, - n, - k, - params, - threshold, - metric_arg); + distanceLauncher(handle, + x.data(), + y.data(), + dist.data(), + dist2.data(), + m, + n, + k, + params, + threshold, + metric_arg); } else { - distanceLauncher(handle, - x.data(), - y.data(), - dist.data(), - dist2.data(), - m, - n, - k, - params, - threshold, - metric_arg); + distanceLauncher(handle, + x.data(), + y.data(), + dist.data(), + dist2.data(), + m, + n, + k, + params, + threshold, + metric_arg); } raft::resource::sync_stream(handle, stream); } @@ -572,8 +589,9 @@ class DistanceTest : public ::testing::TestWithParam> { raft::resources handle; cudaStream_t stream; - DistanceInputs params; - rmm::device_uvector x, y, dist_ref, dist, dist2; + DistanceInputs params; + rmm::device_uvector x, y; + rmm::device_uvector dist_ref, dist, dist2; }; /* @@ -583,12 +601,15 @@ class DistanceTest : public ::testing::TestWithParam> { * It may happen that though both X and Y are same buffer but user passes * different dimensions for them like in case of tiled_brute_force_knn. */ -template -class DistanceTestSameBuffer : public ::testing::TestWithParam> { +template +class DistanceTestSameBuffer + : public ::testing::TestWithParam> { public: - using dev_vector = rmm::device_uvector; + using dev_vector = rmm::device_uvector; DistanceTestSameBuffer() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(raft::resource::get_cuda_stream(handle)), x(params.m * params.k, stream), dist_ref({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), @@ -604,11 +625,11 @@ class DistanceTestSameBuffer : public ::testing::TestWithParamtest_suite_name(), testInfo->name()); raft::random::RngState r(params.seed); - std::int64_t m = params.m; - std::int64_t n = params.m; - std::int64_t k = params.k; - DataType metric_arg = params.metric_arg; - bool isRowMajor = params.isRowMajor; + std::int64_t m = params.m; + std::int64_t n = params.m; + std::int64_t k = params.k; + OutputType metric_arg = params.metric_arg; + bool isRowMajor = params.isRowMajor; if (distanceType == cuvs::distance::DistanceType::HellingerExpanded || distanceType == cuvs::distance::DistanceType::JensenShannon || distanceType == cuvs::distance::DistanceType::KLDivergence) { @@ -637,33 +658,35 @@ class DistanceTestSameBuffer : public ::testing::TestWithParam(handle, - x.data(), - x.data(), - dist[i].data(), - dist2[i].data(), - m, - n, - k, - params, - threshold, - metric_arg); + distanceLauncher( + handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); } else { - distanceLauncher(handle, - x.data(), - x.data(), - dist[i].data(), - dist2[i].data(), - m, - n, - k, - params, - threshold, - metric_arg); + distanceLauncher( + handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); } } raft::resource::sync_stream(handle, stream); @@ -673,8 +696,8 @@ class DistanceTestSameBuffer : public ::testing::TestWithParam params; - dev_vector x; + DistanceInputs params; + rmm::device_uvector x; static const std::int64_t N = 2; std::array dist_ref, dist, dist2; }; diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh new file mode 100644 index 000000000..461a202f2 --- /dev/null +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" + +#include "knn_utils.cuh" +#include "naive_knn.cuh" + +#include + +#include +#include + +#include + +namespace cuvs::neighbors::brute_force { + +template +struct AnnBruteForceInputs { + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + cuvs::distance::DistanceType metric; + float metric_arg = 0.0f; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const AnnBruteForceInputs& p) +{ + os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " + << static_cast(p.metric) << static_cast(p.metric_arg) << '}' << std::endl; + return os; +} + +template +class AnnBruteForceTest : public ::testing::TestWithParam> { + public: + AnnBruteForceTest() + : stream_(raft::resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam>::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void testBruteForce() + { + size_t queries_size = ps.num_queries * ps.k; + + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + raft::resource::sync_stream(handle_); + + { + // Require exact result for brute force + rmm::device_uvector distances_bruteforce_dev(queries_size, stream_); + rmm::device_uvector indices_bruteforce_dev(queries_size, stream_); + + auto idx = [this]() { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + + return brute_force::build(handle_, database_view, ps.metric, ps.metric_arg); + }(); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = raft::make_device_matrix_view( + indices_bruteforce_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_bruteforce_dev.data(), ps.num_queries, ps.k); + + brute_force::search( + handle_, idx, search_queries_view, indices_out_view, dists_out_view, std::nullopt); + + raft::resource::sync_stream(handle_); + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(indices_naive_dev.data(), + indices_bruteforce_dev.data(), + distances_naive_dev.data(), + distances_bruteforce_dev.data(), + ps.num_queries, + ps.k, + 0.001f, + stream_, + true)); + brute_force::search( + handle_, idx, search_queries_view, indices_out_view, dists_out_view, std::nullopt); + } + } + + void SetUp() override + { + database.resize(ps.num_db_vecs * ps.dim, stream_); + search_queries.resize(ps.num_queries * ps.dim, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{} || std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + raft::resource::sync_stream(handle_); + } + + void TearDown() override + { + raft::resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnBruteForceInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + +const std::vector> inputs = { + // test various dims (aligned and not aligned to vector sizes) + {1000, 10000, 1, 16, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 2, 16, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 3, 16, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 4, 16, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 5, 16, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 8, 16, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 5, 16, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 8, 16, cuvs::distance::DistanceType::L2SqrtExpanded}, + + // test dims that do not fit into kernel shared memory limits + {1000, 10000, 2048, 16, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 2049, 16, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 2050, 16, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 2051, 16, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 2052, 16, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 2053, 16, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 2056, 16, cuvs::distance::DistanceType::L2Expanded}, + + // test fused_l2_knn + {100, 1000, 16, 10, cuvs::distance::DistanceType::L2Expanded}, + {256, 256, 30, 10, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, cuvs::distance::DistanceType::L2Expanded}, + {100, 1000, 16, 50, cuvs::distance::DistanceType::L2Expanded}, + {20, 10000, 16, 10, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 50, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 32, 50, cuvs::distance::DistanceType::L2Expanded}, + {10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded}, + {100, 1000, 16, 10, cuvs::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 16, 10, cuvs::distance::DistanceType::L2Unexpanded}, + {100, 1000, 16, 50, cuvs::distance::DistanceType::L2Unexpanded}, + {20, 10000, 16, 50, cuvs::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 16, 50, cuvs::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 32, 50, cuvs::distance::DistanceType::L2Unexpanded}, + {10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Unexpanded}, + + // test tile + {256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded}, + {256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct}, + {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded}, + {10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded}, + {789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded}, + {4, 12, 32, 6, cuvs::distance::DistanceType::L2Expanded}, + {1, 40, 32, 30, cuvs::distance::DistanceType::L2Expanded}, + {1000, 500000, 128, 128, cuvs::distance::DistanceType::L2Expanded}}; +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/ann_brute_force/test_float.cu b/cpp/test/neighbors/ann_brute_force/test_float.cu new file mode 100644 index 000000000..ded371c42 --- /dev/null +++ b/cpp/test/neighbors/ann_brute_force/test_float.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_brute_force.cuh" + +#include + +namespace cuvs::neighbors::brute_force { + +using AnnBruteForceTest_float = AnnBruteForceTest; +TEST_P(AnnBruteForceTest_float, AnnBruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(AnnBruteForceTest, AnnBruteForceTest_float, ::testing::ValuesIn(inputs)); + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/ann_brute_force/test_half.cu b/cpp/test/neighbors/ann_brute_force/test_half.cu new file mode 100644 index 000000000..39b7b7982 --- /dev/null +++ b/cpp/test/neighbors/ann_brute_force/test_half.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_brute_force.cuh" + +#include + +namespace cuvs::neighbors::brute_force { + +using AnnBruteForceTest_half_float = AnnBruteForceTest; +TEST_P(AnnBruteForceTest_half_float, AnnBruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(AnnBruteForceTest, + AnnBruteForceTest_half_float, + ::testing::ValuesIn(inputs)); + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/brute_force.cu b/cpp/test/neighbors/brute_force.cu index c97bb5531..f1a05e045 100644 --- a/cpp/test/neighbors/brute_force.cu +++ b/cpp/test/neighbors/brute_force.cu @@ -25,9 +25,13 @@ #include #include +#include + namespace cuvs::neighbors::brute_force { + +template struct KNNInputs { - std::vector> input; + std::vector> input; int k; std::vector labels; }; @@ -53,11 +57,11 @@ RAFT_KERNEL build_expected_output(int* output, int n_rows, int k, const int* lab } } -template -class KNNTest : public ::testing::TestWithParam { +template +class KNNTest : public ::testing::TestWithParam> { public: KNNTest() - : params_(::testing::TestWithParam::GetParam()), + : params_(::testing::TestWithParam>::GetParam()), stream(raft::resource::get_cuda_stream(handle)), actual_labels_(0, stream), expected_labels_(0, stream), @@ -85,7 +89,7 @@ class KNNTest : public ::testing::TestWithParam { auto indices = raft::make_device_matrix_view(indices_.data(), rows_, k_); auto distances = - raft::make_device_matrix_view(distances_.data(), rows_, k_); + raft::make_device_matrix_view(distances_.data(), rows_, k_); auto metric = cuvs::distance::DistanceType::L2Unexpanded; auto idx = cuvs::neighbors::brute_force::build(handle, index, metric); @@ -119,23 +123,22 @@ class KNNTest : public ::testing::TestWithParam { cudaMemsetAsync(actual_labels_.data(), 0, actual_labels_.size() * sizeof(int), stream)); RAFT_CUDA_TRY( cudaMemsetAsync(expected_labels_.data(), 0, expected_labels_.size() * sizeof(int), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(input_.data(), 0, input_.size() * sizeof(float), stream)); - RAFT_CUDA_TRY( - cudaMemsetAsync(search_data_.data(), 0, search_data_.size() * sizeof(float), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(input_.data(), 0, input_.size() * sizeof(T), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(search_data_.data(), 0, search_data_.size() * sizeof(T), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(indices_.data(), 0, indices_.size() * sizeof(IdxT), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(distances_.data(), 0, distances_.size() * sizeof(float), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(distances_.data(), 0, distances_.size() * sizeof(DistT), stream)); RAFT_CUDA_TRY( cudaMemsetAsync(search_labels_.data(), 0, search_labels_.size() * sizeof(int), stream)); - std::vector row_major_input; + std::vector row_major_input; for (std::size_t i = 0; i < params_.input.size(); ++i) { for (std::size_t j = 0; j < params_.input[i].size(); ++j) { row_major_input.push_back(params_.input[i][j]); } } rmm::device_buffer input_d = - rmm::device_buffer(row_major_input.data(), row_major_input.size() * sizeof(float), stream); - float* input_ptr = static_cast(input_d.data()); + rmm::device_buffer(row_major_input.data(), row_major_input.size() * sizeof(T), stream); + T* input_ptr = static_cast(input_d.data()); rmm::device_buffer labels_d = rmm::device_buffer(params_.labels.data(), params_.labels.size() * sizeof(int), stream); @@ -151,13 +154,13 @@ class KNNTest : public ::testing::TestWithParam { raft::resources handle; cudaStream_t stream; - KNNInputs params_; + KNNInputs params_; int rows_; int cols_; - rmm::device_uvector input_; - rmm::device_uvector search_data_; + rmm::device_uvector input_; + rmm::device_uvector search_data_; rmm::device_uvector indices_; - rmm::device_uvector distances_; + rmm::device_uvector distances_; int k_; rmm::device_uvector search_labels_; @@ -165,7 +168,8 @@ class KNNTest : public ::testing::TestWithParam { rmm::device_uvector expected_labels_; }; -const std::vector inputs = { +template +const std::vector> inputs = { // 2D {{ {2.7810836, 2.550537003}, @@ -182,10 +186,14 @@ const std::vector inputs = { 2, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}}}; -typedef KNNTest KNNTestFint64_t; -TEST_P(KNNTestFint64_t, BruteForce) { this->testBruteForce(); } +typedef KNNTest KNNTest_float_int64_t; +TEST_P(KNNTest_float_int64_t, BruteForce) { this->testBruteForce(); } + +typedef KNNTest KNNTest_half_int64_t; +TEST_P(KNNTest_half_int64_t, BruteForce) { this->testBruteForce(); } -INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint64_t, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(KNNTest, KNNTest_float_int64_t, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(KNNTest, KNNTest_half_int64_t, ::testing::ValuesIn(inputs)); // Also test with larger random inputs, including col-major inputs struct RandomKNNInputs { @@ -205,7 +213,7 @@ std::ostream& operator<<(std::ostream& os, const RandomKNNInputs& input) << " row_major:" << input.row_major; } -template +template class RandomBruteForceKNNTest : public ::testing::TestWithParam { public: RandomBruteForceKNNTest() @@ -229,67 +237,153 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam raft::matrix::fill( handle_, raft::make_device_matrix_view(cuvs_distances_.data(), params_.num_queries, params_.k), - T{0.0}); + DistT{0.0}); raft::matrix::fill( handle_, raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), - T{0.0}); + DistT{0.0}); } protected: + void cpu_distance(const T* d_A, + const T* d_B, + DistT* d_vals, + bool is_row_major_A, + bool is_row_major_B, + bool is_row_major_C, + cudaStream_t stream, + DistT alpha = 1.0, + DistT beta = 0.0) + { + size_t size_A = params_.num_queries * params_.dim * sizeof(T); + size_t size_B = params_.num_db_vecs * params_.dim * sizeof(T); + size_t size_vals = params_.num_queries * params_.num_db_vecs * sizeof(DistT); + + T* h_A = static_cast(malloc(size_A)); + T* h_B = static_cast(malloc(size_B)); + DistT* h_vals = static_cast(malloc(size_vals)); + + cudaMemcpyAsync(h_A, d_A, size_A, cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(h_B, d_B, size_B, cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(h_vals, d_vals, size_vals, cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + bool trans_a = is_row_major_A; + bool trans_b = is_row_major_B; + bool trans_c = is_row_major_C; + + for (int64_t i = 0; i < params_.num_queries; ++i) { + for (int64_t j = 0; j < params_.num_db_vecs; ++j) { + DistT sum = 0; + DistT norms_A = 0; + DistT norms_B = 0; + + for (int64_t l = 0; l < params_.dim; ++l) { + int64_t a_index = trans_a ? i * params_.dim + l : l * params_.num_queries + i; + int64_t b_index = trans_b ? j * params_.dim + l : l * params_.num_db_vecs + j; + DistT A_v; + DistT B_v; + if constexpr (sizeof(T) == 2) { + A_v = __half2float(h_A[a_index]); + B_v = __half2float(h_B[b_index]); + } else { + A_v = h_A[a_index]; + B_v = h_B[b_index]; + } + + sum += A_v * B_v; + + norms_A += A_v * A_v; + norms_B += B_v * B_v; + } + + int64_t c_index = trans_c ? i * params_.num_db_vecs + j : j * params_.num_queries + i; + + h_vals[c_index] = alpha * sum + beta * h_vals[c_index]; + if (params_.metric == cuvs::distance::DistanceType::L2Expanded) { + h_vals[c_index] = DistT(-2.0) * h_vals[c_index] + norms_A + norms_B; + } else if (params_.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + h_vals[c_index] = std::sqrt(DistT(-2.0) * h_vals[c_index] + norms_A + norms_B); + } else if (params_.metric == cuvs::distance::DistanceType::CosineExpanded) { + h_vals[c_index] = DistT(1.0) - h_vals[c_index] / std::sqrt(norms_A * norms_B); + } + } + } + cudaMemcpyAsync(d_vals, h_vals, size_vals, cudaMemcpyHostToDevice, stream); + cudaStreamSynchronize(stream); + + free(h_A); + free(h_B); + free(h_vals); + } + void testBruteForce() { - float metric_arg = 3.0; + DistT metric_arg = 3.0; // calculate the naive knn, by calculating the full pairwise distances and doing a k-select - rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); + rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); rmm::device_uvector workspace(0, stream_); auto temp_dist = temp_distances.data(); - rmm::device_uvector temp_row_major_dist(num_db_vecs * num_queries, stream_); - - if (params_.row_major) { - distance::pairwise_distance( - handle_, - raft::make_device_matrix_view( - search_queries.data(), params_.num_queries, params_.dim), - raft::make_device_matrix_view( - database.data(), params_.num_db_vecs, params_.dim), - raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), - metric, - metric_arg); - + rmm::device_uvector temp_row_major_dist(num_db_vecs * num_queries, stream_); + + // For the complex post processes in these algorithms, we use CPU logic to make the baseline. + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded) { + cpu_distance(search_queries.data(), + database.data(), + temp_distances.data(), + params_.row_major, + params_.row_major, + true, + stream_); } else { - distance::pairwise_distance(handle_, - raft::make_device_matrix_view( - search_queries.data(), params_.num_queries, params_.dim), - raft::make_device_matrix_view( - database.data(), params_.num_db_vecs, params_.dim), - raft::make_device_matrix_view( - temp_distances.data(), num_queries, num_db_vecs), - metric, - metric_arg); - - // the pairwisse_distance call assumes that the inputs and outputs are all either row-major - // or col-major - meaning we have to transpose the output back for col-major queries - // for comparison - raft::linalg::transpose( - handle_, temp_dist, temp_row_major_dist.data(), num_queries, num_db_vecs, stream_); - temp_dist = temp_row_major_dist.data(); + if (params_.row_major) { + distance::pairwise_distance(handle_, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + raft::make_device_matrix_view( + temp_distances.data(), num_queries, num_db_vecs), + metric, + metric_arg); + + } else { + distance::pairwise_distance( + handle_, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + raft::make_device_matrix_view( + temp_distances.data(), num_queries, num_db_vecs), + metric, + metric_arg); + + // the pairwise_distance call assumes that the inputs and outputs are all either row-major + // or col-major - meaning we have to transpose the output back for col-major queries + // for comparison + raft::linalg::transpose( + handle_, temp_dist, temp_row_major_dist.data(), num_queries, num_db_vecs, stream_); + temp_dist = temp_row_major_dist.data(); + } } cuvs::selection::select_k( handle_, - raft::make_device_matrix_view(temp_dist, num_queries, num_db_vecs), + raft::make_device_matrix_view(temp_dist, num_queries, num_db_vecs), std::nullopt, raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), raft::make_device_matrix_view(ref_indices_.data(), params_.num_queries, params_.k), cuvs::distance::is_min_close(metric), true); - auto indices = raft::make_device_matrix_view( + auto indices = raft::make_device_matrix_view( cuvs_indices_.data(), params_.num_queries, params_.k); - auto distances = raft::make_device_matrix_view( + auto distances = raft::make_device_matrix_view( cuvs_distances_.data(), params_.num_queries, params_.k); if (params_.row_major) { @@ -332,7 +426,7 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam cuvs_distances_.data(), num_queries, k_, - float(0.001), + DistT(0.001), stream_, true)); } @@ -364,16 +458,16 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam rmm::device_uvector database; rmm::device_uvector search_queries; rmm::device_uvector cuvs_indices_; - rmm::device_uvector cuvs_distances_; + rmm::device_uvector cuvs_distances_; rmm::device_uvector ref_indices_; - rmm::device_uvector ref_distances_; + rmm::device_uvector ref_distances_; int k_; cuvs::distance::DistanceType metric; }; const std::vector random_inputs = { // test each distance metric on a small-ish input, with row-major inputs - {256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded, true}, + {100, 256, 2, 65, cuvs::distance::DistanceType::L2Expanded, true}, {256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true}, {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true}, {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true}, @@ -387,7 +481,7 @@ const std::vector random_inputs = { {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true}, {256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true}, // test each distance metric with col-major inputs - {256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded, false}, + {256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, false}, {256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false}, {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false}, {256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false}, @@ -404,17 +498,24 @@ const std::vector random_inputs = { {10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false}, {345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true}, {789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 500000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true}, - {1000, 500000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false}, {1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true}, {1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false}, {1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false}}; -typedef RandomBruteForceKNNTest RandomBruteForceKNNTestF; +typedef RandomBruteForceKNNTest RandomBruteForceKNNTestF; TEST_P(RandomBruteForceKNNTestF, BruteForce) { this->testBruteForce(); } +typedef RandomBruteForceKNNTest RandomBruteForceKNNTestH; +TEST_P(RandomBruteForceKNNTestH, BruteForce) { this->testBruteForce(); } + INSTANTIATE_TEST_CASE_P(RandomBruteForceKNNTest, RandomBruteForceKNNTestF, ::testing::ValuesIn(random_inputs)); +INSTANTIATE_TEST_CASE_P(RandomBruteForceKNNTest, + RandomBruteForceKNNTestH, + ::testing::ValuesIn(random_inputs)); + } // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu index 2b8ae9d9a..9304ee045 100644 --- a/cpp/test/neighbors/brute_force_prefiltered.cu +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -30,7 +30,11 @@ #include +#include +#include #include +#include +#include #include #include @@ -94,6 +98,10 @@ RAFT_KERNEL normalize_kernel( } } +struct float_to_half { + __host__ __device__ __half operator()(const float x) const { return __float2half(x); } +}; + template void normalize(OutT* theta, const InT* in_vals, @@ -137,7 +145,8 @@ void set_bitmap(const index_t* src, <<>>(src, dst, bitmap, n_edges, n_cols); RAFT_CUDA_TRY(cudaGetLastError()); } -template + +template class PrefilteredBruteForceTest : public ::testing::TestWithParam> { public: @@ -157,22 +166,22 @@ class PrefilteredBruteForceTest protected: index_t create_sparse_matrix_with_rmat(index_t m, index_t n, - value_t sparsity, + float sparsity, rmm::device_uvector& filter_d) { index_t r_scale = (index_t)std::log2(m); index_t c_scale = (index_t)std::log2(n); - index_t n_edges = (index_t)(m * n * 1.0 * sparsity); + index_t n_edges = (index_t)(m * n * 1.0f * sparsity); index_t max_scale = std::max(r_scale, c_scale); rmm::device_uvector out_src{(unsigned long)n_edges, stream}; rmm::device_uvector out_dst{(unsigned long)n_edges, stream}; - rmm::device_uvector theta{(unsigned long)(4 * max_scale), stream}; + rmm::device_uvector theta{(unsigned long)(4 * max_scale), stream}; raft::random::RngState state{2024ULL, raft::random::GeneratorType::GenPC}; - raft::random::uniform(handle, state, theta.data(), theta.size(), 0.0f, 1.0f); - normalize( + raft::random::uniform(handle, state, theta.data(), theta.size(), 0.0f, 1.0f); + normalize( theta.data(), theta.data(), max_scale, r_scale, c_scale, r_scale != c_scale, true, stream); raft::random::rmat_rectangular_gen((index_t*)nullptr, out_src.data(), @@ -236,15 +245,15 @@ class PrefilteredBruteForceTest } } - void cpu_sddmm(const std::vector& A, - const std::vector& B, - std::vector& vals, + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, const std::vector& cols, const std::vector& row_ptrs, bool is_row_major_A, bool is_row_major_B, - value_t alpha = 1.0, - value_t beta = 0.0) + dist_t alpha = 1.0, + dist_t beta = 0.0) { if (params.n_queries * params.dim != static_cast(A.size()) || params.dim * params.n_dataset != static_cast(B.size())) { @@ -257,24 +266,35 @@ class PrefilteredBruteForceTest for (index_t i = 0; i < params.n_queries; ++i) { for (index_t j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { - value_t sum = 0; - value_t norms_A = 0; - value_t norms_B = 0; + dist_t sum = 0; + dist_t norms_A = 0; + dist_t norms_B = 0; + for (index_t l = 0; l < params.dim; ++l) { index_t a_index = trans_a ? i * params.dim + l : l * params.n_queries + i; index_t b_index = trans_b ? l * params.n_dataset + cols[j] : cols[j] * params.dim + l; - sum += A[a_index] * B[b_index]; - - norms_A += A[a_index] * A[a_index]; - norms_B += B[b_index] * B[b_index]; + dist_t A_v; + dist_t B_v; + if constexpr (sizeof(value_t) == 2) { + A_v = __half2float(__float2half(A[a_index])); + B_v = __half2float(__float2half(B[b_index])); + } else { + A_v = A[a_index]; + B_v = B[b_index]; + } + + sum += A_v * B_v; + + norms_A += A_v * A_v; + norms_B += B_v * B_v; } vals[j] = alpha * sum + beta * vals[j]; if (params.metric == cuvs::distance::DistanceType::L2Expanded) { - vals[j] = value_t(-2.0) * vals[j] + norms_A + norms_B; + vals[j] = dist_t(-2.0) * vals[j] + norms_A + norms_B; } else if (params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - vals[j] = std::sqrt(value_t(-2.0) * vals[j] + norms_A + norms_B); + vals[j] = std::sqrt(dist_t(-2.0) * vals[j] + norms_A + norms_B); } else if (params.metric == cuvs::distance::DistanceType::CosineExpanded) { - vals[j] = value_t(1.0) - vals[j] / std::sqrt(norms_A * norms_B); + vals[j] = dist_t(1.0) - vals[j] / std::sqrt(norms_A * norms_B); } } } @@ -282,32 +302,31 @@ class PrefilteredBruteForceTest void cpu_select_k(const std::vector& indptr_h, const std::vector& indices_h, - const std::vector& values_h, + const std::vector& values_h, std::optional>& in_idx_h, index_t n_queries, index_t n_dataset, index_t top_k, - std::vector& out_values_h, + std::vector& out_values_h, std::vector& out_indices_h, bool select_min = true) { - auto comp = [select_min](const std::pair& a, - const std::pair& b) { + auto comp = [select_min](const std::pair& a, + const std::pair& b) { return select_min ? a.first < b.first : a.first >= b.first; }; for (index_t row = 0; row < n_queries; ++row) { - std::priority_queue, - std::vector>, + std::priority_queue, + std::vector>, decltype(comp)> pq(comp); - for (index_t idx = indptr_h[row]; idx < indptr_h[row + 1]; ++idx) { pq.push({values_h[idx], (in_idx_h.has_value()) ? (*in_idx_h)[idx] : indices_h[idx]}); if (pq.size() > size_t(top_k)) { pq.pop(); } } - std::vector> row_pairs; + std::vector> row_pairs; while (!pq.empty()) { row_pairs.push_back(pq.top()); pq.pop(); @@ -347,40 +366,80 @@ class PrefilteredBruteForceTest index_t dataset_size = params.n_dataset * params.dim; index_t queries_size = params.n_queries * params.dim; - std::vector dataset_h(dataset_size); - std::vector queries_h(queries_size); + std::vector dataset_h(dataset_size); + std::vector queries_h(queries_size); dataset_d.resize(dataset_size, stream); queries_d.resize(queries_size, stream); auto blobs_in_val = - raft::make_device_matrix(handle, 1, dataset_size + queries_size); + raft::make_device_matrix(handle, 1, dataset_size + queries_size); auto labels = raft::make_device_vector(handle, 1); - raft::random::make_blobs(blobs_in_val.data_handle(), - labels.data_handle(), - 1, - dataset_size + queries_size, - 1, - stream, - false, - nullptr, - nullptr, - value_t(1.0), - false, - value_t(-1.0f), - value_t(1.0f), - uint64_t(2024)); + if constexpr (!std::is_same_v) { + raft::random::make_blobs(blobs_in_val.data_handle(), + labels.data_handle(), + 1, + dataset_size + queries_size, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-1.0f), + value_t(1.0f), + uint64_t(2024)); + } else { + raft::random::make_blobs(blobs_in_val.data_handle(), + labels.data_handle(), + 1, + dataset_size + queries_size, + 1, + stream, + false, + nullptr, + nullptr, + dist_t(1.0), + false, + dist_t(-1.0f), + dist_t(1.0f), + uint64_t(2024)); + } raft::copy(dataset_h.data(), blobs_in_val.data_handle(), dataset_size, stream); - raft::copy(dataset_d.data(), blobs_in_val.data_handle(), dataset_size, stream); + + if constexpr (std::is_same_v) { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_in_val.data_handle()); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(dataset_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + dataset_size, + d_value_ptr, + float_to_half()); + } else { + raft::copy(dataset_d.data(), blobs_in_val.data_handle(), dataset_size, stream); + } raft::copy(queries_h.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); - raft::copy(queries_d.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + if constexpr (std::is_same_v) { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_in_val.data_handle() + dataset_size); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(queries_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + queries_size, + d_value_ptr, + float_to_half()); + } else { + raft::copy(queries_d.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + } raft::resource::sync_stream(handle); - std::vector values_h(nnz); + std::vector values_h(nnz); std::vector indices_h(nnz); std::vector indptr_h(params.n_queries + 1); @@ -390,9 +449,9 @@ class PrefilteredBruteForceTest bool select_min = cuvs::distance::is_min_close(params.metric); - std::vector out_val_h(params.n_queries * params.top_k, - select_min ? std::numeric_limits::infinity() - : std::numeric_limits::lowest()); + std::vector out_val_h( + params.n_queries * params.top_k, + select_min ? std::numeric_limits::infinity() : std::numeric_limits::lowest()); std::vector out_idx_h(params.n_queries * params.top_k, static_cast(0)); out_val_d.resize(params.n_queries * params.top_k, stream); @@ -404,7 +463,6 @@ class PrefilteredBruteForceTest raft::resource::sync_stream(handle); std::optional> optional_indices_h = std::nullopt; - cpu_select_k(indptr_h, indices_h, values_h, @@ -415,10 +473,11 @@ class PrefilteredBruteForceTest out_val_h, out_idx_h, select_min); - out_val_expected_d.resize(params.n_queries * params.top_k, stream); out_idx_expected_d.resize(params.n_queries * params.top_k, stream); + // dump_vector(out_val_h.data(), out_val_h.size(), "out_val_h"); + raft::update_device(out_val_expected_d.data(), out_val_h.data(), out_val_h.size(), stream); raft::update_device(out_idx_expected_d.data(), out_idx_h.data(), out_idx_h.size(), stream); @@ -438,12 +497,17 @@ class PrefilteredBruteForceTest auto filter = cuvs::core::bitmap_view( (const bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset); - auto out_val = raft::make_device_matrix_view( + auto out_val = raft::make_device_matrix_view( out_val_d.data(), params.n_queries, params.top_k); auto out_idx = raft::make_device_matrix_view( out_idx_d.data(), params.n_queries, params.top_k); brute_force::search(handle, dataset, queries, out_idx, out_val, std::make_optional(filter)); + std::vector out_val_h(params.n_queries * params.top_k, + std::numeric_limits::infinity()); + + raft::update_host(out_val_h.data(), out_val_d.data(), out_val_h.size(), stream); + raft::resource::sync_stream(handle); ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(out_idx_expected_d.data(), out_idx.data_handle(), @@ -468,49 +532,61 @@ class PrefilteredBruteForceTest rmm::device_uvector queries_d; rmm::device_uvector filter_d; - rmm::device_uvector out_val_d; - rmm::device_uvector out_val_expected_d; + rmm::device_uvector out_val_d; + rmm::device_uvector out_val_expected_d; rmm::device_uvector out_idx_d; rmm::device_uvector out_idx_expected_d; }; -using PrefilteredBruteForceTest_float_int64 = PrefilteredBruteForceTest; +using PrefilteredBruteForceTest_float_int64 = PrefilteredBruteForceTest; TEST_P(PrefilteredBruteForceTest_float_int64, Result) { Run(); } +using PrefilteredBruteForceTest_half_int64 = PrefilteredBruteForceTest; +TEST_P(PrefilteredBruteForceTest_half_int64, Result) { Run(); } + template const std::vector> selectk_inputs = { + {8, 131072, 255, 255, 0.01, cuvs::distance::DistanceType::L2Expanded}, + {8, 131072, 255, 255, 0.01, cuvs::distance::DistanceType::InnerProduct}, + {8, 131072, 255, 255, 0.01, cuvs::distance::DistanceType::L2SqrtExpanded}, + {8, 131072, 255, 255, 0.01, cuvs::distance::DistanceType::CosineExpanded}, {2, 131072, 255, 255, 0.4, cuvs::distance::DistanceType::L2Expanded}, + {8, 131072, 512, 16, 0.5, cuvs::distance::DistanceType::L2Expanded}, {16, 131072, 2052, 16, 0.2, cuvs::distance::DistanceType::L2Expanded}, + {2, 8192, 255, 16, 0.01, cuvs::distance::DistanceType::InnerProduct}, {2, 8192, 255, 16, 0.4, cuvs::distance::DistanceType::InnerProduct}, {16, 8192, 512, 16, 0.5, cuvs::distance::DistanceType::InnerProduct}, + {128, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 1, 0, 0.1, cuvs::distance::DistanceType::L2Expanded}, {1024, 8192, 3, 0, 0.1, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 5, 0, 0.1, cuvs::distance::DistanceType::L2SqrtExpanded}, {1024, 8192, 8, 0, 0.1, cuvs::distance::DistanceType::CosineExpanded}, - {1024, 8192, 1, 1, 0.1, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 1, 1, 0.1, cuvs::distance::DistanceType::L2Expanded}, //-- {1024, 8192, 3, 1, 0.1, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 5, 1, 0.1, cuvs::distance::DistanceType::L2SqrtExpanded}, {1024, 8192, 8, 1, 0.1, cuvs::distance::DistanceType::CosineExpanded}, - {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::L2Expanded}, {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::L2Expanded}, {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::InnerProduct}, + {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::L2SqrtExpanded}, {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::L2SqrtExpanded}, {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::L2SqrtExpanded}, {1024, 8192, 2050, 16, 0.4, cuvs::distance::DistanceType::CosineExpanded}, {1024, 8192, 2051, 16, 0.5, cuvs::distance::DistanceType::CosineExpanded}, - {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::CosineExpanded}, + {1024, 8192, 2052, 16, 0.2, cuvs::distance::DistanceType::CosineExpanded}, {1024, 8192, 1, 16, 0.5, cuvs::distance::DistanceType::L2Expanded}, {1024, 8192, 2, 16, 0.2, cuvs::distance::DistanceType::L2Expanded}, + {1024, 8192, 3, 16, 0.4, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 4, 16, 0.5, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 5, 16, 0.2, cuvs::distance::DistanceType::L2SqrtExpanded}, @@ -522,4 +598,8 @@ INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, PrefilteredBruteForceTest_float_int64, ::testing::ValuesIn(selectk_inputs)); +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, + PrefilteredBruteForceTest_half_int64, + ::testing::ValuesIn(selectk_inputs)); + } // namespace cuvs::neighbors::brute_force