From 9383abdf7706e38e4f5306a81a293f8891de35dd Mon Sep 17 00:00:00 2001
From: Roman Arzumanyan <rarzumanyan@nvidia.com>
Date: Thu, 13 Jan 2022 12:15:27 +0300
Subject: [PATCH] Adding DtoD cupport to CUDA shared memory feature

---
 .../utils/cuda_shared_memory/__init__.py      | 31 ++++++++++++++-----
 .../cuda_shared_memory/cuda_shared_memory.cc  | 31 +++++++++++++++++++
 .../cuda_shared_memory/cuda_shared_memory.h   |  2 ++
 3 files changed, 56 insertions(+), 8 deletions(-)

diff --git a/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py b/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py
index 7b562b205..704317f5c 100644
--- a/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py
+++ b/src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py
@@ -61,6 +61,11 @@ def from_param(cls, value):
 _ccudashm_shared_memory_region_set.argtypes = [
     c_void_p, c_uint64, c_uint64, c_void_p
 ]
+_ccudashm_shared_memory_region_set_dptr = _ccudashm.CudaSharedMemoryRegionSetDptr
+_ccudashm_shared_memory_region_set_dptr.restype = c_int
+_ccudashm_shared_memory_region_set_dptr.argtypes = [
+    c_void_p, c_uint64, c_uint64, c_void_p
+]
 _cshm_get_shared_memory_handle_info = _ccudashm.GetCudaSharedMemoryHandleInfo
 _cshm_get_shared_memory_handle_info.restype = c_int
 _cshm_get_shared_memory_handle_info.argtypes = [
@@ -149,8 +154,8 @@ def get_raw_handle(cuda_shm_handle):
     return craw_handle.value
 
 
-def set_shared_memory_region(cuda_shm_handle, input_values):
-    """Copy the contents of the numpy array into the cuda shared memory region.
+def set_shared_memory_region(cuda_shm_handle, input_values, device='cpu'):
+    """Copy the contents of the numpy array or cuda dptr into the cuda shared memory region.
 
     Parameters
     ----------
@@ -158,6 +163,9 @@ def set_shared_memory_region(cuda_shm_handle, input_values):
         The handle for the cuda shared memory region.
     input_values : list
         The list of numpy arrays to be copied into the shared memory region.
+    device: str
+        Use 'cpu' to copy numpy array into cuda shared memory region.
+        Use 'gpu' to copy from cuda dptr into cuda shared memory region.
 
     Raises
     ------
@@ -173,20 +181,27 @@ def set_shared_memory_region(cuda_shm_handle, input_values):
                 "input_values must be specified as a list/tuple of numpy arrays"
             )
 
+    if 'cpu' == device:
+        fptr = _ccudashm_shared_memory_region_set
+    elif 'gpu' == device:
+        fptr = _ccudashm_shared_memory_region_set_dptr
+    else:
+        _raise_error(
+            "unsupported device type: cpu and gpu are supported only"
+        )
+
     offset_current = 0
     for input_value in input_values:
         input_value = np.ascontiguousarray(input_value).flatten()
         if input_value.dtype == np.object_:
             input_value = input_value.item()
             byte_size = np.dtype(np.byte).itemsize * len(input_value)
-            _raise_if_error(
-                c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \
-                    c_uint64(byte_size), cast(input_value, c_void_p))))
+            _raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64(
+                offset_current),  c_uint64(byte_size), cast(input_value, c_void_p))))
         else:
             byte_size = input_value.size * input_value.itemsize
-            _raise_if_error(
-                c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \
-                    c_uint64(byte_size), input_value.ctypes.data_as(c_void_p))))
+            _raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64(offset_current),
+                c_uint64(byte_size), input_value.ctypes.data_as(c_void_p))))
         offset_current += byte_size
     return
 
diff --git a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc
index ea86e50c6..8b5a1ccc4 100644
--- a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc
+++ b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.cc
@@ -157,6 +157,37 @@ CudaSharedMemoryRegionSet(
   return 0;
 }
 
+int
+CudaSharedMemoryRegionSetDptr(
+    void* cuda_shm_handle, size_t offset, size_t byte_size, const void* dptr)
+{
+  // remember previous device and set to new device
+  int previous_device;
+  cudaGetDevice(&previous_device);
+  cudaError_t err = cudaSetDevice(
+      reinterpret_cast<SharedMemoryHandle*>(cuda_shm_handle)->device_id_);
+  if (err != cudaSuccess) {
+    cudaSetDevice(previous_device);
+    return -1;
+  }
+
+  // Copy data into cuda shared memory
+  void* base_addr =
+      reinterpret_cast<SharedMemoryHandle*>(cuda_shm_handle)->base_addr_;
+  err = cudaMemcpy(
+      reinterpret_cast<uint8_t*>(base_addr) + offset, dptr, byte_size,
+      cudaMemcpyDeviceToDevice);
+  if (err != cudaSuccess) {
+    cudaSetDevice(previous_device);
+    return -3;
+  }
+
+  // Set device to previous GPU
+  cudaSetDevice(previous_device);
+
+  return 0;
+}
+
 int
 GetCudaSharedMemoryHandleInfo(
     void* shm_handle, char** shm_addr, size_t* offset, size_t* byte_size)
diff --git a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h
index ca4e5943b..f344e9502 100644
--- a/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h
+++ b/src/python/library/tritonclient/utils/cuda_shared_memory/cuda_shared_memory.h
@@ -40,6 +40,8 @@ int CudaSharedMemoryGetRawHandle(
     void* cuda_shm_handle, char** serialized_raw_handle);
 int CudaSharedMemoryRegionSet(
     void* cuda_shm_handle, size_t offset, size_t byte_size, const void* data);
+int CudaSharedMemoryRegionSetDptr(
+    void* cuda_shm_handle, size_t offset, size_t byte_size, const void* dptr);
 int GetCudaSharedMemoryHandleInfo(
     void* shm_handle, char** shm_addr, size_t* offset, size_t* byte_size);
 int CudaSharedMemoryReleaseBuffer(char* ptr);