From e9337f16a978b4153e618181aba0d8739c04c019 Mon Sep 17 00:00:00 2001 From: Tom Stesco Date: Thu, 12 Dec 2024 05:08:48 +0000 Subject: [PATCH] update utils/prompt_client_cli.py and docs --- utils/README.md | 100 ++++++++++++++++++++++++++++--------- utils/prompt_client_cli.py | 24 ++++++--- 2 files changed, 94 insertions(+), 30 deletions(-) diff --git a/utils/README.md b/utils/README.md index 389e7ce..984fda7 100644 --- a/utils/README.md +++ b/utils/README.md @@ -23,20 +23,73 @@ The prompt client CLI tool allows you to send prompts to a vLLM API server with - `CACHE_ROOT`: Directory for saving response files (default: current directory) - `VLLM_MODEL`: Model name (default: meta-llama/Llama-3.1-70B-Instruct) -#### Key Arguments - -- `--num_prompts`: Number of prompts to generate -- `--batch_size`: Number of concurrent requests -- `--max_prompt_length`: Maximum length for generated prompts -- `--output_seq_len`: Maximum length for completions -- `--num_full_iterations`: Number of times to repeat the full prompt set -- `--vary-batch-size`: Randomize batch sizes using normal distribution -- `--input_seq_len`: Fixed length for input sequences (-1 for variable) -- `--inter_batch_delay`: Delay between batches in seconds -- `--no-stream`: Disable streaming responses -- `--dataset`: Source dataset (random, alpaca_eval) -- `--distribution`: Prompt length distribution (fixed, uniform, normal) -- `--template`: Path to Jinja2 template or "chat_template" for model tokenizer default +#### Command Line Arguments + +##### Core Parameters + +- `--num_prompts` (default: 1) + Number of unique prompts to generate for testing. + +- `--batch_size` (default: 32) + Number of concurrent requests to send to the API server. Controls parallelization level. + +- `--num_full_iterations` (default: 1) + Number of complete iterations over the entire prompt set. Useful for extended testing cycles. + +##### Model Configuration + +- `--vllm_model` (default: "meta-llama/Llama-3.1-70B-Instruct") + Model identifier for the vLLM API server. Can be overridden by VLLM_MODEL environment variable. + +- `--tokenizer_model` (default: None) + Specific tokenizer model to use for vocabulary, truncation, and templating operations. + +##### Sequence Length Controls + +- `--input_seq_len` (default: -1) + Length parameter for input sequences when using random prompts. -1 allows variable lengths. + +- `--output_seq_len` (default: 2048) + Forces all completions to a fixed maximum length for consistent testing. + +- `--max_prompt_length` (default: -1) + Maximum allowed length for generated prompts. -1 indicates no length restriction. + +##### Batch Processing Options + +- `--vary_batch_size` (default: False) + When enabled, randomizes the batch size for each prompt batch using normal distribution. + +- `--inter_batch_delay` (default: 0) + Seconds to wait between processing each batch. Useful for rate limiting. + +- `--no-stream` (default: False) + Disables streaming responses. By default, streaming is enabled. + +##### Prompt Generation Settings + +- `--distribution` (default: "fixed") + Method for determining random prompt lengths: + - "fixed": Constant length + - "uniform": Uniform distribution + - "normal": Normal distribution + +- `--dataset` (default: "random") + Source dataset for prompt generation. Use "random" for synthetic prompts. + +- `--template` (default: None) + Jinja2 template for formatting prompts. Can be a file path or template string. + +##### Output Controls + +- `--save_path` (default: None) + File path to save generated prompts in JSONL format. + +- `--print_prompts` (default: False) + Enable printing of generated prompts to stdout. + +- `--skip_trace_precapture` (default: False) + Skips trace precapture phase, use to speed up execution if trace captures have already completed. #### Example Usage @@ -54,7 +107,7 @@ python prompt_client_cli.py \ --num_prompts 10 \ --batch_size 4 \ --tokenizer_model meta-llama/Llama-3.1-70B-Instruct \ - --max_prompt_length 512 \ + --input_seq_len 512 \ --output_seq_len 2048 # send prompts from alpaca_eval using chat template from tokenizer @@ -103,13 +156,14 @@ The client saves responses in JSON format with the following structure: ```json { - "response_idx": 0, - "prompt": "example prompt", - "response": "model response", - "prompt_length": 128, - "num_completion_tokens": 256, - "tps": 45.6, - "ttft": 0.15 + "response_idx": number, // Response index in batch + "prompt": string, // Input prompt + "response": string, // Generated completion text + "input_seq_len": number, // Prompt length in tokens + "output_seq_len": number, // Completion length in tokens + "inter_token_latencies": number[], // Per-token generation times in seconds + "time_per_output_token": number, // Average seconds per token + "ttft": number // Time to first token in seconds } ``` @@ -139,7 +193,7 @@ args = SimpleNamespace( input_seq_len=-1, num_prompts=5, distribution="normal", - template="templates/chat.j2", + template="prompt_templates/llama_instruct_example.jinja", save_path="generated_prompts.jsonl", lm_eval_task=None ) diff --git a/utils/prompt_client_cli.py b/utils/prompt_client_cli.py index 31d97e6..7ab7ed6 100644 --- a/utils/prompt_client_cli.py +++ b/utils/prompt_client_cli.py @@ -74,7 +74,7 @@ def add_client_args(parser): parser.add_argument( "--max_prompt_length", type=int, - required=True, + default=-1, help="Maximum length of generated prompts.", ) parser.add_argument( @@ -118,7 +118,7 @@ def add_client_args(parser): "--skip_trace_precapture", action="store_true", default=False, - help="Print generated prompts.", + help="Skips trace precapture phase, use to speed up execution if trace captures have already completed.", ) return parser @@ -131,6 +131,16 @@ def main(): parser = add_client_args(parser) args = parser.parse_args() + assert ( + args.max_prompt_length != -1 or args.input_seq_len != -1 + ), "Either --max_prompt_length or --input_seq_len must be provided." + if args.max_prompt_length == -1: + assert args.input_seq_len > 0 + args.max_prompt_length = args.input_seq_len + elif args.input_seq_len == -1: + assert args.max_prompt_length > 0 + args.input_seq_len = args.max_prompt_length + # Create configs from arguments prompt_config = PromptConfig( input_seq_len=args.input_seq_len, @@ -181,12 +191,12 @@ def main(): # Calculate and log summary statistics if responses: - mean_decode_tps = np.mean([r["decode_tps"] for r in responses]) - mean_total_tps = np.mean([r["total_tps"] for r in responses]) + mean_tpot = np.mean([r["time_per_output_token"] for r in responses]) mean_ttft = np.mean([r["ttft"] for r in responses]) - logger.info(f"Mean Decode TPS: {mean_decode_tps:.2f}") - logger.info(f"Mean Total TPS: {mean_total_tps:.2f}") - logger.info(f"Mean TTFT: {mean_ttft:.2f}") + logger.info(f"Mean TTFT: {mean_ttft:.4f}") + logger.info(f"Mean TPOT: {mean_tpot:.4f}") + mean_tps = 1.0 / max(mean_tpot, 1e-6) + logger.info(f"Mean User TPS: {mean_tps:.4f}") if __name__ == "__main__":