From 0fa015a8533f1c94e2784ed7b48c384c65012df8 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Wed, 21 Aug 2024 15:45:05 -0700 Subject: [PATCH] Make sampling parameters more comprehensive for testing metrics. --- .../metrics_test/vllm_metrics_test.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py index f8b7b08e..cfc62d14 100644 --- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -112,10 +112,19 @@ def vllm_infer( self.triton_client.stop_stream() def test_vllm_metrics(self): + # Adding sampling parameters for testing metrics. + n, best_of = 2, 4 + custom_sampling_parameters = self.sampling_parameters.copy() + # Changing "temperature" because "best_of" must be 1 when using greedy + # sampling, i.e. "temperature": "0". + custom_sampling_parameters.update( + {"n": str(n), "best_of": str(best_of), "temperature": "1"} + ) + # Test vLLM metrics self.vllm_infer( prompts=self.prompts, - sampling_parameters=self.sampling_parameters, + sampling_parameters=custom_sampling_parameters, model_name=self.vllm_model_name, ) metrics_dict = self.parse_vllm_metrics() @@ -124,7 +133,7 @@ def test_vllm_metrics(self): # vllm:prompt_tokens_total self.assertEqual(metrics_dict["vllm:prompt_tokens_total"], 18) # vllm:generation_tokens_total - self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48) + self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 188) # vllm:time_to_first_token_seconds self.assertEqual( metrics_dict["vllm:time_to_first_token_seconds_count"], total_prompts @@ -155,23 +164,27 @@ def test_vllm_metrics(self): ) # vllm:request_generation_tokens self.assertEqual( - metrics_dict["vllm:request_generation_tokens_count"], total_prompts + metrics_dict["vllm:request_generation_tokens_count"], + best_of * total_prompts, ) - self.assertEqual(metrics_dict["vllm:request_generation_tokens_sum"], 48) + self.assertEqual(metrics_dict["vllm:request_generation_tokens_sum"], 188) self.assertEqual( - metrics_dict["vllm:request_generation_tokens_bucket"], total_prompts + metrics_dict["vllm:request_generation_tokens_bucket"], + best_of * total_prompts, ) # vllm:request_params_best_of self.assertEqual( metrics_dict["vllm:request_params_best_of_count"], total_prompts ) - self.assertEqual(metrics_dict["vllm:request_params_best_of_sum"], 3) + self.assertEqual( + metrics_dict["vllm:request_params_best_of_sum"], best_of * total_prompts + ) self.assertEqual( metrics_dict["vllm:request_params_best_of_bucket"], total_prompts ) # vllm:request_params_n self.assertEqual(metrics_dict["vllm:request_params_n_count"], total_prompts) - self.assertEqual(metrics_dict["vllm:request_params_n_sum"], 3) + self.assertEqual(metrics_dict["vllm:request_params_n_sum"], n * total_prompts) self.assertEqual(metrics_dict["vllm:request_params_n_bucket"], total_prompts) def test_vllm_metrics_disabled(self):