From 0f3bf27b71458125198a564867b647a91ff9958d Mon Sep 17 00:00:00 2001
From: Xiyou Zhou <xiyou.zhou@gmail.com>
Date: Wed, 22 Nov 2023 23:03:52 +0000
Subject: [PATCH] Minor fix and tests.

---
 serve/mlc_serve/api/handler.py                    |  2 --
 serve/mlc_serve/model/paged_cache_model.py        |  5 ++---
 serve/tests/unittest/test_engine_with_samplers.py | 10 +++++-----
 3 files changed, 7 insertions(+), 10 deletions(-)

diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py
index 71d51332b2..63deae4a94 100644
--- a/serve/mlc_serve/api/handler.py
+++ b/serve/mlc_serve/api/handler.py
@@ -40,8 +40,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse
 
 
 router = APIRouter()
-import logging
-logger = logging.getLogger(__name__)
 
 def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
     sampling_params = SamplingParams(
diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py
index 7d433fb091..b14459733d 100644
--- a/serve/mlc_serve/model/paged_cache_model.py
+++ b/serve/mlc_serve/model/paged_cache_model.py
@@ -26,7 +26,6 @@
 )
 
 logger = logging.getLogger(__name__)
-# keep track of the selected token's logprob, a list of top tokens and their logprobs.
 
 class KVCache:
     def __init__(
@@ -657,11 +656,11 @@ def generate(
             return [
                 TextGenerationResult(
                     sequence_id=sequence_id,
-                    generated_tokens=[next_token],
+                    generated_tokens=[new_token],
                     error=None,
                     logprob_info=fetch_logprobs(logprob_info, index, sampling_params[index]),
                 )
-                for index, (sequence_id, next_token) in enumerate(zip(sequence_ids, next_tokens))
+                for index, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens))
             ]
         except RuntimeError:
             # Fallback to per-token sampling in case some logits values are corrupted.
diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py
index cdada9dc1e..694d778f0e 100644
--- a/serve/tests/unittest/test_engine_with_samplers.py
+++ b/serve/tests/unittest/test_engine_with_samplers.py
@@ -184,16 +184,16 @@ def test_stop(
 def test_logprobs(
     model_artifact_path, 
     use_staging_engine, 
-    max_num_batched_tokens=2560, 
-    max_input_len=2560,
+    max_num_sequences=4,
+    max_input_len=512,
     num_requests=5,
     logprobs=3,
 ):
     prompt = "hi"
     engine = create_engine(
-        model_artifact_path, 
-        use_staging_engine, 
-        max_num_batched_tokens, 
+        model_artifact_path,
+        use_staging_engine,
+        max_num_sequences,
         max_input_len,
     )
     s = 113