From 1b35e19042f6984c64c9b5f3689c164a05f8ddcb Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Wed, 27 Nov 2024 12:11:31 -0800 Subject: [PATCH] Launch TRTLLM build as a separate process to ensure memory cleanup, add support for MODEL_SOURCE in local testing --- src/triton_cli/repository.py | 9 ++++++++- tests/test_e2e.py | 8 ++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/triton_cli/repository.py b/src/triton_cli/repository.py index 4fc5de3..814f3dd 100644 --- a/src/triton_cli/repository.py +++ b/src/triton_cli/repository.py @@ -29,6 +29,7 @@ import shutil import logging import subprocess +import multiprocessing from pathlib import Path from directory_tree import display_tree @@ -335,7 +336,13 @@ def __generate_trtllm_model(self, name: str, huggingface_id: str): f"Found existing engine(s) at {engines_path}, skipping build." ) else: - self.__build_trtllm_engine(huggingface_id, engines_path) + # Run TRT-LLM build in a separate process to make sure it definitely + # cleans up any GPU memory used when done. + p = multiprocessing.Process( + target=self.__build_trtllm_engine, args=(huggingface_id, engines_path) + ) + p.start() + p.join() # NOTE: In every case, the TRT LLM template should be filled in with values. # If the model exists, the CLI will raise an exception when creating the model repo. diff --git a/tests/test_e2e.py b/tests/test_e2e.py index bac88ab..d48aa39 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -58,8 +58,10 @@ def test_tensorrtllm_e2e(self, llm_server, protocol): # Only a single model will be passed per test to enable tests to run concurrently. model = os.environ.get("TRTLLM_MODEL") assert model is not None, "TRTLLM_MODEL env var must be set!" + # Source is optional if using a "known: model" + source = os.environ.get("MODEL_SOURCE") TritonCommands._clear() - TritonCommands._import(model, backend="tensorrtllm") + TritonCommands._import(model, source=source, backend="tensorrtllm") llm_server.start() TritonCommands._infer(model, prompt=PROMPT, protocol=protocol) TritonCommands._profile(model, backend="tensorrtllm") @@ -86,8 +88,10 @@ def test_vllm_e2e(self, llm_server, protocol): # Only a single model will be passed per test to enable tests to run concurrently. model = os.environ.get("VLLM_MODEL") assert model is not None, "VLLM_MODEL env var must be set!" + # Source is optional if using a "known: model" + source = os.environ.get("MODEL_SOURCE") TritonCommands._clear() - TritonCommands._import(model) + TritonCommands._import(model, source=source) # vLLM will download the model on the fly, so give it a big timeout # TODO: Consider one of the following # (a) Pre-download and mount larger models in test environment