Skip to content

Commit

Permalink
Enable Logprobs in MLC Batch Serving (#82)
Browse files Browse the repository at this point in the history
* Squashed commit for logprobs implementation.

Co-authored-by: Valery Chernov <darkscorpion@rambler.ru>
Co-authored-by: Ilya Kozulin <Ailurus1@users.noreply.github.com>

* fix None check

* Change detokenization to using token ids.

* Fix wrong usage of token ids. Remove logging.

* extend benchmarks for logprobs

* fix test without logprobs

* clean code

* black format engine_common.py

* logprobs is strictly bool, top_logprobs is int

* refactor logprob info collection to not reduce performance

* quick fix for check

* review fix

* fix list index out of range

* rollback after rebase

* test

* Squashed commit for logprobs implementation.

Co-authored-by: Valery Chernov <darkscorpion@rambler.ru>
Co-authored-by: Ilya Kozulin <Ailurus1@users.noreply.github.com>

* fix None check

* Change detokenization to using token ids.

* Fix wrong usage of token ids. Remove logging.

* extend benchmarks for logprobs

* fix test without logprobs

* clean code

* black format engine_common.py

* logprobs is strictly bool, top_logprobs is int

* refactor logprob info collection to not reduce performance

* quick fix for check

* review fix

* fix list index out of range

* rollback after rebase

* test

* small fix

* rename for the sake of clarity

* some fixes with cpu-gpu tensor copying

* refactor logprob pass to calculate

* remove excess deps for token detokenization

* small clean

* small clean

* return None instead of list of Nones

* fix mypy

---------

Co-authored-by: Valery Chernov <darkscorpion@rambler.ru>
Co-authored-by: Ilya Kozulin <Ailurus1@users.noreply.github.com>
Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
4 people authored Jan 31, 2024
1 parent 4535ff5 commit 2b3fcf0
Showing 22 changed files with 376 additions and 53 deletions.
2 changes: 2 additions & 0 deletions serve/benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,8 @@ def create_request(request_id):
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
2 changes: 2 additions & 0 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
@@ -139,6 +139,8 @@ def run_mlc(engine, requests, args) -> float:
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
18 changes: 18 additions & 0 deletions serve/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,18 @@ def add_sampling_flags(parser):
action="store_true",
help="Apply all penalties, logit bias, top-p and top-k.",
)
parser.add_argument(
"--logprobs",
action="store_true",
default=False,
help="Switch on logprobs output"
)
parser.add_argument(
"--top-logprobs",
type=int,
default=5,
help="Number of top logprobs to output, limited by 5. Works only with logprobs true."
)


def postproc_sampling_args(args):
@@ -33,6 +45,8 @@ def postproc_sampling_args(args):
"repetition_penalty": 1.0,
"top_p": 1.0,
"top_k": -1,
"logprobs": False,
"top_logprobs": 5,
}

if args.apply_all_sampling_params:
@@ -51,3 +65,7 @@ def postproc_sampling_args(args):
if args.apply_top_p_top_k:
args.sampling_setting["top_k"] = 2
args.sampling_setting["top_p"] = 0.7

if args.logprobs:
args.sampling_setting["logprobs"] = True
args.sampling_setting["top_logprobs"] = args.top_logprobs
27 changes: 21 additions & 6 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse

# TODO(amalyshe): hadnle random_seed
# TODO(amalyshe): handle random_seed
# from .base import set_global_random_seed
from ..api.protocol import (
ChatCompletionRequest,
@@ -20,6 +20,7 @@
ChatMessage,
DeltaMessage,
ErrorResponse,
Logprobs,
UsageInfo,
)
from ..engine import (
@@ -64,6 +65,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params.top_p = request.top_p
if request.logit_bias is not None:
sampling_params.logit_bias = request.logit_bias
if request.logprobs:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs
return sampling_params


@@ -156,7 +160,7 @@ async def generate_completion_stream(
created_time = int(time.time())

def create_stream_response(
choices: list[ChatCompletionResponseStreamChoice],
choices: List[ChatCompletionResponseStreamChoice],
) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
id=request_id,
@@ -192,6 +196,7 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None
)
for seq in res.sequences
]
@@ -212,6 +217,7 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
@@ -226,18 +232,27 @@ async def collect_result_stream(
if seq.delta:
sequences[seq.index].append(seq.delta)

if seq.logprob_info:
assert seq.delta
logprob_infos[seq.index].extend(seq.logprob_info)

if seq.is_finished:
assert seq.finish_reason is not None
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore

choices = [
ChatCompletionResponseChoice(
choices = []
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)):
logprobs = None
if logprob_info_seq != []:
logprobs = Logprobs(content=logprob_info_seq)

choice = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
logprobs=logprobs,
)
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons))
]
choices.append(choice)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
6 changes: 6 additions & 0 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@

