From 5c034116d48f7a0037dde8e16cda429b54ca5605 Mon Sep 17 00:00:00 2001 From: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Date: Thu, 29 Feb 2024 17:44:52 -0800 Subject: [PATCH] Add `exclude_input_in_output` option to vllm backend (#35) --- .../enabled_stream/enabled_stream_test.py | 98 +++++++++++++-- ci/L0_backend_vllm/enabled_stream/test.sh | 2 +- ci/L0_backend_vllm/vllm_backend/test.sh | 2 +- .../vllm_backend/vllm_backend_test.py | 117 ++++++++++++++++-- ci/common/test_util.py | 5 + samples/client.py | 46 +++++-- src/model.py | 68 ++++++++-- 7 files changed, 300 insertions(+), 38 deletions(-) diff --git a/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py b/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py index f723af8f..5d82333e 100644 --- a/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py +++ b/ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py @@ -34,37 +34,113 @@ sys.path.append("../../common") from test_util import AsyncTestResultCollector, create_vllm_request +PROMPTS = ["The most dangerous animal is"] +SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"} + class VLLMTritonStreamTest(AsyncTestResultCollector): - async def test_vllm_model_enabled_stream(self): + async def _test_vllm_model( + self, + prompts=PROMPTS, + sampling_parameters=SAMPLING_PARAMETERS, + stream=True, + exclude_input_in_output=None, + expected_output=None, + expect_error=False, + ): async with grpcclient.InferenceServerClient( url="localhost:8001" ) as triton_client: model_name = "vllm_opt" - stream = True - prompts = [ - "The most dangerous animal is", - "The future of AI is", - ] - sampling_parameters = {"temperature": "0", "top_p": "1"} async def request_iterator(): for i, prompt in enumerate(prompts): yield create_vllm_request( - prompt, i, stream, sampling_parameters, model_name + prompt, + i, + stream, + sampling_parameters, + model_name, + exclude_input_in_output=exclude_input_in_output, ) response_iterator = triton_client.stream_infer( inputs_iterator=request_iterator() ) - + final_response = [] async for response in response_iterator: result, error = response - self.assertIsNone(error, str(error)) - self.assertIsNotNone(result, str(result)) + if expect_error: + self.assertIsInstance(error, InferenceServerException) + self.assertEquals( + error.message(), + "Error generating stream: When streaming, `exclude_input_in_output` = False is not allowed.", + error, + ) + return + self.assertIsNone(error, error) + self.assertIsNotNone(result, result) output = result.as_numpy("text_output") self.assertIsNotNone(output, "`text_output` should not be None") + final_response.append(str(output[0], encoding="utf-8")) + if expected_output is not None: + self.assertEqual( + final_response, + expected_output, + 'Expected to receive the following response: "{}",\ + but received "{}".'.format( + expected_output, final_response + ), + ) + + async def test_vllm_model_enabled_stream(self): + """ + Verifying that request with multiple prompts runs successfully. + """ + prompts = [ + "The most dangerous animal is", + "The future of AI is", + ] + + await self._test_vllm_model(prompts=prompts) + + async def test_vllm_model_enabled_stream_exclude_input_in_output_default(self): + """ + Verifying that streaming request returns only generated diffs, which + is default behaviour for `stream=True`. + """ + expected_output = [ + " the", + " one", + " that", + " is", + " most", + " likely", + " to", + " be", + " killed", + " by", + " a", + " car", + ".", + "\n", + "I", + "'m", + ] + await self._test_vllm_model(expected_output=expected_output) + + async def test_vllm_model_enabled_stream_exclude_input_in_output_false(self): + """ + Verifying that streaming request returns only generated diffs even if + `exclude_input_in_output` is set to False explicitly. + """ + expected_output = "Error generating stream: When streaming, `exclude_input_in_output` = False is not allowed." + await self._test_vllm_model( + exclude_input_in_output=False, + expected_output=expected_output, + expect_error=True, + ) if __name__ == "__main__": diff --git a/ci/L0_backend_vllm/enabled_stream/test.sh b/ci/L0_backend_vllm/enabled_stream/test.sh index 4080699f..4fab0618 100755 --- a/ci/L0_backend_vllm/enabled_stream/test.sh +++ b/ci/L0_backend_vllm/enabled_stream/test.sh @@ -36,7 +36,7 @@ CLIENT_LOG="./enabled_stream_client.log" TEST_RESULT_FILE='test_results.txt' CLIENT_PY="./enabled_stream_test.py" SAMPLE_MODELS_REPO="../../../samples/model_repository" -EXPECTED_NUM_TESTS=1 +EXPECTED_NUM_TESTS=3 rm -rf models && mkdir -p models cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt diff --git a/ci/L0_backend_vllm/vllm_backend/test.sh b/ci/L0_backend_vllm/vllm_backend/test.sh index 716ee6ff..81a8b41f 100755 --- a/ci/L0_backend_vllm/vllm_backend/test.sh +++ b/ci/L0_backend_vllm/vllm_backend/test.sh @@ -36,7 +36,7 @@ CLIENT_LOG="./vllm_backend_client.log" TEST_RESULT_FILE='test_results.txt' CLIENT_PY="./vllm_backend_test.py" SAMPLE_MODELS_REPO="../../../samples/model_repository" -EXPECTED_NUM_TESTS=3 +EXPECTED_NUM_TESTS=6 # Helpers ======================================= function assert_curl_success { diff --git a/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py b/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py index cd953746..5fe48b1c 100644 --- a/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py +++ b/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py @@ -35,6 +35,13 @@ sys.path.append("../../common") from test_util import TestResultCollector, UserData, callback, create_vllm_request +PROMPTS = [ + "The most dangerous animal is", + "The capital of France is", + "The future of AI is", +] +SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"} + class VLLMTritonBackendTest(TestResultCollector): def setUp(self): @@ -60,8 +67,18 @@ def test_vllm_triton_backend(self): self.assertFalse(self.triton_client.is_model_ready(self.python_model_name)) # Test vllm model and unload vllm model - self._test_vllm_model(send_parameters_as_tensor=True) - self._test_vllm_model(send_parameters_as_tensor=False) + self._test_vllm_model( + prompts=PROMPTS, + sampling_parameters=SAMPLING_PARAMETERS, + stream=False, + send_parameters_as_tensor=True, + ) + self._test_vllm_model( + prompts=PROMPTS, + sampling_parameters=SAMPLING_PARAMETERS, + stream=False, + send_parameters_as_tensor=False, + ) self.triton_client.unload_model(self.vllm_model_name) def test_model_with_invalid_attributes(self): @@ -74,16 +91,90 @@ def test_vllm_invalid_model_name(self): with self.assertRaises(InferenceServerException): self.triton_client.load_model(model_name) - def _test_vllm_model(self, send_parameters_as_tensor): - user_data = UserData() - stream = False + def test_exclude_input_in_output_default(self): + """ + Verifying default behavior for `exclude_input_in_output` + in non-streaming mode. + Expected result: prompt is returned with diffs. + """ + self.triton_client.load_model(self.vllm_model_name) prompts = [ - "The most dangerous animal is", "The capital of France is", - "The future of AI is", ] - number_of_vllm_reqs = len(prompts) + expected_output = [ + b"The capital of France is the capital of the French Republic.\n\nThe capital of France is the capital" + ] + sampling_parameters = {"temperature": "0", "top_p": "1"} + self._test_vllm_model( + prompts, + sampling_parameters, + stream=False, + send_parameters_as_tensor=True, + expected_output=expected_output, + ) + self.triton_client.unload_model(self.vllm_model_name) + + def test_exclude_input_in_output_false(self): + """ + Verifying behavior for `exclude_input_in_output` = False + in non-streaming mode. + Expected result: prompt is returned with diffs. + """ + self.triton_client.load_model(self.vllm_model_name) + # Test vllm model and unload vllm model + prompts = [ + "The capital of France is", + ] + expected_output = [ + b"The capital of France is the capital of the French Republic.\n\nThe capital of France is the capital" + ] + sampling_parameters = {"temperature": "0", "top_p": "1"} + self._test_vllm_model( + prompts, + sampling_parameters, + stream=False, + send_parameters_as_tensor=True, + exclude_input_in_output=False, + expected_output=expected_output, + ) + self.triton_client.unload_model(self.vllm_model_name) + + def test_exclude_input_in_output_true(self): + """ + Verifying behavior for `exclude_input_in_output` = True + in non-streaming mode. + Expected result: only diffs are returned. + """ + self.triton_client.load_model(self.vllm_model_name) + # Test vllm model and unload vllm model + prompts = [ + "The capital of France is", + ] + expected_output = [ + b" the capital of the French Republic.\n\nThe capital of France is the capital" + ] sampling_parameters = {"temperature": "0", "top_p": "1"} + self._test_vllm_model( + prompts, + sampling_parameters, + stream=False, + send_parameters_as_tensor=True, + exclude_input_in_output=True, + expected_output=expected_output, + ) + self.triton_client.unload_model(self.vllm_model_name) + + def _test_vllm_model( + self, + prompts, + sampling_parameters, + stream, + send_parameters_as_tensor, + exclude_input_in_output=None, + expected_output=None, + ): + user_data = UserData() + number_of_vllm_reqs = len(prompts) self.triton_client.start_stream(callback=partial(callback, user_data)) for i in range(number_of_vllm_reqs): @@ -94,6 +185,7 @@ def _test_vllm_model(self, send_parameters_as_tensor): sampling_parameters, self.vllm_model_name, send_parameters_as_tensor, + exclude_input_in_output=exclude_input_in_output, ) self.triton_client.async_stream_infer( model_name=self.vllm_model_name, @@ -111,6 +203,15 @@ def _test_vllm_model(self, send_parameters_as_tensor): output = result.as_numpy("text_output") self.assertIsNotNone(output, "`text_output` should not be None") + if expected_output is not None: + self.assertEqual( + output, + expected_output[i], + 'Actual and expected outputs do not match.\n \ + Expected "{}" \n Actual:"{}"'.format( + output, expected_output[i] + ), + ) self.triton_client.stop_stream() diff --git a/ci/common/test_util.py b/ci/common/test_util.py index 23ea7784..61b2cfbe 100755 --- a/ci/common/test_util.py +++ b/ci/common/test_util.py @@ -95,6 +95,7 @@ def create_vllm_request( sampling_parameters, model_name, send_parameters_as_tensor=True, + exclude_input_in_output=None, ): inputs = [] @@ -111,6 +112,10 @@ def create_vllm_request( inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES")) inputs[-1].set_data_from_numpy(sampling_parameters_data) + if exclude_input_in_output is not None: + inputs.append(grpcclient.InferInput("exclude_input_in_output", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([exclude_input_in_output], dtype=bool)) + outputs = [grpcclient.InferRequestedOutput("text_output")] return { diff --git a/samples/client.py b/samples/client.py index 06bf0c3e..7940a5ca 100755 --- a/samples/client.py +++ b/samples/client.py @@ -45,7 +45,9 @@ def __init__(self, flags: argparse.Namespace): self._loop = asyncio.get_event_loop() self._results_dict = {} - async def async_request_iterator(self, prompts, sampling_parameters): + async def async_request_iterator( + self, prompts, sampling_parameters, exclude_input_in_output + ): try: for iter in range(self._flags.iterations): for i, prompt in enumerate(prompts): @@ -56,16 +58,17 @@ async def async_request_iterator(self, prompts, sampling_parameters): self._flags.streaming_mode, prompt_id, sampling_parameters, + exclude_input_in_output, ) except Exception as error: print(f"Caught an error in the request iterator: {error}") - async def stream_infer(self, prompts, sampling_parameters): + async def stream_infer(self, prompts, sampling_parameters, exclude_input_in_output): try: # Start streaming response_iterator = self._client.stream_infer( inputs_iterator=self.async_request_iterator( - prompts, sampling_parameters + prompts, sampling_parameters, exclude_input_in_output ), stream_timeout=self._flags.stream_timeout, ) @@ -75,33 +78,43 @@ async def stream_infer(self, prompts, sampling_parameters): print(error) sys.exit(1) - async def process_stream(self, prompts, sampling_parameters): + async def process_stream( + self, prompts, sampling_parameters, exclude_input_in_output + ): # Clear results in between process_stream calls self.results_dict = [] - + success = True # Read response from the stream - async for response in self.stream_infer(prompts, sampling_parameters): + async for response in self.stream_infer( + prompts, sampling_parameters, exclude_input_in_output + ): result, error = response if error: print(f"Encountered error while processing: {error}") + success = False else: output = result.as_numpy("text_output") for i in output: self._results_dict[result.get_response().id].append(i) + return success async def run(self): + exclude_input_in_output = self._flags.exclude_inputs_in_outputs sampling_parameters = {"temperature": "0.1", "top_p": "0.95"} with open(self._flags.input_prompts, "r") as file: print(f"Loading inputs from `{self._flags.input_prompts}`...") prompts = file.readlines() - await self.process_stream(prompts, sampling_parameters) + success = await self.process_stream( + prompts, sampling_parameters, exclude_input_in_output + ) with open(self._flags.results_file, "w") as file: for id in self._results_dict.keys(): for result in self._results_dict[id]: file.write(result.decode("utf-8")) - file.write("\n") + + file.write("\n") file.write("\n=========\n\n") print(f"Storing results into `{self._flags.results_file}`...") @@ -109,8 +122,10 @@ async def run(self): with open(self._flags.results_file, "r") as file: print(f"\nContents of `{self._flags.results_file}` ===>") print(file.read()) - - print("PASS: vLLM example") + if success: + print("PASS: vLLM example") + else: + print("FAIL: vLLM example") def run_async(self): self._loop.run_until_complete(self.run()) @@ -121,6 +136,7 @@ def create_request( stream, request_id, sampling_parameters, + exclude_input_in_output, send_parameters_as_tensor=True, ): inputs = [] @@ -146,6 +162,9 @@ def create_request( inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES")) inputs[-1].set_data_from_numpy(sampling_parameters_data) + inputs.append(grpcclient.InferInput("exclude_input_in_output", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([exclude_input_in_output], dtype=bool)) + # Add requested outputs outputs = [] outputs.append(grpcclient.InferRequestedOutput("text_output")) @@ -230,6 +249,13 @@ def create_request( default=False, help="Enable streaming mode", ) + parser.add_argument( + "--exclude-inputs-in-outputs", + action="store_true", + required=False, + default=False, + help="Exclude prompt from outputs", + ) FLAGS = parser.parse_args() client = LLMClient(FLAGS) diff --git a/src/model.py b/src/model.py index 80f51320..dce634b7 100644 --- a/src/model.py +++ b/src/model.py @@ -57,6 +57,12 @@ def auto_complete_config(auto_complete_model_config): "dims": [1], "optional": True, }, + { + "name": "exclude_input_in_output", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, ] outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}] @@ -206,12 +212,14 @@ def get_sampling_params_dict(self, params_json): return params_dict - def create_response(self, vllm_output): + def create_response(self, vllm_output, prepend_input): """ Parses the output from the vLLM engine into Triton response. """ - prompt = vllm_output.prompt + prompt = "" + if prepend_input: + prompt = vllm_output.prompt text_outputs = [ (prompt + output.text).encode("utf-8") for output in vllm_output.outputs ] @@ -220,6 +228,25 @@ def create_response(self, vllm_output): ) return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) + def create_stream_response(self, vllm_output, previous_outputs_lengths): + """ + Parses the output from the vLLM engine, extracts only newly generated + text and packs it into Triton response. + """ + if previous_outputs_lengths is None: + return self.create_response(vllm_output, prepend_input=False) + + text_outputs = [ + (output.text[prev_output_length:]).encode("utf-8") + for output, prev_output_length in zip( + vllm_output.outputs, previous_outputs_lengths + ) + ] + triton_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(text_outputs, dtype=self.output_dtype) + ) + return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) + async def generate(self, request): """ Forwards single request to LLM engine and returns responses. @@ -238,6 +265,23 @@ async def generate(self, request): stream = stream.as_numpy()[0] else: stream = False + prepend_input = pb_utils.get_input_tensor_by_name( + request, "exclude_input_in_output" + ) + if prepend_input: + # When `exclude_input_in_output` is False, we want to prepend + # input prompt to output, thus prepend_input should be True, + # and vice versa. + prepend_input = not prepend_input.as_numpy()[0] + elif prepend_input is None and stream: + prepend_input = False + else: + prepend_input = True + + if prepend_input and stream: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) # Request parameters are not yet supported via # BLS. Provide an optional mechanism to receive serialized @@ -255,6 +299,7 @@ async def generate(self, request): sampling_params = SamplingParams(**sampling_params_dict) last_output = None + prev_outputs = None async for output in self.llm_engine.generate( prompt, sampling_params, request_id ): @@ -264,19 +309,28 @@ async def generate(self, request): self.logger.log_info("[vllm] Successfully cancelled the request") break if stream: + prev_outputs_lengths = None + if prev_outputs is not None: + prev_outputs_lengths = [ + len(prev_output.text) + for prev_output in prev_outputs.outputs + ] if output.finished: response_sender.send( - self.create_response(output), + self.create_stream_response(output, prev_outputs_lengths), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ) else: - response_sender.send(self.create_response(output)) - else: - last_output = output + response_sender.send( + self.create_stream_response(output, prev_outputs_lengths) + ) + prev_outputs = output + + last_output = output if not stream: response_sender.send( - self.create_response(last_output), + self.create_response(last_output, prepend_input), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, )