Skip to content

Commit

Permalink
Squashed commit for logprobs implementation.
Browse files Browse the repository at this point in the history
Init with tests.
Server working.
Major fix, serve working great.
Minor fix and tests.
Remove extra line.
fix log_softmax
use constant for number of top logprobs
small clean
upstream to new OpenAI API

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
zxybazh and vvchernov committed Dec 19, 2023
1 parent f32375a commit 9b053e8
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 50 deletions.
36 changes: 28 additions & 8 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse

router = APIRouter()


def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params = SamplingParams(
# These params came from vllm
Expand All @@ -60,6 +59,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params.temperature = request.temperature
if request.top_p is not None:
sampling_params.top_p = request.top_p
if request.logprobs is not None:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs
return sampling_params


Expand Down Expand Up @@ -152,7 +154,7 @@ async def generate_completion_stream(
created_time = int(time.time())

def create_stream_response(
choices: list[ChatCompletionResponseStreamChoice],
choices: List[ChatCompletionResponseStreamChoice],
) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -172,7 +174,6 @@ def create_stream_response(
],
)
yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n"

async for res in result_generator:
if res.error:
raise RuntimeError(f"Error when generating: {res.error}")
Expand All @@ -188,6 +189,7 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=seq.logprob_info[0] if seq.logprob_info else None
)
for seq in res.sequences
]
Expand All @@ -208,13 +210,16 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)]
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
raise RuntimeError(f"Error when generating: {res.error}")
if res.num_prompt_tokens is not None:
num_prompt_tokens = res.num_prompt_tokens
for seq in res.sequences:
if seq.logprob_info:
logprob_infos[seq.index].append(seq.logprob_info)
if seq.index >= len(sequences):
raise RuntimeError(f"Unexpected sequence index: {seq.index}.")
num_generated_tokens[seq.index] = seq.num_generated_tokens
Expand All @@ -224,15 +229,30 @@ async def collect_result_stream(
else:
assert seq.delta is not None
sequences[seq.index].append(seq.delta)

choices = [
ChatCompletionResponseChoice(

choices = []
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)):
choice = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
)
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons))
]
content = []
if logprob_infos[index] != []:
for logprob_info in logprob_infos[index]:
content.append({
"token": str(logprob_info[0][0]),
"logprob": float(logprob_info[0][1]),
# TODO(vvchernov): implement bytes bases on https://platform.openai.com/docs/api-reference/chat/object
"bytes": None,
"top_logprobs": [{
"token": top_logprob[0],
"logprob": top_logprob[1],
"bytes": None,
} for top_logprob in logprob_info[1]],
})
choice.logprobs.content = content
choices.append(choice)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
Expand Down
10 changes: 9 additions & 1 deletion serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union, Tuple

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -70,11 +70,18 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None


class Logprobs(BaseModel):
content: Optional[List[Dict]]


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Logprobs]
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


Expand All @@ -95,6 +102,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprob_info: Optional[Tuple[Tuple, List[Tuple]]]
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
)
from .sampling_params import SamplingParams, SamplingType
from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER
5 changes: 3 additions & 2 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import structlog
from typing import AsyncIterator, Any
from typing import AsyncIterator, Any, Dict
import logging

