diff --git a/ci/L0_multi_gpu/vllm_backend/test.sh b/ci/L0_multi_gpu/vllm_backend/test.sh index 09a0bb08..fb583cb3 100755 --- a/ci/L0_multi_gpu/vllm_backend/test.sh +++ b/ci/L0_multi_gpu/vllm_backend/test.sh @@ -31,50 +31,104 @@ TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} SERVER=${TRITON_DIR}/bin/tritonserver BACKEND_DIR=${TRITON_DIR}/backends SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1" -SERVER_LOG="./vllm_multi_gpu_test_server.log" -CLIENT_LOG="./vllm_multi_gpu_test_client.log" TEST_RESULT_FILE='test_results.txt' CLIENT_PY="./vllm_multi_gpu_test.py" SAMPLE_MODELS_REPO="../../../samples/model_repository" EXPECTED_NUM_TESTS=1 -rm -rf models && mkdir -p models -cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt -sed -i '3s/^/ "tensor_parallel_size": 2,\n/' models/vllm_opt/1/model.json +### Helpers +function validate_file_contains() { + local KEY="${1}" + local FILE="${2}" -RET=0 + if [ -z "${KEY}" ] || [ -z "${FILE}" ]; then + echo "Error: KEY and FILE must be provided." + return 1 + fi -run_server -if [ "$SERVER_PID" == "0" ]; then - cat $SERVER_LOG - echo -e "\n***\n*** Failed to start $SERVER\n***" - exit 1 -fi + if [ ! -f "${FILE}" ]; then + echo "Error: File '${FILE}' does not exist." + return 1 + fi -set +e -python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1 + count=$(grep -o -w "${KEY}" "${FILE}" | wc -l) + + if [ "${count}" -ne 1 ]; then + echo "Error: KEY '${KEY}' found ${count} times in '${FILE}'. Expected exactly once." + return 1 + fi +} + +function run_multi_gpu_test() { + export KIND="${1}" + export TENSOR_PARALLELISM="${2}" + export INSTANCE_COUNT="${3}" + + # Setup a clean model repository + export TEST_MODEL="vllm_opt_${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}" + local TEST_MODEL_TRITON_CONFIG="models/${TEST_MODEL}/config.pbtxt" + local TEST_MODEL_VLLM_CONFIG="models/${TEST_MODEL}/1/model.json" + + rm -rf models && mkdir -p models + cp -r "${SAMPLE_MODELS_REPO}/vllm_model" "models/${TEST_MODEL}" + sed -i "s/KIND_MODEL/${KIND}/" "${TEST_MODEL_TRITON_CONFIG}" + sed -i "3s/^/ \"tensor_parallel_size\": ${TENSOR_PARALLELISM},\n/" "${TEST_MODEL_VLLM_CONFIG}" + # Assert the correct kind is set in case the template config changes in the future + validate_file_contains "${KIND}" "${TEST_MODEL_TRITON_CONFIG}" + + # Start server + echo "Running multi-GPU test with kind=${KIND}, tp=${TENSOR_PARALLELISM}, instance_count=${INSTANCE_COUNT}" + SERVER_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--server.log" + run_server + if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 + fi + + # Run unit tests + set +e + CLIENT_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--client.log" + python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1 -if [ $? -ne 0 ]; then - cat $CLIENT_LOG - echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" - RET=1 -else - check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS if [ $? -ne 0 ]; then cat $CLIENT_LOG - echo -e "\n***\n*** Test Result Verification FAILED.\n***" + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" RET=1 + else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi fi -fi -set -e + set -e + + # Cleanup + kill $SERVER_PID + wait $SERVER_PID +} + +### Test +rm -f *.log +RET=0 -kill $SERVER_PID -wait $SERVER_PID -rm -rf models/ +# Test the various cases of kind, tensor parallelism, and instance count +# for different ways to run multi-GPU models with vLLM on Triton +KINDS="KIND_MODEL KIND_GPU" +TPS="1 2" +INSTANCE_COUNTS="1 2" +for kind in ${KINDS}; do + for tp in ${TPS}; do + for count in ${INSTANCE_COUNTS}; do + run_multi_gpu_test "${kind}" "${tp}" "${count}" + done + done +done +### Results if [ $RET -eq 1 ]; then - cat $CLIENT_LOG - cat $SERVER_LOG echo -e "\n***\n*** Multi GPU Utilization test FAILED. \n***" else echo -e "\n***\n*** Multi GPU Utilization test PASSED. \n***" diff --git a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py b/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py index f9bb56b3..c7d42fcd 100644 --- a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py +++ b/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py @@ -24,6 +24,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os import sys import unittest from functools import partial @@ -40,7 +41,6 @@ class VLLMMultiGPUTest(TestResultCollector): def setUp(self): pynvml.nvmlInit() self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") - self.vllm_model_name = "vllm_opt" def get_gpu_memory_utilization(self, gpu_id): handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) @@ -56,7 +56,12 @@ def get_available_gpu_ids(self): available_gpus.append(gpu_id) return available_gpus - def test_vllm_multi_gpu_utilization(self): + def _test_vllm_multi_gpu_utilization(self, model_name: str): + """ + Test that loading a given vLLM model will increase GPU utilization + across multiple GPUs, and run a sanity check inference to confirm + that the loaded multi-gpu/multi-instance model is working as expected. + """ gpu_ids = self.get_available_gpu_ids() self.assertGreaterEqual(len(gpu_ids), 2, "Error: Detected single GPU") @@ -67,8 +72,8 @@ def test_vllm_multi_gpu_utilization(self): print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes") mem_util_before_loading_model[gpu_id] = memory_utilization - self.triton_client.load_model(self.vllm_model_name) - self._test_vllm_model() + self.triton_client.load_model(model_name) + self._test_vllm_model(model_name) print("=============== After Loading vLLM Model ===============") vllm_model_used_gpus = 0 @@ -80,7 +85,7 @@ def test_vllm_multi_gpu_utilization(self): self.assertGreaterEqual(vllm_model_used_gpus, 2) - def _test_vllm_model(self, send_parameters_as_tensor=True): + def _test_vllm_model(self, model_name: str, send_parameters_as_tensor: bool = True): user_data = UserData() stream = False prompts = [ @@ -98,11 +103,11 @@ def _test_vllm_model(self, send_parameters_as_tensor=True): i, stream, sampling_parameters, - self.vllm_model_name, + model_name, send_parameters_as_tensor, ) self.triton_client.async_stream_infer( - model_name=self.vllm_model_name, + model_name=model_name, request_id=request_data["request_id"], inputs=request_data["inputs"], outputs=request_data["outputs"], @@ -118,6 +123,59 @@ def _test_vllm_model(self, send_parameters_as_tensor=True): self.triton_client.stop_stream() + def test_multi_gpu_model(self): + """ + Tests that a multi-GPU vLLM model loads successfully on multiple GPUs + and can handle a few sanity check inference requests. + + Multi-GPU models are currently defined here as either: + - a single model instance with tensor parallelism > 1 + - multiple model instances each with tensor parallelism == 1 + + FIXME: This test currently skips over a few combinations that may + be enhanced in the future, such as: + - tensor parallel models with multiple model instances + - KIND_MODEL models with multiple model instances + """ + model = os.environ.get("TEST_MODEL") + kind = os.environ.get("KIND") + tp = os.environ.get("TENSOR_PARALLELISM") + instance_count = os.environ.get("INSTANCE_COUNT") + for env_var in [model, kind, tp, instance_count]: + self.assertIsNotNone(env_var) + + print(f"Test Matrix: {model=}, {kind=}, {tp=}, {instance_count=}") + + # Only support tensor parallelism or multiple instances for now, but not both. + # Support for multi-instance tensor parallel models may require more + # special handling in the backend to better handle device assignment. + # NOTE: This eliminates the 1*1=1 and 2*2=4 test cases. + if int(tp) * int(instance_count) != 2: + msg = "TENSOR_PARALLELISM and INSTANCE_COUNT must have a product of 2 for this 2-GPU test" + print("Skipping Test:", msg) + self.skipTest(msg) + + # Loading a KIND_GPU model with Tensor Parallelism > 1 should fail and + # recommend using KIND_MODEL instead for multi-gpu model instances. + if kind == "KIND_GPU" and int(tp) > 1: + with self.assertRaisesRegex( + InferenceServerException, "please specify KIND_MODEL" + ): + self._test_vllm_multi_gpu_utilization(model) + + return + + # Loading a KIND_MODEL model with multiple instances can cause + # oversubscription to specific GPUs and cause a CUDA OOM if the + # gpu_memory_utilization settings are high without further handling + # of device assignment in the backend. + if kind == "KIND_MODEL" and int(instance_count) > 1: + msg = "Testing multiple model instances of KIND_MODEL is not implemented at this time" + print("Skipping Test:", msg) + self.skipTest(msg) + + self._test_vllm_multi_gpu_utilization(model) + def tearDown(self): pynvml.nvmlShutdown() self.triton_client.close() diff --git a/src/model.py b/src/model.py index 2d9d8ff8..3fe7cd1e 100644 --- a/src/model.py +++ b/src/model.py @@ -31,6 +31,7 @@ from typing import Dict, List import numpy as np +import torch import triton_python_backend_utils as pb_utils from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -98,12 +99,31 @@ def auto_complete_config(auto_complete_model_config): return auto_complete_model_config def initialize(self, args): + self.args = args self.logger = pb_utils.Logger self.model_config = json.loads(args["model_config"]) + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + # Prepare vLLM engine + self.init_engine() + + # Counter to keep track of ongoing request counts + self.ongoing_request_count = 0 + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._loop_thread.start() - # assert are in decoupled mode. Currently, Triton needs to use - # decoupled policy for asynchronously forwarding requests to - # vLLM engine. + def init_engine(self): + # Currently, Triton needs to use decoupled policy for asynchronously + # forwarding requests to vLLM engine, so assert it. self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( self.model_config ) @@ -118,17 +138,25 @@ def initialize(self, args): engine_args_filepath ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'" with open(engine_args_filepath) as file: - vllm_engine_config = json.load(file) + self.vllm_engine_config = json.load(file) + + # Validate device and multi-processing settings are currently set based on model/configs. + self.validate_device_config() + + # Check for LoRA config and set it up if enabled + self.setup_lora() # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(**vllm_engine_config) + AsyncEngineArgs(**self.vllm_engine_config) ) + + def setup_lora(self): self.enable_lora = False if ( - "enable_lora" in vllm_engine_config.keys() - and vllm_engine_config["enable_lora"].lower() == "true" + "enable_lora" in self.vllm_engine_config.keys() + and self.vllm_engine_config["enable_lora"].lower() == "true" ): # create Triton LoRA weights repository multi_lora_args_filepath = os.path.join( @@ -146,21 +174,31 @@ def initialize(self, args): f"Triton backend cannot find {multi_lora_args_filepath}." ) - output_config = pb_utils.get_output_config_by_name( - self.model_config, "text_output" - ) - self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - - # Counter to keep track of ongoing request counts - self.ongoing_request_count = 0 + def validate_device_config(self): + triton_kind = self.args["model_instance_kind"] + triton_device_id = int(self.args["model_instance_device_id"]) + triton_instance = f"{self.args['model_name']}_{triton_device_id}" + + # Triton's current definition of KIND_GPU makes assumptions that + # models only use a single GPU. For multi-GPU models, the recommendation + # is to specify KIND_MODEL to acknowledge that the model will take control + # of the devices made available to it. + # NOTE: Consider other parameters that would indicate multi-GPU in the future. + tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1)) + if tp_size > 1 and triton_kind == "GPU": + raise ValueError( + "KIND_GPU is currently for single-GPU models, please specify KIND_MODEL " + "in the model's config.pbtxt for multi-GPU models" + ) - # Starting asyncio event loop to process the received requests asynchronously. - self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread( - target=self.engine_loop, args=(self._loop,) - ) - self._shutdown_event = asyncio.Event() - self._loop_thread.start() + # If KIND_GPU is specified, specify the device ID assigned by Triton to ensure that + # multiple model instances do not oversubscribe the same default device. + if triton_kind == "GPU" and triton_device_id >= 0: + self.logger.log_info( + f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}" + ) + # vLLM doesn't currently (v0.4.2) expose device selection in the APIs + torch.cuda.set_device(triton_device_id) def create_task(self, coro): """