from pydantic import BaseModel, Field

from ..openai_logprob_protocol import Logprobs


class ErrorResponse(BaseModel):
object: str = "error"
@@ -71,11 +73,14 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False
logprobs: bool = False
top_logprobs: int = 0


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


@@ -96,6 +101,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


4 changes: 3 additions & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -16,5 +16,7 @@
RequestState,
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
RawLogprobsInfo,
RawLogprobsInfos,
)
from .sampling_params import SamplingParams, SamplingType
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
5 changes: 2 additions & 3 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import structlog
from typing import AsyncIterator, Any
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterator, Dict
from collections import deque

from .base import (
@@ -26,7 +25,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine_loop_task = None
self.engine_loop_exception = None
self.shutdown_event = asyncio.Event()
self.result_queues = dict[RequestId, ResultQueue]()
self.result_queues: Dict[RequestId, ResultQueue] = {}
self.recent_cancelled_requests = deque[RequestId](maxlen=64)

async def start(self):
15 changes: 14 additions & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -6,13 +6,25 @@

from typing import List, Callable, Any, Optional, Dict
import inspect
import numpy as np

from .sampling_params import SamplingParams, SamplingType
from ..openai_logprob_protocol import LogprobsContent

LOG = structlog.stdlib.get_logger(__name__)
RequestId = str


@dataclass
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
@@ -155,6 +167,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: List[Optional[LogprobsContent]] = field(default_factory=list)

@property
def is_finished(self) -> bool:
@@ -164,7 +177,7 @@ def is_finished(self) -> bool:
@dataclass
class RequestOutput:
request_id: RequestId
sequences: list[SequenceOutput]
sequences: List[SequenceOutput]
# TODO: reconsider the place to put this number
# Only set for outputs with valid sequence outputs
num_prompt_tokens: Optional[int] = None
6 changes: 3 additions & 3 deletions serve/mlc_serve/engine/dummy.py
Original file line number Diff line number Diff line change
@@ -12,9 +12,9 @@


class DummyInferenceEngine:
def __init__(self):
self.queue_lock = Lock()
self.has_new_requests = Condition(self.queue_lock)
def __init__(self) -> None:
self.queue_lock: Lock = Lock()
self.has_new_requests: Condition = Condition(self.queue_lock)
self.request_queue: Dict[RequestId, int] = {}

def add(self, requests: list[Request]):
53 changes: 51 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
@@ -3,17 +3,19 @@
"""

import time
from typing import Tuple, Deque, Dict, Optional, Union, Callable
from typing import Tuple, Deque, Dict, Optional, Union, Callable, List
from collections import deque
from threading import Condition, Lock

import structlog

from .base import (
GenerationSequence,
RawLogprobsInfo,
RawLogprobsInfos,
Request,
RequestId,
RequestState,
GenerationSequence,
SequenceId,
StoppingCriteria,
)
@@ -27,6 +29,7 @@
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs

LOG = structlog.stdlib.get_logger(__name__)

@@ -135,6 +138,52 @@ def detokenize_incrementally(
return delta


def logprob_detokenize(
tokenizer: TokenizerP,
logprob_info: Optional[RawLogprobsInfo],
) -> Optional[LogprobsContent]:
"""Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent"""
if logprob_info is None:
return None

top_logprobs: List[TopLogprobs] = []
if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None:
top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs))
for top_token_id, top_logprob in top_tokens:
top_logprobs.append(
TopLogprobs(
token=tokenizer.decode(top_token_id),
logprob=float(top_logprob),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
)
)

logprobs_content = LogprobsContent(
token=tokenizer.decode([logprob_info.current_token_id]),
logprob=logprob_info.current_logprob,
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
top_logprobs=top_logprobs,
)

return logprobs_content


def logprobs_detokenize(
tokenizer: TokenizerP,
logprob_info: Optional[RawLogprobsInfos],
) -> List[Optional[LogprobsContent]]:
if logprob_info is None:
return []

res: List[Optional[LogprobsContent]] = []
for info in logprob_info:
res.append(logprob_detokenize(tokenizer, info))

return res


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
10 changes: 9 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,14 @@
from dataclasses import dataclass
from typing import Optional, Protocol, Union, List, Sequence

from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId
from .base import (
ChatMessage,
MLCServeEngineConfig,
RawLogprobsInfos,
RequestId,
RequestState,
SequenceId,
)
from ..model.base import ModelArtifactConfig
from .sampling_params import SamplingParams

@@ -44,6 +51,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: List[int]
error: Optional[str]
logprob_info: Optional[RawLogprobsInfos]


class KVCache(Protocol):
16 changes: 15 additions & 1 deletion serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@
from functools import cached_property
from typing import Dict, Optional


_SAMPLING_EPS = 1e-5
LOGPROB_TOP_K_MAX = 5


class SamplingType(IntEnum):
@@ -46,6 +46,13 @@ class SamplingParams:
to -1 to consider all tokens.
logit_bias: The bias applied on the logit before sampling. Must be in
[-100, 100].
logprobs: Optional[bool] Whether to return log probabilities of the output
tokens or not. If true, returns the log probabilities of each output
token returned in the content of message.
top_logprobs: Optional[Integer] An integer between 0 and 5 specifying
the number of most likely tokens to return at each token position,
each with an associated log probability. logprobs must be set to
true if this parameter is used.
"""

presence_penalty: float = 0.0
@@ -58,6 +65,8 @@ class SamplingParams:
appeared_tokens_freq: Dict[int, int] = None
logit_bias_index: list[int] = None
logit_bias_value: list[float] = None
logprobs: bool = False
top_logprobs: int = 0

def __post_init__(self):
self.appeared_tokens_freq = {}
@@ -95,6 +104,11 @@ def _verify_args(self) -> None:
raise ValueError(
f"logit bias must be in [-100, 100], got {bias} for token {token}."
)
if self.logprobs:
if (self.top_logprobs < 0 or self.top_logprobs > LOGPROB_TOP_K_MAX):
raise ValueError(
f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}."
)

