Skip to content

Commit

Permalink
update utils/prompt_client_cli.py and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
tstescoTT committed Dec 12, 2024
1 parent 52bdbab commit e9337f1
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 30 deletions.
100 changes: 77 additions & 23 deletions utils/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,73 @@ The prompt client CLI tool allows you to send prompts to a vLLM API server with
- `CACHE_ROOT`: Directory for saving response files (default: current directory)
- `VLLM_MODEL`: Model name (default: meta-llama/Llama-3.1-70B-Instruct)

#### Key Arguments

- `--num_prompts`: Number of prompts to generate
- `--batch_size`: Number of concurrent requests
- `--max_prompt_length`: Maximum length for generated prompts
- `--output_seq_len`: Maximum length for completions
- `--num_full_iterations`: Number of times to repeat the full prompt set
- `--vary-batch-size`: Randomize batch sizes using normal distribution
- `--input_seq_len`: Fixed length for input sequences (-1 for variable)
- `--inter_batch_delay`: Delay between batches in seconds
- `--no-stream`: Disable streaming responses
- `--dataset`: Source dataset (random, alpaca_eval)
- `--distribution`: Prompt length distribution (fixed, uniform, normal)
- `--template`: Path to Jinja2 template or "chat_template" for model tokenizer default
#### Command Line Arguments

##### Core Parameters

- `--num_prompts` (default: 1)
Number of unique prompts to generate for testing.

- `--batch_size` (default: 32)
Number of concurrent requests to send to the API server. Controls parallelization level.

- `--num_full_iterations` (default: 1)
Number of complete iterations over the entire prompt set. Useful for extended testing cycles.

##### Model Configuration

- `--vllm_model` (default: "meta-llama/Llama-3.1-70B-Instruct")
Model identifier for the vLLM API server. Can be overridden by VLLM_MODEL environment variable.

- `--tokenizer_model` (default: None)
Specific tokenizer model to use for vocabulary, truncation, and templating operations.

##### Sequence Length Controls

- `--input_seq_len` (default: -1)
Length parameter for input sequences when using random prompts. -1 allows variable lengths.

- `--output_seq_len` (default: 2048)
Forces all completions to a fixed maximum length for consistent testing.

- `--max_prompt_length` (default: -1)
Maximum allowed length for generated prompts. -1 indicates no length restriction.

##### Batch Processing Options

- `--vary_batch_size` (default: False)
When enabled, randomizes the batch size for each prompt batch using normal distribution.

- `--inter_batch_delay` (default: 0)
Seconds to wait between processing each batch. Useful for rate limiting.

- `--no-stream` (default: False)
Disables streaming responses. By default, streaming is enabled.

##### Prompt Generation Settings

- `--distribution` (default: "fixed")
Method for determining random prompt lengths:
- "fixed": Constant length
- "uniform": Uniform distribution
- "normal": Normal distribution

- `--dataset` (default: "random")
Source dataset for prompt generation. Use "random" for synthetic prompts.

- `--template` (default: None)
Jinja2 template for formatting prompts. Can be a file path or template string.

##### Output Controls

- `--save_path` (default: None)
File path to save generated prompts in JSONL format.

- `--print_prompts` (default: False)
Enable printing of generated prompts to stdout.

- `--skip_trace_precapture` (default: False)
Skips trace precapture phase, use to speed up execution if trace captures have already completed.

#### Example Usage

Expand All @@ -54,7 +107,7 @@ python prompt_client_cli.py \
--num_prompts 10 \
--batch_size 4 \
--tokenizer_model meta-llama/Llama-3.1-70B-Instruct \
--max_prompt_length 512 \
--input_seq_len 512 \
--output_seq_len 2048

# send prompts from alpaca_eval using chat template from tokenizer
Expand Down Expand Up @@ -103,13 +156,14 @@ The client saves responses in JSON format with the following structure:

```json
{
"response_idx": 0,
"prompt": "example prompt",
"response": "model response",
"prompt_length": 128,
"num_completion_tokens": 256,
"tps": 45.6,
"ttft": 0.15
"response_idx": number, // Response index in batch
"prompt": string, // Input prompt
"response": string, // Generated completion text
"input_seq_len": number, // Prompt length in tokens
"output_seq_len": number, // Completion length in tokens
"inter_token_latencies": number[], // Per-token generation times in seconds
"time_per_output_token": number, // Average seconds per token
"ttft": number // Time to first token in seconds
}
```

Expand Down Expand Up @@ -139,7 +193,7 @@ args = SimpleNamespace(
input_seq_len=-1,
num_prompts=5,
distribution="normal",
template="templates/chat.j2",
template="prompt_templates/llama_instruct_example.jinja",
save_path="generated_prompts.jsonl",
lm_eval_task=None
)
Expand Down
24 changes: 17 additions & 7 deletions utils/prompt_client_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def add_client_args(parser):
parser.add_argument(
"--max_prompt_length",
type=int,
required=True,
default=-1,
help="Maximum length of generated prompts.",
)
parser.add_argument(
Expand Down Expand Up @@ -118,7 +118,7 @@ def add_client_args(parser):
"--skip_trace_precapture",
action="store_true",
default=False,
help="Print generated prompts.",
help="Skips trace precapture phase, use to speed up execution if trace captures have already completed.",
)
return parser

Expand All @@ -131,6 +131,16 @@ def main():
parser = add_client_args(parser)
args = parser.parse_args()

assert (
args.max_prompt_length != -1 or args.input_seq_len != -1
), "Either --max_prompt_length or --input_seq_len must be provided."
if args.max_prompt_length == -1:
assert args.input_seq_len > 0
args.max_prompt_length = args.input_seq_len
elif args.input_seq_len == -1:
assert args.max_prompt_length > 0
args.input_seq_len = args.max_prompt_length

# Create configs from arguments
prompt_config = PromptConfig(
input_seq_len=args.input_seq_len,
Expand Down Expand Up @@ -181,12 +191,12 @@ def main():

# Calculate and log summary statistics
if responses:
mean_decode_tps = np.mean([r["decode_tps"] for r in responses])
mean_total_tps = np.mean([r["total_tps"] for r in responses])
mean_tpot = np.mean([r["time_per_output_token"] for r in responses])
mean_ttft = np.mean([r["ttft"] for r in responses])
logger.info(f"Mean Decode TPS: {mean_decode_tps:.2f}")
logger.info(f"Mean Total TPS: {mean_total_tps:.2f}")
logger.info(f"Mean TTFT: {mean_ttft:.2f}")
logger.info(f"Mean TTFT: {mean_ttft:.4f}")
logger.info(f"Mean TPOT: {mean_tpot:.4f}")
mean_tps = 1.0 / max(mean_tpot, 1e-6)
logger.info(f"Mean User TPS: {mean_tps:.4f}")


if __name__ == "__main__":
Expand Down

0 comments on commit e9337f1

Please sign in to comment.