Skip to content

Commit

Permalink
updated naming and example
Browse files Browse the repository at this point in the history
  • Loading branch information
nnshah1 committed Jan 9, 2024
1 parent 8d2c3f8 commit d1394a5
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
1 change: 1 addition & 0 deletions python/tritonserver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from tritonserver._api._allocators import (
default_memory_allocators as default_memory_allocators,
)
from tritonserver._api._dlpack import DLDeviceType as DLDeviceType
from tritonserver._api._model import Model as Model
from tritonserver._api._model import ModelBatchFlag as ModelBatchFlag
from tritonserver._api._model import ModelTxnPropertyFlag as ModelTxnPropertyFlag
Expand Down
11 changes: 6 additions & 5 deletions python/tritonserver/_api/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class InferenceResponse:
final: bool = False

@staticmethod
def _from_TRITONSERVER_InferenceResponse(
def _from_tritonserver_inference_response(
model: _model.Model,
server: TRITONSERVER_Server,
request: TRITONSERVER_InferenceRequest,
Expand Down Expand Up @@ -435,10 +435,10 @@ def _from_TRITONSERVER_InferenceResponse(
name,
data_type,
shape,
data_ptr,
byte_size,
memory_type,
memory_type_id,
_data_ptr,
_byte_size,
_memory_type,
_memory_type_id,
memory_buffer,
) = response.output(output_index)
tensor = Tensor(data_type, shape, memory_buffer)
Expand All @@ -450,6 +450,7 @@ def _from_TRITONSERVER_InferenceResponse(
error.args += tuple(asdict(result).items())
result.error = error

# TODO: support classification
# values["classification_label"] = response.output_classification_label()

return result
31 changes: 30 additions & 1 deletion python/tritonserver/_api/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __dlpack__(self, *, stream=None):
Any
A DLPack-compatible object representing the tensor.
"""

# TODO: Handle the stream argument correctly
#
# if not (stream is None or (isinstance(stream, int) and stream == 0)):
# raise UnsupportedError(
# f"DLPack stream synchronization on {stream} not currently supported"
Expand Down Expand Up @@ -364,6 +365,34 @@ def to_device(self, device: DeviceOrMemoryType) -> Tensor:
-------
Tensor
The tensor moved to the specified device.
Examples
--------
tensor_cpu = tritonserver.Tensor.from_dlpack(numpy.array([0,1,2], dtype=numpy.float16))
# Different ways to specify the device
tensor_gpu = tensor_cpu.to_device(MemoryType.GPU)
tensor_gpu = tensor_cpu.to_device((MemoryType.GPU,0))
tensor_gpu = tensor_cpu.to_device((DLDeviceType.kDLCUDA,0))
tensor_gpu = tensor_cpu.to_device("gpu")
tensor_gpu = tensor_cpu.to_device("gpu:0")
ndarray_gpu = cupy.from_dlpack(tensor_gpu)
ndarray_gpu[0] = ndarray_gpu.mean()
tensor_cpu = tensor_gpu.to_device("cpu")
ndarray_cpu = numpy.from_dlpack(tensor_cpu)
assert ndarray_cpu[0] == ndarray_gpu[0]
"""
memory_type, memory_type_id = parse_device_or_memory_type(device)
if self.memory_type == memory_type and self.memory_type_id == memory_type_id:
Expand Down

0 comments on commit d1394a5

Please sign in to comment.