Skip to content

Commit

Permalink
refactor: reflect latest comment on create / destroy behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
GuanLuo committed Oct 23, 2024
1 parent 4d7a0b8 commit c28ba19
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 76 deletions.
47 changes: 22 additions & 25 deletions src/python/library/tests/test_shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,7 @@ def setUp(self):

def tearDown(self):
for shm_handle in self.shm_handles:
# [NOTE] wrapper for old implementation that will fail
try:
shm.destroy_shared_memory_region(shm_handle)
except shm.SharedMemoryException as ex:
if "unlink" in str(ex):
pass
else:
raise ex
shm.destroy_shared_memory_region(shm_handle)

def test_lifecycle(self):
cpu_tensor = numpy.ones([4, 4], dtype=numpy.float32)
Expand All @@ -69,6 +62,15 @@ def test_lifecycle(self):

shm.destroy_shared_memory_region(self.shm_handles.pop(0))

def test_invalid_create_shm(self):
# Raises error since tried to create invalid system shared memory region
try:
self.shm_handles.append(
shm.create_shared_memory_region("dummy_data", "/dummy_data", -1)
)
except Exception as ex:
self.assertTrue(str(ex) == "unable to initialize the size")

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a == b) cannot provide an informative message. Using assertEqual(a, b) instead will give more informative messages.

def test_set_region_offset(self):
large_tensor = numpy.ones([4, 4], dtype=numpy.float32)
large_size = 64
Expand All @@ -87,7 +89,6 @@ def test_set_region_offset(self):

self.assertTrue(numpy.allclose(small_tensor, shm_tensor))

# [NOTE] current impl will fail
def test_set_region_oversize(self):
large_tensor = numpy.ones([4, 4], dtype=numpy.float32)
small_size = 32
Expand All @@ -101,55 +102,51 @@ def test_duplicate_key(self):
# [NOTE] change in behavior:
# previous: okay to create shared memory region of the same key with different size
# and the behavior is not being study clearly.
# now: only allow create by default, flag may be set to return the same handle if
# existed, warning will be print if size is different
# now: return the same handle if existed, warning will be print if size is different
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 32)
)
with self.assertRaises(shm.SharedMemoryException):
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 32)
shm.create_shared_memory_region(
"shm_name", "shm_key", 32, create_only=True
)
)

# Get handle to the same shared memory region but with larger size requested,
# check if actual size is checked
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 64, create=False)
shm.create_shared_memory_region("shm_name", "shm_key", 64)
)

self.assertEqual(len(shm.mapped_shared_memory_regions()), 1)

large_tensor = numpy.ones([4, 4], dtype=numpy.float32)
small_size = 32
# [NOTE] current impl will fail
with self.assertRaises(shm.SharedMemoryException):
shm.set_shared_memory_region(self.shm_handles[-1], [large_tensor])

# [NOTE] current impl will fail
def test_destroy_duplicate(self):
# [NOTE] change in behavior:
# previous: raise exception if underlying shared memory has been unlinked
# now: the exception will be suppressed to align with Windows behavior, unless
# explicitly toggled
# now: no exception as unlink only happen when last managed handle is destroyed
self.assertEqual(len(shm.mapped_shared_memory_regions()), 0)
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 64)
)
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 32, create=False)
shm.create_shared_memory_region("shm_name", "shm_key", 32)
)
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 32, create=False)
shm.create_shared_memory_region("shm_name", "shm_key", 32)
)
self.assertEqual(len(shm.mapped_shared_memory_regions()), 1)

shm.destroy_shared_memory_region(self.shm_handles.pop(0))
self.assertEqual(len(shm.mapped_shared_memory_regions()), 0)
shm.destroy_shared_memory_region(self.shm_handles.pop(0))
self.assertEqual(len(shm.mapped_shared_memory_regions()), 1)

shm.destroy_shared_memory_region(self.shm_handles.pop(0))
with self.assertRaises(shm.SharedMemoryException):
shm.destroy_shared_memory_region(
self.shm_handles.pop(0), raise_unlink_exception=True
)
self.assertEqual(len(shm.mapped_shared_memory_regions()), 0)

def test_numpy_bytes(self):
int_tensor = numpy.arange(start=0, stop=16, dtype=numpy.int32)
Expand Down
119 changes: 68 additions & 51 deletions src/python/library/tritonclient/utils/shared_memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -28,6 +28,7 @@

import os
import struct
import warnings
from ctypes import *

import numpy as np
Expand Down Expand Up @@ -72,6 +73,7 @@ def from_param(cls, value):
_cshm_shared_memory_region_destroy.argtypes = [c_void_p]

mapped_shm_regions = []
_key_mapping = {}


def _raise_if_error(errno):
Expand All @@ -90,7 +92,18 @@ def _raise_error(msg):
raise ex


def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create=True):
class SharedMemoryRegion:
def __init__(
self,
triton_shm_name: str,
shm_key: str,
) -> None:
self._triton_shm_name = triton_shm_name
self._shm_key = shm_key
self._c_handle = c_void_p()


def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only=False):
"""Creates a system shared memory region with the specified name and size.
Parameters
Expand All @@ -101,10 +114,15 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create=True
The unique key of the shared memory object.
byte_size : int
The size in bytes of the shared memory region to be created.
create_only : bool
Whether a shared memory region must be created. If False and
a shared memory region of the same name exists, a handle to that
shared memory region will be returned and user must be aware that
the shared memory size can be different from the size requested.
Returns
-------
shm_handle : c_void_p
shm_handle : SharedMemoryRegion
The handle for the system shared memory region.
Raises
Expand All @@ -113,21 +131,47 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create=True
If unable to create the shared memory region.
"""

