Skip to content

Commit

Permalink
Enhance checks around KIND_GPU and tensor parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
rmccorm4 committed May 24, 2024
1 parent 861a198 commit dc12c3b
Showing 1 changed file with 56 additions and 98 deletions.
154 changes: 56 additions & 98 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -28,18 +28,16 @@
import json
import os
import threading
from typing import Dict, List
from typing import AsyncGenerator

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'AsyncGenerator' is not used.

import numpy as np
import triton_python_backend_utils as pb_utils
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

_VLLM_ENGINE_ARGS_FILENAME = "model.json"
_MULTI_LORA_ARGS_FILENAME = "multi_lora.json"


class TritonPythonModel:
Expand Down Expand Up @@ -98,69 +96,81 @@ def auto_complete_config(auto_complete_model_config):
return auto_complete_model_config

def initialize(self, args):
self.args = args
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])

# assert are in decoupled mode. Currently, Triton needs to use
# decoupled policy for asynchronously forwarding requests to
# vLLM engine.
# Prepare vLLM engine
self.init_engine()

output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()

def init_engine(self):
# Currently, Triton needs to use decoupled policy for asynchronously
# forwarding requests to vLLM engine, so assert it.
self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
self.model_config
)
assert (
self.using_decoupled
), "vLLM Triton backend must be configured to use decoupled model transaction policy"

# Parse user-provided vLLM config
engine_args_filepath = os.path.join(
pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME
)
assert os.path.isfile(
engine_args_filepath
), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'"

with open(engine_args_filepath) as file:
vllm_engine_config = json.load(file)

# Validate device and multi-processing settings are currently set based on model/configs.
kind = self.args["model_instance_kind"]
device_id = self.args["model_instance_device_id"]
self.validate_device_config(kind, device_id, vllm_engine_config)

# Create an AsyncLLMEngine from the config from JSON
self.llm_engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(**vllm_engine_config)
)
self.enable_lora = False

if (
"enable_lora" in vllm_engine_config.keys()
and vllm_engine_config["enable_lora"].lower() == "true"
):
# create Triton LoRA weights repository
multi_lora_args_filepath = os.path.join(
pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME
)
try:
with open(multi_lora_args_filepath) as lora_file:
lora_repository: Dict[str, str] = json.load(lora_file)
self.lora_repository = lora_repository
self.supported_loras: List[str] = list(self.lora_repository.keys())
self.supported_loras_len = len(self.supported_loras)
self.enable_lora = True
except FileNotFoundError:
raise FileNotFoundError(
f"Triton backend cannot find {multi_lora_args_filepath}."
)

output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0
def validate_device_config(self, triton_kind, triton_device_id, vllm_engine_config):
# Triton's current definition of KIND_GPU makes assumptions that
# models only use a single GPU. For multi-GPU models, the recommendation
# is to specify KIND_MODEL to acknowledge that the model will take control
# of the devices made available to it.
# NOTE: Consider other parameters that would indicate multi-GPU in the future.
tp_size = int(vllm_engine_config.get("tensor_parallel_size", 1))
if tp_size > 1 and triton_kind == "GPU":
raise ValueError(
"KIND_GPU is for single-GPU models, please specify KIND_MODEL in the model's config.pbtxt for multi-GPU models"
)

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()
# If KIND_GPU is specified, isolate the selected GPU device to ensure that multiple model instances do not all use the same default
# device (usually device 0) when KIND_GPU is used.
if triton_kind == "GPU" and int(triton_device_id) >= 0:
self.logger.log_info(
f"Detected KIND_GPU model instance, explicitly setting GPU device={device_id}"
)
# NOTE: this only affects this process and it's subprocesses, not other processes.
# vLLM doesn't currently seem to expose selecting a specific device in the APIs.
os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id

def create_task(self, coro):
"""
Expand Down Expand Up @@ -319,19 +329,12 @@ async def generate(self, request):
parameters = request.parameters()

sampling_params_dict = self.get_sampling_params_dict(parameters)
lora_name = sampling_params_dict.pop("lora_name", None)
sampling_params = SamplingParams(**sampling_params_dict)

last_output = None
prev_outputs = None
lora_request = None
if lora_name is not None:
lora_id = str(self.supported_loras.index(lora_name) + 1)
lora_int_id = int(lora_id)
lora_local_path = self.lora_repository[lora_name]
lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path)

async for output in self.llm_engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
prompt, sampling_params, request_id
):
if response_sender.is_cancelled():
self.logger.log_info("[vllm] Cancelling the request")
Expand Down Expand Up @@ -380,49 +383,6 @@ async def generate(self, request):
finally:
self.ongoing_request_count -= 1

def verify_loras(self, request):
# We will check if the requested lora exists here, if not we will send a
# response with `LoRA not found` information. In this way we may avoid
# further processing.
verified_request = None
lora_error = None
lora_name = None
parameters_input_tensor = pb_utils.get_input_tensor_by_name(
request, "sampling_parameters"
)
if parameters_input_tensor:
parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8")
sampling_params_dict = self.get_sampling_params_dict(parameters)
lora_name = sampling_params_dict.pop("lora_name", None)

if lora_name is not None:
if not self.enable_lora:
lora_error = pb_utils.TritonError("LoRA feature is not enabled.")
self.logger.log_info(
"[vllm] LoRA is not enabled, please restart the backend with LoRA enabled."
)
elif lora_name not in self.supported_loras:
lora_error = pb_utils.TritonError(
f"LoRA {lora_name} is not supported, we currently support {self.supported_loras}"
)
self.logger.log_info(f"[vllm] LoRA {lora_name} not found.")

if lora_error is not None:
output_tensor = pb_utils.Tensor(
"text_output",
np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype),
)
response = pb_utils.InferenceResponse(
output_tensors=[output_tensor], error=lora_error
)
response_sender = request.get_response_sender()
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
else:
verified_request = request
return verified_request

def execute(self, requests):
"""
Triton core issues requests to the backend via this method.
Expand All @@ -434,9 +394,7 @@ def execute(self, requests):
We are pushing all the requests on vllm and let it handle the full traffic.
"""
for request in requests:
request = self.verify_loras(request)
if request is not None:
self.create_task(self.generate(request))
self.create_task(self.generate(request))
return None

def finalize(self):
Expand Down

0 comments on commit dc12c3b

Please sign in to comment.