def _verify_greedy_sampling(self) -> None:
if self.top_p < 1.0 - _SAMPLING_EPS:
4 changes: 3 additions & 1 deletion serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable
from collections import defaultdict
from typing import Callable

import structlog

@@ -24,6 +24,7 @@
from .engine_common import (
get_new_request_state,
update_sequence,
logprobs_detokenize
)
from .model_module import ModelModule, TokenizerModule
from .staging_engine_worker import (
@@ -251,6 +252,7 @@ def step(self) -> InferenceStepResult:
delta,
finish_reason,
num_generated_tokens=len(gen_seq.generated_token_ids),
logprob_info=logprobs_detokenize(self.tokenizer, seq_output.logprob_info),
)

seq_outputs[request_id].append(output)
12 changes: 8 additions & 4 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
@@ -4,20 +4,22 @@
import time
import multiprocessing
import multiprocessing.synchronize
from dataclasses import dataclass
from dataclasses import dataclass, field
from threading import Thread, Lock
from typing import Callable, Optional, Union, Any, Dict, List

import structlog

from .base import (
FinishReason,
RawLogprobsInfos,
RequestId,
RequestState,
ValidationError,
SequenceId,
GenerationSequence,
)

from .metrics import PrometheusMetrics
from .metrics_labels import *
from .model_module import (
@@ -40,7 +42,7 @@ class ShutdownCommand:

@dataclass
class AddRequestsCommand:
request_states: list[RequestState]
request_states: List[RequestState]


@dataclass
@@ -61,14 +63,15 @@ class StopSequenceCommand:
@dataclass
class SequenceGenerationOutput:
id: SequenceId
new_tokens: list[int]
new_tokens: List[int]
finish_reason: Optional[FinishReason] = None
error: Optional[Union[str, ValidationError]] = None
logprob_info: Optional[RawLogprobsInfos] = None


@dataclass
class GenerationLoopWorkerOutput:
sequences: list[SequenceGenerationOutput]
sequences: List[SequenceGenerationOutput]
error: Optional[BaseException] = None


@@ -288,6 +291,7 @@ def step(self) -> GenerationLoopWorkerOutput:
id=res.sequence_id,
new_tokens=new_tokens,
finish_reason=finish_reason,
logprob_info=res.logprob_info,
)
)

