Skip to content

Commit

Permalink
update to use variable batch size - remove hard code
Browse files Browse the repository at this point in the history
  • Loading branch information
mvanniasingheTT committed Nov 14, 2024
1 parent 9dbc918 commit 75471e8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
43 changes: 23 additions & 20 deletions tests/mock_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def new__init__(

self.engine = LLMEngine(*args, **kwargs)
num_scheduler_steps = self.engine.scheduler_config.num_scheduler_steps
self.engine.stat_loggers["raw_logging"] = RawStatLogger(num_scheduler_steps)
batch_size = self.engine.scheduler_config.max_num_seqs
self.engine.stat_loggers["raw_logging"] = RawStatLogger(
num_scheduler_steps, batch_size
)
self.log_requests = log_requests

self.use_async_sockets = use_async_sockets
Expand Down Expand Up @@ -395,10 +398,11 @@ def forward(


class RawStatLogger(StatLoggerBase):
def __init__(self, num_scheduler_steps) -> None:
def __init__(self, num_scheduler_steps, batch_size) -> None:
self.time_to_first_token = []
self.time_per_output_token = []
self.num_scheduler_steps = num_scheduler_steps
self.batch_size = batch_size
timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
self.filepath = f"/home/user/tests/statistics_{timestamp}.jsonl"
self.num_total_grouped_step = (
Expand Down Expand Up @@ -430,39 +434,38 @@ def log(self, stats: Stats, log_to_stdout=True) -> None:
def _write_to_json(self, stats):
data = {}

# to record time per output token (decode stage)
if len(stats.time_per_output_tokens_iter) > 0:
data["time per output token"] = {}
data["time per output token"][
f"Total step num:{self.num_total_grouped_step}"
] = {}
data["tpot"] = {}
data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"] = {}
for user_idx, tpot in enumerate(stats.time_per_output_tokens_iter):
data["time per output token"][
f"Total step num:{self.num_total_grouped_step}"
][f"user {user_idx}"] = tpot
data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"][
f"user_{user_idx}"
] = tpot

self.num_total_grouped_step += 1

# to record time to first token (prefill stage)
if len(stats.time_to_first_tokens_iter) > 0:
# if inference is done online, need to handle case where not all user requests are made at same engine step call
if os.path.exists(self.filepath):
with open(self.filepath, "r") as file:
lines = file.readlines()
# load in last line if time to first token not completed for all users
if lines:
if lines: # ensure there is data
last_line = lines[-1]
last_data = json.loads(last_line)
if (
"time to first token" in last_data
"ttft" in last_data
): # if still in prefill stage (incomplete for all users) or only doing prefill and no decode
if (
len(list(last_data["time to first token"].values())[0])
< 32
len(list(last_data["ttft"].values())[0])
< self.batch_size
): # if incomplete prefill for all users
# data = last_data
self._append_new_users(data)
# find the index of the last user for whicht the first token was computed
last_user_processed = len(
list(last_data["time to first token"].values())[0]
list(last_data["ttft"].values())[0]
)

else: # if prefill already complete for all users
Expand All @@ -477,20 +480,20 @@ def _write_to_json(self, stats):
self._append_new_users(data)

for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter):
data["time to first token"][f"Inference num:{self.num_inference}"][
f"user {user_idx + last_user_processed}"
data["ttft"][f"Inference_num:{self.num_inference}"][
f"user_{user_idx + last_user_processed}"
] = ttft

self.num_inference += 1
self.num_inference += 1 # increase number of inference passes

if data:
with open(self.filepath, "a") as file:
json.dump(data, file)
file.write("\n") # Ensure each JSON object is on a new line

def _append_new_users(self, data):
data["time to first token"] = {}
data["time to first token"][f"Inference num:{self.num_inference}"] = {}
data["ttft"] = {}
data["ttft"][f"Inference_num:{self.num_inference}"] = {}

def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
3 changes: 2 additions & 1 deletion tests/mock_vllm_offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def run_inference(
llm = LLM(**engine_kw_args)
# Add raw stats logging to the llm engine
llm.llm_engine.stat_loggers["raw_logging"] = RawStatLogger(
engine_kw_args["num_scheduler_steps"]
num_scheduler_steps=engine_kw_args["num_scheduler_steps"],
batch_size=engine_kw_args["max_num_seqs"],
)
if not measure_perf:
generate_tokens(llm, prompts, sampling_params, print_output=True)
Expand Down

0 comments on commit 75471e8

Please sign in to comment.