-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Logprobs in MLC Batch Serving #82
Changes from 1 commit
ab47b41
86f6fa1
9a29650
012388d
db31164
be81755
e8ec3fc
49187f5
013ed5a
79ec413
fca1a6f
675b631
18f80fa
29ea525
aa99322
8fa785e
d57b197
7995c84
ae3fc5b
0cb036f
ed51e7d
ff17ae2
f5e4339
a3f6e8b
c54a410
379d991
58bac8f
7de8d88
661fa18
6662a65
970d7f8
c58d69c
ebae200
b2863d5
57b3a35
4e29403
a9157b9
39efb61
601e68d
4f9241b
7ec21a7
7aa60ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,41 +38,55 @@ def get_num_cache_blocks( | |
) | ||
|
||
|
||
def fetch_raw_logprob_infos( | ||
def get_raw_logprob_info( | ||
logits, | ||
res_tokens, | ||
sampling_params, | ||
token, | ||
top_logprobs_num, | ||
) -> RawLogprobsInfo: | ||
logprobs = torch.log_softmax(logits, dim=-1) | ||
res_logprob = logprobs[token].cpu().numpy() | ||
|
||
if top_logprobs_num == 0: | ||
top_logprobs = None | ||
top_tokens = None | ||
else: | ||
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" | ||
top_logprobs, top_tokens = torch.topk( | ||
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True | ||
) | ||
top_tokens=top_tokens.cpu().numpy() | ||
top_logprobs=top_logprobs.cpu().numpy() | ||
|
||
# Set to raw logprob info | ||
return RawLogprobsInfo( | ||
# TODO(vvchernov): it is number, cpu().numpy()? | ||
current_token=token.cpu().numpy(), | ||
current_logprob=res_logprob, | ||
top_tokens=top_tokens, | ||
top_logprobs=top_logprobs, | ||
previous_tokens=None | ||
) | ||
|
||
|
||
def get_masked_logprobs( | ||
logprob_infos: List[Optional[RawLogprobsInfo]], | ||
mask: torch.Tensor, | ||
sampling_params: List[SamplingParams], | ||
logits: torch.Tensor, | ||
tokens: torch.Tensor, | ||
) -> List[Optional[RawLogprobsInfo]]: | ||
logprob_infos: List[Optional[RawLogprobsInfo]] = [] | ||
num_seq = logits.shape[0] | ||
for index in range(num_seq): | ||
if sampling_params[index].logprobs: | ||
# Logprob sampling | ||
logprobs = torch.log_softmax(logits[index], dim=-1) | ||
res_logprob = logprobs[res_tokens[index]].cpu().numpy() | ||
|
||
top_logprobs_num = sampling_params[index].top_logprobs | ||
if top_logprobs_num == 0: | ||
top_logprobs = None | ||
top_tokens = None | ||
else: | ||
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" | ||
top_logprobs, top_tokens = torch.topk( | ||
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True | ||
num_seq = len(logprob_infos) | ||
|
||
mask_counter = 0 | ||
for i in range(num_seq): | ||
if mask[i]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't want to loop over Please follow my suggestion #82 (comment) instead. Rather than using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is not so. Due to this I've prepared fix #178 |
||
if sampling_params[i].logprobs: | ||
logprob_infos[i] = get_raw_logprob_info( | ||
logits[mask_counter], | ||
tokens[mask_counter], | ||
sampling_params[i].top_logprobs, | ||
) | ||
top_tokens=top_tokens.cpu().numpy() | ||
top_logprobs=top_logprobs.cpu().numpy() | ||
|
||
# Set to raw logprob info | ||
logprob_infos.append(RawLogprobsInfo( | ||
current_token=res_tokens[index].cpu().numpy(), | ||
current_logprob=res_logprob, | ||
top_tokens=top_tokens, | ||
top_logprobs=top_logprobs, | ||
previous_tokens=None | ||
)) | ||
else: | ||
logprob_infos.append(None) | ||
mask_counter = mask_counter + 1 | ||
|
||
return logprob_infos | ||
|
||
|
@@ -100,25 +114,6 @@ def _apply_top_p_top_k(logits, top_ps, top_ks): | |
return logits | ||
|
||
|
||
def update_masked_list(input_list, mask, update): | ||
j = 0 | ||
for i in range(len(mask)): | ||
if mask[i]: | ||
input_list[i] = update[j] | ||
j = j + 1 | ||
|
||
return input_list | ||
|
||
|
||
def filter_list_by_mask(i_list, mask): | ||
o_list = [] | ||
for i in range(len(mask)): | ||
if mask[i]: | ||
o_list.append(i_list[i]) | ||
|
||
return o_list | ||
|
||
|
||
def sample( | ||
logits: Union[tvm.nd.NDArray, torch.Tensor], | ||
sampling_params: List[SamplingParams], | ||
|
@@ -135,6 +130,8 @@ def _is_safe_to_sample(prob_like): | |
logits = torch.from_dlpack(logits) | ||
num_seq = len(sampling_params) | ||
|
||
logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq | ||
|
||
mask_random = torch.tensor( | ||
[p.sampling_type == SamplingType.RANDOM for p in sampling_params], | ||
dtype=torch.bool, | ||
|
@@ -146,17 +143,19 @@ def _is_safe_to_sample(prob_like): | |
if logits_greedy.shape[0] > 0: | ||
res_greedy = torch.argmax(logits_greedy, -1) | ||
|
||
logprob_infos_greedy = fetch_raw_logprob_infos( | ||
logprob_infos = get_masked_logprobs( | ||
logprob_infos, | ||
mask_greedy, | ||
sampling_params, | ||
logits_greedy, | ||
res_greedy, | ||
filter_list_by_mask(sampling_params, mask_greedy) | ||
) | ||
|
||
res_greedy = res_greedy.cpu().numpy() | ||
# Case when there's only greedy sampling | ||
if logits_greedy.shape[0] == num_seq: | ||
torch.cuda.nvtx.range_pop() | ||
return res_greedy, logprob_infos_greedy | ||
return res_greedy, logprob_infos | ||
|
||
temperatures = [] | ||
top_ps = [] | ||
|
@@ -225,29 +224,26 @@ def _is_safe_to_sample(prob_like): | |
|
||
res_random = torch.multinomial(probs, 1, True)[:, 0] | ||
|
||
logprob_infos_random = fetch_raw_logprob_infos( | ||
logprob_infos = get_masked_logprobs( | ||
logprob_infos, | ||
mask_random, | ||
sampling_params, | ||
logits_random, | ||
res_random, | ||
filter_list_by_mask(sampling_params, mask_random), | ||
) | ||
|
||
res_random = res_random.cpu().numpy() | ||
# Case when there's only random sampling | ||
if logits_random.shape[0] == num_seq: | ||
torch.cuda.nvtx.range_pop() | ||
return res_random, logprob_infos_random | ||
return res_random, logprob_infos | ||
|
||
res = np.empty((num_seq,), dtype=np.int32) | ||
res[mask_random] = res_random | ||
|
||
logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq | ||
logprob_infos = update_masked_list(logprob_infos, mask_random, logprob_infos_random) | ||
|
||
if logits_greedy.shape[0] > 0: | ||
res[mask_greedy] = res_greedy | ||
|
||
logprob_infos = update_masked_list(logprob_infos, mask_greedy, logprob_infos_greedy) | ||
|
||
torch.cuda.nvtx.range_pop() | ||
return res, logprob_infos | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
.cpu().numpy()
can be removed here.current_token
seems to be an element fromres_random
orres_greedy
, which is already on cpu. Remember,.cpu()
and.numpy()
involve cudaMemcpy and another copy on CPU.