Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enhance checks around KIND_GPU and tensor parallelism #42

Merged
merged 10 commits into from
May 31, 2024
110 changes: 82 additions & 28 deletions ci/L0_multi_gpu/vllm_backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,50 +31,104 @@ TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}
SERVER=${TRITON_DIR}/bin/tritonserver
BACKEND_DIR=${TRITON_DIR}/backends
SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1"
SERVER_LOG="./vllm_multi_gpu_test_server.log"
CLIENT_LOG="./vllm_multi_gpu_test_client.log"
TEST_RESULT_FILE='test_results.txt'
CLIENT_PY="./vllm_multi_gpu_test.py"
SAMPLE_MODELS_REPO="../../../samples/model_repository"
EXPECTED_NUM_TESTS=1

rm -rf models && mkdir -p models
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt
sed -i '3s/^/ "tensor_parallel_size": 2,\n/' models/vllm_opt/1/model.json
### Helpers
function validate_file_contains() {
local KEY="${1}"
local FILE="${2}"

RET=0
if [ -z "${KEY}" ] || [ -z "${FILE}" ]; then
echo "Error: KEY and FILE must be provided."
return 1
fi

run_server
if [ "$SERVER_PID" == "0" ]; then
cat $SERVER_LOG
echo -e "\n***\n*** Failed to start $SERVER\n***"
exit 1
fi
if [ ! -f "${FILE}" ]; then
echo "Error: File '${FILE}' does not exist."
return 1
fi

set +e
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1
count=$(grep -o -w "${KEY}" "${FILE}" | wc -l)

if [ "${count}" -ne 1 ]; then
echo "Error: KEY '${KEY}' found ${count} times in '${FILE}'. Expected exactly once."
return 1
fi
}

function run_multi_gpu_test() {
export KIND="${1}"
export TENSOR_PARALLELISM="${2}"
export INSTANCE_COUNT="${3}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you kneed export? Looks like all usages are local

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I see it now.. Should try to move server setup into py test (setup / teardown), @jbkyang-nvi had done something similar.

Copy link
Contributor Author

@rmccorm4 rmccorm4 May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you or @jbkyang-nvi have a reference for that? If not, I can probably just do all this stuff inside the pytest using the in-process python API a bit more easily if we don't need any frontend features and @oandreeva-nv doesn't mind.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just turn what we do in bash to in python (spawn process / file system manipulation etc.)

Seems like the change is reverted, sad.
triton-inference-server/server#7195 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree a common set of utils to prep/start/stop server via python+subprocess would be great. That would probably take me some time to write something good though. Can I merge these tests using bash and follow-up with this after we deal with the P0's and pipeline failures? I'll take this test as a specific example to refactor using the common util I write. @GuanLuo @oandreeva-nv


# Setup a clean model repository
export TEST_MODEL="vllm_opt_${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}"
local TEST_MODEL_TRITON_CONFIG="models/${TEST_MODEL}/config.pbtxt"
local TEST_MODEL_VLLM_CONFIG="models/${TEST_MODEL}/1/model.json"

rm -rf models && mkdir -p models
cp -r "${SAMPLE_MODELS_REPO}/vllm_model" "models/${TEST_MODEL}"
sed -i "s/KIND_MODEL/${KIND}/" "${TEST_MODEL_TRITON_CONFIG}"
sed -i "3s/^/ \"tensor_parallel_size\": ${TENSOR_PARALLELISM},\n/" "${TEST_MODEL_VLLM_CONFIG}"
# Assert the correct kind is set in case the template config changes in the future
validate_file_contains "${KIND}" "${TEST_MODEL_TRITON_CONFIG}"

# Start server
echo "Running multi-GPU test with kind=${KIND}, tp=${TENSOR_PARALLELISM}, instance_count=${INSTANCE_COUNT}"
SERVER_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--server.log"
run_server
if [ "$SERVER_PID" == "0" ]; then
cat $SERVER_LOG
echo -e "\n***\n*** Failed to start $SERVER\n***"
exit 1
fi

# Run unit tests
set +e
CLIENT_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--client.log"
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running all unit tests against different settings? Is that necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only a single test right now, just lots of helpers. If I move the server/model setup into the python test like you mentioned, then the bash part can be simplified.


if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
RET=1
fi
fi
fi
set -e
set -e

# Cleanup
kill $SERVER_PID
wait $SERVER_PID
}