from .base import (
InferenceEngine,
Expand Down Expand Up @@ -29,7 +30,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine_loop_task = None
self.engine_loop_exception = None
self.shutdown_event = asyncio.Event()
self.result_queues = dict[RequestId, ResultQueue]()
self.result_queues: Dict[RequestId, ResultQueue] = {}

async def start(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from abc import ABC, abstractmethod

from typing import List, Callable, Any, Optional, Dict
from typing import List, Callable, Any, Optional, Dict, Tuple
import inspect

from .sampling_params import SamplingParams, SamplingType
Expand Down Expand Up @@ -161,6 +161,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None

@property
def is_finished(self) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
Required interfaces for the actual inference capability in InferenceEngine.
"""
from dataclasses import dataclass
from typing import Optional, Protocol, Union, List
from typing import Optional, Protocol, Union, Tuple, List

from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId
from ..model.base import ModelArtifactConfig
from .sampling_params import SamplingParams


LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]]
# ((token, logprob), [(top1_token, top1_logprob), ...])

@dataclass
class PrefillRequest:
request_id: RequestId
Expand Down Expand Up @@ -44,6 +47,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: List[int]
error: Optional[str]
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None


class KVCache(Protocol):
Expand Down
16 changes: 16 additions & 0 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from enum import IntEnum
from functools import cached_property

from typing import Optional

_SAMPLING_EPS = 1e-5
TOP_LOGPROBS_NUMBER = 5


class SamplingType(IntEnum):
Expand Down Expand Up @@ -37,13 +39,22 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
logprobs: Optional[bool] Whether to return log probabilities of the output
tokens or not. If true, returns the log probabilities of each output
token returned in the content of message.
top_logprobs: Optional[Integer] An integer between 0 and 5 specifying
the number of most likely tokens to return at each token position,
each with an associated log probability. logprobs must be set to
true if this parameter is used.
"""

presence_penalty: float = 0.0
frequency_penalty: float = 0.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None

def __post_init__(self):
self._verify_args()
Expand Down Expand Up @@ -71,6 +82,11 @@ def _verify_args(self) -> None:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
)
if self.logprobs is not None and self.logprobs:
if (self.top_logprobs < 0 or self.top_logprobs > TOP_LOGPROBS_NUMBER):
raise ValueError(
f"top_logprobs must be between 0 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}."
)

def _verify_greedy_sampling(self) -> None:
if self.top_p < 1.0 - _SAMPLING_EPS:
Expand Down
29 changes: 27 additions & 2 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable
from collections import defaultdict
from typing import Callable, Tuple, List

import structlog

Expand All @@ -23,7 +23,7 @@
get_new_request_state,
update_sequence,
)
from .model_module import ModelModule, TokenizerModule
from .model_module import ModelModule, TokenizerModule, Tokenizer
from .staging_engine_worker import (
AddRequestsCommand,
CancelRequestCommand,
Expand All @@ -37,6 +37,30 @@
LOG = structlog.stdlib.get_logger(__name__)


def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, List[Tuple]]:
"""Detokenize logprob information"""
if logprob_info is None:
return None
(res, res_logprob), top_tokens = logprob_info
top_tokens = list(top_tokens)
count = {}
logprob_dict = {}
# dedup duplicates
# Todo: Make sure decode can generate different tokens
for top_token, _ in top_tokens:
detokenized = tokenizer.decode(top_token)
if detokenized in count:
count[detokenized] += 1
else:
count[detokenized] = 1
for top_token, top_logprob in top_tokens:
detokenized = tokenizer.decode(top_token)
if count[detokenized] == 1:
logprob_dict[detokenized] = float(top_logprob)
else:
logprob_dict[f"{detokenized}_{top_token}"] = float(top_logprob)
return (str(tokenizer.decode(res)), res_logprob), logprob_dict

class StagingInferenceEngine(ScopedInferenceEngine):
"""
An implementation of InferenceEngine that offloads the text generation loop to another worker process,
Expand Down Expand Up @@ -235,6 +259,7 @@ def step(self) -> InferenceStepResult:
delta,
finish_reason=seq_output.finish_reason,
num_generated_tokens=len(gen_seq.generated_token_ids),
logprob_info=logprob_detokenize(self.tokenizer, seq_output.logprob_info),
)

seq_outputs[request_id].append(output)
Expand Down
21 changes: 11 additions & 10 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import multiprocessing
import multiprocessing.synchronize
from dataclasses import dataclass
from threading import Thread
from typing import Callable, Optional, Union, Any, Dict, List

from threading import Condition, Lock, Thread
from typing import Callable, Optional, Union, Tuple, Any, Dict, Deque, List
import structlog
import numpy as np

from .base import FinishReason, RequestId, RequestState, ValidationError, SequenceId
from .metrics import PrometheusMetrics
Expand All @@ -33,7 +33,7 @@ class ShutdownCommand:

@dataclass
class AddRequestsCommand:
request_states: list[RequestState]
request_states: List[RequestState]


@dataclass
Expand All @@ -54,14 +54,15 @@ class StopRequestCommand:
@dataclass
class SequenceGenerationOutput:
id: SequenceId
new_tokens: list[int]
new_tokens: List[int]
finish_reason: Optional[FinishReason] = None
error: Optional[Union[str, ValidationError]] = None
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None


@dataclass
class GenerationLoopWorkerOutput:
sequences: list[SequenceGenerationOutput]
sequences: List[SequenceGenerationOutput]
error: Optional[str] = None


Expand All @@ -77,8 +78,8 @@ def __init__(
):
EngineBase.__init__(self, model_module)

self.cancelled_requests = list[RequestState]()
self.stopped_requests = list[RequestState]()
self.cancelled_requests: List[RequestState] = []
self.stopped_requests: List[RequestState] = []

self.prom_metrics = PrometheusMetrics()
self.inv_kv_cache_size = 1.0 / self.cache_manager.get_kv_cache_size()
Expand Down Expand Up @@ -167,7 +168,7 @@ def has_pending_requests(self) -> bool:
def step(self) -> GenerationLoopWorkerOutput:
LOG.debug("Starting new inference step.")

outputs = list[SequenceGenerationOutput]()
outputs: List[SequenceGenerationOutput] = []
result = GenerationLoopWorkerOutput(sequences=outputs)

# TODO: consolidate into a single function
Expand Down Expand Up @@ -263,7 +264,7 @@ def step(self) -> GenerationLoopWorkerOutput:

gen_seq.generated_token_ids.extend(new_tokens)
outputs.append(
SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens)
SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens, logprob_info=res.logprob_info)
)

if is_prompt_batch:
Expand Down
1 change: 1 addition & 0 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def step(self) -> InferenceStepResult:
delta,
num_generated_tokens=len(gen_seq.generated_token_ids),
finish_reason=finish_reason,
logprob_info=res.logprob_info,
)
)

Expand Down
Loading

0 comments on commit 9b053e8

Please sign in to comment.