diff --git a/.gitignore b/.gitignore index 8b4d3df..3a229d3 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ generated_modules *.ipynb *_params.pt *.json +*.jsonl *.pt *.tti *.txt @@ -36,6 +37,7 @@ db.sqlite3 # secrets .env +.env.llama31_8b # unignore !requirements.txt diff --git a/tests/mock_vllm_api_server.py b/tests/mock_vllm_api_server.py index e913cbf..727b423 100644 --- a/tests/mock_vllm_api_server.py +++ b/tests/mock_vllm_api_server.py @@ -13,8 +13,13 @@ # import classes to mock from vllm.worker.tt_worker import TTWorker, TTCacheEngine -from mock_vllm_model import new_init_cache_enginer, new_allocate_kv_cache, MockModel -from vllm.engine.multiprocessing.engine import run_mp_engine +from mock_vllm_model import ( + new_init_cache_enginer, + new_allocate_kv_cache, + MockModel, + new__init__, +) +from vllm.engine.multiprocessing.engine import MQLLMEngine, run_mp_engine # register the mock model ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) @@ -34,7 +39,9 @@ def patched_run_mp_engine(engine_args, usage_context, ipc_path): # so we need to apply the patches to this target function with patch.object(TTWorker, "init_device", new=lambda x: None), patch.object( TTWorker, "_init_cache_engine", new=new_init_cache_enginer - ), patch.object(TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache): + ), patch.object( + TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache + ), patch.object(MQLLMEngine, "__init__", new=new__init__): # Call the original `run_mp_engine` with patches applied run_mp_engine(engine_args, usage_context, ipc_path) diff --git a/tests/mock_vllm_model.py b/tests/mock_vllm_model.py index 50b487b..dce71d0 100644 --- a/tests/mock_vllm_model.py +++ b/tests/mock_vllm_model.py @@ -2,11 +2,11 @@ import os import sys import time +import json +from datetime import datetime from dataclasses import dataclass from typing import List - import torch -from loguru import logger sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from tt_metal.models.demos.t3000.llama2_70b.tt.llama_common import ( @@ -19,6 +19,83 @@ from tt_metal.models.demos.t3000.llama2_70b.tt.model_config import ( get_model_config, ) +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo +from vllm.engine.metrics import logger + + +import zmq +import threading +from typing import Optional +from vllm import LLMEngine +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.engine.multiprocessing import ( + IPC_DATA_EXT, + IPC_HEALTH_EXT, + IPC_INPUT_EXT, + IPC_OUTPUT_EXT, +) + + +# new init function for MQLLMEngine to be used in vllm api server (online inference) +def new__init__( + self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs, +) -> None: + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + kwargs["use_cached_outputs"] = True + + self.engine = LLMEngine(*args, **kwargs) + num_scheduler_steps = self.engine.scheduler_config.num_scheduler_steps + batch_size = self.engine.scheduler_config.max_num_seqs + self.engine.stat_loggers["raw_logging"] = RawStatLogger( + num_scheduler_steps, batch_size + ) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = ( + self._async_socket_engine_callback + ) + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + # Heartbeat thread + self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True) + self._heartbeat_stop_event = threading.Event() + # The heartbeat needs to be faster than what the client will wait for + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 + + self._last_alive_time = time.time() + # The heartbeats can tolerate a long period of the engine chugging + # away at a generation request. + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 def new_init_cache_enginer(self): @@ -166,7 +243,7 @@ def decode_forward_trace( start_pos: int, trace_id, tt_inp, - rot_mat, + rot_idxs_tt, cache_idxs_tt, tt_logits, page_table=None, @@ -318,3 +395,105 @@ def forward( kv_cache=kv_cache, prompt_lens=prompt_lens, ) + + +class RawStatLogger(StatLoggerBase): + def __init__(self, num_scheduler_steps, batch_size) -> None: + self.time_to_first_token = [] + self.time_per_output_token = [] + self.num_scheduler_steps = num_scheduler_steps + self.batch_size = batch_size + timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S") + self.filepath = f"/home/user/tests/statistics_{timestamp}.jsonl" + self.num_total_grouped_step = ( + 0 # number of iterations of size num_scheduler_steps + ) + self.num_inference = ( + 0 # number of times inference is done (ie. how many batches) + ) + + def log(self, stats: Stats, log_to_stdout=True) -> None: + if len(stats.time_to_first_tokens_iter) > 0: + self.time_to_first_token.append( + stats.time_to_first_tokens_iter + ) # Add all values to the list + + if log_to_stdout: + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + logger.info(f"User {user_idx}: Time to first token {ttft:.2f} s\n") + + if len(stats.time_per_output_tokens_iter) > 0: + tpot = [ + time / self.num_scheduler_steps + for time in stats.time_per_output_tokens_iter + ] + self.time_per_output_token.append(tpot) # Add all values to the list + + self._write_to_json(stats) + + def _write_to_json(self, stats): + data = {} + + # to record time per output token (decode stage) + if len(stats.time_per_output_tokens_iter) > 0: + data["tpot"] = {} + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"] = {} + for user_idx, tpot in enumerate(stats.time_per_output_tokens_iter): + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"][ + f"user_{user_idx}" + ] = tpot + + self.num_total_grouped_step += 1 + + # to record time to first token (prefill stage) + if len(stats.time_to_first_tokens_iter) > 0: + # if inference is done online, need to handle case where not all user requests are made at same engine step call + if os.path.exists(self.filepath): + with open(self.filepath, "r") as file: + lines = file.readlines() + # load in last line if time to first token not completed for all users + if lines: # ensure there is data + last_line = lines[-1] + last_data = json.loads(last_line) + if ( + "ttft" in last_data + ): # if still in prefill stage (incomplete for all users) or only doing prefill and no decode + if ( + len(list(last_data["ttft"].values())[0]) + < self.batch_size + ): # if incomplete prefill for all users + self._append_new_users(data) + # find the index of the last user for whicht the first token was computed + last_user_processed = len( + list(last_data["ttft"].values())[0] + ) + + else: # if prefill already complete for all users + last_user_processed = 0 + self._append_new_users(data) + + else: # if in decode stage + last_user_processed = 0 + self._append_new_users(data) + else: # if first forward pass + last_user_processed = 0 + self._append_new_users(data) + + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + data["ttft"][f"Inference_num:{self.num_inference}"][ + f"user_{user_idx + last_user_processed}" + ] = ttft + + self.num_inference += 1 # increase number of inference passes + + if data: + with open(self.filepath, "a") as file: + json.dump(data, file) + file.write("\n") # Ensure each JSON object is on a new line + + def _append_new_users(self, data): + data["ttft"] = {} + data["ttft"][f"Inference_num:{self.num_inference}"] = {} + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError diff --git a/tests/mock_vllm_offline_inference_tt.py b/tests/mock_vllm_offline_inference_tt.py index 990d06c..436fe6e 100644 --- a/tests/mock_vllm_offline_inference_tt.py +++ b/tests/mock_vllm_offline_inference_tt.py @@ -1,11 +1,10 @@ import argparse import json import time -from unittest.mock import patch - import uvloop -from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer from tqdm import tqdm +from unittest.mock import patch + from vllm import LLM, ModelRegistry, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.multiprocessing.client import MQLLMEngineClient @@ -14,8 +13,14 @@ ) from vllm.inputs.data import TokensPrompt from vllm.utils import merge_async_iterators + +# import mocking utils + classes to mock +from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer from vllm.worker.tt_worker import TTCacheEngine, TTWorker +# importing logging utils +from mock_vllm_model import RawStatLogger + ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) @@ -97,6 +102,11 @@ def run_inference( # Create and run LLM if not async_engine: llm = LLM(**engine_kw_args) + # Add raw stats logging to the llm engine + llm.llm_engine.stat_loggers["raw_logging"] = RawStatLogger( + num_scheduler_steps=engine_kw_args["num_scheduler_steps"], + batch_size=engine_kw_args["max_num_seqs"], + ) if not measure_perf: generate_tokens(llm, prompts, sampling_params, print_output=True) else: diff --git a/vllm-tt-metal-llama3-70b/src/logging_utils.py b/vllm-tt-metal-llama3-70b/src/logging_utils.py new file mode 100644 index 0000000..73e679e --- /dev/null +++ b/vllm-tt-metal-llama3-70b/src/logging_utils.py @@ -0,0 +1,183 @@ +import os +import json +from datetime import datetime +from vllm.engine.metrics import logger +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo + +# imports for new__init__ function to add logging to the MQLLMEngine +import zmq +import time +import threading +from typing import Optional +from vllm import LLMEngine +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.engine.multiprocessing import ( + IPC_DATA_EXT, + IPC_HEALTH_EXT, + IPC_INPUT_EXT, + IPC_OUTPUT_EXT, +) + + +# new init function for MQLLMEngine to be used in vllm api server (online inference) +def new__init__( + self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs, +) -> None: + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + kwargs["use_cached_outputs"] = True + + self.engine = LLMEngine(*args, **kwargs) + num_scheduler_steps = self.engine.scheduler_config.num_scheduler_steps + batch_size = self.engine.scheduler_config.max_num_seqs + self.engine.stat_loggers["raw_logging"] = RawStatLogger( + num_scheduler_steps, batch_size + ) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = ( + self._async_socket_engine_callback + ) + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + # Heartbeat thread + self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True) + self._heartbeat_stop_event = threading.Event() + # The heartbeat needs to be faster than what the client will wait for + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 + + self._last_alive_time = time.time() + # The heartbeats can tolerate a long period of the engine chugging + # away at a generation request. + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 + + +class RawStatLogger(StatLoggerBase): + def __init__(self, num_scheduler_steps, batch_size) -> None: + self.time_to_first_token = [] + self.time_per_output_token = [] + self.num_scheduler_steps = num_scheduler_steps + self.batch_size = batch_size + timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S") + self.filepath = f"/home/user/tests/statistics_{timestamp}.jsonl" + self.num_total_grouped_step = ( + 0 # number of iterations of size num_scheduler_steps + ) + self.num_inference = ( + 0 # number of times inference is done (ie. how many batches) + ) + + def log(self, stats: Stats, log_to_stdout=True) -> None: + if len(stats.time_to_first_tokens_iter) > 0: + self.time_to_first_token.append( + stats.time_to_first_tokens_iter + ) # Add all values to the list + + if log_to_stdout: + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + logger.info(f"User {user_idx}: Time to first token {ttft:.2f} s\n") + + if len(stats.time_per_output_tokens_iter) > 0: + tpot = [ + time / self.num_scheduler_steps + for time in stats.time_per_output_tokens_iter + ] + self.time_per_output_token.append(tpot) # Add all values to the list + + self._write_to_json(stats) + + def _write_to_json(self, stats): + data = {} + + # to record time per output token (decode stage) + if len(stats.time_per_output_tokens_iter) > 0: + data["tpot"] = {} + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"] = {} + for user_idx, tpot in enumerate(stats.time_per_output_tokens_iter): + data["tpot"][f"Total_step_num:{self.num_total_grouped_step}"][ + f"user_{user_idx}" + ] = tpot + + self.num_total_grouped_step += 1 + + # to record time to first token (prefill stage) + if len(stats.time_to_first_tokens_iter) > 0: + # if inference is done online, need to handle case where not all user requests are made at same engine step call + if os.path.exists(self.filepath): + with open(self.filepath, "r") as file: + lines = file.readlines() + # load in last line if time to first token not completed for all users + if lines: # ensure there is data + last_line = lines[-1] + last_data = json.loads(last_line) + if ( + "ttft" in last_data + ): # if still in prefill stage (incomplete for all users) or only doing prefill and no decode + if ( + len(list(last_data["ttft"].values())[0]) + < self.batch_size + ): # if incomplete prefill for all users + self._append_new_users(data) + # find the index of the last user for whicht the first token was computed + last_user_processed = len( + list(last_data["ttft"].values())[0] + ) + + else: # if prefill already complete for all users + last_user_processed = 0 + self._append_new_users(data) + + else: # if in decode stage + last_user_processed = 0 + self._append_new_users(data) + else: # if first forward pass + last_user_processed = 0 + self._append_new_users(data) + + for user_idx, ttft in enumerate(stats.time_to_first_tokens_iter): + data["ttft"][f"Inference_num:{self.num_inference}"][ + f"user_{user_idx + last_user_processed}" + ] = ttft + + self.num_inference += 1 # increase number of inference passes + + if data: + with open(self.filepath, "a") as file: + json.dump(data, file) + file.write("\n") # Ensure each JSON object is on a new line + + def _append_new_users(self, data): + data["ttft"] = {} + data["ttft"][f"Inference_num:{self.num_inference}"] = {} + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError diff --git a/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py b/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py index 3c6968d..9f4421f 100644 --- a/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py +++ b/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py @@ -10,6 +10,11 @@ import jwt from vllm import ModelRegistry +# importing logging utils +from logging_utils import new__init__ +from vllm.engine.multiprocessing.engine import MQLLMEngine +from unittest.mock import patch + # importing from tt-metal install path from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration @@ -25,6 +30,7 @@ def get_encoded_api_key(jwt_secret): return encoded_jwt +@patch.object(MQLLMEngine, "__init__", new=new__init__) def main(): # vLLM CLI arguments args = {