From dc12c3b61ef1cb7da84c59992b26457df474617c Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 24 May 2024 11:54:55 -0700 Subject: [PATCH 1/9] Enhance checks around KIND_GPU and tensor parallelism --- src/model.py | 154 +++++++++++++++++++-------------------------------- 1 file changed, 56 insertions(+), 98 deletions(-) diff --git a/src/model.py b/src/model.py index 2d9d8ff8..d0e1c6ad 100644 --- a/src/model.py +++ b/src/model.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -28,18 +28,16 @@ import json import os import threading -from typing import Dict, List +from typing import AsyncGenerator import numpy as np import triton_python_backend_utils as pb_utils +from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid _VLLM_ENGINE_ARGS_FILENAME = "model.json" -_MULTI_LORA_ARGS_FILENAME = "multi_lora.json" class TritonPythonModel: @@ -98,12 +96,32 @@ def auto_complete_config(auto_complete_model_config): return auto_complete_model_config def initialize(self, args): + self.args = args self.logger = pb_utils.Logger self.model_config = json.loads(args["model_config"]) - # assert are in decoupled mode. Currently, Triton needs to use - # decoupled policy for asynchronously forwarding requests to - # vLLM engine. + # Prepare vLLM engine + self.init_engine() + + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + # Counter to keep track of ongoing request counts + self.ongoing_request_count = 0 + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._loop_thread.start() + + def init_engine(self): + # Currently, Triton needs to use decoupled policy for asynchronously + # forwarding requests to vLLM engine, so assert it. self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( self.model_config ) @@ -111,56 +129,48 @@ def initialize(self, args): self.using_decoupled ), "vLLM Triton backend must be configured to use decoupled model transaction policy" + # Parse user-provided vLLM config engine_args_filepath = os.path.join( pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME ) assert os.path.isfile( engine_args_filepath ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'" + with open(engine_args_filepath) as file: vllm_engine_config = json.load(file) + # Validate device and multi-processing settings are currently set based on model/configs. + kind = self.args["model_instance_kind"] + device_id = self.args["model_instance_device_id"] + self.validate_device_config(kind, device_id, vllm_engine_config) + # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs(**vllm_engine_config) ) - self.enable_lora = False - - if ( - "enable_lora" in vllm_engine_config.keys() - and vllm_engine_config["enable_lora"].lower() == "true" - ): - # create Triton LoRA weights repository - multi_lora_args_filepath = os.path.join( - pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME - ) - try: - with open(multi_lora_args_filepath) as lora_file: - lora_repository: Dict[str, str] = json.load(lora_file) - self.lora_repository = lora_repository - self.supported_loras: List[str] = list(self.lora_repository.keys()) - self.supported_loras_len = len(self.supported_loras) - self.enable_lora = True - except FileNotFoundError: - raise FileNotFoundError( - f"Triton backend cannot find {multi_lora_args_filepath}." - ) - output_config = pb_utils.get_output_config_by_name( - self.model_config, "text_output" - ) - self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - - # Counter to keep track of ongoing request counts - self.ongoing_request_count = 0 + def validate_device_config(self, triton_kind, triton_device_id, vllm_engine_config): + # Triton's current definition of KIND_GPU makes assumptions that + # models only use a single GPU. For multi-GPU models, the recommendation + # is to specify KIND_MODEL to acknowledge that the model will take control + # of the devices made available to it. + # NOTE: Consider other parameters that would indicate multi-GPU in the future. + tp_size = int(vllm_engine_config.get("tensor_parallel_size", 1)) + if tp_size > 1 and triton_kind == "GPU": + raise ValueError( + "KIND_GPU is for single-GPU models, please specify KIND_MODEL in the model's config.pbtxt for multi-GPU models" + ) - # Starting asyncio event loop to process the received requests asynchronously. - self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread( - target=self.engine_loop, args=(self._loop,) - ) - self._shutdown_event = asyncio.Event() - self._loop_thread.start() + # If KIND_GPU is specified, isolate the selected GPU device to ensure that multiple model instances do not all use the same default + # device (usually device 0) when KIND_GPU is used. + if triton_kind == "GPU" and int(triton_device_id) >= 0: + self.logger.log_info( + f"Detected KIND_GPU model instance, explicitly setting GPU device={device_id}" + ) + # NOTE: this only affects this process and it's subprocesses, not other processes. + # vLLM doesn't currently seem to expose selecting a specific device in the APIs. + os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id def create_task(self, coro): """ @@ -319,19 +329,12 @@ async def generate(self, request): parameters = request.parameters() sampling_params_dict = self.get_sampling_params_dict(parameters) - lora_name = sampling_params_dict.pop("lora_name", None) sampling_params = SamplingParams(**sampling_params_dict) + last_output = None prev_outputs = None - lora_request = None - if lora_name is not None: - lora_id = str(self.supported_loras.index(lora_name) + 1) - lora_int_id = int(lora_id) - lora_local_path = self.lora_repository[lora_name] - lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) - async for output in self.llm_engine.generate( - prompt, sampling_params, request_id, lora_request=lora_request + prompt, sampling_params, request_id ): if response_sender.is_cancelled(): self.logger.log_info("[vllm] Cancelling the request") @@ -380,49 +383,6 @@ async def generate(self, request): finally: self.ongoing_request_count -= 1 - def verify_loras(self, request): - # We will check if the requested lora exists here, if not we will send a - # response with `LoRA not found` information. In this way we may avoid - # further processing. - verified_request = None - lora_error = None - lora_name = None - parameters_input_tensor = pb_utils.get_input_tensor_by_name( - request, "sampling_parameters" - ) - if parameters_input_tensor: - parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") - sampling_params_dict = self.get_sampling_params_dict(parameters) - lora_name = sampling_params_dict.pop("lora_name", None) - - if lora_name is not None: - if not self.enable_lora: - lora_error = pb_utils.TritonError("LoRA feature is not enabled.") - self.logger.log_info( - "[vllm] LoRA is not enabled, please restart the backend with LoRA enabled." - ) - elif lora_name not in self.supported_loras: - lora_error = pb_utils.TritonError( - f"LoRA {lora_name} is not supported, we currently support {self.supported_loras}" - ) - self.logger.log_info(f"[vllm] LoRA {lora_name} not found.") - - if lora_error is not None: - output_tensor = pb_utils.Tensor( - "text_output", - np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype), - ) - response = pb_utils.InferenceResponse( - output_tensors=[output_tensor], error=lora_error - ) - response_sender = request.get_response_sender() - response_sender.send( - response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - ) - else: - verified_request = request - return verified_request - def execute(self, requests): """ Triton core issues requests to the backend via this method. @@ -434,9 +394,7 @@ def execute(self, requests): We are pushing all the requests on vllm and let it handle the full traffic. """ for request in requests: - request = self.verify_loras(request) - if request is not None: - self.create_task(self.generate(request)) + self.create_task(self.generate(request)) return None def finalize(self): From a32de53e9903d640944e718638f637fa6b24f4bf Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 24 May 2024 12:35:08 -0700 Subject: [PATCH 2/9] Log instance name --- src/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/model.py b/src/model.py index d0e1c6ad..3da480ae 100644 --- a/src/model.py +++ b/src/model.py @@ -141,16 +141,18 @@ def init_engine(self): vllm_engine_config = json.load(file) # Validate device and multi-processing settings are currently set based on model/configs. - kind = self.args["model_instance_kind"] - device_id = self.args["model_instance_device_id"] - self.validate_device_config(kind, device_id, vllm_engine_config) + self.validate_device_config(vllm_engine_config) # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs(**vllm_engine_config) ) - def validate_device_config(self, triton_kind, triton_device_id, vllm_engine_config): + def validate_device_config(self, vllm_engine_config): + triton_kind = self.args["model_instance_kind"] + triton_device_id = self.args["model_instance_device_id"] + triton_instance = f"{self.args['model_name']}_{triton_device_id}" + # Triton's current definition of KIND_GPU makes assumptions that # models only use a single GPU. For multi-GPU models, the recommendation # is to specify KIND_MODEL to acknowledge that the model will take control @@ -166,7 +168,7 @@ def validate_device_config(self, triton_kind, triton_device_id, vllm_engine_conf # device (usually device 0) when KIND_GPU is used. if triton_kind == "GPU" and int(triton_device_id) >= 0: self.logger.log_info( - f"Detected KIND_GPU model instance, explicitly setting GPU device={device_id}" + f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}" ) # NOTE: this only affects this process and it's subprocesses, not other processes. # vLLM doesn't currently seem to expose selecting a specific device in the APIs. From cdd7e77b96adc5c193ebe9ebcace7cac8e53b549 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 24 May 2024 12:41:00 -0700 Subject: [PATCH 3/9] Sync back with main --- src/model.py | 156 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 98 insertions(+), 58 deletions(-) diff --git a/src/model.py b/src/model.py index 3da480ae..2d9d8ff8 100644 --- a/src/model.py +++ b/src/model.py @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -28,16 +28,18 @@ import json import os import threading -from typing import AsyncGenerator +from typing import Dict, List import numpy as np import triton_python_backend_utils as pb_utils -from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid _VLLM_ENGINE_ARGS_FILENAME = "model.json" +_MULTI_LORA_ARGS_FILENAME = "multi_lora.json" class TritonPythonModel: @@ -96,32 +98,12 @@ def auto_complete_config(auto_complete_model_config): return auto_complete_model_config def initialize(self, args): - self.args = args self.logger = pb_utils.Logger self.model_config = json.loads(args["model_config"]) - # Prepare vLLM engine - self.init_engine() - - output_config = pb_utils.get_output_config_by_name( - self.model_config, "text_output" - ) - self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - - # Counter to keep track of ongoing request counts - self.ongoing_request_count = 0 - - # Starting asyncio event loop to process the received requests asynchronously. - self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread( - target=self.engine_loop, args=(self._loop,) - ) - self._shutdown_event = asyncio.Event() - self._loop_thread.start() - - def init_engine(self): - # Currently, Triton needs to use decoupled policy for asynchronously - # forwarding requests to vLLM engine, so assert it. + # assert are in decoupled mode. Currently, Triton needs to use + # decoupled policy for asynchronously forwarding requests to + # vLLM engine. self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( self.model_config ) @@ -129,50 +111,56 @@ def init_engine(self): self.using_decoupled ), "vLLM Triton backend must be configured to use decoupled model transaction policy" - # Parse user-provided vLLM config engine_args_filepath = os.path.join( pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME ) assert os.path.isfile( engine_args_filepath ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'" - with open(engine_args_filepath) as file: vllm_engine_config = json.load(file) - # Validate device and multi-processing settings are currently set based on model/configs. - self.validate_device_config(vllm_engine_config) - # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs(**vllm_engine_config) ) - - def validate_device_config(self, vllm_engine_config): - triton_kind = self.args["model_instance_kind"] - triton_device_id = self.args["model_instance_device_id"] - triton_instance = f"{self.args['model_name']}_{triton_device_id}" - - # Triton's current definition of KIND_GPU makes assumptions that - # models only use a single GPU. For multi-GPU models, the recommendation - # is to specify KIND_MODEL to acknowledge that the model will take control - # of the devices made available to it. - # NOTE: Consider other parameters that would indicate multi-GPU in the future. - tp_size = int(vllm_engine_config.get("tensor_parallel_size", 1)) - if tp_size > 1 and triton_kind == "GPU": - raise ValueError( - "KIND_GPU is for single-GPU models, please specify KIND_MODEL in the model's config.pbtxt for multi-GPU models" + self.enable_lora = False + + if ( + "enable_lora" in vllm_engine_config.keys() + and vllm_engine_config["enable_lora"].lower() == "true" + ): + # create Triton LoRA weights repository + multi_lora_args_filepath = os.path.join( + pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME ) + try: + with open(multi_lora_args_filepath) as lora_file: + lora_repository: Dict[str, str] = json.load(lora_file) + self.lora_repository = lora_repository + self.supported_loras: List[str] = list(self.lora_repository.keys()) + self.supported_loras_len = len(self.supported_loras) + self.enable_lora = True + except FileNotFoundError: + raise FileNotFoundError( + f"Triton backend cannot find {multi_lora_args_filepath}." + ) - # If KIND_GPU is specified, isolate the selected GPU device to ensure that multiple model instances do not all use the same default - # device (usually device 0) when KIND_GPU is used. - if triton_kind == "GPU" and int(triton_device_id) >= 0: - self.logger.log_info( - f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}" - ) - # NOTE: this only affects this process and it's subprocesses, not other processes. - # vLLM doesn't currently seem to expose selecting a specific device in the APIs. - os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + # Counter to keep track of ongoing request counts + self.ongoing_request_count = 0 + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._loop_thread.start() def create_task(self, coro): """ @@ -331,12 +319,19 @@ async def generate(self, request): parameters = request.parameters() sampling_params_dict = self.get_sampling_params_dict(parameters) + lora_name = sampling_params_dict.pop("lora_name", None) sampling_params = SamplingParams(**sampling_params_dict) - last_output = None prev_outputs = None + lora_request = None + if lora_name is not None: + lora_id = str(self.supported_loras.index(lora_name) + 1) + lora_int_id = int(lora_id) + lora_local_path = self.lora_repository[lora_name] + lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) + async for output in self.llm_engine.generate( - prompt, sampling_params, request_id + prompt, sampling_params, request_id, lora_request=lora_request ): if response_sender.is_cancelled(): self.logger.log_info("[vllm] Cancelling the request") @@ -385,6 +380,49 @@ async def generate(self, request): finally: self.ongoing_request_count -= 1 + def verify_loras(self, request): + # We will check if the requested lora exists here, if not we will send a + # response with `LoRA not found` information. In this way we may avoid + # further processing. + verified_request = None + lora_error = None + lora_name = None + parameters_input_tensor = pb_utils.get_input_tensor_by_name( + request, "sampling_parameters" + ) + if parameters_input_tensor: + parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") + sampling_params_dict = self.get_sampling_params_dict(parameters) + lora_name = sampling_params_dict.pop("lora_name", None) + + if lora_name is not None: + if not self.enable_lora: + lora_error = pb_utils.TritonError("LoRA feature is not enabled.") + self.logger.log_info( + "[vllm] LoRA is not enabled, please restart the backend with LoRA enabled." + ) + elif lora_name not in self.supported_loras: + lora_error = pb_utils.TritonError( + f"LoRA {lora_name} is not supported, we currently support {self.supported_loras}" + ) + self.logger.log_info(f"[vllm] LoRA {lora_name} not found.") + + if lora_error is not None: + output_tensor = pb_utils.Tensor( + "text_output", + np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype), + ) + response = pb_utils.InferenceResponse( + output_tensors=[output_tensor], error=lora_error + ) + response_sender = request.get_response_sender() + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + else: + verified_request = request + return verified_request + def execute(self, requests): """ Triton core issues requests to the backend via this method. @@ -396,7 +434,9 @@ def execute(self, requests): We are pushing all the requests on vllm and let it handle the full traffic. """ for request in requests: - self.create_task(self.generate(request)) + request = self.verify_loras(request) + if request is not None: + self.create_task(self.generate(request)) return None def finalize(self): From a4e6162fbcae7f671ff5c1bd95487702942a98f8 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 24 May 2024 12:47:34 -0700 Subject: [PATCH 4/9] Bring back LoRA changes, with slight refactor --- src/model.py | 77 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/src/model.py b/src/model.py index 2d9d8ff8..3d0e01d7 100644 --- a/src/model.py +++ b/src/model.py @@ -98,8 +98,31 @@ def auto_complete_config(auto_complete_model_config): return auto_complete_model_config def initialize(self, args): + self.args = args self.logger = pb_utils.Logger self.model_config = json.loads(args["model_config"]) + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + # Prepare vLLM engine + self.init_engine() + + # Counter to keep track of ongoing request counts + self.ongoing_request_count = 0 + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._loop_thread.start() + + def init_engine(self): + # Currently, Triton needs to use decoupled policy for asynchronously + # forwarding requests to vLLM engine, so assert it. # assert are in decoupled mode. Currently, Triton needs to use # decoupled policy for asynchronously forwarding requests to @@ -118,17 +141,25 @@ def initialize(self, args): engine_args_filepath ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'" with open(engine_args_filepath) as file: - vllm_engine_config = json.load(file) + self.vllm_engine_config = json.load(file) + + # Validate device and multi-processing settings are currently set based on model/configs. + self.validate_device_config() + + # Check for LoRA config and set it up if enabled + self.setup_lora() # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(**vllm_engine_config) + AsyncEngineArgs(**self.vllm_engine_config) ) + + def setup_lora(self): self.enable_lora = False if ( - "enable_lora" in vllm_engine_config.keys() - and vllm_engine_config["enable_lora"].lower() == "true" + "enable_lora" in self.vllm_engine_config.keys() + and self.vllm_engine_config["enable_lora"].lower() == "true" ): # create Triton LoRA weights repository multi_lora_args_filepath = os.path.join( @@ -146,21 +177,31 @@ def initialize(self, args): f"Triton backend cannot find {multi_lora_args_filepath}." ) - output_config = pb_utils.get_output_config_by_name( - self.model_config, "text_output" - ) - self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - - # Counter to keep track of ongoing request counts - self.ongoing_request_count = 0 + def validate_device_config(self): + triton_kind = self.args["model_instance_kind"] + triton_device_id = self.args["model_instance_device_id"] + triton_instance = f"{self.args['model_name']}_{triton_device_id}" + + # Triton's current definition of KIND_GPU makes assumptions that + # models only use a single GPU. For multi-GPU models, the recommendation + # is to specify KIND_MODEL to acknowledge that the model will take control + # of the devices made available to it. + # NOTE: Consider other parameters that would indicate multi-GPU in the future. + tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1)) + if tp_size > 1 and triton_kind == "GPU": + raise ValueError( + "KIND_GPU is for single-GPU models, please specify KIND_MODEL in the model's config.pbtxt for multi-GPU models" + ) - # Starting asyncio event loop to process the received requests asynchronously. - self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread( - target=self.engine_loop, args=(self._loop,) - ) - self._shutdown_event = asyncio.Event() - self._loop_thread.start() + # If KIND_GPU is specified, isolate the selected GPU device to ensure that multiple model instances do not all use the same default + # device (usually device 0) when KIND_GPU is used. + if triton_kind == "GPU" and int(triton_device_id) >= 0: + self.logger.log_info( + f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}" + ) + # NOTE: this only affects this process and it's subprocesses, not other processes. + # vLLM doesn't currently seem to expose selecting a specific device in the APIs. + os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id def create_task(self, coro): """ From 34abb900c92b6c7b4787c6d0f224294624480009 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 24 May 2024 14:51:13 -0700 Subject: [PATCH 5/9] [comment only change] Remove dupe comment --- src/model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/model.py b/src/model.py index 3d0e01d7..41903f87 100644 --- a/src/model.py +++ b/src/model.py @@ -123,10 +123,6 @@ def initialize(self, args): def init_engine(self): # Currently, Triton needs to use decoupled policy for asynchronously # forwarding requests to vLLM engine, so assert it. - - # assert are in decoupled mode. Currently, Triton needs to use - # decoupled policy for asynchronously forwarding requests to - # vLLM engine. self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( self.model_config ) From 935bf9224119b89fe48e41bed0732e8c0f5ffa3b Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Wed, 29 May 2024 15:18:30 -0700 Subject: [PATCH 6/9] Add tests for various combinations of kind, tensor_parallelism, and model instance_count --- ci/L0_multi_gpu/vllm_backend/test.sh | 111 +++++++++++++----- .../vllm_backend/vllm_multi_gpu_test.py | 73 ++++++++++-- 2 files changed, 148 insertions(+), 36 deletions(-) diff --git a/ci/L0_multi_gpu/vllm_backend/test.sh b/ci/L0_multi_gpu/vllm_backend/test.sh index 36369196..70f31788 100755 --- a/ci/L0_multi_gpu/vllm_backend/test.sh +++ b/ci/L0_multi_gpu/vllm_backend/test.sh @@ -31,52 +31,105 @@ TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} SERVER=${TRITON_DIR}/bin/tritonserver BACKEND_DIR=${TRITON_DIR}/backends SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1" -SERVER_LOG="./vllm_multi_gpu_test_server.log" -CLIENT_LOG="./vllm_multi_gpu_test_client.log" TEST_RESULT_FILE='test_results.txt' CLIENT_PY="./vllm_multi_gpu_test.py" SAMPLE_MODELS_REPO="../../../samples/model_repository" EXPECTED_NUM_TESTS=1 -rm -rf models && mkdir -p models -cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt -sed -i '3s/^/ "tensor_parallel_size": 2,\n/' models/vllm_opt/1/model.json +### Helpers +function validate_file_contains() { + local KEY="${1}" + local FILE="${2}" -python3 -m pip install --upgrade pip && pip3 install tritonclient[grpc] nvidia-ml-py3 + if [ -z "${KEY}" ] || [ -z "${FILE}" ]; then + echo "Error: KEY and FILE must be provided." + return 1 + fi -RET=0 + if [ ! -f "${FILE}" ]; then + echo "Error: File '${FILE}' does not exist." + return 1 + fi -run_server -if [ "$SERVER_PID" == "0" ]; then - cat $SERVER_LOG - echo -e "\n***\n*** Failed to start $SERVER\n***" - exit 1 -fi + count=$(grep -o -w "${KEY}" "${FILE}" | wc -l) -set +e -python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1 + if [ "${count}" -ne 1 ]; then + echo "Error: KEY '${KEY}' found ${count} times in '${FILE}'. Expected exactly once." + return 1 + fi +} + +function run_multi_gpu_test() { + export KIND="${1}" + export TENSOR_PARALLELISM="${2}" + export INSTANCE_COUNT="${3}" + + # Setup a clean model repository + export TEST_MODEL="vllm_opt_${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}" + local TEST_MODEL_TRITON_CONFIG="models/${TEST_MODEL}/config.pbtxt" + local TEST_MODEL_VLLM_CONFIG="models/${TEST_MODEL}/1/model.json" + + rm -rf models && mkdir -p models + cp -r "${SAMPLE_MODELS_REPO}/vllm_model" "models/${TEST_MODEL}" + sed -i "s/KIND_MODEL/${KIND}/" "${TEST_MODEL_TRITON_CONFIG}" + sed -i "3s/^/ \"tensor_parallel_size\": ${TENSOR_PARALLELISM},\n/" "${TEST_MODEL_VLLM_CONFIG}" + # Assert the correct kind is set in case the template config changes in the future + validate_file_contains "${KIND}" "${TEST_MODEL_TRITON_CONFIG}" + + # Start server + echo "Running multi-GPU test with kind=${KIND}, tp=${TENSOR_PARALLELISM}, instance_count=${INSTANCE_COUNT}" + SERVER_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--server.log" + run_server + if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 + fi + + # Run unit tests + set +e + CLIENT_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--client.log" + python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1 -if [ $? -ne 0 ]; then - cat $CLIENT_LOG - echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" - RET=1 -else - check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS if [ $? -ne 0 ]; then cat $CLIENT_LOG - echo -e "\n***\n*** Test Result Verification FAILED.\n***" + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" RET=1 + else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi fi -fi -set -e + set -e + + # Cleanup + kill $SERVER_PID + wait $SERVER_PID +} + +### Test +python3 -m pip install --upgrade pip && pip3 install tritonclient[grpc] nvidia-ml-py3 +rm -f *.log +RET=0 -kill $SERVER_PID -wait $SERVER_PID -rm -rf models/ +# Test the various cases of kind, tensor parallelism, and instance count +# for different ways to run multi-GPU models with vLLM on Triton +KINDS="KIND_MODEL KIND_GPU" +TPS="1 2" +INSTANCE_COUNTS="1 2" +for kind in ${KINDS}; do + for tp in ${TPS}; do + for count in ${INSTANCE_COUNTS}; do + run_multi_gpu_test "${kind}" "${tp}" "${count}" + done + done +done +### Results if [ $RET -eq 1 ]; then - cat $CLIENT_LOG - cat $SERVER_LOG echo -e "\n***\n*** Multi GPU Utilization test FAILED. \n***" else echo -e "\n***\n*** Multi GPU Utilization test PASSED. \n***" diff --git a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py b/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py index baa71632..92b7d6e8 100644 --- a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py +++ b/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py @@ -24,7 +24,9 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os import sys +import time import unittest from functools import partial @@ -40,7 +42,6 @@ class VLLMMultiGPUTest(TestResultCollector): def setUp(self): nvidia_smi.nvmlInit() self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") - self.vllm_model_name = "vllm_opt" def get_gpu_memory_utilization(self, gpu_id): handle = nvidia_smi.nvmlDeviceGetHandleByIndex(gpu_id) @@ -56,7 +57,12 @@ def get_available_gpu_ids(self): available_gpus.append(gpu_id) return available_gpus - def test_vllm_multi_gpu_utilization(self): + def _test_vllm_multi_gpu_utilization(self, model_name: str): + """ + Test that loading a given vLLM model will increase GPU utilization + across multiple GPUs, and run a sanity check inference to confirm + that the loaded multi-gpu/multi-instance model is working as expected. + """ gpu_ids = self.get_available_gpu_ids() self.assertGreaterEqual(len(gpu_ids), 2, "Error: Detected single GPU") @@ -67,8 +73,8 @@ def test_vllm_multi_gpu_utilization(self): print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes") mem_util_before_loading_model[gpu_id] = memory_utilization - self.triton_client.load_model(self.vllm_model_name) - self._test_vllm_model() + self.triton_client.load_model(model_name) + self._test_vllm_model(model_name) print("=============== After Loading vLLM Model ===============") vllm_model_used_gpus = 0 @@ -80,7 +86,7 @@ def test_vllm_multi_gpu_utilization(self): self.assertGreaterEqual(vllm_model_used_gpus, 2) - def _test_vllm_model(self, send_parameters_as_tensor=True): + def _test_vllm_model(self, model_name: str, send_parameters_as_tensor: bool = True): user_data = UserData() stream = False prompts = [ @@ -98,11 +104,11 @@ def _test_vllm_model(self, send_parameters_as_tensor=True): i, stream, sampling_parameters, - self.vllm_model_name, + model_name, send_parameters_as_tensor, ) self.triton_client.async_stream_infer( - model_name=self.vllm_model_name, + model_name=model_name, request_id=request_data["request_id"], inputs=request_data["inputs"], outputs=request_data["outputs"], @@ -118,6 +124,59 @@ def _test_vllm_model(self, send_parameters_as_tensor=True): self.triton_client.stop_stream() + def test_multi_gpu_model(self): + """ + Tests that a multi-GPU vLLM model loads successfully on multiple GPUs + and can handle a few sanity check inference requests. + + Multi-GPU models are currently defined here as either: + - a single model instance with tensor parallelism > 1 + - multiple model instances each with tensor parallelism == 1 + + FIXME: This test currently skips over a few combinations that may + be enhanced in the future, such as: + - tensor parallel models with multiple model instances + - KIND_MODEL models with multiple model instances + """ + model = os.environ.get("TEST_MODEL") + kind = os.environ.get("KIND") + tp = os.environ.get("TENSOR_PARALLELISM") + instance_count = os.environ.get("INSTANCE_COUNT") + for env_var in [model, kind, tp, instance_count]: + self.assertIsNotNone(env_var) + + print(f"Test Matrix: {model=}, {kind=}, {tp=}, {instance_count=}") + + # Only support tensor parallelism or multiple instances for now, but not both. + # Support for multi-instance tensor parallel models may require more + # special handling in the backend to better handle device assignment. + # NOTE: This eliminates the 1*1=1 and 2*2=4 test cases. + if int(tp) * int(instance_count) != 2: + msg = "TENSOR_PARALLELISM and INSTANCE_COUNT must have a product of 2 for this 2-GPU test" + print("Skipping Test:", msg) + self.skipTest(msg) + + # Loading a KIND_GPU model with Tensor Parallelism > 1 should fail and + # recommend using KIND_MODEL instead for multi-gpu model instances. + if kind == "KIND_GPU" and int(tp) > 1: + with self.assertRaisesRegex( + InferenceServerException, "please specify KIND_MODEL" + ): + self._test_vllm_multi_gpu_utilization(model) + + return + + # Loading a KIND_MODEL model with multiple instances can cause + # oversubscription to specific GPUs and cause a CUDA OOM if the + # gpu_memory_utilization settings are high without further handling + # of device assignment in the backend. + if kind == "KIND_MODEL" and int(instance_count) > 1: + msg = "Testing multiple model instances of KIND_MODEL is not implemented at this time" + print("Skipping Test:", msg) + self.skipTest(msg) + + self._test_vllm_multi_gpu_utilization(model) + def tearDown(self): nvidia_smi.nvmlShutdown() self.triton_client.close() From ef8a12ddbaafe160ec67424bc445d00ff61e1349 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Thu, 30 May 2024 10:43:16 -0700 Subject: [PATCH 7/9] Use torch.cuda.set_device instead of CUDA_VISIBLE_DEVICES to avoid hiding other GPUs for d2d copies --- src/model.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/model.py b/src/model.py index 41903f87..571cb184 100644 --- a/src/model.py +++ b/src/model.py @@ -31,6 +31,7 @@ from typing import Dict, List import numpy as np +import torch import triton_python_backend_utils as pb_utils from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -175,7 +176,7 @@ def setup_lora(self): def validate_device_config(self): triton_kind = self.args["model_instance_kind"] - triton_device_id = self.args["model_instance_device_id"] + triton_device_id = int(self.args["model_instance_device_id"]) triton_instance = f"{self.args['model_name']}_{triton_device_id}" # Triton's current definition of KIND_GPU makes assumptions that @@ -186,18 +187,18 @@ def validate_device_config(self): tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1)) if tp_size > 1 and triton_kind == "GPU": raise ValueError( - "KIND_GPU is for single-GPU models, please specify KIND_MODEL in the model's config.pbtxt for multi-GPU models" + "KIND_GPU is currently for single-GPU models, please specify KIND_MODEL " + "in the model's config.pbtxt for multi-GPU models" ) - # If KIND_GPU is specified, isolate the selected GPU device to ensure that multiple model instances do not all use the same default - # device (usually device 0) when KIND_GPU is used. - if triton_kind == "GPU" and int(triton_device_id) >= 0: + # If KIND_GPU is specified, specify the device ID assigned by Triton to ensure that + # multiple model instances do not oversubscribe the same default device. + if triton_kind == "GPU" and triton_device_id >= 0: self.logger.log_info( f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}" ) - # NOTE: this only affects this process and it's subprocesses, not other processes. - # vLLM doesn't currently seem to expose selecting a specific device in the APIs. - os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id + # vLLM doesn't currently expose device selection in the APIs + torch.cuda.set_device(triton_device_id) def create_task(self, coro): """ From b0d753b99eab49a5dd1a3f4fff665710cc9a186b Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Thu, 30 May 2024 10:44:21 -0700 Subject: [PATCH 8/9] Remove unused import --- ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py b/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py index 74be7cbb..c7d42fcd 100644 --- a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py +++ b/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py @@ -26,7 +26,6 @@ import os import sys -import time import unittest from functools import partial From e60b47a0b7ec7574f79f549cfa2e57dc97274527 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Thu, 30 May 2024 11:16:30 -0700 Subject: [PATCH 9/9] Add vllm version to comment for last known API info Co-authored-by: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> --- src/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.py b/src/model.py index 571cb184..3fe7cd1e 100644 --- a/src/model.py +++ b/src/model.py @@ -197,7 +197,7 @@ def validate_device_config(self): self.logger.log_info( f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}" ) - # vLLM doesn't currently expose device selection in the APIs + # vLLM doesn't currently (v0.4.2) expose device selection in the APIs torch.cuda.set_device(triton_device_id) def create_task(self, coro):