diff --git a/src/triton_cli/server/server_local.py b/src/triton_cli/server/server_local.py index 0d51591..81f3718 100755 --- a/src/triton_cli/server/server_local.py +++ b/src/triton_cli/server/server_local.py @@ -135,12 +135,10 @@ def stop(self): if self._tritonserver_process is not None: self._tritonserver_process.terminate() try: - self._tritonserver_process.communicate( - timeout=SERVER_OUTPUT_TIMEOUT_SECS - ) + self._tritonserver_process.wait(timeout=SERVER_OUTPUT_TIMEOUT_SECS) except TimeoutExpired: self._tritonserver_process.kill() - self._tritonserver_process.communicate() + self._tritonserver_process.wait() self._tritonserver_process = None logger.debug("Stopped Triton Server.") diff --git a/tests/utils.py b/tests/utils.py index c84bf66..8e1f15b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -25,39 +25,13 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import io -import os -import sys import json import time import psutil -import requests import subprocess from contextlib import redirect_stdout -from typing import List, Dict, Optional - from triton_cli.main import run - - -def find_processes_by_name(names: List[str]): - def name_in_proc(name: str, proc: psutil.Process): - if name.lower() in proc.info["name"]: - return True - if not proc.info["cmdline"]: - return False - if any(name.lower() in cmd.lower() for cmd in proc.info["cmdline"]): - return True - - return False - - for proc in psutil.process_iter(["pid", "name", "cmdline"]): - try: - for name in names: - if name_in_proc(name, proc): - print(proc.info) - except psutil.ZombieProcess: - print(f"Zombie process detected: {proc.info}") - except (psutil.NoSuchProcess, psutil.AccessDenied): - pass +from subprocess import Popen class TritonCommands: @@ -136,199 +110,41 @@ def _clear(repo=None): run(args) -# TODO: Consider removal if other version works # Context Manager to start and kill a server running in background and used by testing functions -# class ScopedTritonServerOld: -# def __init__(self, repo=None, mode="local", timeout=60): -# self.repo = repo -# self.mode = mode -# self.timeout = timeout -# -# def __enter__(self): -# self.pid = self.run_server(self.repo, self.mode) -# self.wait_for_server_ready(timeout=self.timeout) # Polling -# -# def __exit__(self, type, value, traceback): -# self.kill_server(self.pid) -# self.repo, self.mode = None, None -# -# def run_server(self, repo=None, mode="local"): -# args = ["triton", "start"] -# if repo: -# args += ["--repo", repo] -# if mode: -# args += ["--mode", mode] -# # Use Popen to run the server in the background as a separate process. -# p = Popen(args) -# return p.pid -# -# def check_pid(self): -# """Check the 'triton start' PID and raise an exception if the process is unhealthy""" -# # Check if the PID exists, an exception is raised if not -# self.check_pid_with_signal() -# # If the PID exists, check the status of the process. Raise an exception -# # for a bad status. -# self.check_pid_status() -# -# def check_pid_with_signal(self): -# """Check for the existence of a PID by sending signal 0""" -# try: -# proc = psutil.Process(self.pid) -# proc.send_signal(0) -# except psutil.NoSuchProcess as e: -# # PID doesn't exist, passthrough the exception -# raise e -# -# def check_pid_status(self): -# """Check the status of the 'triton start' process based on its PID""" -# process = psutil.Process(self.pid) -# # NOTE: May need to check other statuses in the future, but zombie was observed -# # in some local test cases. -# if process.status() == psutil.STATUS_ZOMBIE: -# raise Exception(f"'triton start' PID {self.pid} was in a zombie state.") -# -# def wait_for_server_ready(self, timeout: int = 60): -# start = time.time() -# while time.time() - start < timeout: -# print( -# "[DEBUG] Waiting for server to be ready ", -# round(timeout - (time.time() - start)), -# flush=True, -# ) -# time.sleep(1) -# try: -# print(f"[DEBUG] Checking status of 'triton start' PID {self.pid}...") -# self.check_pid() -# -# # For simplicity in testing, make sure both HTTP and GRPC endpoints -# # are ready before marking server ready. -# if self.check_server_ready(protocol="http") and self.check_server_ready( -# protocol="grpc" -# ): -# print("[DEBUG] Server is ready!") -# return -# except ConnectionRefusedError as e: -# # Dump to log for testing transparency -# print(e) -# except InferenceServerException: -# pass -# raise Exception(f"=== Timeout {timeout} secs. Server not ready. ===") -# -# def kill_server(self, pid: int, sig: int = 2, timeout: int = 30): -# try: -# proc = psutil.Process(pid) -# proc.send_signal(sig) -# # Add wait timeout to avoid hanging if process can't be cleanly -# # stopped for some reason. -# proc.wait(timeout=timeout) -# except psutil.NoSuchProcess as e: -# print(e) -# -# def check_server_ready(self, protocol="grpc"): -# status = TritonCommands._status(protocol) -# return status["ready"] +class ScopedTritonServer: + def __init__(self, repo=None, mode="local", timeout=60): + self.repo = repo + self.mode = mode + self.timeout = timeout + def __enter__(self): + self.start() -class ScopedTritonServer: - def __init__( - self, - repo: Optional[str] = None, - mode: Optional[str] = "local", - timeout: int = 120, - extra_args: Optional[List[str]] = None, - env_dict: Optional[Dict[str, str]] = None, - ) -> None: - self.started = False - self.stopped = False + def __exit__(self, type, value, traceback): + self.stop() - # TODO: Be more coupled with Triton CLI settings - self.host = "localhost" - self.port = 8000 - self.start_timeout = timeout - self.proc = None + def start(self): + self.proc = self.run_server(self.repo, self.mode) + self.wait_for_server_ready(timeout=self.timeout) # Polling - env = os.environ.copy() - if env_dict is not None: - env.update(env_dict) + def stop(self): + self.kill_server() - args: List[str] = ["triton", "start"] + def run_server(self, repo=None, mode="local"): + args = ["triton", "start"] if repo: args += ["--repo", repo] if mode: args += ["--mode", mode] - if extra_args: - args += extra_args - - self.args = args - self.env = env - - def _startup(self): - print("Starting server ...") - self.proc = subprocess.Popen( - self.args, - env=self.env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - print("Waiting for server ready...") - # Wait until health endpoint is responsive - self._wait_for_server( - url=self.url_for("v2", "health", "ready"), - timeout=self.start_timeout, - ) - print("DONE: Server ready!") + # Use Popen to run the server in the background as a separate process. + p = Popen(args) + return p - def _shutdown(self): - self.proc.terminate() - try: - wait_secs = 60 - self.proc.wait(wait_secs) - except subprocess.TimeoutExpired: - # force kill if needed - self.proc.kill() - - def start(self): - print("[DEBUG] Processes before server start:") - find_processes_by_name(["triton", "python"]) - - if self.started: - print("[WARNING] Server has already been started, skipping startup.") - return - - self.started = True - self._startup() - - print("[DEBUG] Processes after server start:") - find_processes_by_name(["triton", "python"]) - - def stop(self): - print("[DEBUG] Processes before server stop:") - find_processes_by_name(["triton", "python"]) - - # print(f"[DEBUG] =========== [START] STOPPING SERVER {self.proc} ========") - if self.stopped: - print("[WARNING] Server has already been stopped, skipping shutdown.") - return - - self.stopped = True - self._shutdown() - # print(f"[DEBUG] =========== [DONE] STOPPING SERVER {self.proc} ========") - - print("[DEBUG] Processes after server stop:") - find_processes_by_name(["triton", "python"]) - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.stop() - - def _wait_for_server(self, *, url: str, timeout: float): + def wait_for_server_ready(self, timeout: int = 60): start = time.time() while True: try: - if requests.get(url).status_code == 200: + if self.check_server_ready(): break except Exception as err: result = self.proc.poll() @@ -339,9 +155,17 @@ def _wait_for_server(self, *, url: str, timeout: float): if time.time() - start > timeout: raise RuntimeError("Server failed to start in time.") from err - @property - def url_root(self) -> str: - return f"http://{self.host}:{self.port}" - - def url_for(self, *parts: str) -> str: - return self.url_root + "/" + "/".join(parts) + def kill_server(self, timeout: int = 60): + try: + self.proc.terminate() + self.proc.wait(timeout=timeout) # Wait for triton to clean up + except subprocess.TimeoutExpired: + self.proc.kill() + self.proc.wait() # Indefinetely wait until the process is cleaned up. + except psutil.NoSuchProcess as e: + print(e) + + def check_server_ready(self): + status_grpc = TritonCommands._status(protocol="grpc") + status_http = TritonCommands._status(protocol="http") + return status_grpc["ready"] and status_http["ready"]