Skip to content

Commit

Permalink
fix: Enhance checks around KIND_GPU and tensor parallelism (#42)
Browse files Browse the repository at this point in the history
Co-authored-by: Olga Andreeva <[email protected]>
  • Loading branch information
rmccorm4 and oandreeva-nv authored May 31, 2024
1 parent 2a1691a commit 18a96e3
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 56 deletions.
110 changes: 82 additions & 28 deletions ci/L0_multi_gpu/vllm_backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,50 +31,104 @@ TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}
SERVER=${TRITON_DIR}/bin/tritonserver
BACKEND_DIR=${TRITON_DIR}/backends
SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1"
SERVER_LOG="./vllm_multi_gpu_test_server.log"
CLIENT_LOG="./vllm_multi_gpu_test_client.log"
TEST_RESULT_FILE='test_results.txt'
CLIENT_PY="./vllm_multi_gpu_test.py"
SAMPLE_MODELS_REPO="../../../samples/model_repository"
EXPECTED_NUM_TESTS=1

rm -rf models && mkdir -p models
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt
sed -i '3s/^/ "tensor_parallel_size": 2,\n/' models/vllm_opt/1/model.json
### Helpers
function validate_file_contains() {
local KEY="${1}"
local FILE="${2}"

RET=0
if [ -z "${KEY}" ] || [ -z "${FILE}" ]; then
echo "Error: KEY and FILE must be provided."
return 1
fi

run_server
if [ "$SERVER_PID" == "0" ]; then
cat $SERVER_LOG
echo -e "\n***\n*** Failed to start $SERVER\n***"
exit 1
fi
if [ ! -f "${FILE}" ]; then
echo "Error: File '${FILE}' does not exist."
return 1
fi

set +e
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1
count=$(grep -o -w "${KEY}" "${FILE}" | wc -l)

if [ "${count}" -ne 1 ]; then
echo "Error: KEY '${KEY}' found ${count} times in '${FILE}'. Expected exactly once."
return 1
fi
}

function run_multi_gpu_test() {
export KIND="${1}"
export TENSOR_PARALLELISM="${2}"
export INSTANCE_COUNT="${3}"

# Setup a clean model repository
export TEST_MODEL="vllm_opt_${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}"
local TEST_MODEL_TRITON_CONFIG="models/${TEST_MODEL}/config.pbtxt"
local TEST_MODEL_VLLM_CONFIG="models/${TEST_MODEL}/1/model.json"

rm -rf models && mkdir -p models
cp -r "${SAMPLE_MODELS_REPO}/vllm_model" "models/${TEST_MODEL}"
sed -i "s/KIND_MODEL/${KIND}/" "${TEST_MODEL_TRITON_CONFIG}"
sed -i "3s/^/ \"tensor_parallel_size\": ${TENSOR_PARALLELISM},\n/" "${TEST_MODEL_VLLM_CONFIG}"
# Assert the correct kind is set in case the template config changes in the future
validate_file_contains "${KIND}" "${TEST_MODEL_TRITON_CONFIG}"

# Start server
echo "Running multi-GPU test with kind=${KIND}, tp=${TENSOR_PARALLELISM}, instance_count=${INSTANCE_COUNT}"
SERVER_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--server.log"
run_server
if [ "$SERVER_PID" == "0" ]; then
cat $SERVER_LOG
echo -e "\n***\n*** Failed to start $SERVER\n***"
exit 1
fi

# Run unit tests
set +e
CLIENT_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--client.log"
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1

if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
RET=1
fi
fi
fi
set -e
set -e

# Cleanup
kill $SERVER_PID
wait $SERVER_PID
}

### Test
rm -f *.log
RET=0

kill $SERVER_PID
wait $SERVER_PID
rm -rf models/
# Test the various cases of kind, tensor parallelism, and instance count
# for different ways to run multi-GPU models with vLLM on Triton
KINDS="KIND_MODEL KIND_GPU"
TPS="1 2"
INSTANCE_COUNTS="1 2"
for kind in ${KINDS}; do
for tp in ${TPS}; do
for count in ${INSTANCE_COUNTS}; do
run_multi_gpu_test "${kind}" "${tp}" "${count}"
done
done
done

