Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Sampling Parameters related improvements #80

Merged
merged 9 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 105 additions & 13 deletions ci/L0_backend_vllm/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -26,6 +26,7 @@

import argparse
import asyncio
import json
import pickle
import sys
import unittest
Expand All @@ -36,6 +37,7 @@
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import GuidedDecodingParams
from vllm.utils import random_uuid

sys.path.append("../../common")
Expand All @@ -53,14 +55,22 @@
"The future of AI is",
]

GUIDED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. "]

SAMPLING_PARAMETERS = {"temperature": 0, "top_p": 1}


async def generate_python_vllm_output(prompt, llm_engine):
async def generate_python_vllm_output(
prompt,
llm_engine,
sampling_params=SamplingParams(**SAMPLING_PARAMETERS),
guided_generation=None,
):
request_id = random_uuid()
sampling_params = SamplingParams(**SAMPLING_PARAMETERS)
python_vllm_output = None
last_output = None
if guided_generation:
sampling_params.guided_decoding = guided_generation

async for vllm_output in llm_engine.generate(prompt, sampling_params, request_id):
last_output = vllm_output
Expand All @@ -69,24 +79,28 @@ async def generate_python_vllm_output(prompt, llm_engine):
python_vllm_output = [
(prompt + output.text).encode("utf-8") for output in last_output.outputs
]

return python_vllm_output


def prepare_vllm_baseline_outputs():
def prepare_vllm_baseline_outputs(
export_file="vllm_baseline_output.pkl", prompts=PROMPTS, guided_generation=None
):
"""
Helper function that starts async vLLM engine and generates output for each
prompt in `PROMPTS`. Saves resulted baselines in `vllm_baseline_output.pkl`
prompt in `prompts`. Saves resulted baselines in `vllm_baseline_output.pkl`
for further use.
"""
llm_engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**VLLM_ENGINE_CONFIG))
python_vllm_output = []
for i in range(len(PROMPTS)):
for i in range(len(prompts)):
python_vllm_output.extend(
asyncio.run(generate_python_vllm_output(PROMPTS[i], llm_engine))
asyncio.run(
generate_python_vllm_output(
prompts[i], llm_engine, guided_generation=guided_generation
)
)
)

with open("vllm_baseline_output.pkl", "wb") as f:
with open(export_file, "wb") as f:
pickle.dump(python_vllm_output, f)

return
Expand All @@ -96,6 +110,9 @@ class VLLMTritonAccuracyTest(TestResultCollector):
def setUp(self):
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
self.vllm_model_name = "vllm_opt"

def test_vllm_model(self):
# Reading and verifying baseline data
self.python_vllm_output = []
with open("vllm_baseline_output.pkl", "rb") as f:
self.python_vllm_output = pickle.load(f)
Expand All @@ -116,11 +133,9 @@ def setUp(self):
),
)

def test_vllm_model(self):
user_data = UserData()
stream = False
triton_vllm_output = []

self.triton_client.start_stream(callback=partial(callback, user_data))
for i in range(len(PROMPTS)):
request_data = create_vllm_request(
Expand All @@ -131,7 +146,7 @@ def test_vllm_model(self):
request_id=request_data["request_id"],
inputs=request_data["inputs"],
outputs=request_data["outputs"],
parameters=SAMPLING_PARAMETERS,
parameters=request_data["parameters"],
)

for i in range(len(PROMPTS)):
Expand All @@ -146,6 +161,63 @@ def test_vllm_model(self):
self.triton_client.stop_stream()
self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort())

def test_guided_decoding(self):
# Reading and verifying baseline data
self.python_vllm_output = []
with open("vllm_guided_baseline_output.pkl", "rb") as f:
self.python_vllm_output = pickle.load(f)

self.assertNotEqual(
self.python_vllm_output,
[],
"Loaded baseline outputs' list should not be empty",
)
self.assertIsNotNone(
self.python_vllm_output, "Loaded baseline outputs' list should not be None"
)
self.assertEqual(
len(self.python_vllm_output),
len(GUIDED_PROMPTS),
"Unexpected number of baseline outputs loaded, expected {}, but got {}".format(
len(GUIDED_PROMPTS), len(self.python_vllm_output)
),
)

user_data = UserData()
stream = False
triton_vllm_output = []

self.triton_client.start_stream(callback=partial(callback, user_data))
sampling_params = SAMPLING_PARAMETERS
guided_decoding_params = {
"choice": ["Positive", "Negative"],
"backend": "outlines",
}
sampling_params["guided_decoding"] = json.dumps(guided_decoding_params)
for i in range(len(GUIDED_PROMPTS)):
request_data = create_vllm_request(
GUIDED_PROMPTS[i], i, stream, sampling_params, self.vllm_model_name
)
self.triton_client.async_stream_infer(
model_name=self.vllm_model_name,
request_id=request_data["request_id"],
inputs=request_data["inputs"],
outputs=request_data["outputs"],
parameters=request_data["parameters"],
)

for i in range(len(GUIDED_PROMPTS)):
result = user_data._completed_requests.get()
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")

triton_vllm_output.extend(output)

self.triton_client.stop_stream()
self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort())

def tearDown(self):
self.triton_client.close()

Expand All @@ -159,9 +231,29 @@ def tearDown(self):
default=False,
help="Generates baseline output for accuracy tests",
)
parser.add_argument(
"--generate-guided-baseline",
action="store_true",
required=False,
default=False,
help="Generates baseline output for accuracy tests",
)
FLAGS = parser.parse_args()
if FLAGS.generate_baseline:
prepare_vllm_baseline_outputs()
exit(0)

if FLAGS.generate_guided_baseline:
guided_decoding_params = {
"choice": ["Positive", "Negative"],
"backend": "outlines",
}
guided_generation = GuidedDecodingParams(**guided_decoding_params)
prepare_vllm_baseline_outputs(
export_file="vllm_guided_baseline_output.pkl",
prompts=GUIDED_PROMPTS,
guided_generation=guided_generation,
)
exit(0)

unittest.main()
8 changes: 6 additions & 2 deletions ci/L0_backend_vllm/accuracy_test/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -37,7 +37,7 @@ TEST_RESULT_FILE='test_results.txt'
CLIENT_PY="./accuracy_test.py"
SAMPLE_MODELS_REPO="../../../samples/model_repository"
VLLM_ENGINE_LOG="vllm_engine.log"
EXPECTED_NUM_TESTS=1
EXPECTED_NUM_TESTS=2

rm -rf models && mkdir -p models
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt
Expand All @@ -50,6 +50,10 @@ set +e
# memory issues: https://github.com/vllm-project/vllm/issues/2248
python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
wait $BASELINE_PID

python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
wait $BASELINE_PID

set -e

run_server
Expand Down
Loading
Loading