if create and shm_key in mapped_shm_regions:
if create_only and shm_key in mapped_shm_regions:
raise SharedMemoryException(
"unable to create the shared memory region, already exists"
)

shm_handle = c_void_p()
_raise_if_error(
c_int(
_cshm_shared_memory_region_create(
triton_shm_name, shm_key, byte_size, byref(shm_handle)
shm_handle = SharedMemoryRegion(triton_shm_name, shm_key)
# Has been created
if shm_key in _key_mapping:
shm_handle._c_handle = _key_mapping[shm_key][0]
_key_mapping[shm_key][1] += 1
# check on the size
shm_fd = c_int()
region_offset = c_uint64()
shm_byte_size = c_uint64()
shm_addr = c_char_p()
shm_key = c_char_p()
_raise_if_error(
c_int(
_cshm_get_shared_memory_handle_info(
shm_handle._c_handle,
byref(shm_addr),
byref(shm_key),
byref(shm_fd),
byref(region_offset),
byref(shm_byte_size),
)
)
)
)

if create:
if byte_size > shm_byte_size.value:
warnings.warn(
f"reusing shared memory region with key '{shm_key}', region size is {shm_byte_size.value} instead of requested {byte_size}"
)
else:
_raise_if_error(
c_int(
_cshm_shared_memory_region_create(
triton_shm_name, shm_key, byte_size, byref(shm_handle._c_handle)
)
)
)
_key_mapping[shm_key] = [shm_handle._c_handle, 1]
mapped_shm_regions.append(shm_key)

return shm_handle
Expand All @@ -138,7 +182,7 @@ def set_shared_memory_region(shm_handle, input_values, offset=0):
Parameters
----------
shm_handle : c_void_p
shm_handle : SharedMemoryRegion
The handle for the system shared memory region.
input_values : list
The list of numpy arrays to be copied into the shared memory region.
Expand Down Expand Up @@ -167,7 +211,7 @@ def set_shared_memory_region(shm_handle, input_values, offset=0):
_raise_if_error(
c_int(
_cshm_shared_memory_region_set(
shm_handle,
shm_handle._c_handle,
c_uint64(offset_current),
c_uint64(byte_size),
cast(input_value, c_void_p),
Expand All @@ -179,7 +223,7 @@ def set_shared_memory_region(shm_handle, input_values, offset=0):
_raise_if_error(
c_int(
_cshm_shared_memory_region_set(
shm_handle,
shm_handle._c_handle,
c_uint64(offset_current),
c_uint64(byte_size),
input_value.ctypes.data_as(c_void_p),
Expand All @@ -196,7 +240,7 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
Parameters
----------
shm_handle : c_void_p
shm_handle : SharedMemoryRegion
The handle for the system shared memory region.
datatype : np.dtype
The datatype of the array to be returned.
Expand All @@ -220,7 +264,7 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
_raise_if_error(
c_int(
_cshm_get_shared_memory_handle_info(
shm_handle,
shm_handle._c_handle,
byref(shm_addr),
byref(shm_key),
byref(shm_fd),
Expand Down Expand Up @@ -278,56 +322,29 @@ def mapped_shared_memory_regions():
return mapped_shm_regions


def destroy_shared_memory_region(shm_handle, raise_unlink_exception=False):
def destroy_shared_memory_region(shm_handle):
"""Unlink a system shared memory region with the specified handle.
Parameters
----------
shm_handle : c_void_p
shm_handle : SharedMemoryRegion
The handle for the system shared memory region.
Raises
------
SharedMemoryException
If unable to unlink the shared memory region.
"""
shm_fd = c_int()
offset = c_uint64()
byte_size = c_uint64()
shm_addr = c_char_p()
shm_key = c_char_p()
_raise_if_error(
c_int(
_cshm_get_shared_memory_handle_info(
shm_handle,
byref(shm_addr),
byref(shm_key),
byref(shm_fd),
byref(offset),
byref(byte_size),
)
)
)
# It is safer to remove the shared memory key from the list before
# deleting the shared memory region because if the deletion should
# fail, a re-attempt could result in a segfault. Secondarily, if we
# fail to delete a region, we should not report it back to the user
# as a valid memory region.
try:
mapped_shm_regions.remove(shm_key.value.decode("utf-8"))
except ValueError:
# okay if mapped_shm_regions doesn't have the key as there can be
# destroy call on handles with the same shared memory key
pass
try:
_raise_if_error(c_int(_cshm_shared_memory_region_destroy(shm_handle)))
except SharedMemoryException as ex:
# Suppress unlink exception except when explicitly allow to raise
if not raise_unlink_exception and "unlink" in str(ex):
pass
else:
raise ex
return
_key_mapping[shm_handle._shm_key][1] -= 1
if _key_mapping[shm_handle._shm_key][1] == 0:
mapped_shm_regions.remove(shm_handle._shm_key)
_key_mapping.pop(shm_handle._shm_key)
_raise_if_error(c_int(_cshm_shared_memory_region_destroy(shm_handle._c_handle)))


class SharedMemoryException(Exception):
Expand Down

0 comments on commit c28ba19

Please sign in to comment.