Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: athitten <[email protected]>
  • Loading branch information
athitten committed Jan 3, 2025
1 parent 9e76daa commit 630966e
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions nemo/collections/llm/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def _generate_tokens_logits(self, payload, return_text: bool = False, return_log
if return_text:
return response["choices"][0]["text"] # shape[batch_size, 1]
if return_logits:
return response["choices"][0]["context_logits"], response["choices"][0]["generation_logits"] # shape[batch_size, 1, num_tokens, vocab_size]
return (
response["choices"][0]["context_logits"],
response["choices"][0]["generation_logits"],
) # shape[batch_size, 1, num_tokens, vocab_size]

def tokenizer_type(self, tokenizer):
"""
Expand Down Expand Up @@ -120,10 +123,11 @@ def loglikelihood(self, requests: list[Instance]):
# Get the logits from the model
context_logits, generation_logits = self._generate_tokens_logits(payload, return_logits=True)
import numpy as np
are_equal=np.array_equal(context_logits[0][len(context_logits[0])-1], generation_logits[0][0][0])
logits = context_logits[:,-num_cont_tokens:,:]

are_equal = np.array_equal(context_logits[0][len(context_logits[0]) - 1], generation_logits[0][0][0])
logits = context_logits[:, -num_cont_tokens:, :]
# Convert generation_logits to torch tensor to easily get logprobs wo manual implementation of log_softmax
#multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1)
# multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1)
multi_logits = F.log_softmax(torch.tensor(logits), dim=-1)
# Convert encoded continuation tokens to torch tensor
cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0)
Expand Down Expand Up @@ -172,6 +176,7 @@ def generate_until(self, inputs: list[Instance]):

return results


def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2):
"""
Wait for the Triton server and model to be ready, with retry logic.
Expand All @@ -187,13 +192,14 @@ def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2):
"""

import time

import requests
from pytriton.client import ModelClient
from pytriton.client.exceptions import PyTritonClientTimeoutError, PyTritonClientModelUnavailableError
from pytriton.client.exceptions import PyTritonClientModelUnavailableError, PyTritonClientTimeoutError

# If gRPC URL, extract HTTP URL from gRPC URL for health checks
if url.startswith("grpc://"):
#TODO use triton port and grpc port instaed of harcoding
# TODO use triton port and grpc port instaed of harcoding
url = url.replace("grpc://", "http://").replace(":8001", ":8000")
health_url = f"{url}/v2/health/ready"

Expand Down Expand Up @@ -225,4 +231,4 @@ def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2):
time.sleep(retry_interval)

logging.error(f"Server or model '{model_name}' not ready after {max_retries} attempts.")
return False
return False

0 comments on commit 630966e

Please sign in to comment.