-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,287 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except | ||
# in compliance with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software distributed under the License | ||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
# or implied. See the License for the specific language governing permissions and limitations under | ||
# the License. | ||
# ============================================================================= | ||
|
||
# Set the list of Cython files to build | ||
set(cython_sources dlpack.pyx) | ||
set(linked_libraries cuvs::cuvs) | ||
|
||
# Build all of the Cython targets | ||
rapids_cython_create_modules( | ||
CXX | ||
SOURCE_FILES "${cython_sources}" | ||
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS cuvs MODULE_PREFIX common_ | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from .cagra import Index, IndexParams, SearchParams, build, load, save, search | ||
|
||
__all__ = [ | ||
"Index", | ||
"IndexParams", | ||
"SearchParams", | ||
"build", | ||
"load", | ||
"save", | ||
"search", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# cython: language_level=3 | ||
|
||
from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t | ||
|
||
cdef extern from 'dlpack.h' nogil: | ||
ctypedef enum DLDeviceType: | ||
kDLCPU | ||
kDLCUDA | ||
kDLCUDAHost | ||
kDLOpenCL | ||
kDLVulkan | ||
kDLMetal | ||
kDLVPI | ||
kDLROCM | ||
kDLROCMHost | ||
kDLExtDev | ||
kDLCUDAManaged | ||
kDLOneAPI | ||
kDLWebGPU | ||
kDLHexagon | ||
|
||
ctypedef struct DLDevice: | ||
DLDeviceType device_type | ||
int32_t device_id | ||
|
||
ctypedef enum DLDataTypeCode: | ||
kDLInt | ||
kDLUInt | ||
kDLFloat | ||
kDLBfloat | ||
kDLComplex | ||
kDLBool | ||
|
||
ctypedef struct DLDataType: | ||
uint8_t code | ||
uint8_t bits | ||
uint16_t lanes | ||
|
||
ctypedef struct DLTensor: | ||
void* data | ||
DLDevice device | ||
int32_t ndim | ||
DLDataType dtype | ||
int64_t* shape | ||
int64_t* strides | ||
uint64_t byte_offset | ||
|
||
ctypedef struct DLManagedTensor: | ||
DLTensor dl_tensor | ||
void* manager_ctx | ||
void (*deleter)(DLManagedTensor*) # noqa: E211 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# cython: language_level=3 | ||
|
||
import numpy as np | ||
|
||
|
||
cdef void deleter(DLManagedTensor* tensor): | ||
if tensor.manager_ctx is NULL: | ||
return | ||
stdlib.free(tensor.dl_tensor.shape) | ||
tensor.manager_ctx = NULL | ||
stdlib.free(tensor) | ||
|
||
|
||
cdef DLManagedTensor dlpack_c(ary): | ||
#todo(dgd): add checking options/parameters | ||
cdef DLDeviceType dev_type | ||
cdef DLDevice dev | ||
cdef DLDataType dtype | ||
cdef DLTensor tensor | ||
cdef DLManagedTensor dlm | ||
|
||
if hasattr(ary, "__cuda_array_interface__"): | ||
dev_type = DLDeviceType.kDLCUDA | ||
else: | ||
dev_type = DLDeviceType.kDLCPU | ||
|
||
dev.device_type = dev_type | ||
dev.device_id = 0 | ||
|
||
# todo (dgd): change to nice dict | ||
if ary.dtype == np.float32: | ||
dtype.code = DLDataTypeCode.kDLFloat | ||
dtype.bits = 32 | ||
elif ary.dtype == np.float64: | ||
dtype.code = DLDataTypeCode.kDLFloat | ||
dtype.bits = 64 | ||
elif ary.dtype == np.int32: | ||
dtype.code = DLDataTypeCode.kDLInt | ||
dtype.bits = 32 | ||
elif ary.dtype == np.int64: | ||
dtype.code = DLDataTypeCode.kDLFloat | ||
dtype.bits = 64 | ||
elif ary.dtype == np.bool: | ||
dtype.code = DLDataTypeCode.kDLFloat | ||
|
||
if hasattr(ary, "__cuda_array_interface__"): | ||
tensor_ptr = ary.__cuda_array_interface__["data"][0] | ||
else: | ||
tensor_ptr = ary.__array_interface__["data"][0] | ||
|
||
|
||
tensor.data = <void*> tensor_ptr | ||
tensor.device = dev | ||
tensor.dtype = dtype | ||
|
||
dlm.dl_tensor = tensor | ||
dlm.manager_ct = NULL | ||
dlm.deleter = deleter | ||
|
||
return dlm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
|
||
|
||
def auto_sync_resources(f): | ||
""" | ||
This is identical to auto_sync_handle except for the proposed name change. | ||
""" | ||
|
||
@functools.wraps(f) | ||
def wrapper(*args, resources=None, **kwargs): | ||
sync_handle = resources is None | ||
resources = resources if resources is not None else DeviceResources() | ||
|
||
ret_value = f(*args, resources=resources, **kwargs) | ||
|
||
if sync_handle: | ||
resources.sync() | ||
|
||
return ret_value | ||
|
||
wrapper.__doc__ = wrapper.__doc__.format( | ||
handle_docstring=_HANDLE_PARAM_DOCSTRING | ||
) | ||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except | ||
# in compliance with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software distributed under the License | ||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
# or implied. See the License for the specific language governing permissions and limitations under | ||
# the License. | ||
# ============================================================================= | ||
|
||
# Set the list of Cython files to build | ||
set(cython_sources cagra.pyx) | ||
set(linked_libraries cuvs::cuvs) | ||
|
||
# Build all of the Cython targets | ||
rapids_cython_create_modules( | ||
CXX | ||
SOURCE_FILES "${cython_sources}" | ||
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS cuvs MODULE_PREFIX neighbors_cagra_ | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from .cagra import Index, IndexParams, SearchParams, build, load, save, search | ||
|
||
__all__ = [ | ||
"Index", | ||
"IndexParams", | ||
"SearchParams", | ||
"build", | ||
"load", | ||
"save", | ||
"search", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# cython: language_level=3 | ||
|
||
from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t | ||
|
||
|
||
cdef extern from "cuvs/core/c_api.h" | ||
ctypedef uintptr_t cuvsResources_t | ||
|
||
ctypedef enum cuvsError_t: | ||
CUVS_ERROR, | ||
CUVS_SUCCESS | ||
|
||
cuvsError_t cuvsResourcesCreate(cuvsResources_t* res) | ||
cuvsError_t cuvsResourcesDestroy(cuvsResources_t res) | ||
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream) | ||
|
||
|
||
cdef extern from "cuvs/neighborscagra_c.h" nogil: | ||
|
||
ctypedef enum cagraGraphBuildAlgo: | ||
IVF_PQ | ||
NN_DESCENT | ||
|
||
|
||
ctypedef struct cagraIndexParams: | ||
size_t intermediate_graph_degree | ||
size_t graph_degree | ||
cagraGraphBuildAlgo build_algo | ||
size_t nn_descent_niter | ||
|
||
|
||
ctypedef enum search_algo: | ||
SINGLE_CTA, | ||
MULTI_CTA, | ||
MULTI_KERNEL, | ||
AUTO | ||
|
||
ctypedef enum cagraHashMode: | ||
HASH, | ||
SMALL, | ||
AUTO_HASH | ||
|
||
ctypedef struct cagraSearchParams: | ||
size_t max_queries | ||
size_t itopk_size | ||
size_t max_iterations | ||
cagraSearchAlgo algo | ||
size_t team_size | ||
size_t search_width | ||
size_t min_iterations | ||
size_t thread_block_size | ||
cagraHashMode hashmap_mode | ||
size_t hashmap_min_bitlen | ||
float hashmap_max_fill_rate | ||
uint32_t num_random_samplings | ||
uint64_t rand_xor_mask | ||
|
||
ctypedef struct cagraIndex: | ||
uintptr_t addr | ||
DLDataType dtype | ||
|
||
ctypedef cagraIndex* cagraIndex_t | ||
|
||
cuvsError_t cagraIndexCreate(cagraIndex_t* index) | ||
|
||
cuvsError_t cagraIndexDestroy(cagraIndex_t index) | ||
|
||
cuvsError_t cagraBuild(cuvsResources_t res, | ||
struct cagraIndexParams params, | ||
DLManagedTensor* dataset, | ||
cagraIndex_t index); | ||
|
||
cuvsError_t cagraSearch(cuvsResources_t res, | ||
cagraSearchParams params, | ||
cagraIndex_t index, | ||
DLManagedTensor* queries, | ||
DLManagedTensor* neighbors, | ||
DLManagedTensor* distances) |
Oops, something went wrong.