Skip to content

Commit

Permalink
Adding autocomplete to vllm model.py (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
oandreeva-nv authored Nov 9, 2023
1 parent 797038d commit 0e5b209
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 47 deletions.
4 changes: 2 additions & 2 deletions ci/L0_backend_vllm/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def test_vllm_model(self):

for i in range(number_of_vllm_reqs):
result = user_data._completed_requests.get()
self.assertIsNot(type(result), InferenceServerException)
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output)
self.assertIsNotNone(output, "`text_output` should not be None")

triton_vllm_output.extend(output)

Expand Down
6 changes: 3 additions & 3 deletions ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ async def request_iterator():

async for response in response_iterator:
result, error = response
self.assertIsNone(error)
self.assertIsNotNone(result)
self.assertIsNone(error, str(error))
self.assertIsNotNone(result, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output)
self.assertIsNotNone(output, "`text_output` should not be None")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def _test_vllm_model(self, send_parameters_as_tensor):
result = user_data._completed_requests.get()
if type(result) is InferenceServerException:
print(result.message())
self.assertIsNot(type(result), InferenceServerException)
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output)
self.assertIsNotNone(output, "`text_output` should not be None")

self.triton_client.stop_stream()

Expand Down
40 changes: 0 additions & 40 deletions samples/model_repository/vllm_model/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,6 @@

backend: "vllm"

# Disabling batching in Triton, let vLLM handle the batching on its own.
max_batch_size: 0

# We need to use decoupled transaction policy for saturating
# vLLM engine for max throughtput.
# TODO [DLIS:5233]: Allow asynchronous execution to lift this
# restriction for cases there is exactly a single response to
# a single request.
model_transaction_policy {
decoupled: True
}
# Note: The vLLM backend uses the following input and output names.
# Any change here needs to also be made in model.py
input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "stream"
data_type: TYPE_BOOL
dims: [ 1 ]
},
{
name: "sampling_parameters"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
}
]

output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ -1 ]
}
]

# The usage of device is deferred to the vLLM engine
instance_group [
{
Expand Down
48 changes: 48 additions & 0 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,54 @@


class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
{
"name": "stream",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "sampling_parameters",
"data_type": "TYPE_STRING",
"dims": [1],
"optional": True,
},
]
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]

# Store the model configuration as a dictionary.
config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])

# Add only missing inputs and output to the model configuration.
for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)

# We need to use decoupled transaction policy for saturating
# vLLM engine for max throughtput.
# TODO [DLIS:5233]: Allow asynchronous execution to lift this
# restriction for cases there is exactly a single response to
# a single request.
auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))

# Disabling batching in Triton, let vLLM handle the batching on its own.
auto_complete_model_config.set_max_batch_size(0)

return auto_complete_model_config

def initialize(self, args):
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
Expand Down

0 comments on commit 0e5b209

Please sign in to comment.