From 9383abdf7706e38e4f5306a81a293f8891de35dd Mon Sep 17 00:00:00 2001 From: Roman Arzumanyan <rarzumanyan@nvidia.com> Date: Thu, 13 Jan 2022 12:15:27 +0300 Subject: [PATCH] Adding DtoD cupport to CUDA shared memory feature --- .../utils/cuda_shared_memory/__init__.py | 31 ++++++++++++++----- .../cuda_shared_memory/cuda_shared_memory.cc | 31 +++++++++++++++++++ .../cuda_shared_memory/cuda_shared_memory.h | 2 ++ 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py b/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py index 7b562b205..704317f5c 100644 --- a/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py +++ b/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py @@ -61,6 +61,11 @@ def from_param(cls, value): _ccudashm_shared_memory_region_set.argtypes = [ c_void_p, c_uint64, c_uint64, c_void_p ] +_ccudashm_shared_memory_region_set_dptr = _ccudashm.CudaSharedMemoryRegionSetDptr +_ccudashm_shared_memory_region_set_dptr.restype = c_int +_ccudashm_shared_memory_region_set_dptr.argtypes = [ + c_void_p, c_uint64, c_uint64, c_void_p +] _cshm_get_shared_memory_handle_info = _ccudashm.GetCudaSharedMemoryHandleInfo _cshm_get_shared_memory_handle_info.restype = c_int _cshm_get_shared_memory_handle_info.argtypes = [ @@ -149,8 +154,8 @@ def get_raw_handle(cuda_shm_handle): return craw_handle.value -def set_shared_memory_region(cuda_shm_handle, input_values): - """Copy the contents of the numpy array into the cuda shared memory region. +def set_shared_memory_region(cuda_shm_handle, input_values, device='cpu'): + """Copy the contents of the numpy array or cuda dptr into the cuda shared memory region. Parameters ---------- @@ -158,6 +163,9 @@ def set_shared_memory_region(cuda_shm_handle, input_values): The handle for the cuda shared memory region. input_values : list The list of numpy arrays to be copied into the shared memory region. + device: str + Use 'cpu' to copy numpy array into cuda shared memory region. + Use 'gpu' to copy from cuda dptr into cuda shared memory region. Raises ------ @@ -173,20 +181,27 @@ def set_shared_memory_region(cuda_shm_handle, input_values): "input_values must be specified as a list/tuple of numpy arrays" ) + if 'cpu' == device: + fptr = _ccudashm_shared_memory_region_set + elif 'gpu' == device: + fptr = _ccudashm_shared_memory_region_set_dptr + else: + _raise_error( + "unsupported device type: cpu and gpu are supported only" + ) + offset_current = 0 for input_value in input_values: input_value = np.ascontiguousarray(input_value).flatten() if input_value.dtype == np.object_: input_value = input_value.item() byte_size = np.dtype(np.byte).itemsize * len(input_value) - _raise_if_error( - c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \ - c_uint64(byte_size), cast(input_value, c_void_p)))) + _raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64( + offset_current), c_uint64(byte_size), cast(input_value, c_void_p)))) else: byte_size = input_value.size * input_value.itemsize - _raise_if_error( - c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \ - c_uint64(byte_size), input_value.ctypes.data_as(c_void_p)))) + _raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64(offset_current), + c_uint64(byte_size), input_value.ctypes.data_as(c_void_p)))) offset_current += byte_size return diff --git a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc index ea86e50c6..8b5a1ccc4 100644 --- a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc +++ b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc @@ -157,6 +157,37 @@ CudaSharedMemoryRegionSet( return 0; } +int +CudaSharedMemoryRegionSetDptr( + void* cuda_shm_handle, size_t offset, size_t byte_size, const void* dptr) +{ + // remember previous device and set to new device + int previous_device; + cudaGetDevice(&previous_device); + cudaError_t err = cudaSetDevice( + reinterpret_cast<SharedMemoryHandle*>(cuda_shm_handle)->device_id_); + if (err != cudaSuccess) { + cudaSetDevice(previous_device); + return -1; + } + + // Copy data into cuda shared memory + void* base_addr = + reinterpret_cast<SharedMemoryHandle*>(cuda_shm_handle)->base_addr_; + err = cudaMemcpy( + reinterpret_cast<uint8_t*>(base_addr) + offset, dptr, byte_size, + cudaMemcpyDeviceToDevice); + if (err != cudaSuccess) { + cudaSetDevice(previous_device); + return -3; + } + + // Set device to previous GPU + cudaSetDevice(previous_device); + + return 0; +} + int GetCudaSharedMemoryHandleInfo( void* shm_handle, char** shm_addr, size_t* offset, size_t* byte_size) diff --git a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h index ca4e5943b..f344e9502 100644 --- a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h +++ b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h @@ -40,6 +40,8 @@ int CudaSharedMemoryGetRawHandle( void* cuda_shm_handle, char** serialized_raw_handle); int CudaSharedMemoryRegionSet( void* cuda_shm_handle, size_t offset, size_t byte_size, const void* data); +int CudaSharedMemoryRegionSetDptr( + void* cuda_shm_handle, size_t offset, size_t byte_size, const void* dptr); int GetCudaSharedMemoryHandleInfo( void* shm_handle, char** shm_addr, size_t* offset, size_t* byte_size); int CudaSharedMemoryReleaseBuffer(char* ptr);