Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Logprobs in MLC Batch Serving #82

Merged
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ab47b41
Squashed commit for logprobs implementation.
zxybazh Jan 22, 2024
86f6fa1
fix None check
Jan 23, 2024
9a29650
Change detokenization to using token ids.
zxybazh Jan 25, 2024
012388d
Fix wrong usage of token ids. Remove logging.
zxybazh Jan 29, 2024
db31164
extend benchmarks for logprobs
Jan 26, 2024
be81755
fix test without logprobs
Jan 26, 2024
e8ec3fc
clean code
Jan 26, 2024
49187f5
black format engine_common.py
Jan 26, 2024
013ed5a
logprobs is strictly bool, top_logprobs is int
Jan 26, 2024
79ec413
refactor logprob info collection to not reduce performance
Jan 28, 2024
fca1a6f
quick fix for check
Jan 29, 2024
675b631
review fix
Jan 29, 2024
18f80fa
fix list index out of range
Jan 29, 2024
29ea525
rollback after rebase
Jan 29, 2024
aa99322
test
Jan 29, 2024
8fa785e
Merge pull request #7 from Deelvin/vc/benchmark
Jan 29, 2024
d57b197
Squashed commit for logprobs implementation.
zxybazh Jan 22, 2024
7995c84
fix None check
Jan 23, 2024
ae3fc5b
Change detokenization to using token ids.
zxybazh Jan 25, 2024
0cb036f
Fix wrong usage of token ids. Remove logging.
zxybazh Jan 29, 2024
ed51e7d
extend benchmarks for logprobs
Jan 26, 2024
ff17ae2
fix test without logprobs
Jan 26, 2024
f5e4339
clean code
Jan 26, 2024
a3f6e8b
black format engine_common.py
Jan 26, 2024
c54a410
logprobs is strictly bool, top_logprobs is int
Jan 26, 2024
379d991
refactor logprob info collection to not reduce performance
Jan 28, 2024
58bac8f
quick fix for check
Jan 29, 2024
7de8d88
review fix
Jan 29, 2024
661fa18
fix list index out of range
Jan 29, 2024
6662a65
rollback after rebase
Jan 29, 2024
970d7f8
test
Jan 29, 2024
c58d69c
small fix
Jan 30, 2024
ebae200
rename for the sake of clarity
Jan 30, 2024
b2863d5
some fixes with cpu-gpu tensor copying
Jan 30, 2024
57b3a35
refactor logprob pass to calculate
Jan 30, 2024
4e29403
remove excess deps for token detokenization
Jan 30, 2024
a9157b9
small clean
Jan 30, 2024
39efb61
small clean
Jan 31, 2024
601e68d
return None instead of list of Nones
Jan 31, 2024
4f9241b
resolve conflicts
Jan 31, 2024
7ec21a7
fix mypy
Jan 31, 2024
7aa60ed
Merge pull request #8 from Deelvin/vc/perf
Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor logprob info collection to not reduce performance
Valery Chernov committed Jan 29, 2024
commit 79ec4135b5d93a0be3684fb1dd84443be9cd533b
120 changes: 58 additions & 62 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
@@ -38,41 +38,55 @@ def get_num_cache_blocks(
)