2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
get_requests_to_process,
update_sequence,
EngineBase,
logprobs_detokenize
)
from .model_module import (
ModelModule,
@@ -222,6 +223,7 @@ def step(self) -> InferenceStepResult:
delta,
num_generated_tokens=len(gen_seq.generated_token_ids),
finish_reason=finish_reason,
logprob_info=logprobs_detokenize(self.tokenizer, res.logprob_info),
)
)

1 change: 1 addition & 0 deletions serve/mlc_serve/model/dummy_model.py
Original file line number Diff line number Diff line change
@@ -123,6 +123,7 @@ def generate(
generated_tokens=[req.token_ids[-1] + 1],
# generated_tokens=[1],
error=None,
logprob_info=None,
)
)
return result
115 changes: 110 additions & 5 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union, Optional
from typing import List, Optional, Tuple, Union

import structlog
import numpy as np
@@ -9,6 +9,9 @@
from ..engine import (
SamplingType,
SamplingParams,
LOGPROB_TOP_K_MAX,
RawLogprobsInfo,
RawLogprobsInfos,
)

LOG = structlog.stdlib.get_logger(__name__)
@@ -36,6 +39,86 @@ def get_num_cache_blocks(
)


def get_raw_logprob_info(
logits,
token_id,
top_logprobs_num,
) -> RawLogprobsInfo:
logprobs = torch.log_softmax(logits, dim=-1)
res_logprob = logprobs[token_id]

if top_logprobs_num == 0:
top_logprobs = None
top_tokens = None
else:
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs"
top_logprobs, top_tokens = torch.topk(
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True
)
top_tokens=top_tokens.cpu().numpy()
top_logprobs=top_logprobs.cpu().numpy()

# Set to raw logprob info
return RawLogprobsInfo(
current_token_id=token_id,
current_logprob=res_logprob,
top_token_ids=top_tokens,
top_logprobs=top_logprobs,
)


def get_logprob_indices(
sampling_params: List[SamplingParams],
num_seq: int,
) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]:
lgp_inds_greedy: List[Tuple[int, int, int]] = []
lgp_inds_random: List[Tuple[int, int, int]] = []

g_ind = 0
r_ind = 0
for i in range(num_seq):
sampling_param = sampling_params[i]
if sampling_param.sampling_type == SamplingType.RANDOM:
if sampling_param.logprobs:
lgp_inds_random.append((i, r_ind, sampling_param.top_logprobs))
r_ind = r_ind + 1
else:
if sampling_param.logprobs:
lgp_inds_greedy.append((i, g_ind, sampling_param.top_logprobs))
g_ind = g_ind + 1

return lgp_inds_greedy, lgp_inds_random


def get_raw_logprob_infos(
logprob_infos: RawLogprobsInfos,
indices: List[Tuple[int, int, int]],
logits: torch.Tensor,
token_ids: torch.Tensor,
) -> RawLogprobsInfos:
for (i, ind, top_logprobs) in indices:
logprob_infos[i] = get_raw_logprob_info(
logits[ind],
token_ids[ind],
top_logprobs,
)

return logprob_infos


def check_logprob_infos(
logprob_infos: RawLogprobsInfos,
) -> Optional[RawLogprobsInfos]:
check = False
for info in logprob_infos:
if info is not None:
check = True
break
if check:
return logprob_infos
return None


def _apply_top_p_top_k(logits, top_ps, top_ks):
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
@@ -64,7 +147,7 @@ def sample(
sampling_params: List[SamplingParams],
vocab_size: int,
check_safety=False,
) -> Optional[np.ndarray]:
) -> Optional[Tuple[np.ndarray, Optional[RawLogprobsInfos]]]:
def _is_safe_to_sample(prob_like):
return (
torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0))
@@ -89,12 +172,26 @@ def _is_safe_to_sample(prob_like):

