diff --git a/src/python/library/tests/test_cuda_shared_memory.py b/src/python/library/tests/test_cuda_shared_memory.py index f970b2a0d..caa232755 100644 --- a/src/python/library/tests/test_cuda_shared_memory.py +++ b/src/python/library/tests/test_cuda_shared_memory.py @@ -42,7 +42,7 @@ class DLPackTest(unittest.TestCase): def test_from_gpu(self): # Create GPU tensor via PyTorch and CUDA shared memory region with # enough space - gpu_tensor = torch.ones(4, 4).cuda(0) + gpu_tensor = torch.ones(1, 4, 4).cuda(0) byte_size = 64 shm_handle = cudashm.create_shared_memory_region("cudashm_data", byte_size, 0) @@ -51,7 +51,7 @@ def test_from_gpu(self): # Make sure the DLPack specification of the shared memory region can # be consumed by PyTorch - smt = cudashm.as_shared_memory_tensor(shm_handle, "FP32", [4, 4]) + smt = cudashm.as_shared_memory_tensor(shm_handle, "FP32", [1, 4, 4]) generated_torch_tensor = torch.from_dlpack(smt) self.assertTrue(torch.allclose(gpu_tensor, generated_torch_tensor)) diff --git a/src/python/library/tritonclient/utils/_dlpack.py b/src/python/library/tritonclient/utils/_dlpack.py index 94d3dfe5f..643ed0964 100644 --- a/src/python/library/tritonclient/utils/_dlpack.py +++ b/src/python/library/tritonclient/utils/_dlpack.py @@ -227,6 +227,9 @@ def is_contiguous_data( calculated_stride = 1 # iterate stride in reverse order [ndim-1, -1) for i in reversed(range(ndim)): + # don't check stride when shape is 1 + if shape[i] == 1: + continue if stride[i] != calculated_stride: return False calculated_stride *= shape[i]