From 4b91f8c86e1329162c8118824ad44dd9be6b9e3d Mon Sep 17 00:00:00 2001
From: Yingge He <yinggeh@nvidia.com>
Date: Tue, 6 Aug 2024 22:20:15 -0700
Subject: [PATCH] Add histogram test

---
 .../metrics_test/vllm_metrics_test.py         | 25 ++++---
 src/utils/metrics.py                          | 74 ++++++++++++++++++-
 2 files changed, 89 insertions(+), 10 deletions(-)

diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py
index 8284835b..196a0d64 100644
--- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py
+++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py
@@ -112,21 +112,28 @@ def vllm_infer(
         self.triton_client.stop_stream()
 
     def test_vllm_metrics(self):
-        # All vLLM metrics from tritonserver
-        expected_metrics_dict = {
-            "vllm:prompt_tokens_total": 0,
-            "vllm:generation_tokens_total": 0,
-        }
-
         # Test vLLM metrics
         self.vllm_infer(
             prompts=self.prompts,
             sampling_parameters=self.sampling_parameters,
             model_name=self.vllm_model_name,
         )
-        expected_metrics_dict["vllm:prompt_tokens_total"] = 18
-        expected_metrics_dict["vllm:generation_tokens_total"] = 48
-        self.assertEqual(self.get_metrics(), expected_metrics_dict)
+        metrics_dict = self.get_metrics()
+
+        # vllm:prompt_tokens_total
+        self.assertEqual(metrics_dict["vllm:prompt_tokens_total"], 18)
+        # vllm:generation_tokens_total
+        self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48)
+        # vllm:time_to_first_token_seconds
+        self.assertEqual(metrics_dict["vllm:time_to_first_token_seconds_count"], 3)
+        self.assertTrue(
+            0 < metrics_dict["vllm:time_to_first_token_seconds_sum"] < 0.0005
+        )
+        # vllm:time_per_output_token_seconds
+        self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45)
+        self.assertTrue(
+            0 <= metrics_dict["vllm:time_per_output_token_seconds_sum"] <= 0.005
+        )
 
     def tearDown(self):
         self.triton_client.close()
diff --git a/src/utils/metrics.py b/src/utils/metrics.py
index e8c58372..d8c71ebc 100644
--- a/src/utils/metrics.py
+++ b/src/utils/metrics.py
@@ -24,7 +24,7 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-from typing import Dict, Union
+from typing import Dict, List, Union
 
 import triton_python_backend_utils as pb_utils
 from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
@@ -46,6 +46,16 @@ def __init__(self, labels):
             description="Number of generation tokens processed.",
             kind=pb_utils.MetricFamily.COUNTER,
         )
+        self.histogram_time_to_first_token_family = pb_utils.MetricFamily(
+            name="vllm:time_to_first_token_seconds",
+            description="Histogram of time to first token in seconds.",
+            kind=pb_utils.MetricFamily.HISTOGRAM,
+        )
+        self.histogram_time_per_output_token_family = pb_utils.MetricFamily(
+            name="vllm:time_per_output_token_seconds",
+            description="Histogram of time per output token in seconds.",
+            kind=pb_utils.MetricFamily.HISTOGRAM,
+        )
 
         # Initialize metrics
         # Iteration stats
@@ -55,6 +65,49 @@ def __init__(self, labels):
         self.counter_generation_tokens = self.counter_generation_tokens_family.Metric(
             labels=labels
         )
+        self.histogram_time_to_first_token = (
+            self.histogram_time_to_first_token_family.Metric(
+                labels=labels,
+                buckets=[
+                    0.001,
+                    0.005,
+                    0.01,
+                    0.02,
+                    0.04,
+                    0.06,
+                    0.08,
+                    0.1,
+                    0.25,
+                    0.5,
+                    0.75,
+                    1.0,
+                    2.5,
+                    5.0,
+                    7.5,
+                    10.0,
+                ],
+            )
+        )
+        self.histogram_time_per_output_token = (
+            self.histogram_time_per_output_token_family.Metric(
+                labels=labels,
+                buckets=[
+                    0.01,
+                    0.025,
+                    0.05,
+                    0.075,
+                    0.1,
+                    0.15,
+                    0.2,
+                    0.3,
+                    0.4,
+                    0.5,
+                    0.75,
+                    1.0,
+                    2.5,
+                ],
+            )
+        )
 
 
 class VllmStatLogger(VllmStatLoggerBase):
@@ -93,6 +146,19 @@ def _log_counter(self, counter, data: Union[int, float]) -> None:
         """
         if data != 0:
             counter.increment(data)
+    
+    def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
+        """Convenience function for logging list to histogram.
+
+        Args:
+            histogram: A histogram metric instance.
+            data: A list of int or float data to observe into the histogram metric.
+
+        Returns:
+            None
+        """
+        for datum in data:
+            histogram.observe(datum)
 
     def log(self, stats: VllmStats) -> None:
         """Logs tracked stats to triton metrics server every iteration.
@@ -108,4 +174,10 @@ def log(self, stats: VllmStats) -> None:
         )
         self._log_counter(
             self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter
+        self._log_histogram(
+            self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter
+        )
+        self._log_histogram(
+            self.metrics.histogram_time_per_output_token,
+            stats.time_per_output_tokens_iter,
         )