logits_greedy = logits[mask_greedy_dvc]

logprob_infos: RawLogprobsInfos = [None] * num_seq
lgp_inds_greedy, lgp_inds_random = get_logprob_indices(
sampling_params,
num_seq,
)

if logits_greedy.shape[0] > 0:
res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy()

logprob_infos = get_raw_logprob_infos(
logprob_infos,
lgp_inds_greedy,
logits_greedy,
res_greedy,
)

# Case when there's only greedy sampling
if logits_greedy.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_greedy
return res_greedy, check_logprob_infos(logprob_infos)

temperatures = []
top_ps = []
@@ -163,9 +260,17 @@ def _is_safe_to_sample(prob_like):

res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy()

logprob_infos = get_raw_logprob_infos(
logprob_infos,
lgp_inds_random,
logits_random,
res_random,
)

# Case when there's only random sampling
if logits_random.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_random
return res_random, check_logprob_infos(logprob_infos)

res = np.empty((num_seq,), dtype=np.int32)
res[mask_random_cpu] = res_random
@@ -174,7 +279,7 @@ def _is_safe_to_sample(prob_like):
res[mask_greedy_cpu] = res_greedy

torch.cuda.nvtx.range_pop()
return res
return res, check_logprob_infos(logprob_infos)


def prepare_inputs(
10 changes: 5 additions & 5 deletions serve/mlc_serve/model/paged_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from collections import defaultdict
from typing import List, Optional
from typing import Any, List, Optional

from ..engine import (
RequestId,
@@ -104,16 +104,16 @@ def replace_head_prompt_block_with(self, new_block):
class KVCacheInfo:
def __init__(
self,
block_size,
block_size: int
):
self.block_size = block_size

# SequenceId -> list[int]
self.prompt_block_tables = defaultdict(list)
self.slot_mappings = defaultdict(list)
self.prompt_block_tables = defaultdict(list) # type: ignore
self.slot_mappings = defaultdict(list) # type: ignore

# The core data structure
self.decode_block_tables = dict[SequenceId, DecodeBlockTable]()
self.decode_block_tables: dict = dict[SequenceId, DecodeBlockTable]()

# Record indices of blocks to copy after prefill in the format [src1, dst1, src2, dst2, ...]
self.pending_copy_from_to: list[int] = []
8 changes: 4 additions & 4 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union
from pathlib import Path
import structlog
from typing import List, Union

from .base import get_model_artifact_config
from .paged_cache_manager import CacheManager
@@ -10,11 +10,11 @@
from ..engine import MLCServeEngineConfig
from ..engine.model_module import (
DecodeRequest,
ModelModule,
PrefillRequest,
TextGenerationResult,
TextGenerator,
)
from ..engine.model_module import ModelModule

LOG = structlog.stdlib.get_logger(__name__)

@@ -24,8 +24,8 @@ def __init__(self, model: TextGenerator):
self.model = model

def generate(
self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache
) -> list[TextGenerationResult]:
self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache
) -> List[TextGenerationResult]:
prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)]
decode_requests = [r for r in requests if isinstance(r, DecodeRequest)]

42 changes: 27 additions & 15 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import os
from typing import List, Union, Tuple, Sequence
from typing import List, Optional, Union, Tuple, Sequence

import structlog
import numpy as np
@@ -18,8 +18,9 @@
)

from ..engine import (
SequenceId,
PROMPT_SEQEUNCE_INDEX,
RawLogprobsInfos,
SequenceId,
get_prompt_sequence_id,
MLCServeEngineConfig,
)
@@ -203,6 +204,16 @@ def profile_memory_usage(self, seq_lens):

return self.get_used_memory()

def get_logprob_infos(
self,
i: int,
logprob_infos: Optional[RawLogprobsInfos],
) -> Optional[RawLogprobsInfos]:
if logprob_infos is None or logprob_infos[i] is None:
return None
return [logprob_infos[i]]