### Test
rm -f *.log
RET=0

kill $SERVER_PID
wait $SERVER_PID
rm -rf models/
# Test the various cases of kind, tensor parallelism, and instance count
# for different ways to run multi-GPU models with vLLM on Triton
KINDS="KIND_MODEL KIND_GPU"
TPS="1 2"
INSTANCE_COUNTS="1 2"
for kind in ${KINDS}; do
for tp in ${TPS}; do
for count in ${INSTANCE_COUNTS}; do
run_multi_gpu_test "${kind}" "${tp}" "${count}"
done
done
done

### Results
if [ $RET -eq 1 ]; then
cat $CLIENT_LOG
cat $SERVER_LOG
echo -e "\n***\n*** Multi GPU Utilization test FAILED. \n***"
else
echo -e "\n***\n*** Multi GPU Utilization test PASSED. \n***"
Expand Down
72 changes: 65 additions & 7 deletions ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import sys
import unittest
from functools import partial
Expand All @@ -40,7 +41,6 @@ class VLLMMultiGPUTest(TestResultCollector):
def setUp(self):
pynvml.nvmlInit()
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
self.vllm_model_name = "vllm_opt"

def get_gpu_memory_utilization(self, gpu_id):
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
Expand All @@ -56,7 +56,12 @@ def get_available_gpu_ids(self):
available_gpus.append(gpu_id)
return available_gpus

def test_vllm_multi_gpu_utilization(self):
def _test_vllm_multi_gpu_utilization(self, model_name: str):
"""
Test that loading a given vLLM model will increase GPU utilization
across multiple GPUs, and run a sanity check inference to confirm
that the loaded multi-gpu/multi-instance model is working as expected.
"""
gpu_ids = self.get_available_gpu_ids()
self.assertGreaterEqual(len(gpu_ids), 2, "Error: Detected single GPU")

Expand All @@ -67,8 +72,8 @@ def test_vllm_multi_gpu_utilization(self):
print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes")
mem_util_before_loading_model[gpu_id] = memory_utilization

self.triton_client.load_model(self.vllm_model_name)
self._test_vllm_model()
self.triton_client.load_model(model_name)
self._test_vllm_model(model_name)

print("=============== After Loading vLLM Model ===============")
vllm_model_used_gpus = 0
Expand All @@ -80,7 +85,7 @@ def test_vllm_multi_gpu_utilization(self):

self.assertGreaterEqual(vllm_model_used_gpus, 2)

