Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Logprobs in MLC Batch Serving #82

Merged
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ab47b41
Squashed commit for logprobs implementation.
zxybazh Jan 22, 2024
86f6fa1
fix None check
Jan 23, 2024
9a29650
Change detokenization to using token ids.
zxybazh Jan 25, 2024
012388d
Fix wrong usage of token ids. Remove logging.
zxybazh Jan 29, 2024
db31164
extend benchmarks for logprobs
Jan 26, 2024
be81755
fix test without logprobs
Jan 26, 2024
e8ec3fc
clean code
Jan 26, 2024
49187f5
black format engine_common.py
Jan 26, 2024
013ed5a
logprobs is strictly bool, top_logprobs is int
Jan 26, 2024
79ec413
refactor logprob info collection to not reduce performance
Jan 28, 2024
fca1a6f
quick fix for check
Jan 29, 2024
675b631
review fix
Jan 29, 2024
18f80fa
fix list index out of range
Jan 29, 2024
29ea525
rollback after rebase
Jan 29, 2024
aa99322
test
Jan 29, 2024
8fa785e
Merge pull request #7 from Deelvin/vc/benchmark
Jan 29, 2024
d57b197
Squashed commit for logprobs implementation.
zxybazh Jan 22, 2024
7995c84
fix None check
Jan 23, 2024
ae3fc5b
Change detokenization to using token ids.
zxybazh Jan 25, 2024
0cb036f
Fix wrong usage of token ids. Remove logging.
zxybazh Jan 29, 2024
ed51e7d
extend benchmarks for logprobs
Jan 26, 2024
ff17ae2
fix test without logprobs
Jan 26, 2024
f5e4339
clean code
Jan 26, 2024
a3f6e8b
black format engine_common.py
Jan 26, 2024
c54a410
logprobs is strictly bool, top_logprobs is int
Jan 26, 2024
379d991
refactor logprob info collection to not reduce performance
Jan 28, 2024
58bac8f
quick fix for check
Jan 29, 2024
7de8d88
review fix
Jan 29, 2024
661fa18
fix list index out of range
Jan 29, 2024
6662a65
rollback after rebase
Jan 29, 2024
970d7f8
test
Jan 29, 2024
c58d69c
small fix
Jan 30, 2024
ebae200
rename for the sake of clarity
Jan 30, 2024
b2863d5
some fixes with cpu-gpu tensor copying
Jan 30, 2024
57b3a35
refactor logprob pass to calculate
Jan 30, 2024
4e29403
remove excess deps for token detokenization
Jan 30, 2024
a9157b9
small clean
Jan 30, 2024
39efb61
small clean
Jan 31, 2024
601e68d
return None instead of list of Nones
Jan 31, 2024
4f9241b
resolve conflicts
Jan 31, 2024
7ec21a7
fix mypy
Jan 31, 2024
7aa60ed
Merge pull request #8 from Deelvin/vc/perf
Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename for the sake of clarity
Valery Chernov committed Jan 31, 2024
commit ebae20023a7d4b77a2c3b3e1f21c8279682dcc2b
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,9 @@

@dataclass
class RawLogprobsInfo:
current_token: int
current_token_id: int
current_logprob: float
top_tokens: Optional[np.array]
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]
previous_tokens: Optional[List[int]]

10 changes: 5 additions & 5 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
@@ -146,16 +146,16 @@ def logprob_detokenize(
return None

top_logprobs: List[TopLogprobs] = []
if logprob_info.top_tokens is not None and logprob_info.top_logprobs is not None:
top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs))
if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None:
top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs))
if logprob_info.previous_tokens is None:
logprob_info.previous_tokens = []
for top_token, top_logprob in top_tokens:
for top_token_id, top_logprob in top_tokens:
# TODO(vvchernov): not clear what do we want
# detokenized = tokenizer.convert_ids_to_tokens(
# logprob_info.previous_tokens + [top_token]
# )[-1]
detokenized = tokenizer.decode(top_token)
detokenized = tokenizer.decode(top_token_id)
top_logprobs.append(
TopLogprobs(
token=detokenized,
@@ -166,7 +166,7 @@ def logprob_detokenize(
)

logprobs_content = LogprobsContent(
token=tokenizer.decode([logprob_info.current_token]),
token=tokenizer.decode([logprob_info.current_token_id]),
logprob=logprob_info.current_logprob,
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
12 changes: 6 additions & 6 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
@@ -40,11 +40,11 @@ def get_num_cache_blocks(

def get_raw_logprob_info(
logits,
token,
token_id,
top_logprobs_num,
) -> RawLogprobsInfo:
logprobs = torch.log_softmax(logits, dim=-1)
res_logprob = logprobs[token].cpu().numpy()
res_logprob = logprobs[token_id].cpu().numpy()

if top_logprobs_num == 0:
top_logprobs = None
@@ -59,9 +59,9 @@ def get_raw_logprob_info(

# Set to raw logprob info
return RawLogprobsInfo(
current_token=token,
current_token_id=token_id,
current_logprob=res_logprob,
top_tokens=top_tokens,
top_token_ids=top_tokens,
top_logprobs=top_logprobs,
previous_tokens=None
)
@@ -72,7 +72,7 @@ def get_masked_logprobs(
mask: torch.Tensor,
sampling_params: List[SamplingParams],
logits: torch.Tensor,
tokens: torch.Tensor,
token_ids: torch.Tensor,
) -> List[Optional[RawLogprobsInfo]]:
num_seq = len(logprob_infos)

@@ -82,7 +82,7 @@ def get_masked_logprobs(
if sampling_params[i].logprobs:
logprob_infos[i] = get_raw_logprob_info(
logits[mask_counter],
tokens[mask_counter],
token_ids[mask_counter],
sampling_params[i].top_logprobs,
)
mask_counter = mask_counter + 1