Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: vLLM metrics optimization #66

Merged
merged 3 commits into from
Sep 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -161,6 +161,7 @@ def init_engine(self):
self.llm_engine = AsyncLLMEngine.from_engine_args(aync_engine_args)

# Create vLLM custom metrics
self.vllm_metrics = None
if (
"REPORT_CUSTOM_METRICS" in self.model_config["parameters"]
and self.model_config["parameters"]["REPORT_CUSTOM_METRICS"]["string_value"]
@@ -174,9 +175,10 @@ def init_engine(self):
}
# Add vLLM custom metrics
engine_config = self.llm_engine.engine.model_config
self.llm_engine.add_logger(
"triton", VllmStatLogger(labels, engine_config.max_model_len)
self.vllm_metrics = VllmStatLogger(
labels, engine_config.max_model_len, self.logger
)
self.llm_engine.add_logger("triton", self.vllm_metrics)
except pb_utils.TritonModelException as e:
if "metrics not supported" in str(e):
# Metrics are disabled at the server
@@ -572,6 +574,10 @@ def finalize(self):
self._response_thread.join()
self._response_thread = None

# Shutdown the logger thread.
if self.vllm_metrics is not None:
self.vllm_metrics.finalize()

# When using parallel tensors, the stub process may not shutdown due to
# unreleased references, so manually run the garbage collector once.
self.logger.log_info("[vllm] Running Garbage Collector on finalize...")
38 changes: 34 additions & 4 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,8 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import queue
import threading
from typing import Dict, List, Union

import triton_python_backend_utils as pb_utils
@@ -170,11 +172,18 @@ def __init__(self, labels: List[str], max_model_len: int):
class VllmStatLogger(VllmStatLoggerBase):
"""StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider."""

# local_interval not used here. It's for vLLM logs to stdout.
def __init__(self, labels: Dict, max_model_len: int) -> None:
def __init__(self, labels: Dict, max_model_len: int, log_logger) -> None:
# Tracked stats over current local logging interval.
# local_interval not used here. It's for vLLM logs to stdout.
super().__init__(local_interval=0)
self.metrics = TritonMetrics(labels, max_model_len)
self.log_logger = log_logger

# Starting the metrics thread. It allows vLLM to keep making progress
# while reporting metrics to triton metrics service.
self._logger_queue = queue.Queue()
self._logger_thread = threading.Thread(target=self.logger_loop)
self._logger_thread.start()

def info(self, type: str, obj: SupportsMetricsInfo) -> None:
pass
@@ -190,7 +199,7 @@ def _log_counter(self, counter, data: Union[int, float]) -> None:
None
"""
if data != 0:
counter.increment(data)
self._logger_queue.put_nowait((counter, "increment", data))

def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
"""Convenience function for logging list to histogram.
@@ -203,7 +212,7 @@ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None
None
"""
for datum in data:
histogram.observe(datum)
self._logger_queue.put_nowait((histogram, "observe", datum))

def log(self, stats: VllmStats) -> None:
"""Report stats to Triton metrics server.
@@ -246,3 +255,24 @@ def log(self, stats: VllmStats) -> None:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
self._log_histogram(metric, data)

def logger_loop(self):
while True:
item = self._logger_queue.get()
# To signal shutdown a None item will be added to the queue.
if item is None:
break
metric, command, data = item
if command == "increment":
metric.increment(data)
elif command == "observe":
metric.observe(data)
else:
self.log_logger.log_error(f"Undefined command name: {command}")

def finalize(self):
# Shutdown the logger thread.
self._logger_queue.put(None)
if self._logger_thread is not None:
self._logger_thread.join()
self._logger_thread = None