def _test_vllm_model(self, send_parameters_as_tensor=True):
def _test_vllm_model(self, model_name: str, send_parameters_as_tensor: bool = True):
user_data = UserData()
stream = False
prompts = [
Expand All @@ -98,11 +103,11 @@ def _test_vllm_model(self, send_parameters_as_tensor=True):
i,
stream,
sampling_parameters,
self.vllm_model_name,
model_name,
send_parameters_as_tensor,
)
self.triton_client.async_stream_infer(
model_name=self.vllm_model_name,
model_name=model_name,
request_id=request_data["request_id"],
inputs=request_data["inputs"],
outputs=request_data["outputs"],
Expand All @@ -118,6 +123,59 @@ def _test_vllm_model(self, send_parameters_as_tensor=True):

self.triton_client.stop_stream()

def test_multi_gpu_model(self):
"""
Tests that a multi-GPU vLLM model loads successfully on multiple GPUs
and can handle a few sanity check inference requests.

Multi-GPU models are currently defined here as either:
- a single model instance with tensor parallelism > 1
- multiple model instances each with tensor parallelism == 1

FIXME: This test currently skips over a few combinations that may
be enhanced in the future, such as:
- tensor parallel models with multiple model instances
- KIND_MODEL models with multiple model instances
"""
model = os.environ.get("TEST_MODEL")
kind = os.environ.get("KIND")
tp = os.environ.get("TENSOR_PARALLELISM")
instance_count = os.environ.get("INSTANCE_COUNT")
for env_var in [model, kind, tp, instance_count]:
self.assertIsNotNone(env_var)

print(f"Test Matrix: {model=}, {kind=}, {tp=}, {instance_count=}")

# Only support tensor parallelism or multiple instances for now, but not both.
# Support for multi-instance tensor parallel models may require more
# special handling in the backend to better handle device assignment.
# NOTE: This eliminates the 1*1=1 and 2*2=4 test cases.
if int(tp) * int(instance_count) != 2:
msg = "TENSOR_PARALLELISM and INSTANCE_COUNT must have a product of 2 for this 2-GPU test"
print("Skipping Test:", msg)
self.skipTest(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into @unittest.skipIf ? It would be easier to locate then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd have to move the tp and instance_counts to be global or passed directly to the test somehow to do this - I was trying to avoid being too fancy with these tests, but looks like I'll need to rethink them based on the comments so far.


# Loading a KIND_GPU model with Tensor Parallelism > 1 should fail and
# recommend using KIND_MODEL instead for multi-gpu model instances.
if kind == "KIND_GPU" and int(tp) > 1:
with self.assertRaisesRegex(
InferenceServerException, "please specify KIND_MODEL"
):
self._test_vllm_multi_gpu_utilization(model)

return

# Loading a KIND_MODEL model with multiple instances can cause
# oversubscription to specific GPUs and cause a CUDA OOM if the
# gpu_memory_utilization settings are high without further handling
# of device assignment in the backend.
if kind == "KIND_MODEL" and int(instance_count) > 1:
msg = "Testing multiple model instances of KIND_MODEL is not implemented at this time"
print("Skipping Test:", msg)
self.skipTest(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


self._test_vllm_multi_gpu_utilization(model)

def tearDown(self):
pynvml.nvmlShutdown()
self.triton_client.close()
Expand Down
80 changes: 59 additions & 21 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Dict, List

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down Expand Up @@ -98,12 +99,31 @@ def auto_complete_config(auto_complete_model_config):
return auto_complete_model_config

def initialize(self, args):
self.args = args
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Prepare vLLM engine
self.init_engine()

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()

# assert are in decoupled mode. Currently, Triton needs to use
# decoupled policy for asynchronously forwarding requests to
# vLLM engine.
def init_engine(self):
# Currently, Triton needs to use decoupled policy for asynchronously
# forwarding requests to vLLM engine, so assert it.
self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
self.model_config
)
Expand All @@ -118,17 +138,25 @@ def initialize(self, args):
engine_args_filepath
), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'"
with open(engine_args_filepath) as file:
vllm_engine_config = json.load(file)
self.vllm_engine_config = json.load(file)

# Validate device and multi-processing settings are currently set based on model/configs.
self.validate_device_config()

# Check for LoRA config and set it up if enabled
self.setup_lora()

# Create an AsyncLLMEngine from the config from JSON
self.llm_engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(**vllm_engine_config)
AsyncEngineArgs(**self.vllm_engine_config)
)

def setup_lora(self):
self.enable_lora = False

if (
"enable_lora" in vllm_engine_config.keys()
and vllm_engine_config["enable_lora"].lower() == "true"
"enable_lora" in self.vllm_engine_config.keys()
and self.vllm_engine_config["enable_lora"].lower() == "true"
):
# create Triton LoRA weights repository
multi_lora_args_filepath = os.path.join(
Expand All @@ -146,21 +174,31 @@ def initialize(self, args):
f"Triton backend cannot find {multi_lora_args_filepath}."
)

output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Counter to keep track of ongoing request counts
self.ongoing_request_count = 0
def validate_device_config(self):
triton_kind = self.args["model_instance_kind"]
triton_device_id = int(self.args["model_instance_device_id"])
triton_instance = f"{self.args['model_name']}_{triton_device_id}"

# Triton's current definition of KIND_GPU makes assumptions that
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
# models only use a single GPU. For multi-GPU models, the recommendation
# is to specify KIND_MODEL to acknowledge that the model will take control
# of the devices made available to it.
# NOTE: Consider other parameters that would indicate multi-GPU in the future.
tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1))
if tp_size > 1 and triton_kind == "GPU":
raise ValueError(
"KIND_GPU is currently for single-GPU models, please specify KIND_MODEL "
"in the model's config.pbtxt for multi-GPU models"
)

# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._loop_thread = threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._loop_thread.start()
# If KIND_GPU is specified, specify the device ID assigned by Triton to ensure that
# multiple model instances do not oversubscribe the same default device.
if triton_kind == "GPU" and triton_device_id >= 0:
self.logger.log_info(
f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}"
)
# vLLM doesn't currently (v0.4.2) expose device selection in the APIs
torch.cuda.set_device(triton_device_id)

def create_task(self, coro):
"""
Expand Down
Loading