Skip to content

Commit

Permalink
refactor: Simplify testing with ScopedTritonServer instead of pytest …
Browse files Browse the repository at this point in the history
…fixtures (#68)
  • Loading branch information
KrishnanPrash authored Jun 3, 2024
1 parent dbbfcdf commit 4e62d74
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 271 deletions.
8 changes: 0 additions & 8 deletions src/triton_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,11 @@

import sys
import logging
import io
from contextlib import redirect_stdout

logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("triton")


def run_and_capture_stdout(args):
with io.StringIO() as buf, redirect_stdout(buf):
run(args)
return buf.getvalue()


# Separate function that can raise exceptions used for testing
# to assert correct errors and messages.
# Optional argv used for testing - will default to sys.argv if None.
Expand Down
173 changes: 48 additions & 125 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@

import os
import pytest
from triton_cli.main import run, run_and_capture_stdout
from triton_cli.parser import KNOWN_MODEL_SOURCES, parse_args
import utils
import json
from utils import TritonCommands, ScopedTritonServer

KNOWN_MODELS = KNOWN_MODEL_SOURCES.keys()
KNOWN_SOURCES = KNOWN_MODEL_SOURCES.values()
Expand All @@ -50,110 +48,41 @@


class TestRepo:
def _list(self, repo=None):
args = ["list"]
if repo:
args += ["--repo", repo]
run(args)

def _clear(self, repo=None):
args = ["remove", "-m", "all"]
run(args)

def _import(self, model, source=None, repo=None, backend=None):
args = ["import", "-m", model]
if source:
args += ["--source", source]
if repo:
args += ["--repo", repo]
if backend:
args += ["--backend", backend]
run(args)

def _infer(self, model, prompt=None, protocol=None):
args = ["infer", "-m", model]
if prompt:
args += ["--prompt", prompt]
if protocol:
args += ["-i", protocol]
run(args)

def _metrics(self):
args = ["metrics"]
output = run_and_capture_stdout(args)
return json.loads(output)

def _config(self, model):
args = ["config", "-m", model]
output = run_and_capture_stdout(args)
return json.loads(output)

def _status(self):
args = ["status"]
output = run_and_capture_stdout(args)
return json.loads(output)

def _remove(self, model, repo=None):
args = ["remove", "-m", model]
if repo:
args += ["--repo", repo]
run(args)

class KillServerByPid:
def __init__(self):
self.pid = None

def kill_server(self):
if self.pid is not None:
utils.kill_server(self.pid)

@pytest.fixture
def setup_and_teardown(self):
# Setup before the test case is run.
kill_server = self.KillServerByPid()
self._clear()

yield kill_server

# Teardown after the test case is done.
kill_server.kill_server()
self._clear()

@pytest.mark.parametrize("repo", TEST_REPOS)
def test_clear(self, repo):
self._clear(repo)
TritonCommands._clear(repo)

# TODO: Add pre/post repo clear to a fixture for setup/teardown
@pytest.mark.parametrize("model", KNOWN_MODELS)
@pytest.mark.parametrize("repo", TEST_REPOS)
def test_import_known_model(self, model, repo):
self._clear(repo)
self._import(model, repo=repo)
self._clear(repo)
TritonCommands._clear(repo)
TritonCommands._import(model, repo=repo)
TritonCommands._clear(repo)

@pytest.mark.parametrize("source", KNOWN_SOURCES)
@pytest.mark.parametrize("repo", TEST_REPOS)
def test_import_known_source(self, source, repo):
self._clear(repo)
self._import("known_source", source=source, repo=repo)
self._clear(repo)
TritonCommands._clear(repo)
TritonCommands._import("known_source", source=source, repo=repo)
TritonCommands._clear(repo)

@pytest.mark.parametrize("model,source", CUSTOM_VLLM_MODEL_SOURCES)
def test_import_vllm(self, model, source):
self._clear()
self._import(model, source=source)
TritonCommands._clear()
TritonCommands._import(model, source=source)
# TODO: Parse repo to find model, with vllm backend in config
self._clear()
TritonCommands._clear()

@pytest.mark.skipif(
os.environ.get("IMAGE_KIND") != "TRTLLM", reason="Only run for TRT-LLM image"
)
@pytest.mark.parametrize("model,source", CUSTOM_TRTLLM_MODEL_SOURCES)
def test_repo_add_trtllm_build(self, model, source):
# TODO: Parse repo to find TRT-LLM models and backend in config
self._clear()
self._import(model, source=source, backend="tensorrtllm")
self._clear()
TritonCommands._clear()
TritonCommands._import(model, source=source, backend="tensorrtllm")
TritonCommands._clear()

@pytest.mark.skip(reason="Pre-built TRT-LLM engines not available")
def test_import_trtllm_prebuilt(self, model, source):
Expand All @@ -165,21 +94,21 @@ def test_import_no_source(self):
with pytest.raises(
Exception, match="Please use a known model, or provide a --source"
):
self._import("no_source", source=None)
TritonCommands._import("no_source", source=None)

def test_remove(self):
self._import("gpt2", source="hf:gpt2")
self._remove("gpt2")
TritonCommands._import("gpt2", source="hf:gpt2")
TritonCommands._remove("gpt2")

# TODO: Find a way to raise well-typed errors for testing purposes, without
# always dumping traceback to user-facing output.
def test_remove_nonexistent(self):
with pytest.raises(FileNotFoundError, match="No model folder exists"):
self._remove("does-not-exist")
TritonCommands._remove("does-not-exist")

@pytest.mark.parametrize("repo", TEST_REPOS)
def test_list(self, repo):
self._list(repo)
TritonCommands._list(repo)

# This test uses mock system args and a mock subprocess call
# to ensure that the correct subprocess call is made for profile.
Expand All @@ -192,48 +121,42 @@ def test_triton_profile(self, mocker, monkeypatch):
mock_run.assert_called_once_with(["genai-perf", "-m", "add_sub"], check=True)

@pytest.mark.parametrize("model", ["mock_llm"])
def test_triton_metrics(self, model, setup_and_teardown):
def test_triton_metrics(self, model):
# Import the Model Repo
pid = utils.run_server(repo=MODEL_REPO)
setup_and_teardown.pid = pid
utils.wait_for_server_ready()

metrics_before = self._metrics()
with ScopedTritonServer(repo=MODEL_REPO):
metrics_before = TritonCommands._metrics()

# Before Inference, Verifying Inference Count == 0
for loaded_models in metrics_before["nv_inference_request_success"]["metrics"]:
if loaded_models["labels"]["model"] == model: # If mock_llm
assert loaded_models["value"] == 0
# Before Inference, Verifying Inference Count == 0
for loaded_models in metrics_before["nv_inference_request_success"][
"metrics"
]:
if loaded_models["labels"]["model"] == model: # If mock_llm
assert loaded_models["value"] == 0

# Model Inference
self._infer(model, prompt=PROMPT)
# Model Inference
TritonCommands._infer(model, prompt=PROMPT)

metrics_after = self._metrics()
metrics_after = TritonCommands._metrics()

# After Inference, Verifying Inference Count == 0
for loaded_models in metrics_after["nv_inference_request_success"]["metrics"]:
if loaded_models["labels"]["model"] == model: # If mock_llm
assert loaded_models["value"] == 1
# After Inference, Verifying Inference Count == 1
for loaded_models in metrics_after["nv_inference_request_success"][
"metrics"
]:
if loaded_models["labels"]["model"] == model: # If mock_llm
assert loaded_models["value"] == 1

@pytest.mark.parametrize("model", ["mock_llm"])
def test_triton_config(self, model, setup_and_teardown):
def test_triton_config(self, model):
# Import the Model
pid = utils.run_server(repo=MODEL_REPO)
setup_and_teardown.pid = pid
utils.wait_for_server_ready()

config = self._config(model)

# Checks if correct model is loaded
assert config["name"] == model
with ScopedTritonServer(repo=MODEL_REPO):
config = TritonCommands._config(model)
# Checks if correct model is loaded
assert config["name"] == model

@pytest.mark.parametrize("model", ["mock_llm"])
def test_triton_status(self, model, setup_and_teardown):
pid = utils.run_server(repo=MODEL_REPO) # Import the Model
setup_and_teardown.pid = pid
utils.wait_for_server_ready()

status = self._status()

# Checks if model(s) are live and ready
assert status["live"] and status["ready"]
def test_triton_status(self, model):
# Import the Model
with ScopedTritonServer(repo=MODEL_REPO):
status = TritonCommands._status(protocol="grpc")
# Checks if model(s) are live and ready
assert status["live"] and status["ready"]
Loading

0 comments on commit 4e62d74

Please sign in to comment.