def generate(
self,
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
@@ -282,13 +293,6 @@ def generate(
slot_mapping,
self.params,
)

if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
logits = out[
0
] # Ignore returned KV cache since it is updated in-place anyway.
else:
torch.cuda.nvtx.range_push(f"forward decode {input_shape}")

@@ -305,10 +309,12 @@ def generate(
self.params,
)

if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
logits = out[0]
if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
logits = out[
0
] # Ignore returned KV cache since it is updated in-place anyway.

torch.cuda.synchronize()
torch.cuda.nvtx.range_pop()
@@ -330,7 +336,7 @@ def generate(
cache.pending_copy_from_to = []

try:
next_tokens = sample(logits, sampling_params, self.vocab_size)
next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size)
assert next_tokens is not None
outputs = []
for i, (sequence_id, new_token) in enumerate(
@@ -346,6 +352,7 @@ def generate(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token],
error=None,
logprob_info=self.get_logprob_infos(i, logprob_infos),
)
)
else:
@@ -354,6 +361,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
logprob_info=self.get_logprob_infos(i, logprob_infos),
)
)

@@ -369,7 +377,7 @@ def generate(
for i, (sequence_id, logits_per_token, sampling_param) in enumerate(
zip(sequence_ids, torch.from_dlpack(logits), sampling_params)
):
maybe_new_token = sample(
maybe_new_token, logprob_infos = sample(
torch.unsqueeze(logits_per_token, 0),
[sampling_param],
self.vocab_size,
@@ -393,6 +401,7 @@ def generate(
),
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=self.get_logprob_infos(0, logprob_infos),
)
)
else:
@@ -401,6 +410,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=self.get_logprob_infos(0, logprob_infos),
)
)
else:
@@ -413,6 +423,7 @@ def generate(
),
generated_tokens=[],
error=err_msg,
logprob_info=self.get_logprob_infos(0, logprob_infos),
)
)
else:
@@ -421,6 +432,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
logprob_info=self.get_logprob_infos(0, logprob_infos),
)
)

28 changes: 28 additions & 0 deletions serve/mlc_serve/openai_logprob_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List, Optional

from pydantic import BaseModel

class TopLogprobs(BaseModel):
"""An OpenAI API compatible schema for logprobs output."""

token: str
logprob: float
bytes: Optional[List] = None


class LogprobsContent(BaseModel):
"""An OpenAI API compatible schema for logprobs output."""

token: str
logprob: float
bytes: Optional[List] = None
top_logprobs: List[TopLogprobs] # It can be empty


class Logprobs(BaseModel):
"""
An OpenAI API compatible schema for logprobs output.
See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs
"""

content: List[LogprobsContent]
43 changes: 42 additions & 1 deletion serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ def create_engine(


def create_request(
idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, logit_bias=None
idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, top_logprobs=0, logprobs=False, logit_bias=None
):
return Request(
request_id=str(idx),
@@ -58,6 +58,8 @@ def create_request(
frequency_penalty=freq_pen,
presence_penalty=pre_pen,
logit_bias=logit_bias,
logprobs=logprobs,
top_logprobs=top_logprobs,
),
stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop),
debug_options=DebugOptions(ignore_eos=ignore_eos),
@@ -337,6 +339,43 @@ def _test_penalty(
if use_staging_engine:
engine.stop()

def _test_logprobs(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
top_logprobs=3,
):
prompt = "hi"
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
)
s = 113
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs, logprobs=True) for n in range(s, s+num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta

if use_staging_engine:
engine.stop()

if __name__ == "__main__":
parser = get_default_mlc_serve_argparser("test engine with samplers")
@@ -349,6 +388,8 @@ def _test_penalty(
_test_ignore_eos(args.model_artifact_path, use_staging_engine=False)
_test_stop(args.model_artifact_path, use_staging_engine=False)
_test_stop(args.model_artifact_path, use_staging_engine=True)
_test_logprobs(args.model_artifact_path, use_staging_engine=True)
_test_logprobs(args.model_artifact_path, use_staging_engine=False)
# These tests are broken since we are now imposing no length limit
# if max_tokens = None. The tests do not finish in a reasonable time.
# _test_max_context_length(model_artifact_path, use_staging_engine=True)

0 comments on commit 2b3fcf0

Please sign in to comment.