### Results
if [ $RET -eq 1 ]; then
cat $CLIENT_LOG
cat $SERVER_LOG
echo -e "\n***\n*** Multi GPU Utilization test FAILED. \n***"
else
echo -e "\n***\n*** Multi GPU Utilization test PASSED. \n***"
Expand Down
72 changes: 65 additions & 7 deletions ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import sys
import unittest
from functools import partial
Expand All @@ -40,7 +41,6 @@ class VLLMMultiGPUTest(TestResultCollector):
def setUp(self):
pynvml.nvmlInit()
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
self.vllm_model_name = "vllm_opt"

def get_gpu_memory_utilization(self, gpu_id):
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
Expand All @@ -56,7 +56,12 @@ def get_available_gpu_ids(self):
available_gpus.append(gpu_id)
return available_gpus

def test_vllm_multi_gpu_utilization(self):
def _test_vllm_multi_gpu_utilization(self, model_name: str):
"""
Test that loading a given vLLM model will increase GPU utilization
across multiple GPUs, and run a sanity check inference to confirm
that the loaded multi-gpu/multi-instance model is working as expected.
"""
gpu_ids = self.get_available_gpu_ids()
self.assertGreaterEqual(len(gpu_ids), 2, "Error: Detected single GPU")

Expand All @@ -67,8 +72,8 @@ def test_vllm_multi_gpu_utilization(self):
print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes")
mem_util_before_loading_model[gpu_id] = memory_utilization

self.triton_client.load_model(self.vllm_model_name)
self._test_vllm_model()
self.triton_client.load_model(model_name)
self._test_vllm_model(model_name)

print("=============== After Loading vLLM Model ===============")
vllm_model_used_gpus = 0
Expand All @@ -80,7 +85,7 @@ def test_vllm_multi_gpu_utilization(self):

self.assertGreaterEqual(vllm_model_used_gpus, 2)

def _test_vllm_model(self, send_parameters_as_tensor=True):
def _test_vllm_model(self, model_name: str, send_parameters_as_tensor: bool = True):
user_data = UserData()
stream = False
prompts = [
Expand All @@ -98,11 +103,11 @@ def _test_vllm_model(self, send_parameters_as_tensor=True):
i,
stream,
sampling_parameters,
self.vllm_model_name,
model_name,
send_parameters_as_tensor,
)
self.triton_client.async_stream_infer(
model_name=self.vllm_model_name,
model_name=model_name,
request_id=request_data["request_id"],
inputs=request_data["inputs"],
outputs=request_data["outputs"],
Expand All @@ -118,6 +123,59 @@ def _test_vllm_model(self, send_parameters_as_tensor=True):

self.triton_client.stop_stream()

def test_multi_gpu_model(self):
"""
Tests that a multi-GPU vLLM model loads successfully on multiple GPUs
and can handle a few sanity check inference requests.
Multi-GPU models are currently defined here as either:
- a single model instance with tensor parallelism > 1
- multiple model instances each with tensor parallelism == 1
FIXME: This test currently skips over a few combinations that may
be enhanced in the future, such as:
- tensor parallel models with multiple model instances
- KIND_MODEL models with multiple model instances
"""
model = os.environ.get("TEST_MODEL")
kind = os.environ.get("KIND")
tp = os.environ.get("TENSOR_PARALLELISM")
instance_count = os.environ.get("INSTANCE_COUNT")
for env_var in [model, kind, tp, instance_count]:
self.assertIsNotNone(env_var)

print(f"Test Matrix: {model=}, {kind=}, {tp=}, {instance_count=}")

# Only support tensor parallelism or multiple instances for now, but not both.
# Support for multi-instance tensor parallel models may require more
# special handling in the backend to better handle device assignment.
# NOTE: This eliminates the 1*1=1 and 2*2=4 test cases.
if int(tp) * int(instance_count) != 2:
msg = "TENSOR_PARALLELISM and INSTANCE_COUNT must have a product of 2 for this 2-GPU test"
print("Skipping Test:", msg)
self.skipTest(msg)

# Loading a KIND_GPU model with Tensor Parallelism > 1 should fail and
# recommend using KIND_MODEL instead for multi-gpu model instances.
if kind == "KIND_GPU" and int(tp) > 1:
with self.assertRaisesRegex(
InferenceServerException, "please specify KIND_MODEL"
):
self._test_vllm_multi_gpu_utilization(model)

return

# Loading a KIND_MODEL model with multiple instances can cause
# oversubscription to specific GPUs and cause a CUDA OOM if the
# gpu_memory_utilization settings are high without further handling
# of device assignment in the backend.
if kind == "KIND_MODEL" and int(instance_count) > 1:
msg = "Testing multiple model instances of KIND_MODEL is not implemented at this time"
print("Skipping Test:", msg)
self.skipTest(msg)

self._test_vllm_multi_gpu_utilization(model)

def tearDown(self):
pynvml.nvmlShutdown()
self.triton_client.close()
Expand Down
80 changes: 59 additions & 21 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Dict, List

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down Expand Up @@ -98,12 +99,31 @@ def auto_complete_config(auto_complete_model_config):
return auto_complete_model_config

def initialize(self, args):
self.args = args
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Prepare vLLM engine
self.init_engine()

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()

# assert are in decoupled mode. Currently, Triton needs to use
# decoupled policy for asynchronously forwarding requests to
# vLLM engine.
def init_engine(self):
# Currently, Triton needs to use decoupled policy for asynchronously
# forwarding requests to vLLM engine, so assert it.
self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
self.model_config
)
Expand All @@ -118,17 +138,25 @@ def initialize(self, args):
engine_args_filepath
), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'"
with open(engine_args_filepath) as file:
vllm_engine_config = json.load(file)
self.vllm_engine_config = json.load(file)

