diff --git a/src/model.py b/src/model.py index 2d9d8ff8..d0e1c6ad 100644 --- a/src/model.py +++ b/src/model.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -28,18 +28,16 @@ import json import os import threading -from typing import Dict, List +from typing import AsyncGenerator import numpy as np import triton_python_backend_utils as pb_utils +from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid _VLLM_ENGINE_ARGS_FILENAME = "model.json" -_MULTI_LORA_ARGS_FILENAME = "multi_lora.json" class TritonPythonModel: @@ -98,12 +96,32 @@ def auto_complete_config(auto_complete_model_config): return auto_complete_model_config def initialize(self, args): + self.args = args self.logger = pb_utils.Logger self.model_config = json.loads(args["model_config"]) - # assert are in decoupled mode. Currently, Triton needs to use - # decoupled policy for asynchronously forwarding requests to - # vLLM engine. + # Prepare vLLM engine + self.init_engine() + + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + # Counter to keep track of ongoing request counts + self.ongoing_request_count = 0 + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._loop_thread.start() + + def init_engine(self): + # Currently, Triton needs to use decoupled policy for asynchronously + # forwarding requests to vLLM engine, so assert it. self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( self.model_config ) @@ -111,56 +129,48 @@ def initialize(self, args): self.using_decoupled ), "vLLM Triton backend must be configured to use decoupled model transaction policy" + # Parse user-provided vLLM config engine_args_filepath = os.path.join( pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME ) assert os.path.isfile( engine_args_filepath ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'" + with open(engine_args_filepath) as file: vllm_engine_config = json.load(file) + # Validate device and multi-processing settings are currently set based on model/configs. + kind = self.args["model_instance_kind"] + device_id = self.args["model_instance_device_id"] + self.validate_device_config(kind, device_id, vllm_engine_config) + # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs(**vllm_engine_config) ) - self.enable_lora = False - - if ( - "enable_lora" in vllm_engine_config.keys() - and vllm_engine_config["enable_lora"].lower() == "true" - ): - # create Triton LoRA weights repository - multi_lora_args_filepath = os.path.join( - pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME - ) - try: - with open(multi_lora_args_filepath) as lora_file: - lora_repository: Dict[str, str] = json.load(lora_file) - self.lora_repository = lora_repository - self.supported_loras: List[str] = list(self.lora_repository.keys()) - self.supported_loras_len = len(self.supported_loras) - self.enable_lora = True - except FileNotFoundError: - raise FileNotFoundError( - f"Triton backend cannot find {multi_lora_args_filepath}." - ) - output_config = pb_utils.get_output_config_by_name( - self.model_config, "text_output" - ) - self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - - # Counter to keep track of ongoing request counts - self.ongoing_request_count = 0 + def validate_device_config(self, triton_kind, triton_device_id, vllm_engine_config): + # Triton's current definition of KIND_GPU makes assumptions that + # models only use a single GPU. For multi-GPU models, the recommendation + # is to specify KIND_MODEL to acknowledge that the model will take control + # of the devices made available to it. + # NOTE: Consider other parameters that would indicate multi-GPU in the future. + tp_size = int(vllm_engine_config.get("tensor_parallel_size", 1)) + if tp_size > 1 and triton_kind == "GPU": + raise ValueError( + "KIND_GPU is for single-GPU models, please specify KIND_MODEL in the model's config.pbtxt for multi-GPU models" + ) - # Starting asyncio event loop to process the received requests asynchronously. - self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread( - target=self.engine_loop, args=(self._loop,) - ) - self._shutdown_event = asyncio.Event() - self._loop_thread.start() + # If KIND_GPU is specified, isolate the selected GPU device to ensure that multiple model instances do not all use the same default + # device (usually device 0) when KIND_GPU is used. + if triton_kind == "GPU" and int(triton_device_id) >= 0: + self.logger.log_info( + f"Detected KIND_GPU model instance, explicitly setting GPU device={device_id}" + ) + # NOTE: this only affects this process and it's subprocesses, not other processes. + # vLLM doesn't currently seem to expose selecting a specific device in the APIs. + os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id def create_task(self, coro): """ @@ -319,19 +329,12 @@ async def generate(self, request): parameters = request.parameters() sampling_params_dict = self.get_sampling_params_dict(parameters) - lora_name = sampling_params_dict.pop("lora_name", None) sampling_params = SamplingParams(**sampling_params_dict) + last_output = None prev_outputs = None - lora_request = None - if lora_name is not None: - lora_id = str(self.supported_loras.index(lora_name) + 1) - lora_int_id = int(lora_id) - lora_local_path = self.lora_repository[lora_name] - lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) - async for output in self.llm_engine.generate( - prompt, sampling_params, request_id, lora_request=lora_request + prompt, sampling_params, request_id ): if response_sender.is_cancelled(): self.logger.log_info("[vllm] Cancelling the request") @@ -380,49 +383,6 @@ async def generate(self, request): finally: self.ongoing_request_count -= 1 - def verify_loras(self, request): - # We will check if the requested lora exists here, if not we will send a - # response with `LoRA not found` information. In this way we may avoid - # further processing. - verified_request = None - lora_error = None - lora_name = None - parameters_input_tensor = pb_utils.get_input_tensor_by_name( - request, "sampling_parameters" - ) - if parameters_input_tensor: - parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") - sampling_params_dict = self.get_sampling_params_dict(parameters) - lora_name = sampling_params_dict.pop("lora_name", None) - - if lora_name is not None: - if not self.enable_lora: - lora_error = pb_utils.TritonError("LoRA feature is not enabled.") - self.logger.log_info( - "[vllm] LoRA is not enabled, please restart the backend with LoRA enabled." - ) - elif lora_name not in self.supported_loras: - lora_error = pb_utils.TritonError( - f"LoRA {lora_name} is not supported, we currently support {self.supported_loras}" - ) - self.logger.log_info(f"[vllm] LoRA {lora_name} not found.") - - if lora_error is not None: - output_tensor = pb_utils.Tensor( - "text_output", - np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype), - ) - response = pb_utils.InferenceResponse( - output_tensors=[output_tensor], error=lora_error - ) - response_sender = request.get_response_sender() - response_sender.send( - response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - ) - else: - verified_request = request - return verified_request - def execute(self, requests): """ Triton core issues requests to the backend via this method. @@ -434,9 +394,7 @@ def execute(self, requests): We are pushing all the requests on vllm and let it handle the full traffic. """ for request in requests: - request = self.verify_loras(request) - if request is not None: - self.create_task(self.generate(request)) + self.create_task(self.generate(request)) return None def finalize(self):