Skip to content

Commit

Permalink
perf: Improve vLLM backend performance by using a separate thread for…
Browse files Browse the repository at this point in the history
… responses (#46)

Co-authored-by: Jacky <[email protected]>
  • Loading branch information
Tabrizian and kthui authored Jul 26, 2024
1 parent 05c5a8b commit 128abc3
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import asyncio
import gc
import json
import os
import queue
import threading
from typing import Dict, List

Expand Down Expand Up @@ -113,13 +115,19 @@ def initialize(self, args):
# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0

# Starting the response thread. It allows vLLM to keep making progress while
# response sender(s) are sending responses to server frontend.
self._response_queue = queue.Queue()
self._response_thread = threading.Thread(target=self.response_loop)
self._response_thread.start()

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
self._event_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()
self._event_thread.start()

def init_engine(self):
# Currently, Triton needs to use decoupled policy for asynchronously
Expand Down Expand Up @@ -273,6 +281,27 @@ def get_sampling_params_dict(self, params_json):

return params_dict

def response_loop(self):
while True:
item = self._response_queue.get()
# To signal shutdown a None item will be added to the queue.
if item is None:
break
response_sender, response, response_flag = item
del item
try:
response_sender.send(response, response_flag)
except Exception as e:
self.logger.log_error(
f"An error occurred while sending a response: {e}"
)
finally:
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
self.ongoing_request_count -= 1
del response_sender
if self.ongoing_request_count == 0:
gc.collect()

def create_response(self, vllm_output, prepend_input):
"""
Parses the output from the vLLM engine into Triton
Expand Down Expand Up @@ -314,6 +343,7 @@ async def generate(self, request):
"""
response_sender = request.get_response_sender()
self.ongoing_request_count += 1
decrement_ongoing_request_count = True
try:
request_id = random_uuid()
prompt = pb_utils.get_input_tensor_by_name(
Expand Down Expand Up @@ -368,9 +398,11 @@ async def generate(self, request):
lora_local_path = self.lora_repository[lora_name]
lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path)

async for output in self.llm_engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
):
response_iterator = await self.llm_engine.add_request(
request_id, prompt, sampling_params, lora_request=lora_request
)

async for output in response_iterator:
if response_sender.is_cancelled():
self.logger.log_info("[vllm] Cancelling the request")
await self.llm_engine.abort(request_id)
Expand All @@ -383,15 +415,12 @@ async def generate(self, request):
len(prev_output.text)
for prev_output in prev_outputs.outputs
]
response = self.create_stream_response(output, prev_outputs_lengths)
flags = 0
if output.finished:
response_sender.send(
self.create_stream_response(output, prev_outputs_lengths),
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
)
else:
response_sender.send(
self.create_stream_response(output, prev_outputs_lengths)
)
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
decrement_ongoing_request_count = False
self._response_queue.put_nowait((response_sender, response, flags))
prev_outputs = output

last_output = output
Expand All @@ -403,7 +432,7 @@ async def generate(self, request):
)

except Exception as e:
self.logger.log_info(f"[vllm] Error generating stream: {e}")
self.logger.log_error(f"[vllm] Error generating stream: {e}")
error = pb_utils.TritonError(f"Error generating stream: {e}")
triton_output_tensor = pb_utils.Tensor(
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
Expand All @@ -416,7 +445,11 @@ async def generate(self, request):
)
raise e
finally:
self.ongoing_request_count -= 1
if decrement_ongoing_request_count:
self.ongoing_request_count -= 1
del response_sender
if self.ongoing_request_count == 0:
gc.collect()

def verify_loras(self, request):
# We will check if the requested lora exists here, if not we will send a
Expand Down Expand Up @@ -483,6 +516,14 @@ def finalize(self):
"""
self.logger.log_info("[vllm] Issuing finalize to vllm backend")
self._shutdown_event.set()
if self._loop_thread is not None:
self._loop_thread.join()
self._loop_thread = None

# Shutdown the event thread.
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None

# Shutdown the response thread.
self._response_queue.put(None)
if self._response_thread is not None:
self._response_thread.join()
self._response_thread = None

0 comments on commit 128abc3

Please sign in to comment.