From 52e125f29a540f1eea261098cf6544eb7567985a Mon Sep 17 00:00:00 2001 From: jibxie Date: Mon, 14 Oct 2024 17:29:52 +0800 Subject: [PATCH 1/6] Support modal model input --- src/model.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/model.py b/src/model.py index 3f6e23bb..7639c3b0 100644 --- a/src/model.py +++ b/src/model.py @@ -52,6 +52,12 @@ class TritonPythonModel: def auto_complete_config(auto_complete_model_config): inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, + { + "name": "multi_modal_data", + "data_type": "TYPE_STRING", + "dims": [1], + "optional": True, + }, { "name": "stream", "data_type": "TYPE_BOOL", @@ -385,6 +391,21 @@ async def generate(self, request): ).as_numpy()[0] if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") + + multi_modal_data = pb_utils.get_input_tensor_by_name( + request, "multi_modal_data" + ).as_numpy()[0] + if isinstance(multi_modal_data, bytes): + multi_modal_data = multi_modal_data.decode("utf-8") + + if multi_modal_data is not None: + # Build TextPrompt format prompt for multi modal models + multi_modal_data = json.loads(multi_modal_data) + prompt = { + "prompt": prompt, + "multi_modal_data": multi_modal_data + } + stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: stream = stream.as_numpy()[0] From 114c2359684ecdcd028165d6004a80c1f2c33d1c Mon Sep 17 00:00:00 2001 From: jibxie Date: Mon, 14 Oct 2024 22:02:28 +0800 Subject: [PATCH 2/6] Update --- src/model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/model.py b/src/model.py index 7639c3b0..33fb7622 100644 --- a/src/model.py +++ b/src/model.py @@ -392,13 +392,11 @@ async def generate(self, request): if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") - multi_modal_data = pb_utils.get_input_tensor_by_name( + multi_modal_data_input_tensor = pb_utils.get_input_tensor_by_name( request, "multi_modal_data" - ).as_numpy()[0] - if isinstance(multi_modal_data, bytes): - multi_modal_data = multi_modal_data.decode("utf-8") - - if multi_modal_data is not None: + ) + if multi_modal_data_input_tensor: + multi_modal_data = multi_modal_data_input_tensor.as_numpy()[0].decode("utf-8") # Build TextPrompt format prompt for multi modal models multi_modal_data = json.loads(multi_modal_data) prompt = { From 4a8b7378a93e93b354dd77275065de97ffa6e0fa Mon Sep 17 00:00:00 2001 From: jibxie Date: Fri, 1 Nov 2024 15:52:23 +0800 Subject: [PATCH 3/6] support image list input --- src/model.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/model.py b/src/model.py index 33fb7622..0f09c3d3 100644 --- a/src/model.py +++ b/src/model.py @@ -31,7 +31,9 @@ import queue import threading from typing import Dict, List - +import base64 +from PIL import Image +from io import BytesIO import numpy as np import torch import triton_python_backend_utils as pb_utils @@ -397,12 +399,21 @@ async def generate(self, request): ) if multi_modal_data_input_tensor: multi_modal_data = multi_modal_data_input_tensor.as_numpy()[0].decode("utf-8") - # Build TextPrompt format prompt for multi modal models multi_modal_data = json.loads(multi_modal_data) - prompt = { - "prompt": prompt, - "multi_modal_data": multi_modal_data - } + if "image" in multi_modal_data: + image_list = [] + for image_base64_string in multi_modal_data["image"]: + if "base64," in image_base64_string: + image_base64_string = image_base64_string.split("base64,")[-1] + image_data = base64.b64decode(image_base64_string) + image = Image.open(BytesIO(image_data)).convert("RGB") + image_list.append(image) + prompt = { + "prompt": prompt, + "multi_modal_data": { + "image": image_list + } + } stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: From 9897102f7847a3bb1fe5d11470b44562bb7249b6 Mon Sep 17 00:00:00 2001 From: jibxie Date: Fri, 1 Nov 2024 15:59:45 +0800 Subject: [PATCH 4/6] Remove best of metrics to align with latest vllm --- src/utils/metrics.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 75c097dc..3588a0d0 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -76,11 +76,6 @@ def __init__(self, labels: List[str], max_model_len: int): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.HISTOGRAM, ) - self.histogram_best_of_request_family = pb_utils.MetricFamily( - name="vllm:request_params_best_of", - description="Histogram of the best_of request parameter.", - kind=pb_utils.MetricFamily.HISTOGRAM, - ) self.histogram_n_request_family = pb_utils.MetricFamily( name="vllm:request_params_n", description="Histogram of the n request parameter.", @@ -159,10 +154,6 @@ def __init__(self, labels: List[str], max_model_len: int): buckets=build_1_2_5_buckets(max_model_len), ) ) - self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( - labels=labels, - buckets=[1, 2, 5, 10, 20], - ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, buckets=[1, 2, 5, 10, 20], @@ -247,7 +238,6 @@ def log(self, stats: VllmStats) -> None: self.metrics.histogram_num_generation_tokens_request, stats.num_generation_tokens_requests, ), - (self.metrics.histogram_best_of_request, stats.best_of_requests), (self.metrics.histogram_n_request, stats.n_requests), ] From c85d9722d833018989b54ed70ec118dcadd6f81d Mon Sep 17 00:00:00 2001 From: jibxie Date: Mon, 4 Nov 2024 16:16:58 +0800 Subject: [PATCH 5/6] supply the images as individual tensors --- src/model.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/model.py b/src/model.py index 0f09c3d3..7211b10b 100644 --- a/src/model.py +++ b/src/model.py @@ -55,9 +55,9 @@ def auto_complete_config(auto_complete_model_config): inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, { - "name": "multi_modal_data", + "name": "image", "data_type": "TYPE_STRING", - "dims": [1], + "dims": [-1], # can be multiple images as separate elements "optional": True, }, { @@ -394,20 +394,16 @@ async def generate(self, request): if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") - multi_modal_data_input_tensor = pb_utils.get_input_tensor_by_name( - request, "multi_modal_data" + image_input_tensor = pb_utils.get_input_tensor_by_name( + request, "image" ) - if multi_modal_data_input_tensor: - multi_modal_data = multi_modal_data_input_tensor.as_numpy()[0].decode("utf-8") - multi_modal_data = json.loads(multi_modal_data) - if "image" in multi_modal_data: - image_list = [] - for image_base64_string in multi_modal_data["image"]: - if "base64," in image_base64_string: - image_base64_string = image_base64_string.split("base64,")[-1] - image_data = base64.b64decode(image_base64_string) - image = Image.open(BytesIO(image_data)).convert("RGB") - image_list.append(image) + if image_input_tensor: + image_list = [] + for image_raw in image_input_tensor.as_numpy(): + image_data = base64.b64decode(image_raw.decode("utf-8")) + image = Image.open(BytesIO(image_data)).convert("RGB") + image_list.append(image) + if len(image_list) > 0: prompt = { "prompt": prompt, "multi_modal_data": { From 566e0cc7fa90ad169b49314320e3937c06c6353e Mon Sep 17 00:00:00 2001 From: jibxie Date: Wed, 6 Nov 2024 22:08:33 +0800 Subject: [PATCH 6/6] Add vllm version check for compatibility --- src/model.py | 46 ++++++++++++++++++++++++-------------------- src/utils/metrics.py | 18 +++++++++++++++-- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/model.py b/src/model.py index 7211b10b..0fdbe0ce 100644 --- a/src/model.py +++ b/src/model.py @@ -42,6 +42,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.version import __version__ as _VLLM_VERSION from utils.metrics import VllmStatLogger @@ -54,12 +55,6 @@ class TritonPythonModel: def auto_complete_config(auto_complete_model_config): inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, - { - "name": "image", - "data_type": "TYPE_STRING", - "dims": [-1], # can be multiple images as separate elements - "optional": True, - }, { "name": "stream", "data_type": "TYPE_BOOL", @@ -79,6 +74,14 @@ def auto_complete_config(auto_complete_model_config): "optional": True, }, ] + if _VLLM_VERSION >= "0.6.3.post1": + inputs.append({ + "name": "image", + "data_type": "TYPE_STRING", + "dims": [-1], # can be multiple images as separate elements + "optional": True, + }) + outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}] # Store the model configuration as a dictionary. @@ -394,22 +397,23 @@ async def generate(self, request): if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") - image_input_tensor = pb_utils.get_input_tensor_by_name( - request, "image" - ) - if image_input_tensor: - image_list = [] - for image_raw in image_input_tensor.as_numpy(): - image_data = base64.b64decode(image_raw.decode("utf-8")) - image = Image.open(BytesIO(image_data)).convert("RGB") - image_list.append(image) - if len(image_list) > 0: - prompt = { - "prompt": prompt, - "multi_modal_data": { - "image": image_list + if _VLLM_VERSION >= "0.6.3.post1": + image_input_tensor = pb_utils.get_input_tensor_by_name( + request, "image" + ) + if image_input_tensor: + image_list = [] + for image_raw in image_input_tensor.as_numpy(): + image_data = base64.b64decode(image_raw.decode("utf-8")) + image = Image.open(BytesIO(image_data)).convert("RGB") + image_list.append(image) + if len(image_list) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": { + "image": image_list + } } - } stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 3588a0d0..0504eef9 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -32,7 +32,7 @@ from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase from vllm.engine.metrics import Stats as VllmStats from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets - +from vllm.version import __version__ as _VLLM_VERSION class TritonMetrics: def __init__(self, labels: List[str], max_model_len: int): @@ -76,6 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.HISTOGRAM, ) + # 'best_of' metric has been hidden since vllm 0.6.3 + # https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005 + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request_family = pb_utils.MetricFamily( + name="vllm:request_params_best_of", + description="Histogram of the best_of request parameter.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) self.histogram_n_request_family = pb_utils.MetricFamily( name="vllm:request_params_n", description="Histogram of the n request parameter.", @@ -154,6 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int): buckets=build_1_2_5_buckets(max_model_len), ) ) + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( + labels=labels, + buckets=[1, 2, 5, 10, 20], + ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, buckets=[1, 2, 5, 10, 20], @@ -240,7 +253,8 @@ def log(self, stats: VllmStats) -> None: ), (self.metrics.histogram_n_request, stats.n_requests), ] - + if _VLLM_VERSION < "0.6.3": + histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests)) for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: