Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding DtoD support to CUDA shared memory feature #58

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def from_param(cls, value):
_ccudashm_shared_memory_region_set.argtypes = [
c_void_p, c_uint64, c_uint64, c_void_p
]
_ccudashm_shared_memory_region_set_dptr = _ccudashm.CudaSharedMemoryRegionSetDptr
_ccudashm_shared_memory_region_set_dptr.restype = c_int
_ccudashm_shared_memory_region_set_dptr.argtypes = [
c_void_p, c_uint64, c_uint64, c_void_p
]
_cshm_get_shared_memory_handle_info = _ccudashm.GetCudaSharedMemoryHandleInfo
_cshm_get_shared_memory_handle_info.restype = c_int
_cshm_get_shared_memory_handle_info.argtypes = [
Expand Down Expand Up @@ -149,15 +154,18 @@ def get_raw_handle(cuda_shm_handle):
return craw_handle.value


def set_shared_memory_region(cuda_shm_handle, input_values):
"""Copy the contents of the numpy array into the cuda shared memory region.
def set_shared_memory_region(cuda_shm_handle, input_values, device='cpu'):
"""Copy the contents of the numpy array or cuda dptr into the cuda shared memory region.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction: "Copy the contents of the list/tuple of numpy arrays or the content pointed to by a cuda device pointer into the cuda shared memory region."


Parameters
----------
cuda_shm_handle : c_void_p
The handle for the cuda shared memory region.
input_values : list
The list of numpy arrays to be copied into the shared memory region.
device: str
Use 'cpu' to copy numpy array into cuda shared memory region.
Use 'gpu' to copy from cuda dptr into cuda shared memory region.

Raises
------
Expand All @@ -173,20 +181,27 @@ def set_shared_memory_region(cuda_shm_handle, input_values):
"input_values must be specified as a list/tuple of numpy arrays"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to expand checks here to allow for a single cuda device ptr instead of a list or tuple of numpy arrays

)

if 'cpu' == device:
fptr = _ccudashm_shared_memory_region_set
elif 'gpu' == device:
fptr = _ccudashm_shared_memory_region_set_dptr
else:
_raise_error(
"unsupported device type: cpu and gpu are supported only"
)

offset_current = 0
for input_value in input_values:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the input_values a list of cudaptr or a single cudaptr? I think a single cuda device ptr is suffcient.

Copy link
Author

@rarzumanyan rarzumanyan Feb 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a single cuda device ptr.

input_value = np.ascontiguousarray(input_value).flatten()
if input_value.dtype == np.object_:
input_value = input_value.item()
byte_size = np.dtype(np.byte).itemsize * len(input_value)
_raise_if_error(
c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \
c_uint64(byte_size), cast(input_value, c_void_p))))
_raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64(
offset_current), c_uint64(byte_size), cast(input_value, c_void_p))))
else:
byte_size = input_value.size * input_value.itemsize
_raise_if_error(
c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \
c_uint64(byte_size), input_value.ctypes.data_as(c_void_p))))
_raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64(offset_current),
c_uint64(byte_size), input_value.ctypes.data_as(c_void_p))))
offset_current += byte_size
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,37 @@ CudaSharedMemoryRegionSet(
return 0;
}

int
CudaSharedMemoryRegionSetDptr(
CoderHam marked this conversation as resolved.
Show resolved Hide resolved
void* cuda_shm_handle, size_t offset, size_t byte_size, const void* dptr)
{
// remember previous device and set to new device
int previous_device;
cudaGetDevice(&previous_device);
cudaError_t err = cudaSetDevice(
reinterpret_cast<SharedMemoryHandle*>(cuda_shm_handle)->device_id_);
if (err != cudaSuccess) {
cudaSetDevice(previous_device);
return -1;
}

// Copy data into cuda shared memory
void* base_addr =
reinterpret_cast<SharedMemoryHandle*>(cuda_shm_handle)->base_addr_;
err = cudaMemcpy(
reinterpret_cast<uint8_t*>(base_addr) + offset, dptr, byte_size,
cudaMemcpyDeviceToDevice);
if (err != cudaSuccess) {
cudaSetDevice(previous_device);
return -3;
}

// Set device to previous GPU
cudaSetDevice(previous_device);

return 0;
}

int
GetCudaSharedMemoryHandleInfo(
void* shm_handle, char** shm_addr, size_t* offset, size_t* byte_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ int CudaSharedMemoryGetRawHandle(
void* cuda_shm_handle, char** serialized_raw_handle);
int CudaSharedMemoryRegionSet(
void* cuda_shm_handle, size_t offset, size_t byte_size, const void* data);
int CudaSharedMemoryRegionSetDptr(
void* cuda_shm_handle, size_t offset, size_t byte_size, const void* dptr);
int GetCudaSharedMemoryHandleInfo(
void* shm_handle, char** shm_addr, size_t* offset, size_t* byte_size);
int CudaSharedMemoryReleaseBuffer(char* ptr);
Expand Down