From 45703bfc5b3a78740ce8bc67f3e89acb95f08ffc Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 4 Feb 2025 16:14:02 -0800 Subject: [PATCH] Expose NN-Descent to C and Python (#635) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/635 --- cpp/CMakeLists.txt | 1 + cpp/include/cuvs/neighbors/nn_descent.h | 181 +++++++++++++ cpp/src/neighbors/nn_descent_c.cpp | 167 ++++++++++++ python/cuvs/cuvs/distance/__init__.py | 4 +- python/cuvs/cuvs/distance/distance.pyx | 2 + python/cuvs/cuvs/neighbors/CMakeLists.txt | 1 + python/cuvs/cuvs/neighbors/__init__.py | 10 +- .../cuvs/neighbors/nn_descent/CMakeLists.txt | 28 ++ .../cuvs/neighbors/nn_descent/__init__.pxd | 0 .../cuvs/neighbors/nn_descent/__init__.py | 22 ++ .../cuvs/neighbors/nn_descent/nn_descent.pxd | 63 +++++ .../cuvs/neighbors/nn_descent/nn_descent.pyx | 239 ++++++++++++++++++ python/cuvs/cuvs/tests/test_nn_descent.py | 53 ++++ 13 files changed, 768 insertions(+), 3 deletions(-) create mode 100644 cpp/include/cuvs/neighbors/nn_descent.h create mode 100644 cpp/src/neighbors/nn_descent_c.cpp create mode 100644 python/cuvs/cuvs/neighbors/nn_descent/CMakeLists.txt create mode 100644 python/cuvs/cuvs/neighbors/nn_descent/__init__.pxd create mode 100644 python/cuvs/cuvs/neighbors/nn_descent/__init__.py create mode 100644 python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd create mode 100644 python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx create mode 100644 python/cuvs/cuvs/tests/test_nn_descent.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 65b1471f5..fb33c1ab7 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -691,6 +691,7 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$:NVTX_ENAB src/neighbors/ivf_pq_c.cpp src/neighbors/cagra_c.cpp $<$:src/neighbors/hnsw_c.cpp> + src/neighbors/nn_descent_c.cpp src/neighbors/refine/refine_c.cpp src/preprocessing/quantize/scalar_c.cpp src/distance/pairwise_distance_c.cpp diff --git a/cpp/include/cuvs/neighbors/nn_descent.h b/cpp/include/cuvs/neighbors/nn_descent.h new file mode 100644 index 000000000..81c162598 --- /dev/null +++ b/cpp/include/cuvs/neighbors/nn_descent.h @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @defgroup nn_descent_c_index_params The nn-descent algorithm parameters. + * @{ + */ +/** + * @brief Parameters used to build an nn-descent index + * + * `metric`: The distance metric to use + * `metric_arg`: The argument used by distance metrics like Minkowskidistance + * `graph_degree`: For an input dataset of dimensions (N, D), + * determines the final dimensions of the all-neighbors knn graph + * which turns out to be of dimensions (N, graph_degree) + * `intermediate_graph_degree`: Internally, nn-descent builds an + * all-neighbors knn graph of dimensions (N, intermediate_graph_degree) + * before selecting the final `graph_degree` neighbors. It's recommended + * that `intermediate_graph_degree` >= 1.5 * graph_degree + * `max_iterations`: The number of iterations that nn-descent will refine + * the graph for. More iterations produce a better quality graph at cost of performance + * `termination_threshold`: The delta at which nn-descent will terminate its iterations + */ +struct cuvsNNDescentIndexParams { + cuvsDistanceType metric; + float metric_arg; + size_t graph_degree; + size_t intermediate_graph_degree; + size_t max_iterations; + float termination_threshold; + bool return_distances; + size_t n_clusters; +}; + +typedef struct cuvsNNDescentIndexParams* cuvsNNDescentIndexParams_t; + +/** + * @brief Allocate NN-Descent Index params, and populate with default values + * + * @param[in] index_params cuvsNNDescentIndexParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsNNDescentIndexParamsCreate(cuvsNNDescentIndexParams_t* index_params); + +/** + * @brief De-allocate NN-Descent Index params + * + * @param[in] index_params + * @return cuvsError_t + */ +cuvsError_t cuvsNNDescentIndexParamsDestroy(cuvsNNDescentIndexParams_t index_params); +/** + * @} + */ + +/** + * @defgroup nn_descent_c_index NN-Descent index + * @{ + */ +/** + * @brief Struct to hold address of cuvs::neighbors::nn_descent::index and its active trained dtype + * + */ +typedef struct { + uintptr_t addr; + DLDataType dtype; +} cuvsNNDescentIndex; + +typedef cuvsNNDescentIndex* cuvsNNDescentIndex_t; + +/** + * @brief Allocate NN-Descent index + * + * @param[in] index cuvsNNDescentIndex_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsNNDescentIndexCreate(cuvsNNDescentIndex_t* index); + +/** + * @brief De-allocate NN-Descent index + * + * @param[in] index cuvsNNDescentIndex_t to de-allocate + */ +cuvsError_t cuvsNNDescentIndexDestroy(cuvsNNDescentIndex_t index); +/** + * @} + */ + +/** + * @defgroup nn_descent_c_index_build NN-Descent index build + * @{ + */ +/** + * @brief Build a NN-Descent index with a `DLManagedTensor` which has underlying + * `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`, + * or `kDLCPU`. Also, acceptable underlying types are: + * 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 2. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 16` + * 3. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` + * 4. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` + * + * @code {.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Assume a populated `DLManagedTensor` type here + * DLManagedTensor dataset; + * + * // Create default index params + * cuvsNNDescentIndexParams_t index_params; + * cuvsError_t params_create_status = cuvsNNDescentIndexParamsCreate(&index_params); + * + * // Create NN-Descent index + * cuvsNNDescentIndex_t index; + * cuvsError_t index_create_status = cuvsNNDescentIndexCreate(&index); + * + * // Build the NN-Descent Index + * cuvsError_t build_status = cuvsNNDescentBuild(res, index_params, &dataset, index); + * + * // de-allocate `index_params`, `index` and `res` + * cuvsError_t params_destroy_status = cuvsNNDescentIndexParamsDestroy(index_params); + * cuvsError_t index_destroy_status = cuvsNNDescentIndexDestroy(index); + * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] index_params cuvsNNDescentIndexParams_t used to build NN-Descent index + * @param[in] dataset DLManagedTensor* training dataset on host or device memory + * @param[inout] graph Optional preallocated graph on host memory to store output + * @param[out] index cuvsNNDescentIndex_t Newly built NN-Descent index + * @return cuvsError_t + */ +cuvsError_t cuvsNNDescentBuild(cuvsResources_t res, + cuvsNNDescentIndexParams_t index_params, + DLManagedTensor* dataset, + DLManagedTensor* graph, + cuvsNNDescentIndex_t index); +/** + * @} + */ + +/** + * @brief Get the KNN graph from a built NN-Descent index + * + * @param[in] index cuvsNNDescentIndex_t Built NN-Descent index + * @param[inout] graph Optional preallocated graph on host memory to store output + * @return cuvsError_t + */ +cuvsError_t cuvsNNDescentIndexGetGraph(cuvsNNDescentIndex_t index, DLManagedTensor* graph); +#ifdef __cplusplus +} +#endif diff --git a/cpp/src/neighbors/nn_descent_c.cpp b/cpp/src/neighbors/nn_descent_c.cpp new file mode 100644 index 000000000..727b332cb --- /dev/null +++ b/cpp/src/neighbors/nn_descent_c.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace { + +template +void* _build(cuvsResources_t res, + cuvsNNDescentIndexParams params, + DLManagedTensor* dataset_tensor, + DLManagedTensor* graph_tensor) +{ + auto res_ptr = reinterpret_cast(res); + auto dataset = dataset_tensor->dl_tensor; + + auto build_params = cuvs::neighbors::nn_descent::index_params(); + build_params.metric = static_cast((int)params.metric), + build_params.metric_arg = params.metric_arg; + build_params.graph_degree = params.graph_degree; + build_params.intermediate_graph_degree = params.intermediate_graph_degree; + build_params.max_iterations = params.max_iterations; + build_params.termination_threshold = params.termination_threshold; + build_params.return_distances = params.return_distances; + build_params.n_clusters = params.n_clusters; + + using graph_type = raft::host_matrix_view; + std::optional graph; + if (graph_tensor != NULL) { graph = cuvs::core::from_dlpack(graph_tensor); } + + if (cuvs::core::is_dlpack_device_compatible(dataset)) { + using dataset_type = raft::device_matrix_view; + auto dataset = cuvs::core::from_dlpack(dataset_tensor); + auto index = cuvs::neighbors::nn_descent::build(*res_ptr, build_params, dataset, graph); + return new cuvs::neighbors::nn_descent::index(std::move(index)); + } else if (cuvs::core::is_dlpack_host_compatible(dataset)) { + using dataset_type = raft::host_matrix_view; + auto dataset = cuvs::core::from_dlpack(dataset_tensor); + auto index = cuvs::neighbors::nn_descent::build(*res_ptr, build_params, dataset, graph); + return new cuvs::neighbors::nn_descent::index(std::move(index)); + } else { + RAFT_FAIL("dataset must be accessible on host or device memory"); + } +} +} // namespace + +extern "C" cuvsError_t cuvsNNDescentIndexCreate(cuvsNNDescentIndex_t* index) +{ + return cuvs::core::translate_exceptions([=] { *index = new cuvsNNDescentIndex{}; }); +} + +extern "C" cuvsError_t cuvsNNDescentIndexDestroy(cuvsNNDescentIndex_t index_c_ptr) +{ + return cuvs::core::translate_exceptions([=] { + auto index = *index_c_ptr; + if ((index.dtype.code == kDLUInt) && (index.dtype.bits == 32)) { + auto index_ptr = reinterpret_cast*>(index.addr); + delete index_ptr; + } else { + RAFT_FAIL( + "Unsupported nn-descent index dtype: %d and bits: %d", index.dtype.code, index.dtype.bits); + } + delete index_c_ptr; + }); +} + +extern "C" cuvsError_t cuvsNNDescentBuild(cuvsResources_t res, + cuvsNNDescentIndexParams_t params, + DLManagedTensor* dataset_tensor, + DLManagedTensor* graph_tensor, + cuvsNNDescentIndex_t index) +{ + return cuvs::core::translate_exceptions([=] { + index->dtype.code = kDLUInt; + index->dtype.bits = 32; + + auto dtype = dataset_tensor->dl_tensor.dtype; + + if ((dtype.code == kDLFloat) && (dtype.bits == 32)) { + index->addr = reinterpret_cast( + _build(res, *params, dataset_tensor, graph_tensor)); + } else if ((dtype.code == kDLFloat) && (dtype.bits == 16)) { + index->addr = reinterpret_cast( + _build(res, *params, dataset_tensor, graph_tensor)); + } else if ((dtype.code == kDLInt) && (dtype.bits == 8)) { + index->addr = reinterpret_cast( + _build(res, *params, dataset_tensor, graph_tensor)); + } else if ((dtype.code == kDLUInt) && (dtype.bits == 8)) { + index->addr = reinterpret_cast( + _build(res, *params, dataset_tensor, graph_tensor)); + } else { + RAFT_FAIL("Unsupported nn-descent dataset dtype: %d and bits: %d", dtype.code, dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsNNDescentIndexParamsCreate(cuvsNNDescentIndexParams_t* params) +{ + return cuvs::core::translate_exceptions([=] { + // get defaults from cpp parameters struct + cuvs::neighbors::nn_descent::index_params cpp_params; + + *params = new cuvsNNDescentIndexParams{ + .metric = cpp_params.metric, + .metric_arg = cpp_params.metric_arg, + .graph_degree = cpp_params.graph_degree, + .intermediate_graph_degree = cpp_params.intermediate_graph_degree, + .max_iterations = cpp_params.max_iterations, + .termination_threshold = cpp_params.termination_threshold, + .return_distances = cpp_params.return_distances, + .n_clusters = cpp_params.n_clusters}; + }); +} + +extern "C" cuvsError_t cuvsNNDescentIndexParamsDestroy(cuvsNNDescentIndexParams_t params) +{ + return cuvs::core::translate_exceptions([=] { delete params; }); +} + +extern "C" cuvsError_t cuvsNNDescentIndexGetGraph(cuvsNNDescentIndex_t index, + DLManagedTensor* graph) +{ + return cuvs::core::translate_exceptions([=] { + auto dtype = index->dtype; + if ((dtype.code == kDLUInt) && (dtype.bits == 32)) { + auto index_ptr = reinterpret_cast*>(index->addr); + using output_mdspan_type = raft::host_matrix_view; + auto dst = cuvs::core::from_dlpack(graph); + auto src = index_ptr->graph(); + + RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output graph has incorrect number of rows"); + RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output graph has incorrect number of cols"); + std::copy(src.data_handle(), src.data_handle() + dst.size(), dst.data_handle()); + } else { + RAFT_FAIL("Unsupported nn-descent index dtype: %d and bits: %d", dtype.code, dtype.bits); + } + }); +} diff --git a/python/cuvs/cuvs/distance/__init__.py b/python/cuvs/cuvs/distance/__init__.py index 5c985e7b1..024f2f2c0 100644 --- a/python/cuvs/cuvs/distance/__init__.py +++ b/python/cuvs/cuvs/distance/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .distance import DISTANCE_TYPES, pairwise_distance +from .distance import DISTANCE_NAMES, DISTANCE_TYPES, pairwise_distance -__all__ = ["DISTANCE_TYPES", "pairwise_distance"] +__all__ = ["DISTANCE_NAMES", "DISTANCE_TYPES", "pairwise_distance"] diff --git a/python/cuvs/cuvs/distance/distance.pyx b/python/cuvs/cuvs/distance/distance.pyx index d50fc152f..6b80d43b2 100644 --- a/python/cuvs/cuvs/distance/distance.pyx +++ b/python/cuvs/cuvs/distance/distance.pyx @@ -48,6 +48,8 @@ DISTANCE_TYPES = { "dice": cuvsDistanceType.DiceExpanded, } +DISTANCE_NAMES = {v: k for k, v in DISTANCE_TYPES.items()} + SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product", "chebyshev", "minkowski", "canberra", "kl_divergence", "correlation", "russellrao", "hellinger", "lp", diff --git a/python/cuvs/cuvs/neighbors/CMakeLists.txt b/python/cuvs/cuvs/neighbors/CMakeLists.txt index b9161eefc..3ac426f2b 100644 --- a/python/cuvs/cuvs/neighbors/CMakeLists.txt +++ b/python/cuvs/cuvs/neighbors/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(hnsw) add_subdirectory(ivf_flat) add_subdirectory(ivf_pq) add_subdirectory(filters) +add_subdirectory(nn_descent) # Set the list of Cython files to build set(cython_sources refine.pyx) diff --git a/python/cuvs/cuvs/neighbors/__init__.py b/python/cuvs/cuvs/neighbors/__init__.py index 52bb1eef8..1ba7f79ec 100644 --- a/python/cuvs/cuvs/neighbors/__init__.py +++ b/python/cuvs/cuvs/neighbors/__init__.py @@ -13,7 +13,14 @@ # limitations under the License. -from cuvs.neighbors import brute_force, cagra, filters, ivf_flat, ivf_pq +from cuvs.neighbors import ( + brute_force, + cagra, + filters, + ivf_flat, + ivf_pq, + nn_descent, +) from .refine import refine @@ -23,5 +30,6 @@ "filters", "ivf_flat", "ivf_pq", + "nn_descent", "refine", ] diff --git a/python/cuvs/cuvs/neighbors/nn_descent/CMakeLists.txt b/python/cuvs/cuvs/neighbors/nn_descent/CMakeLists.txt new file mode 100644 index 000000000..20f37bc23 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/nn_descent/CMakeLists.txt @@ -0,0 +1,28 @@ +# ============================================================================= +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources nn_descent.pyx) +set(linked_libraries cuvs::cuvs cuvs::c_api) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX neighbors_nn_descent_ +) + +foreach(tgt IN LISTS RAPIDS_CYTHON_CREATED_TARGETS) + target_link_libraries(${tgt} PRIVATE cuvs_rmm_logger) +endforeach() diff --git a/python/cuvs/cuvs/neighbors/nn_descent/__init__.pxd b/python/cuvs/cuvs/neighbors/nn_descent/__init__.pxd new file mode 100644 index 000000000..e69de29bb diff --git a/python/cuvs/cuvs/neighbors/nn_descent/__init__.py b/python/cuvs/cuvs/neighbors/nn_descent/__init__.py new file mode 100644 index 000000000..312bc735d --- /dev/null +++ b/python/cuvs/cuvs/neighbors/nn_descent/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .nn_descent import Index, IndexParams, build + +__all__ = [ + "Index", + "IndexParams", + "build", +] diff --git a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd new file mode 100644 index 000000000..6d0edb1f3 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd @@ -0,0 +1,63 @@ +# +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +from libc.stdint cimport uint32_t, uintptr_t +from libcpp cimport bool + +from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t +from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor +from cuvs.distance_type cimport cuvsDistanceType + + +cdef extern from "cuvs/neighbors/nn_descent.h" nogil: + + ctypedef struct cuvsNNDescentIndexParams: + cuvsDistanceType metric + float metric_arg + size_t graph_degree + size_t intermediate_graph_degree + size_t max_iterations + float termination_threshold + bool return_distances + size_t n_clusters + + ctypedef cuvsNNDescentIndexParams* cuvsNNDescentIndexParams_t + + ctypedef struct cuvsNNDescentIndex: + uintptr_t addr + DLDataType dtype + + ctypedef cuvsNNDescentIndex* cuvsNNDescentIndex_t + + cuvsError_t cuvsNNDescentIndexParamsCreate( + cuvsNNDescentIndexParams_t* params) + + cuvsError_t cuvsNNDescentIndexParamsDestroy( + cuvsNNDescentIndexParams_t index) + + cuvsError_t cuvsNNDescentIndexCreate(cuvsNNDescentIndex_t* index) + + cuvsError_t cuvsNNDescentIndexDestroy(cuvsNNDescentIndex_t index) + + cuvsError_t cuvsNNDescentIndexGetGraph(cuvsNNDescentIndex_t index, + DLManagedTensor * output) + + cuvsError_t cuvsNNDescentBuild(cuvsResources_t res, + cuvsNNDescentIndexParams* params, + DLManagedTensor* dataset, + DLManagedTensor* graph, + cuvsNNDescentIndex_t index) except + diff --git a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx new file mode 100644 index 000000000..9dd14c27e --- /dev/null +++ b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx @@ -0,0 +1,239 @@ +# +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +import numpy as np + +cimport cuvs.common.cydlpack + +from cuvs.common.resources import auto_sync_resources + +from cython.operator cimport dereference as deref +from libcpp cimport bool, cast +from libcpp.string cimport string + +from cuvs.common cimport cydlpack +from cuvs.distance_type cimport cuvsDistanceType + +from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray +from pylibraft.common.cai_wrapper import wrap_array +from pylibraft.common.interruptible import cuda_interruptible + +from cuvs.distance import DISTANCE_NAMES, DISTANCE_TYPES +from cuvs.neighbors.common import _check_input_array + +from libc.stdint cimport ( + int8_t, + int64_t, + uint8_t, + uint32_t, + uint64_t, + uintptr_t, +) + +from cuvs.common.exceptions import check_cuvs + + +cdef class IndexParams: + """ + Parameters to build NN-Descent Index + + Parameters + ---------- + metric : str, default = "sqeuclidean" + String denoting the metric type. + distribution of the newly added data. + graph_degree : int + For an input dataset of dimensions (N, D), determines the final + dimensions of the all-neighbors knn graph which turns out to be of + dimensions (N, graph_degree) + intermediate_graph_degree : int + Internally, nn-descent builds an all-neighbors knn graph of dimensions + (N, intermediate_graph_degree) before selecting the final + `graph_degree` neighbors. It's recommended that + `intermediate_graph_degree` >= 1.5 * graph_degree + max_iterations : int + The number of iterations that nn-descent will refine the graph for. + More iterations produce a better quality graph at cost of performance + termination_threshold : float + The delta at which nn-descent will terminate its iterations + """ + + cdef cuvsNNDescentIndexParams* params + cdef object _metric + + def __cinit__(self): + cuvsNNDescentIndexParamsCreate(&self.params) + + def __dealloc__(self): + check_cuvs(cuvsNNDescentIndexParamsDestroy(self.params)) + + def __init__(self, *, + metric=None, + metric_arg=None, + graph_degree=None, + intermediate_graph_degree=None, + max_iterations=None, + termination_threshold=None, + n_clusters=None + ): + if metric is not None: + self.params.metric = DISTANCE_TYPES[metric] + if graph_degree is not None: + self.params.graph_degree = graph_degree + if intermediate_graph_degree is not None: + self.params.intermediate_graph_degree = intermediate_graph_degree + if max_iterations is not None: + self.params.max_iterations = max_iterations + if termination_threshold is not None: + self.params.termination_threshold = termination_threshold + if n_clusters is not None: + self.params.n_clusters = n_clusters + + # setting this parameter to true will cause an exception in the c++ + # api (`Using return_distances set to true requires distance view to + # be allocated.`) - so instead force to be false here + self.params.return_distances = False + + @property + def metric(self): + return DISTANCE_NAMES[self.params.metric] + + @property + def metric_arg(self): + return self.params.metric_arg + + @property + def graph_degree(self): + return self.params.graph_degree + + @property + def intermediate_graph_degree(self): + return self.params.intermediate_graph_degree + + @property + def max_iterations(self): + return self.params.max_iterations + + @property + def termination_threshold(self): + return self.params.termination_threshold + + @property + def n_clusters(self): + return self.params.n_clusters + +cdef class Index: + """ + NN-Descent index object. This object stores the trained NN-Descent index, + which can be used to get the NN-Descent graph and distances after + building + """ + + cdef cuvsNNDescentIndex_t index + cdef bool trained + cdef int64_t num_rows + cdef size_t graph_degree + + def __cinit__(self): + self.trained = False + self.num_rows = 0 + self.graph_degree = 0 + check_cuvs(cuvsNNDescentIndexCreate(&self.index)) + + def __dealloc__(self): + check_cuvs(cuvsNNDescentIndexDestroy(self.index)) + + @property + def trained(self): + return self.trained + + @property + def graph(self): + if not self.trained: + raise ValueError("Index needs to be built before getting graph") + + output = np.empty((self.num_rows, self.graph_degree), dtype='uint32') + ai = wrap_array(output) + cdef cydlpack.DLManagedTensor* output_dlpack = cydlpack.dlpack_c(ai) + check_cuvs(cuvsNNDescentIndexGetGraph(self.index, output_dlpack)) + return output + + def __repr__(self): + return "Index(type=NNDescent)" + + +@auto_sync_resources +def build(IndexParams index_params, dataset, graph=None, resources=None): + """ + Build KNN graph from the dataset + + Parameters + ---------- + index_params : :py:class:`cuvs.neighbors.nn_descent.IndexParams` + dataset : Array interface compliant matrix, on either host or device memory + Supported dtype [float, int8, uint8] + graph : Optional host matrix for storing output graph + {resources_docstring} + + Returns + ------- + index: py:class:`cuvs.neighbors.nn_descent.Index` + + Examples + -------- + + >>> import cupy as cp + >>> from cuvs.neighbors import nn_descent + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> k = 10 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> build_params = nn_descent.IndexParams(metric="sqeuclidean") + >>> index = nn_descent.build(build_params, dataset) + >>> graph = index.graph + """ + dataset_ai = wrap_array(dataset) + _check_input_array(dataset_ai, [np.dtype('float32'), np.dtype('byte'), + np.dtype('ubyte')]) + + cdef Index idx = Index() + cdef cydlpack.DLManagedTensor* dataset_dlpack = \ + cydlpack.dlpack_c(dataset_ai) + cdef cuvsNNDescentIndexParams* params = index_params.params + + cdef cuvsResources_t res = resources.get_c_obj() + + cdef cydlpack.DLManagedTensor* graph_dlpack = NULL + if graph is not None: + graph_ai = wrap_array(graph) + graph_dlpack = cydlpack.dlpack_c(graph_ai) + + with cuda_interruptible(): + check_cuvs(cuvsNNDescentBuild( + res, + params, + dataset_dlpack, + graph_dlpack, + idx.index + )) + idx.trained = True + idx.num_rows = dataset_ai.shape[0] + idx.graph_degree = params.graph_degree + + return idx diff --git a/python/cuvs/cuvs/tests/test_nn_descent.py b/python/cuvs/cuvs/tests/test_nn_descent.py new file mode 100644 index 000000000..e2cc5555f --- /dev/null +++ b/python/cuvs/cuvs/tests/test_nn_descent.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from pylibraft.common import device_ndarray + +from cuvs.neighbors import brute_force, nn_descent +from cuvs.tests.ann_utils import calc_recall + + +@pytest.mark.parametrize("n_rows", [1024, 2048]) +@pytest.mark.parametrize("n_cols", [32, 64]) +@pytest.mark.parametrize("device_memory", [True, False]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize("inplace", [True, False]) +def test_nn_descent(n_rows, n_cols, device_memory, dtype, inplace): + metric = "sqeuclidean" + graph_degree = 64 + + input1 = np.random.random_sample((n_rows, n_cols)).astype(dtype) + input1_device = device_ndarray(input1) + graph = np.zeros((n_rows, graph_degree), dtype="uint32") + + params = nn_descent.IndexParams(metric=metric, graph_degree=graph_degree) + index = nn_descent.build( + params, + input1_device if device_memory else input1, + graph=graph if inplace else None, + ) + + if not inplace: + graph = index.graph + + bfknn_index = brute_force.build(input1_device, metric=metric) + _, bfknn_graph = brute_force.search( + bfknn_index, input1_device, k=graph_degree + ) + bfknn_graph = bfknn_graph.copy_to_host() + + assert calc_recall(graph, bfknn_graph) > 0.9