Skip to content

Commit

Permalink
use TPOT and TPS in benchmarking/prompt_client_online_benchmark.py, a…
Browse files Browse the repository at this point in the history
…dd support in client for ITL and TPOT
  • Loading branch information
tstescoTT committed Dec 12, 2024
1 parent bec53b0 commit 52bdbab
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
25 changes: 16 additions & 9 deletions benchmarking/prompt_client_online_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,22 @@ def run_sequence_length_test(
)

# Calculate statistics
mean_tpot = np.mean([r["time_per_output_token"] for r in responses])
mean_tpot = max(mean_tpot, 1e-6) # Avoid division by zero
mean_tps = 1.0 / mean_tpot
std_tpot = np.std([r["time_per_output_token"] for r in responses])
std_tpot = max(std_tpot, 1e-6) # Avoid division by zero
std_tps = mean_tps - 1.0 / (mean_tpot + std_tpot)
stats = {
"input_seq_len": input_len,
"output_seq_len": output_len,
"batch_size": batch_size,
"mean_decode_tps": np.mean([r["decode_tps"] for r in responses]),
"mean_total_tps": np.mean([r["total_tps"] for r in responses]),
"total_output_tokens": sum([r["output_seq_len"] for r in responses]),
"mean_tpot": mean_tpot,
"mean_tps": mean_tps,
"mean_ttft": np.mean([r["ttft"] for r in responses]),
"std_decode_tps": np.std([r["decode_tps"] for r in responses]),
"std_total_tps": np.std([r["total_tps"] for r in responses]),
"std_tpot": std_tpot,
"std_tps": std_tps,
"std_ttft": np.std([r["ttft"] for r in responses]),
"num_prompts": num_prompts,
"num_iterations": num_iterations,
Expand All @@ -161,11 +168,11 @@ def run_sequence_length_test(
# Log results
logger.info(
f"Results for combination {idx}/{total_combinations}:\n"
f"Mean Decode TPS: {stats['mean_decode_tps']:.2f} ± "
f"{stats['std_decode_tps']:.2f}\n"
f"Mean Total TPS: {stats['mean_total_tps']:.2f} ± "
f"{stats['std_total_tps']:.2f}\n"
f"Mean TTFT: {stats['mean_ttft']:.2f} ± {stats['std_ttft']:.2f}"
f"Mean TPOT: {stats['mean_tpot']:.4f} ± "
f"{stats['std_tpot']:.4f}\n"
f"Mean user TPS: {stats['mean_tps']:.4f} ± "
f"{stats['std_tps']:.4f}\n"
f"Mean TTFT: {stats['mean_ttft']:.4f} ± {stats['std_ttft']:.4f}"
)

# Save results after each combination
Expand Down
5 changes: 2 additions & 3 deletions utils/batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,8 @@ def _log_progress(
):
logger.info(
f"Processed {response_counter}/{total_prompts} responses. "
f"decode_tps: {response_data['decode_tps']:.2f}, "
f"total_tps: {response_data['total_tps']:.2f}, "
f"ttft: {response_data['ttft']:.2f}, "
f"TPOT: {response_data['time_per_output_token']:.4f}, "
f"TTFT: {response_data['ttft']:.4f}, "
f"input_seq_len: {response_data['input_seq_len']}, "
f"output_seq_len: {response_data['output_seq_len']}"
)
25 changes: 19 additions & 6 deletions utils/prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def capture_traces(
)
logger.info(
f"tokens generated: {response_data['output_seq_len']}, "
f"TTFT: {response_data['ttft']:.3f}s"
f"TTFT: {response_data['ttft']:.3f}s, "
f"TPOT: {response_data['time_per_output_token']:.3f}s"
)
except Exception as e:
logger.error(f"Error processing prompt: {e}")
Expand Down Expand Up @@ -218,15 +219,17 @@ def _process_response(
first_token_time = 0
ttft = 0
usage_dict = {}
token_timestamps = []

if stream:
assert (
response.headers.get("transfer-encoding") == "chunked"
), "Response is not chunked"
for line in response.iter_lines(decode_unicode=True):
if line and line.startswith("data: "):
current_time = time.perf_counter()
if num_completion_tokens == 0:
first_token_time = time.perf_counter()
first_token_time = current_time
ttft = first_token_time - req_time

data_str = line[len("data: ") :].strip()
Expand All @@ -237,6 +240,7 @@ def _process_response(
data = json.loads(data_str)
if data["choices"]:
full_text += data["choices"][0].get("text", "")
token_timestamps.append(current_time)
num_completion_tokens += 1
else:
usage_dict = data.get("usage", {})
Expand All @@ -249,8 +253,17 @@ def _process_response(
usage_dict = data["usage"]
first_token_time = req_time

decode_time = max(time.perf_counter() - first_token_time, 0.0001)
total_time = max(time.perf_counter() - req_time, 0.0001)
# Calculate inter-token latencies
inter_token_latencies = []
if len(token_timestamps) > 1:
inter_token_latencies = [
token_timestamps[i] - token_timestamps[i - 1]
for i in range(1, len(token_timestamps))
]

gen_time = max(time.perf_counter() - first_token_time, 0.0001)
# discount the TTFT and 1st token time from the generation time
time_per_output_token = gen_time / max(num_completion_tokens - 1, 1)

# verify the number of input tokens
isl_diff = usage_dict["prompt_tokens"] - prompt_len
Expand Down Expand Up @@ -281,7 +294,7 @@ def _process_response(
"response": full_text,
"input_seq_len": prompt_len,
"output_seq_len": num_completion_tokens,
"decode_tps": (max(num_completion_tokens, 1)) / decode_time,
"total_tps": (max(num_completion_tokens, 1)) / total_time,
"inter_token_latencies": inter_token_latencies,
"time_per_output_token": time_per_output_token,
"ttft": ttft,
}

0 comments on commit 52bdbab

Please sign in to comment.