diff --git a/src/model.py b/src/model.py index 3fe7cd1e..c9517208 100644 --- a/src/model.py +++ b/src/model.py @@ -25,8 +25,10 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import asyncio +import gc import json import os +import queue import threading from typing import Dict, List @@ -113,13 +115,19 @@ def initialize(self, args): # Counter to keep track of ongoing request counts self.ongoing_request_count = 0 + # Starting the response thread. It allows vLLM to keep making progress while + # response sender(s) are sending responses to server frontend. + self._response_queue = queue.Queue() + self._response_thread = threading.Thread(target=self.response_loop) + self._response_thread.start() + # Starting asyncio event loop to process the received requests asynchronously. self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread( + self._event_thread = threading.Thread( target=self.engine_loop, args=(self._loop,) ) self._shutdown_event = asyncio.Event() - self._loop_thread.start() + self._event_thread.start() def init_engine(self): # Currently, Triton needs to use decoupled policy for asynchronously @@ -273,6 +281,27 @@ def get_sampling_params_dict(self, params_json): return params_dict + def response_loop(self): + while True: + item = self._response_queue.get() + # To signal shutdown a None item will be added to the queue. + if item is None: + break + response_sender, response, response_flag = item + del item + try: + response_sender.send(response, response_flag) + except Exception as e: + self.logger.log_error( + f"An error occurred while sending a response: {e}" + ) + finally: + if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: + self.ongoing_request_count -= 1 + del response_sender + if self.ongoing_request_count == 0: + gc.collect() + def create_response(self, vllm_output, prepend_input): """ Parses the output from the vLLM engine into Triton @@ -314,6 +343,7 @@ async def generate(self, request): """ response_sender = request.get_response_sender() self.ongoing_request_count += 1 + decrement_ongoing_request_count = True try: request_id = random_uuid() prompt = pb_utils.get_input_tensor_by_name( @@ -368,9 +398,11 @@ async def generate(self, request): lora_local_path = self.lora_repository[lora_name] lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) - async for output in self.llm_engine.generate( - prompt, sampling_params, request_id, lora_request=lora_request - ): + response_iterator = await self.llm_engine.add_request( + request_id, prompt, sampling_params, lora_request=lora_request + ) + + async for output in response_iterator: if response_sender.is_cancelled(): self.logger.log_info("[vllm] Cancelling the request") await self.llm_engine.abort(request_id) @@ -383,15 +415,12 @@ async def generate(self, request): len(prev_output.text) for prev_output in prev_outputs.outputs ] + response = self.create_stream_response(output, prev_outputs_lengths) + flags = 0 if output.finished: - response_sender.send( - self.create_stream_response(output, prev_outputs_lengths), - flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, - ) - else: - response_sender.send( - self.create_stream_response(output, prev_outputs_lengths) - ) + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + decrement_ongoing_request_count = False + self._response_queue.put_nowait((response_sender, response, flags)) prev_outputs = output last_output = output @@ -403,7 +432,7 @@ async def generate(self, request): ) except Exception as e: - self.logger.log_info(f"[vllm] Error generating stream: {e}") + self.logger.log_error(f"[vllm] Error generating stream: {e}") error = pb_utils.TritonError(f"Error generating stream: {e}") triton_output_tensor = pb_utils.Tensor( "text_output", np.asarray(["N/A"], dtype=self.output_dtype) @@ -416,7 +445,11 @@ async def generate(self, request): ) raise e finally: - self.ongoing_request_count -= 1 + if decrement_ongoing_request_count: + self.ongoing_request_count -= 1 + del response_sender + if self.ongoing_request_count == 0: + gc.collect() def verify_loras(self, request): # We will check if the requested lora exists here, if not we will send a @@ -483,6 +516,14 @@ def finalize(self): """ self.logger.log_info("[vllm] Issuing finalize to vllm backend") self._shutdown_event.set() - if self._loop_thread is not None: - self._loop_thread.join() - self._loop_thread = None + + # Shutdown the event thread. + if self._event_thread is not None: + self._event_thread.join() + self._event_thread = None + + # Shutdown the response thread. + self._response_queue.put(None) + if self._response_thread is not None: + self._response_thread.join() + self._response_thread = None