diff --git a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py index 16b05fe2..59c669b6 100644 --- a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py +++ b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py @@ -104,10 +104,10 @@ def test_vllm_model(self): for i in range(number_of_vllm_reqs): result = user_data._completed_requests.get() - self.assertIsNot(type(result), InferenceServerException) + self.assertIsNot(type(result), InferenceServerException, str(result)) output = result.as_numpy("text_output") - self.assertIsNotNone(output) + self.assertIsNotNone(output, "`text_output` should not be None") triton_vllm_output.extend(output) diff --git a/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py b/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py index aec8c2eb..f723af8f 100644 --- a/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py +++ b/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py @@ -60,11 +60,11 @@ async def request_iterator(): async for response in response_iterator: result, error = response - self.assertIsNone(error) - self.assertIsNotNone(result) + self.assertIsNone(error, str(error)) + self.assertIsNotNone(result, str(result)) output = result.as_numpy("text_output") - self.assertIsNotNone(output) + self.assertIsNotNone(output, "`text_output` should not be None") if __name__ == "__main__": diff --git a/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py b/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py index 2b670662..cd953746 100644 --- a/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py +++ b/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py @@ -107,10 +107,10 @@ def _test_vllm_model(self, send_parameters_as_tensor): result = user_data._completed_requests.get() if type(result) is InferenceServerException: print(result.message()) - self.assertIsNot(type(result), InferenceServerException) + self.assertIsNot(type(result), InferenceServerException, str(result)) output = result.as_numpy("text_output") - self.assertIsNotNone(output) + self.assertIsNotNone(output, "`text_output` should not be None") self.triton_client.stop_stream() diff --git a/samples/model_repository/vllm_model/config.pbtxt b/samples/model_repository/vllm_model/config.pbtxt index 169f3815..b5a6c1ae 100644 --- a/samples/model_repository/vllm_model/config.pbtxt +++ b/samples/model_repository/vllm_model/config.pbtxt @@ -28,46 +28,6 @@ backend: "vllm" -# Disabling batching in Triton, let vLLM handle the batching on its own. -max_batch_size: 0 - -# We need to use decoupled transaction policy for saturating -# vLLM engine for max throughtput. -# TODO [DLIS:5233]: Allow asynchronous execution to lift this -# restriction for cases there is exactly a single response to -# a single request. -model_transaction_policy { - decoupled: True -} -# Note: The vLLM backend uses the following input and output names. -# Any change here needs to also be made in model.py -input [ - { - name: "text_input" - data_type: TYPE_STRING - dims: [ 1 ] - }, - { - name: "stream" - data_type: TYPE_BOOL - dims: [ 1 ] - }, - { - name: "sampling_parameters" - data_type: TYPE_STRING - dims: [ 1 ] - optional: true - } -] - -output [ - { - name: "text_output" - data_type: TYPE_STRING - dims: [ -1 ] - } -] - # The usage of device is deferred to the vLLM engine instance_group [ { diff --git a/src/model.py b/src/model.py index acdf33cd..f21f2a35 100644 --- a/src/model.py +++ b/src/model.py @@ -41,6 +41,54 @@ class TritonPythonModel: + @staticmethod + def auto_complete_config(auto_complete_model_config): + inputs = [ + {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, + { + "name": "stream", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "sampling_parameters", + "data_type": "TYPE_STRING", + "dims": [1], + "optional": True, + }, + ] + outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}] + + # Store the model configuration as a dictionary. + config = auto_complete_model_config.as_dict() + input_names = [] + output_names = [] + for input in config["input"]: + input_names.append(input["name"]) + for output in config["output"]: + output_names.append(output["name"]) + + # Add only missing inputs and output to the model configuration. + for input in inputs: + if input["name"] not in input_names: + auto_complete_model_config.add_input(input) + for output in outputs: + if output["name"] not in output_names: + auto_complete_model_config.add_output(output) + + # We need to use decoupled transaction policy for saturating + # vLLM engine for max throughtput. + # TODO [DLIS:5233]: Allow asynchronous execution to lift this + # restriction for cases there is exactly a single response to + # a single request. + auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True)) + + # Disabling batching in Triton, let vLLM handle the batching on its own. + auto_complete_model_config.set_max_batch_size(0) + + return auto_complete_model_config + def initialize(self, args): self.logger = pb_utils.Logger self.model_config = json.loads(args["model_config"])