Skip to content

Commit

Permalink
updated with allocation tests and additional dlpack synchronization
Browse files Browse the repository at this point in the history
  • Loading branch information
nnshah1 committed Jan 10, 2024
1 parent a8fd94c commit ca117ce
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 4 deletions.
122 changes: 120 additions & 2 deletions python/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,130 @@ def test_create_request(self):
request = tritonserver.InferenceRequest(server.model("test"))


class AllocatorTests(unittest.TestCase):
def test_allocate_on_cpu_and_reshape(self):
allocator = tritonserver.default_memory_allocators[tritonserver.MemoryType.CPU]

memory_buffer = allocator.allocate(
memory_type=tritonserver.MemoryType.CPU, memory_type_id=0, size=200
)

cpu_array = memory_buffer.owner

self.assertEqual(memory_buffer.size, 200)

fp32_size = int(memory_buffer.size / 4)

tensor = tritonserver.Tensor(
tritonserver.DataType.FP32, shape=[fp32_size], memory_buffer=memory_buffer
)

cpu_fp32_array = numpy.from_dlpack(tensor)
self.assertEqual(cpu_array.ctypes.data, cpu_fp32_array.ctypes.data)
self.assertEqual(cpu_fp32_array.dtype, numpy.float32)
self.assertEqual(cpu_fp32_array.nbytes, 200)

torch_fp32_tensor = torch.from_dlpack(tensor)
self.assertEqual(torch_fp32_tensor.dtype, torch.float32)
self.assertEqual(torch_fp32_tensor.data_ptr(), cpu_array.ctypes.data)
self.assertEqual(torch_fp32_tensor.nbytes, 200)

@pytest.mark.skipif(cupy is None, reason="Skipping gpu memory, cpupy not installed")
def test_allocate_on_gpu_and_reshape(self):
if cupy is None:
return

allocator = tritonserver.default_memory_allocators[tritonserver.MemoryType.GPU]

memory_buffer = allocator.allocate(
memory_type=tritonserver.MemoryType.GPU, memory_type_id=0, size=200
)

gpu_array = memory_buffer.owner

gpu_array = cupy.empty([10, 20], dtype=cupy.uint8)
memory_buffer = tritonserver.MemoryBuffer.from_dlpack(gpu_array)

self.assertEqual(memory_buffer.size, 200)

fp32_size = int(memory_buffer.size / 4)

tensor = tritonserver.Tensor(
tritonserver.DataType.FP32, shape=[fp32_size], memory_buffer=memory_buffer
)

gpu_fp32_array = cupy.from_dlpack(tensor)
self.assertEqual(
gpu_array.__cuda_array_interface__["data"][0],
gpu_fp32_array.__cuda_array_interface__["data"][0],
)
self.assertEqual(gpu_fp32_array.dtype, cupy.float32)
self.assertEqual(gpu_fp32_array.nbytes, 200)

torch_fp32_tensor = torch.from_dlpack(tensor)
self.assertEqual(torch_fp32_tensor.dtype, torch.float32)
self.assertEqual(
torch_fp32_tensor.data_ptr(), gpu_array.__cuda_array_interface__["data"][0]
)
self.assertEqual(torch_fp32_tensor.nbytes, 200)


class TensorTests(unittest.TestCase):
@pytest.mark.skipif(cupy is None, reason="Skipping gpu memory, cupy not installed")
def test_cpu_to_gpu(self):
if cupy is None:
return
cpu_array = numpy.random.rand(1, 3, 100, 100).astype(numpy.float32)
cpu_tensor = tritonserver.Tensor.from_dlpack(cpu_array)
gpu_tensor = cpu_tensor.to_device("gpu")
cupy.from_dlpack(gpu_tensor)
gpu_tensor = cpu_tensor.to_device("gpu:0")
gpu_array = cupy.from_dlpack(gpu_tensor)

self.assertEqual(gpu_array.device, cupy.cuda.Device(0))

numpy.testing.assert_array_equal(cpu_array, gpu_array.get())

