diff --git a/tests/mock_vllm_model.py b/tests/mock_vllm_model.py index 4c823ac..9b48c5c 100644 --- a/tests/mock_vllm_model.py +++ b/tests/mock_vllm_model.py @@ -2,7 +2,7 @@ import os import sys import time -import datetime +from datetime import datetime import json from dataclasses import dataclass from typing import List @@ -183,7 +183,7 @@ def decode_forward_trace( start_pos: int, trace_id, tt_inp, - rot_mat, + rot_idxs_tt, cache_idxs_tt, tt_logits, page_table=None, diff --git a/tests/mock_vllm_offline_inference_tt.py b/tests/mock_vllm_offline_inference_tt.py index 990d06c..c02d461 100644 --- a/tests/mock_vllm_offline_inference_tt.py +++ b/tests/mock_vllm_offline_inference_tt.py @@ -4,7 +4,12 @@ from unittest.mock import patch import uvloop -from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer +from mock_vllm_model import ( + MockModel, + new_allocate_kv_cache, + new_init_cache_enginer, + init_wrapper, +) from tqdm import tqdm from vllm import LLM, ModelRegistry, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs @@ -15,6 +20,7 @@ from vllm.inputs.data import TokensPrompt from vllm.utils import merge_async_iterators from vllm.worker.tt_worker import TTCacheEngine, TTWorker +from vllm.engine.llm_engine import LLMEngine ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) @@ -26,6 +32,7 @@ @patch.object( TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache ) # Patch to stop allocation on TT device since nonexistent +@patch.object(LLMEngine, "__init__", new=init_wrapper) def run_inference( prompts_json, max_tokens=128,