diff --git a/.gitignore b/.gitignore index 8b4d3df..24914ed 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ generated_modules *.ipynb *_params.pt *.json +*.jsonl *.pt *.tti *.txt diff --git a/tests/mock_vllm_api_server.py b/tests/mock_vllm_api_server.py index e913cbf..ed28740 100644 --- a/tests/mock_vllm_api_server.py +++ b/tests/mock_vllm_api_server.py @@ -12,9 +12,16 @@ from vllm import ModelRegistry # import classes to mock +# TODO: import logging_init_wrapper from vllm-tt-metal-llama3-70b/src/logging_utils.py after refactoring from vllm.worker.tt_worker import TTWorker, TTCacheEngine -from mock_vllm_model import new_init_cache_enginer, new_allocate_kv_cache, MockModel +from mock_vllm_model import ( + new_init_cache_enginer, + new_allocate_kv_cache, + MockModel, + logging_init_wrapper, +) from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.llm_engine import LLMEngine # register the mock model ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) @@ -34,7 +41,9 @@ def patched_run_mp_engine(engine_args, usage_context, ipc_path): # so we need to apply the patches to this target function with patch.object(TTWorker, "init_device", new=lambda x: None), patch.object( TTWorker, "_init_cache_engine", new=new_init_cache_enginer - ), patch.object(TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache): + ), patch.object( + TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache + ), patch.object(LLMEngine, "__init__", new=logging_init_wrapper): # Call the original `run_mp_engine` with patches applied run_mp_engine(engine_args, usage_context, ipc_path) diff --git a/tests/mock_vllm_model.py b/tests/mock_vllm_model.py index 50b487b..1553c41 100644 --- a/tests/mock_vllm_model.py +++ b/tests/mock_vllm_model.py @@ -1,24 +1,40 @@ import copy import os -import sys import time +from datetime import datetime +import json from dataclasses import dataclass from typing import List +from pathlib import Path import torch -from loguru import logger -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from tt_metal.models.demos.t3000.llama2_70b.tt.llama_common import ( +from models.demos.t3000.llama2_70b.tt.llama_common import ( setup_llama_env, ) -from tt_metal.models.demos.t3000.llama2_70b.tt.llama_generation import ( +from models.demos.t3000.llama2_70b.tt.llama_generation import ( TtLlamaModelForGeneration, get_padded_prefill_len, ) -from tt_metal.models.demos.t3000.llama2_70b.tt.model_config import ( +from models.demos.t3000.llama2_70b.tt.model_config import ( get_model_config, ) +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo +from vllm.engine.metrics import logger + + +from vllm.engine.llm_engine import LLMEngine + + +# new init function for LLMEngine to be used in vllm api server (online inference) when init in MQLLMEngine +original_init = LLMEngine.__init__ + + +def logging_init_wrapper(self, *args, **kwargs): + original_init(self, *args, **kwargs) # Call the original __init__ + num_scheduler_steps = self.scheduler_config.num_scheduler_steps + batch_size = self.scheduler_config.max_num_seqs + self.stat_loggers["raw_logging"] = RawStatLogger(num_scheduler_steps, batch_size) def new_init_cache_enginer(self): @@ -166,7 +182,7 @@ def decode_forward_trace( start_pos: int, trace_id, tt_inp, - rot_mat, + rot_idxs_tt, cache_idxs_tt, tt_logits, page_table=None, @@ -318,3 +334,106 @@ def forward( kv_cache=kv_cache, prompt_lens=prompt_lens, ) + + +class RawStatLogger(StatLoggerBase): + def __init__(self, num_scheduler_steps, batch_size) -> None: + self.time_to_first_token = [] + self.time_per_output_token = [] + self.num_scheduler_steps = num_scheduler_steps + self.batch_size = batch_size + timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S") + cache_root = Path(os.getenv("CACHE_ROOT", ".")) + self.filepath = cache_root / f"statistics{timestamp}.jsonl" + self.num_total_grouped_step = ( + 0 # number of iterations of size num_scheduler_steps + ) + self.num_inference = ( + 0 # number of times inference is done (ie. how many batches) + ) + + def log(self, stats: Stats, log_to_stdout=True) -> None: + if len(stats.time_to_first_tokens_iter) > 0: + self.time_to_first_token.append( + stats.time_to_first_tokens_iter + ) # Add all values to the list + + if log_to_stdout: + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + logger.info(f"User {user_idx}: Time to first token {ttft:.2f} s\n") + + if len(stats.time_per_output_tokens_iter) > 0: + tpot = [ + time / self.num_scheduler_steps + for time in stats.time_per_output_tokens_iter + ] + self.time_per_output_token.append(tpot) # Add all values to the list + + self._write_to_json(stats) + + def _write_to_json(self, stats): + data = {} + + # to record time per output token (decode stage) + if len(stats.time_per_output_tokens_iter) > 0: + data["tpot"] = {} + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"] = {} + for user_idx, tpot in enumerate(stats.time_per_output_tokens_iter): + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"][ + f"user_{user_idx}" + ] = tpot + + self.num_total_grouped_step += 1 + + # to record time to first token (prefill stage) + if len(stats.time_to_first_tokens_iter) > 0: + # if inference is done online, need to handle case where not all user requests are made at same engine step call + if os.path.exists(self.filepath): + with open(self.filepath, "r") as file: + lines = file.readlines() + # load in last line if time to first token not completed for all users + if lines: # ensure there is data + last_line = lines[-1] + last_data = json.loads(last_line) + if ( + "ttft" in last_data + ): # if still in prefill stage (incomplete for all users) or only doing prefill and no decode + if ( + len(list(last_data["ttft"].values())[0]) + < self.batch_size + ): # if incomplete prefill for all users + self._append_new_users(data) + # find the index of the last user for whicht the first token was computed + last_user_processed = len( + list(last_data["ttft"].values())[0] + ) + + else: # if prefill already complete for all users + last_user_processed = 0 + self._append_new_users(data) + + else: # if in decode stage + last_user_processed = 0 + self._append_new_users(data) + else: # if first forward pass + last_user_processed = 0 + self._append_new_users(data) + + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + data["ttft"][f"Inference_num:{self.num_inference}"][ + f"user_{user_idx + last_user_processed}" + ] = ttft + + self.num_inference += 1 # increase number of inference passes + + if data: + with open(self.filepath, "a") as file: + json.dump(data, file) + file.write("\n") # Ensure each JSON object is on a new line + + def _append_new_users(self, data): + data["ttft"] = {} + data["ttft"][f"Inference_num:{self.num_inference}"] = {} + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError diff --git a/tests/mock_vllm_offline_inference_tt.py b/tests/mock_vllm_offline_inference_tt.py index 990d06c..f874d2d 100644 --- a/tests/mock_vllm_offline_inference_tt.py +++ b/tests/mock_vllm_offline_inference_tt.py @@ -1,10 +1,16 @@ import argparse import json import time +import uvloop from unittest.mock import patch -import uvloop -from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer +# TODO: import logging_init_wrapper from vllm-tt-metal-llama3-70b/src/logging_utils.py after refactoring +from mock_vllm_model import ( + MockModel, + new_allocate_kv_cache, + new_init_cache_enginer, + logging_init_wrapper, +) from tqdm import tqdm from vllm import LLM, ModelRegistry, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs @@ -15,6 +21,7 @@ from vllm.inputs.data import TokensPrompt from vllm.utils import merge_async_iterators from vllm.worker.tt_worker import TTCacheEngine, TTWorker +from vllm.engine.llm_engine import LLMEngine ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) @@ -26,6 +33,7 @@ @patch.object( TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache ) # Patch to stop allocation on TT device since nonexistent +@patch.object(LLMEngine, "__init__", new=logging_init_wrapper) def run_inference( prompts_json, max_tokens=128, diff --git a/vllm-tt-metal-llama3-70b/src/logging_utils.py b/vllm-tt-metal-llama3-70b/src/logging_utils.py new file mode 100644 index 0000000..987e885 --- /dev/null +++ b/vllm-tt-metal-llama3-70b/src/logging_utils.py @@ -0,0 +1,122 @@ +import os +import datetime +import json +from pathlib import Path + +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo +from vllm.engine.metrics import logger +from vllm.engine.llm_engine import LLMEngine + + +# new init function for LLMEngine to be used in vllm api server (online inference) when init in MQLLMEngine +original_init = LLMEngine.__init__ + + +def logging_init_wrapper(self, *args, **kwargs): + original_init(self, *args, **kwargs) # Call the original __init__ + num_scheduler_steps = self.scheduler_config.num_scheduler_steps + batch_size = self.scheduler_config.max_num_seqs + self.stat_loggers["raw_logging"] = RawStatLogger(num_scheduler_steps, batch_size) + + +class RawStatLogger(StatLoggerBase): + def __init__(self, num_scheduler_steps, batch_size) -> None: + self.time_to_first_token = [] + self.time_per_output_token = [] + self.num_scheduler_steps = num_scheduler_steps + self.batch_size = batch_size + timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S") + cache_root = Path(os.getenv("CACHE_ROOT", ".")) + self.filepath = cache_root / f"statistics{timestamp}.jsonl" + self.num_total_grouped_step = ( + 0 # number of iterations of size num_scheduler_steps + ) + self.num_inference = ( + 0 # number of times inference is done (ie. how many batches) + ) + + def log(self, stats: Stats, log_to_stdout=True) -> None: + if len(stats.time_to_first_tokens_iter) > 0: + self.time_to_first_token.append( + stats.time_to_first_tokens_iter + ) # Add all values to the list + + if log_to_stdout: + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + logger.info(f"User {user_idx}: Time to first token {ttft:.2f} s\n") + + if len(stats.time_per_output_tokens_iter) > 0: + tpot = [ + time / self.num_scheduler_steps + for time in stats.time_per_output_tokens_iter + ] + self.time_per_output_token.append(tpot) # Add all values to the list + + self._write_to_json(stats) + + def _write_to_json(self, stats): + data = {} + + # to record time per output token (decode stage) + if len(stats.time_per_output_tokens_iter) > 0: + data["tpot"] = {} + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"] = {} + for user_idx, tpot in enumerate(stats.time_per_output_tokens_iter): + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"][ + f"user_{user_idx}" + ] = tpot + + self.num_total_grouped_step += 1 + + # to record time to first token (prefill stage) + if len(stats.time_to_first_tokens_iter) > 0: + # if inference is done online, need to handle case where not all user requests are made at same engine step call + if os.path.exists(self.filepath): + with open(self.filepath, "r") as file: + lines = file.readlines() + # load in last line if time to first token not completed for all users + if lines: # ensure there is data + last_line = lines[-1] + last_data = json.loads(last_line) + if ( + "ttft" in last_data + ): # if still in prefill stage (incomplete for all users) or only doing prefill and no decode + if ( + len(list(last_data["ttft"].values())[0]) + < self.batch_size + ): # if incomplete prefill for all users + self._append_new_users(data) + # find the index of the last user for whicht the first token was computed + last_user_processed = len( + list(last_data["ttft"].values())[0] + ) + + else: # if prefill already complete for all users + last_user_processed = 0 + self._append_new_users(data) + + else: # if in decode stage + last_user_processed = 0 + self._append_new_users(data) + else: # if first forward pass + last_user_processed = 0 + self._append_new_users(data) + + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + data["ttft"][f"Inference_num:{self.num_inference}"][ + f"user_{user_idx + last_user_processed}" + ] = ttft + + self.num_inference += 1 # increase number of inference passes + + if data: + with open(self.filepath, "a") as file: + json.dump(data, file) + file.write("\n") # Ensure each JSON object is on a new line + + def _append_new_users(self, data): + data["ttft"] = {} + data["ttft"][f"Inference_num:{self.num_inference}"] = {} + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError diff --git a/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py b/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py index 3c6968d..8804aa4 100644 --- a/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py +++ b/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py @@ -10,6 +10,11 @@ import jwt from vllm import ModelRegistry +# importing logging utils +from logging_utils import logging_init_wrapper +from vllm.engine.llm_engine import LLMEngine +from unittest.mock import patch + # importing from tt-metal install path from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration @@ -25,6 +30,7 @@ def get_encoded_api_key(jwt_secret): return encoded_jwt +@patch.object(LLMEngine, "__init__", new=logging_init_wrapper) def main(): # vLLM CLI arguments args = {