Skip to content

Commit

Permalink
updated rot_mat -> rot_idxs_tt
Browse files Browse the repository at this point in the history
  • Loading branch information
mvanniasingheTT committed Nov 18, 2024
1 parent 1557066 commit 2f6bd64
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tests/mock_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys
import time
import datetime
from datetime import datetime
import json
from dataclasses import dataclass
from typing import List
Expand Down Expand Up @@ -183,7 +183,7 @@ def decode_forward_trace(
start_pos: int,
trace_id,
tt_inp,
rot_mat,
rot_idxs_tt,
cache_idxs_tt,
tt_logits,
page_table=None,
Expand Down
9 changes: 8 additions & 1 deletion tests/mock_vllm_offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from unittest.mock import patch

import uvloop
from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer
from mock_vllm_model import (
MockModel,
new_allocate_kv_cache,
new_init_cache_enginer,
init_wrapper,
)
from tqdm import tqdm
from vllm import LLM, ModelRegistry, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
Expand All @@ -15,6 +20,7 @@
from vllm.inputs.data import TokensPrompt
from vllm.utils import merge_async_iterators
from vllm.worker.tt_worker import TTCacheEngine, TTWorker
from vllm.engine.llm_engine import LLMEngine

ModelRegistry.register_model("TTLlamaForCausalLM", MockModel)

Expand All @@ -26,6 +32,7 @@
@patch.object(
TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache
) # Patch to stop allocation on TT device since nonexistent
@patch.object(LLMEngine, "__init__", new=init_wrapper)
def run_inference(
prompts_json,
max_tokens=128,
Expand Down

0 comments on commit 2f6bd64

Please sign in to comment.