# Validate device and multi-processing settings are currently set based on model/configs.
self.validate_device_config()

# Check for LoRA config and set it up if enabled
self.setup_lora()

# Create an AsyncLLMEngine from the config from JSON
self.llm_engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(**vllm_engine_config)
AsyncEngineArgs(**self.vllm_engine_config)
)

def setup_lora(self):
self.enable_lora = False

if (
"enable_lora" in vllm_engine_config.keys()
and vllm_engine_config["enable_lora"].lower() == "true"
"enable_lora" in self.vllm_engine_config.keys()
and self.vllm_engine_config["enable_lora"].lower() == "true"
):
# create Triton LoRA weights repository
multi_lora_args_filepath = os.path.join(
Expand All @@ -146,21 +174,31 @@ def initialize(self, args):
f"Triton backend cannot find {multi_lora_args_filepath}."
)

output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0
def validate_device_config(self):
triton_kind = self.args["model_instance_kind"]
triton_device_id = int(self.args["model_instance_device_id"])
triton_instance = f"{self.args['model_name']}_{triton_device_id}"

# Triton's current definition of KIND_GPU makes assumptions that
# models only use a single GPU. For multi-GPU models, the recommendation
# is to specify KIND_MODEL to acknowledge that the model will take control
# of the devices made available to it.
# NOTE: Consider other parameters that would indicate multi-GPU in the future.
tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1))
if tp_size > 1 and triton_kind == "GPU":
raise ValueError(
"KIND_GPU is currently for single-GPU models, please specify KIND_MODEL "
"in the model's config.pbtxt for multi-GPU models"
)

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()
# If KIND_GPU is specified, specify the device ID assigned by Triton to ensure that
# multiple model instances do not oversubscribe the same default device.
if triton_kind == "GPU" and triton_device_id >= 0:
self.logger.log_info(
f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}"
)
# vLLM doesn't currently (v0.4.2) expose device selection in the APIs
torch.cuda.set_device(triton_device_id)

def create_task(self, coro):
"""
Expand Down

0 comments on commit 18a96e3

Please sign in to comment.