Skip to content

Commit

Permalink
Remove hardcoding of grpc and http port
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishree <[email protected]>
  • Loading branch information
athitten committed Jan 7, 2025
1 parent 18da26c commit 8cbaef3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
5 changes: 4 additions & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def deploy(
def evaluate(
nemo_checkpoint_path: Path,
url: str = "grpc://0.0.0.0:8001",
triton_http_port: int = 8000,
model_name: str = "triton_model",
eval_task: str = "gsm8k",
num_fewshot: Optional[int] = None,
Expand All @@ -438,6 +439,8 @@ def evaluate(
nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt
which is required to tokenize the evaluation input and output prompts.
url (str): grpc service url that were used in the deploy method above in the format: grpc://{grpc_service_ip}:{grpc_port}.
triton_http_port (int): HTTP port that was used for the PyTriton server in the deploy method. Default: 8000.
Please pass the triton_http_port if using a custom port in the deploy method.
model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as
triton_model_name passed to the deploy method above to be able to launch evaluation. Deafult: "triton_model".
eval_task (str): task to be evaluated on. For ex: "gsm8k", "gsm8k_cot", "mmlu", "lambada". Default: "gsm8k".
Expand Down Expand Up @@ -475,7 +478,7 @@ def evaluate(
# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
tokenizer = io.load_context(nemo_checkpoint_path + "/context", subpath="model.tokenizer")
# Wait for server to be ready before starting evaluation
evaluation.wait_for_server_ready(url=url, model_name=model_name)
evaluation.wait_for_server_ready(url=url, triton_http_port=triton_http_port, model_name=model_name)
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
model = evaluation.NeMoFWLMEval(
model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos
Expand Down
18 changes: 12 additions & 6 deletions nemo/collections/llm/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from tqdm import tqdm
import re

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
Expand Down Expand Up @@ -165,12 +166,13 @@ def generate_until(self, inputs: list[Instance]):

return results

def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2):
def wait_for_server_ready(url, triton_http_port, model_name, max_retries=600, retry_interval=2):
"""
Wait for the Triton server and model to be ready, with retry logic.
Wait for PyTriton server and model to be ready.
Args:
url (str): The URL of the Triton server (e.g., "grpc://0.0.0.0:8001").
triton_http_port (int): http port of the triton server.
model_name (str): The name of the deployed model.
max_retries (int): Maximum number of retries before giving up.
retry_interval (int): Time in seconds to wait between retries.
Expand All @@ -186,8 +188,12 @@ def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2):

# If gRPC URL, extract HTTP URL from gRPC URL for health checks
if url.startswith("grpc://"):
#TODO use triton port and grpc port instaed of harcoding
url = url.replace("grpc://", "http://").replace(":8001", ":8000")
# Extract the gRPC port using regex
pattern = r":(\d+)" # Matches a colon followed by one or more digits
match = re.search(pattern, url)
grpc_port = match.group(1)
# Replace 'grpc' with 'http' and replace the grpc_port with http port
url = url.replace("grpc://", "http://").replace(f":{grpc_port}", f":{triton_http_port}")
health_url = f"{url}/v2/health/ready"

for _ in range(max_retries):
Expand All @@ -211,8 +217,8 @@ def wait_for_server_ready(url, model_name, max_retries=600, retry_interval=2):
logging.info(f"Timeout: Server or model '{model_name}' not ready yet.")
except PyTritonClientModelUnavailableError:
logging.info(f"Model '{model_name}' is unavailable on the server.")
except requests.exceptions.RequestException as e:
logging.info(f"Error checking server readiness: {e}")
except requests.exceptions.RequestException:
logging.info(f"Pytriton server not ready yet. Retrying in {retry_interval} seconds...")

# Wait before retrying
time.sleep(retry_interval)
Expand Down

0 comments on commit 8cbaef3

Please sign in to comment.