memory_buffer = tritonserver.MemoryBuffer.from_dlpack(gpu_array)

self.assertEqual(
gpu_array.__cuda_array_interface__["data"][0], memory_buffer.data_ptr
)

@pytest.mark.skipif(
torch is None, reason="Skipping gpu memory, torch not installed"
)
@pytest.mark.skipif(cupy is None, reason="Skipping gpu memory, cupy not installed")
def test_gpu_tensor_from_dl_pack(self):
if cupy is None or torch is None:
return
cupy_array = cupy.ones([100]).astype(cupy.float64)
tensor = tritonserver.Tensor.from_dlpack(cupy_array)
torch_tensor = torch.from_dlpack(cupy_array)

self.assertEqual(torch_tensor.data_ptr(), tensor.data_ptr)
self.assertEqual(torch_tensor.nbytes, tensor.size)
self.assertEqual(torch_tensor.__dlpack_device__(), tensor.__dlpack_device__())

@pytest.mark.skipif(torch is None, reason="Skipping test, torch not installed")
def test_tensor_from_numpy(self):
cpu_array = numpy.random.rand(1, 3, 100, 100).astype(numpy.float32)
tensor = tritonserver.Tensor.from_dlpack(cpu_array)
torch_tensor = torch.from_dlpack(tensor)
numpy.testing.assert_array_equal(torch_tensor.numpy(), cpu_array)
self.assertEqual(torch_tensor.data_ptr(), cpu_array.ctypes.data)

def test_tensor_to_dlpack(self):
ndarray = numpy.zeros([10, 20], dtype=numpy.uint8)
memory_buffer = _datautils.MemoryBuffer.from_dlpack(ndarray)
tensor = _datautils.Tensor(
tritonserver.DataType.FP32, shape=[50], memory_buffer=memory_buffer
)
ndarray_2 = numpy.from_dlpack(tensor)
self.assertEqual(ndarray.ctypes.data, ndarray_2.ctypes.data)
self.assertEqual(ndarray_2.dtype, numpy.float32)

torch_tensor = torch.from_dlpack(tensor)
self.assertEqual(torch_tensor.data_ptr(), ndarray.ctypes.data)
self.assertEqual(torch_tensor.dtype, torch.float32)


class ServerTests(unittest.TestCase):
Expand Down
16 changes: 14 additions & 2 deletions python/tritonserver/_api/_datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,20 @@ def parse_device_or_memory_type(
class DLPackObject:
def __init__(self, value) -> None:
try:
self._capsule = _dlpack.get_dlpack_capsule(value)
self._tensor = _dlpack.get_managed_tensor(self._capsule).dl_tensor
stream = None
device, device_id = value.__dlpack_device__()
if device == _dlpack.DLDeviceType.kDLCUDA:
if cupy is None:
raise UnsupportedError(
f"DLPack synchronization on device {device,device_id} not supported"
)
with cupy.cuda.Device(device_id):
stream = 1 # legacy default stream
self._capsule = _dlpack.get_dlpack_capsule(value, stream)
self._tensor = _dlpack.get_managed_tensor(self._capsule).dl_tensor
else:
self._capsule = _dlpack.get_dlpack_capsule(value)
self._tensor = _dlpack.get_managed_tensor(self._capsule).dl_tensor
except Exception as e:
raise InvalidArgumentError(
f"Object does not support DLPack protocol: {e}"
Expand Down
5 changes: 5 additions & 0 deletions python/tritonserver/_api/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ def _sync_on_requested_stream(self, requested_stream_ptr):
if cupy is None:
raise unsupported

# NOTE: Technically this is not required by the protocol. It is the
# responsibility of the caller(consumer) to ensure that
# we are on the correct device. Added to ensure
# the semantics are correct - but should be a no-op.
# May be removed in the future.
with cupy.cuda.Device(self.memory_type_id):
current_stream = cupy.cuda.get_current_stream()
curr_stream_ptr = current_stream.ptr
Expand Down

0 comments on commit ca117ce

Please sign in to comment.