Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding autocomplete to vllm model.py (#20) #21

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ci/L0_backend_vllm/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def test_vllm_model(self):

for i in range(number_of_vllm_reqs):
result = user_data._completed_requests.get()
self.assertIsNot(type(result), InferenceServerException)
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output)
self.assertIsNotNone(output, "`text_output` should not be None")

triton_vllm_output.extend(output)

Expand Down
6 changes: 3 additions & 3 deletions ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ async def request_iterator():

async for response in response_iterator:
result, error = response
self.assertIsNone(error)
self.assertIsNotNone(result)
self.assertIsNone(error, str(error))
self.assertIsNotNone(result, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output)
self.assertIsNotNone(output, "`text_output` should not be None")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def _test_vllm_model(self, send_parameters_as_tensor):
result = user_data._completed_requests.get()
if type(result) is InferenceServerException:
print(result.message())
self.assertIsNot(type(result), InferenceServerException)
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output)
self.assertIsNotNone(output, "`text_output` should not be None")

self.triton_client.stop_stream()

Expand Down
40 changes: 0 additions & 40 deletions samples/model_repository/vllm_model/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,6 @@

backend: "vllm"

# Disabling batching in Triton, let vLLM handle the batching on its own.
max_batch_size: 0

# We need to use decoupled transaction policy for saturating
# vLLM engine for max throughtput.
# TODO [DLIS:5233]: Allow asynchronous execution to lift this
# restriction for cases there is exactly a single response to
# a single request.
model_transaction_policy {
decoupled: True
}
# Note: The vLLM backend uses the following input and output names.
# Any change here needs to also be made in model.py
input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "stream"
data_type: TYPE_BOOL
dims: [ 1 ]
},
{
name: "sampling_parameters"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
}
]

output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ -1 ]
}
]

# The usage of device is deferred to the vLLM engine
instance_group [
{
Expand Down
48 changes: 48 additions & 0 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,54 @@


class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
{
"name": "stream",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "sampling_parameters",
"data_type": "TYPE_STRING",
"dims": [1],
"optional": True,
},
]
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]

# Store the model configuration as a dictionary.
config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])

# Add only missing inputs and output to the model configuration.
for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)

# We need to use decoupled transaction policy for saturating
# vLLM engine for max throughtput.
# TODO [DLIS:5233]: Allow asynchronous execution to lift this
# restriction for cases there is exactly a single response to
# a single request.
auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))

# Disabling batching in Triton, let vLLM handle the batching on its own.
auto_complete_model_config.set_max_batch_size(0)

return auto_complete_model_config

def initialize(self, args):
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
Expand Down
Loading