-
Notifications
You must be signed in to change notification settings - Fork 235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding DtoD support to CUDA shared memory feature #58
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,11 @@ def from_param(cls, value): | |
_ccudashm_shared_memory_region_set.argtypes = [ | ||
c_void_p, c_uint64, c_uint64, c_void_p | ||
] | ||
_ccudashm_shared_memory_region_set_dptr = _ccudashm.CudaSharedMemoryRegionSetDptr | ||
_ccudashm_shared_memory_region_set_dptr.restype = c_int | ||
_ccudashm_shared_memory_region_set_dptr.argtypes = [ | ||
c_void_p, c_uint64, c_uint64, c_void_p | ||
] | ||
_cshm_get_shared_memory_handle_info = _ccudashm.GetCudaSharedMemoryHandleInfo | ||
_cshm_get_shared_memory_handle_info.restype = c_int | ||
_cshm_get_shared_memory_handle_info.argtypes = [ | ||
|
@@ -149,15 +154,18 @@ def get_raw_handle(cuda_shm_handle): | |
return craw_handle.value | ||
|
||
|
||
def set_shared_memory_region(cuda_shm_handle, input_values): | ||
"""Copy the contents of the numpy array into the cuda shared memory region. | ||
def set_shared_memory_region(cuda_shm_handle, input_values, device='cpu'): | ||
"""Copy the contents of the numpy array or cuda dptr into the cuda shared memory region. | ||
|
||
Parameters | ||
---------- | ||
cuda_shm_handle : c_void_p | ||
The handle for the cuda shared memory region. | ||
input_values : list | ||
The list of numpy arrays to be copied into the shared memory region. | ||
device: str | ||
Use 'cpu' to copy numpy array into cuda shared memory region. | ||
Use 'gpu' to copy from cuda dptr into cuda shared memory region. | ||
|
||
Raises | ||
------ | ||
|
@@ -173,20 +181,27 @@ def set_shared_memory_region(cuda_shm_handle, input_values): | |
"input_values must be specified as a list/tuple of numpy arrays" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to expand checks here to allow for a single cuda device ptr instead of a list or tuple of numpy arrays |
||
) | ||
|
||
if 'cpu' == device: | ||
fptr = _ccudashm_shared_memory_region_set | ||
elif 'gpu' == device: | ||
fptr = _ccudashm_shared_memory_region_set_dptr | ||
else: | ||
_raise_error( | ||
"unsupported device type: cpu and gpu are supported only" | ||
) | ||
|
||
offset_current = 0 | ||
for input_value in input_values: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the input_values a list of cudaptr or a single cudaptr? I think a single cuda device ptr is suffcient. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a single cuda device ptr. |
||
input_value = np.ascontiguousarray(input_value).flatten() | ||
if input_value.dtype == np.object_: | ||
input_value = input_value.item() | ||
byte_size = np.dtype(np.byte).itemsize * len(input_value) | ||
_raise_if_error( | ||
c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \ | ||
c_uint64(byte_size), cast(input_value, c_void_p)))) | ||
_raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64( | ||
offset_current), c_uint64(byte_size), cast(input_value, c_void_p)))) | ||
else: | ||
byte_size = input_value.size * input_value.itemsize | ||
_raise_if_error( | ||
c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \ | ||
c_uint64(byte_size), input_value.ctypes.data_as(c_void_p)))) | ||
_raise_if_error(c_int(fptr(cuda_shm_handle, c_uint64(offset_current), | ||
c_uint64(byte_size), input_value.ctypes.data_as(c_void_p)))) | ||
offset_current += byte_size | ||
return | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correction: "Copy the contents of the list/tuple of numpy arrays or the content pointed to by a cuda device pointer into the cuda shared memory region."