From 630966ee9116839cbeb24e35f2b9e4b87dfcfe16 Mon Sep 17 00:00:00 2001 From: athitten Date: Fri, 3 Jan 2025 22:46:29 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: athitten --- nemo/collections/llm/evaluation/base.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/evaluation/base.py b/nemo/collections/llm/evaluation/base.py index e4fd05eb2fc4..0d7d2375df89 100644 --- a/nemo/collections/llm/evaluation/base.py +++ b/nemo/collections/llm/evaluation/base.py @@ -63,7 +63,10 @@ def _generate_tokens_logits(self, payload, return_text: bool = False, return_log if return_text: return response["choices"][0]["text"] # shape[batch_size, 1] if return_logits: - return response["choices"][0]["context_logits"], response["choices"][0]["generation_logits"] # shape[batch_size, 1, num_tokens, vocab_size] + return ( + response["choices"][0]["context_logits"], + response["choices"][0]["generation_logits"], + ) # shape[batch_size, 1, num_tokens, vocab_size] def tokenizer_type(self, tokenizer): """ @@ -120,10 +123,11 @@ def loglikelihood(self, requests: list[Instance]): # Get the logits from the model context_logits, generation_logits = self._generate_tokens_logits(payload, return_logits=True) import numpy as np - are_equal=np.array_equal(context_logits[0][len(context_logits[0])-1], generation_logits[0][0][0]) - logits = context_logits[:,-num_cont_tokens:,:] + + are_equal = np.array_equal(context_logits[0][len(context_logits[0]) - 1], generation_logits[0][0][0]) + logits = context_logits[:, -num_cont_tokens:, :] # Convert generation_logits to torch tensor to easily get logprobs wo manual implementation of log_softmax - #multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1) + # multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1) multi_logits = F.log_softmax(torch.tensor(logits), dim=-1) # Convert encoded continuation tokens to torch tensor cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0) @@ -172,6 +176,7 @@ def generate_until(self, inputs: list[Instance]): return results + def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2): """ Wait for the Triton server and model to be ready, with retry logic. @@ -187,13 +192,14 @@ def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2): """ import time + import requests from pytriton.client import ModelClient - from pytriton.client.exceptions import PyTritonClientTimeoutError, PyTritonClientModelUnavailableError + from pytriton.client.exceptions import PyTritonClientModelUnavailableError, PyTritonClientTimeoutError # If gRPC URL, extract HTTP URL from gRPC URL for health checks if url.startswith("grpc://"): - #TODO use triton port and grpc port instaed of harcoding + # TODO use triton port and grpc port instaed of harcoding url = url.replace("grpc://", "http://").replace(":8001", ":8000") health_url = f"{url}/v2/health/ready" @@ -225,4 +231,4 @@ def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2): time.sleep(retry_interval) logging.error(f"Server or model '{model_name}' not ready after {max_retries} attempts.") - return False \ No newline at end of file + return False