Skip to content

Commit

Permalink
updated with basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nnshah1 committed Jan 9, 2024
1 parent d1394a5 commit e7d441d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 214 deletions.
230 changes: 21 additions & 209 deletions python/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@
except ImportError:
cupy = None

try:
import torch
except ImportError:
torch = None


class ServerTests(unittest.TestCase):
server_options = tritonserver.Options(
Expand Down Expand Up @@ -75,212 +70,29 @@ def test_ready(self):
self.assertTrue(server.ready())


class ModelTests(unittest.TestCase):
pass


class InferenceTests(unittest.TestCase):
pass


class TensorTests(unittest.TestCase):
pass


# class TrtionServerAPITest(unittest.TestCase):
# def test_not_started(self):
# server = tritonserver.Server()
# with self.assertRaises(tritonserver.InvalidArgumentError):
# server.ready()

# @pytest.mark.skipif(cupy is None, reason="Skipping gpu memory, cpupy not installed")
# def test_gpu_memory(self):
# import cupy

# server = tritonserver.Server(
# model_repository="/workspace/models", exit_timeout=5
# )

# server.start(blocking=True)

# test = server.model("test")
# fp16_input = cupy.array([[5], [6], [7], [8]], dtype=numpy.float16)
# responses = test.infer(inputs={"fp16_input": fp16_input}, request_id="1")

# for response in responses:
# print(response)

# responses = server.models()["test"].infer(
# inputs={"fp16_input": fp16_input}, request_id="1"
# )

# for response in responses:
# print(response)
# try:
# pass
# # server.stop()
# except Exception as error:
# print(error)

# def test_unload(self):
# server = tritonserver.Server(
# model_repository="/workspace/models",
# exit_timeout=5,
# model_control_mode=tritonserver.ModelControlMode.EXPLICIT,
# startup_models=["test"],
# log_verbose=True,
# log_error=True,
# log_info=True,
# )
# server.start(blocking=True)

# model = server.models["test"]

# responses = model.infer(
# inputs={"fp16_input": numpy.array([[0.5]], dtype=numpy.float16)}
# )

# print(list(responses)[0])

# print(model.is_ready())

# model_versions = [key for key in server.models.keys() if key[0] == model.name]

# server.unload_model(model.name, blocking=True)

# server.unload_model("foo", blocking=True)

# # del responses

# while True:
# if [
# key
# for key in model_versions
# if (
# server.models[key].state not in server.models[key].state is not None
# and server.models[key].state != "UNAVAILABLE"
# )
# ]:
# print(list(server.models.items()))
# time.sleep(5)
# continue
# break
# print(server.models[model.name])
# print(list(server.models.items()))

# print(model.is_ready())
# server.stop()

# def test_inference(self):
# server = tritonserver.Server(
# model_repository="/workspace/models",
# exit_timeout=5
# # log_verbose=True,
# )
# # log_error=True,
# server.start()
# while not server.ready():
# pass

# response_queue = queue.SimpleQueue()

# test = server.get_model("test")
# test_2 = server.get_model("test_2")

# inputs = {
# "text_input": numpy.array(["hello"], dtype=numpy.object_),
# "fp16_input": numpy.array([["1"]], dtype=numpy.float16),
# }

# responses_1 = test.infer(
# inputs=inputs, request_id="1", response_queue=response_queue
# )
# responses_2 = test.infer(
# inputs=inputs, request_id="2", response_queue=response_queue
# )

# responses_3 = test_2.infer(inputs=inputs)

# for response in responses_3:
# print(response)

# count = 0
# while count < 2:
# response = response_queue.get()
# count += 1
# print(response, count)
# print(response.outputs["text_output"])
# print(bytes(response.outputs["text_output"][0]))
# print(type(response.outputs["text_output"][0]))
# print(response.outputs["fp16_output"])
# print(type(response.outputs["fp16_output"][0]))

# # for response in test.infer(inputs=inputs):
# # print(response.outputs["text_output"])
# # print(response.outputs["fp16_output"])

# print(test.statistics())
# print(test_2.statistics())

# # print(server.metrics())

# try:
# # pass
# server.stop()
# except Exception as error:
# print(error)


# class AsyncInferenceTest(unittest.IsolatedAsyncioTestCase):
# async def test_async_inference(self):
# server = tritonserver.Server(
# model_repository=["/workspace/models"],
# exit_timeout=30
# # log_verbose=True,
# # log_error=True)
# )
# server.start()
# while not server.is_ready():
# pass

# test = server.models["test"]

# inputs = {
# "text_input": numpy.array(["hello"], dtype=numpy.object_),
# "fp16_input": numpy.array([["1"]], dtype=numpy.float16),
# }

# response_queue = asyncio.Queue()
# responses = test.async_infer(
# inputs=inputs, response_queue=response_queue, request_id="1"
# )
# responses_2 = test.async_infer(
# inputs=inputs, response_queue=response_queue, request_id="2"
# )
# responses_3 = test.async_infer(
# inputs=inputs, response_queue=response_queue, request_id="3"
# )
server_options = tritonserver.Options(
server_id="TestServer",
model_repository="test_api_models",
log_verbose=1,
exit_on_error=False,
exit_timeout=5,
)

# print("here cancelling!", flush=True)
# responses.cancel()
# print("here cancelling!", flush=True)
def test_basic_inference(self):
server = tritonserver.Server(InferenceTests.server_options).start(
wait_until_ready=True
)

# async for response in responses:
# print("async")
# print(response.outputs["text_output"])
# print(response.outputs["fp16_output"])
# print(response.request_id)
self.assertTrue(server.ready())

# count = 0
# while count < 3:
# response = await response_queue.get()
# print(response, count)
# count += 1
fp16_input = numpy.array([[5]], dtype=numpy.float16)

# print("calling stop!")
# try:
# # pass
# server.stop()
# except Exception as error:
# print(error)
# print("stopped!", flush=True)
for response in server.model("test").infer(
inputs={"fp16_input": fp16_input},
output_memory_type="cpu",
raise_on_error=True,
):
fp16_output = numpy.from_dlpack(response.outputs["fp16_output"])
self.assertEqual(fp16_input, fp16_output)
server.stop()
7 changes: 5 additions & 2 deletions python/tritonserver/_api/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@

"""Class for interacting with Triton Inference Server Models"""

from __future__ import annotations

import asyncio
import json
import queue
from typing import Any, Optional

from tritonserver._api._allocators import ResponseAllocator
from tritonserver._api._request import InferenceRequest
from tritonserver._api._response import AsyncResponseIterator, ResponseIterator
from tritonserver._c.triton_bindings import InvalidArgumentError
Expand Down Expand Up @@ -194,7 +197,7 @@ def async_infer(
raise_on_error,
)

response_allocator = _datautils.ResponseAllocator(
response_allocator = ResponseAllocator(
inference_request.output_memory_allocator,
inference_request.output_memory_type,
).create_TRITONSERVER_ResponseAllocator()
Expand Down Expand Up @@ -290,7 +293,7 @@ def infer(
inference_request.response_queue,
raise_on_error,
)
response_allocator = _datautils.ResponseAllocator(
response_allocator = ResponseAllocator(
inference_request.output_memory_allocator,
inference_request.output_memory_type,
).create_TRITONSERVER_ResponseAllocator()
Expand Down
8 changes: 6 additions & 2 deletions python/tritonserver/_api/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _response_callback(self, response, flags, unused):
if self._request is None:
raise InternalError("Response received after final response flag")

response = InferenceResponse._from_TRITONSERVER_InferenceResponse(
response = InferenceResponse._from_tritonserver_inference_response(
self._model, self._server, self._request, response, flags
)
asyncio.run_coroutine_threadsafe(self._queue.put(response), self._loop)
Expand All @@ -206,6 +206,8 @@ def _response_callback(self, response, flags, unused):
line_number,
str(e),
)
# catastrophic failure
raise e from None


class ResponseIterator:
Expand Down Expand Up @@ -330,7 +332,7 @@ def _response_callback(self, response, flags, unused):
if self._request is None:
raise InternalError("Response received after final response flag")

response = InferenceResponse._from_TRITONSERVER_InferenceResponse(
response = InferenceResponse._from_tritonserver_inference_response(
self._model, self._server, self._request, response, flags
)
self._queue.put(response)
Expand All @@ -352,6 +354,8 @@ def _response_callback(self, response, flags, unused):
line_number,
str(e),
)
# catastrophic failure
raise e from None


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion python/tritonserver/_api/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def start(
raise InvalidArgumentError("Server already started")

self._server = TRITONSERVER_Server(
self.options._create_TRITONSERVER_ServerOptions()
self.options._create_tritonserver_server_options()
)
start_time = time.time()
while (
Expand Down

0 comments on commit e7d441d

Please sign in to comment.