def fetch_raw_logprob_infos(
def get_raw_logprob_info(
logits,
res_tokens,
sampling_params,
token,
top_logprobs_num,
) -> RawLogprobsInfo:
logprobs = torch.log_softmax(logits, dim=-1)
res_logprob = logprobs[token].cpu().numpy()

if top_logprobs_num == 0:
top_logprobs = None
top_tokens = None
else:
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs"
top_logprobs, top_tokens = torch.topk(
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True
)
top_tokens=top_tokens.cpu().numpy()
top_logprobs=top_logprobs.cpu().numpy()

# Set to raw logprob info
return RawLogprobsInfo(
# TODO(vvchernov): it is number, cpu().numpy()?
current_token=token.cpu().numpy(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think .cpu().numpy() can be removed here. current_token seems to be an element from res_random or res_greedy, which is already on cpu. Remember, .cpu() and .numpy() involve cudaMemcpy and another copy on CPU.

current_logprob=res_logprob,
top_tokens=top_tokens,
top_logprobs=top_logprobs,
previous_tokens=None
)


def get_masked_logprobs(
logprob_infos: List[Optional[RawLogprobsInfo]],
mask: torch.Tensor,
sampling_params: List[SamplingParams],
logits: torch.Tensor,
tokens: torch.Tensor,
) -> List[Optional[RawLogprobsInfo]]:
logprob_infos: List[Optional[RawLogprobsInfo]] = []
num_seq = logits.shape[0]
for index in range(num_seq):
if sampling_params[index].logprobs:
# Logprob sampling
logprobs = torch.log_softmax(logits[index], dim=-1)
res_logprob = logprobs[res_tokens[index]].cpu().numpy()

top_logprobs_num = sampling_params[index].top_logprobs
if top_logprobs_num == 0:
top_logprobs = None
top_tokens = None
else:
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs"
top_logprobs, top_tokens = torch.topk(
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True
num_seq = len(logprob_infos)

mask_counter = 0
for i in range(num_seq):
if mask[i]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to loop over num_seq and check mask. I think mask is on GPU, so every time you access its element from python you get cudaMemcpy.

Please follow my suggestion #82 (comment) instead. Rather than using mask_greedy or mask_random, collect a list of indices that wants logprob. Then you only need to loop over such indices here without mask check.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mask is on GPU

It is not so. Due to this I've prepared fix #178
I'm investigating the issue with performance degradation

if sampling_params[i].logprobs:
logprob_infos[i] = get_raw_logprob_info(
logits[mask_counter],
tokens[mask_counter],
sampling_params[i].top_logprobs,
)
top_tokens=top_tokens.cpu().numpy()
top_logprobs=top_logprobs.cpu().numpy()

# Set to raw logprob info
logprob_infos.append(RawLogprobsInfo(
current_token=res_tokens[index].cpu().numpy(),
current_logprob=res_logprob,
top_tokens=top_tokens,
top_logprobs=top_logprobs,
previous_tokens=None
))
else:
logprob_infos.append(None)
mask_counter = mask_counter + 1

return logprob_infos

@@ -100,25 +114,6 @@ def _apply_top_p_top_k(logits, top_ps, top_ks):
return logits


def update_masked_list(input_list, mask, update):
j = 0
for i in range(len(mask)):
if mask[i]:
input_list[i] = update[j]
j = j + 1

return input_list


def filter_list_by_mask(i_list, mask):
o_list = []
for i in range(len(mask)):
if mask[i]:
o_list.append(i_list[i])

return o_list


def sample(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sampling_params: List[SamplingParams],
@@ -135,6 +130,8 @@ def _is_safe_to_sample(prob_like):
logits = torch.from_dlpack(logits)
num_seq = len(sampling_params)

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq

mask_random = torch.tensor(
[p.sampling_type == SamplingType.RANDOM for p in sampling_params],
dtype=torch.bool,
@@ -146,17 +143,19 @@ def _is_safe_to_sample(prob_like):
if logits_greedy.shape[0] > 0:
res_greedy = torch.argmax(logits_greedy, -1)

logprob_infos_greedy = fetch_raw_logprob_infos(
logprob_infos = get_masked_logprobs(
logprob_infos,
mask_greedy,
sampling_params,
logits_greedy,
res_greedy,
filter_list_by_mask(sampling_params, mask_greedy)
)

res_greedy = res_greedy.cpu().numpy()
# Case when there's only greedy sampling
if logits_greedy.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_greedy, logprob_infos_greedy
return res_greedy, logprob_infos

temperatures = []
top_ps = []
@@ -225,29 +224,26 @@ def _is_safe_to_sample(prob_like):

res_random = torch.multinomial(probs, 1, True)[:, 0]

logprob_infos_random = fetch_raw_logprob_infos(
logprob_infos = get_masked_logprobs(
logprob_infos,
mask_random,
sampling_params,
logits_random,
res_random,
filter_list_by_mask(sampling_params, mask_random),
)

res_random = res_random.cpu().numpy()
# Case when there's only random sampling
if logits_random.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_random, logprob_infos_random
return res_random, logprob_infos

res = np.empty((num_seq,), dtype=np.int32)
res[mask_random] = res_random

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq
logprob_infos = update_masked_list(logprob_infos, mask_random, logprob_infos_random)

if logits_greedy.shape[0] > 0:
res[mask_greedy] = res_greedy

logprob_infos = update_masked_list(logprob_infos, mask_greedy, logprob_infos_greedy)

torch.cuda.nvtx.range_pop()
return res, logprob_infos