Skip to content

Commit

Permalink
refactor logprob info collection to not reduce performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Jan 29, 2024
1 parent 013ed5a commit 79ec413
Showing 1 changed file with 58 additions and 62 deletions.
120 changes: 58 additions & 62 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,55 @@ def get_num_cache_blocks(
)


def fetch_raw_logprob_infos(
def get_raw_logprob_info(
logits,
res_tokens,
sampling_params,
token,
top_logprobs_num,
) -> RawLogprobsInfo:
logprobs = torch.log_softmax(logits, dim=-1)
res_logprob = logprobs[token].cpu().numpy()

if top_logprobs_num == 0:
top_logprobs = None
top_tokens = None
else:
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs"
top_logprobs, top_tokens = torch.topk(
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True
)
top_tokens=top_tokens.cpu().numpy()
top_logprobs=top_logprobs.cpu().numpy()

# Set to raw logprob info
return RawLogprobsInfo(
# TODO(vvchernov): it is number, cpu().numpy()?
current_token=token.cpu().numpy(),
current_logprob=res_logprob,
top_tokens=top_tokens,
top_logprobs=top_logprobs,
previous_tokens=None
)


def get_masked_logprobs(
logprob_infos: List[Optional[RawLogprobsInfo]],
mask: torch.Tensor,
sampling_params: List[SamplingParams],
logits: torch.Tensor,
tokens: torch.Tensor,
) -> List[Optional[RawLogprobsInfo]]:
logprob_infos: List[Optional[RawLogprobsInfo]] = []
num_seq = logits.shape[0]
for index in range(num_seq):
if sampling_params[index].logprobs:
# Logprob sampling
logprobs = torch.log_softmax(logits[index], dim=-1)
res_logprob = logprobs[res_tokens[index]].cpu().numpy()

top_logprobs_num = sampling_params[index].top_logprobs
if top_logprobs_num == 0:
top_logprobs = None
top_tokens = None
else:
assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs"
top_logprobs, top_tokens = torch.topk(
logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True
num_seq = len(logprob_infos)

mask_counter = 0
for i in range(num_seq):
if mask[i]:
if sampling_params[i].logprobs:
logprob_infos[i] = get_raw_logprob_info(
logits[mask_counter],
tokens[mask_counter],
sampling_params[i].top_logprobs,
)
top_tokens=top_tokens.cpu().numpy()
top_logprobs=top_logprobs.cpu().numpy()

# Set to raw logprob info
logprob_infos.append(RawLogprobsInfo(
current_token=res_tokens[index].cpu().numpy(),
current_logprob=res_logprob,
top_tokens=top_tokens,
top_logprobs=top_logprobs,
previous_tokens=None
))
else:
logprob_infos.append(None)
mask_counter = mask_counter + 1

return logprob_infos

Expand Down Expand Up @@ -100,25 +114,6 @@ def _apply_top_p_top_k(logits, top_ps, top_ks):
return logits


def update_masked_list(input_list, mask, update):
j = 0
for i in range(len(mask)):
if mask[i]:
input_list[i] = update[j]
j = j + 1

return input_list


def filter_list_by_mask(i_list, mask):
o_list = []
for i in range(len(mask)):
if mask[i]:
o_list.append(i_list[i])

return o_list


def sample(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sampling_params: List[SamplingParams],
Expand All @@ -135,6 +130,8 @@ def _is_safe_to_sample(prob_like):
logits = torch.from_dlpack(logits)
num_seq = len(sampling_params)

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq

mask_random = torch.tensor(
[p.sampling_type == SamplingType.RANDOM for p in sampling_params],
dtype=torch.bool,
Expand All @@ -146,17 +143,19 @@ def _is_safe_to_sample(prob_like):
if logits_greedy.shape[0] > 0:
res_greedy = torch.argmax(logits_greedy, -1)

logprob_infos_greedy = fetch_raw_logprob_infos(
logprob_infos = get_masked_logprobs(
logprob_infos,
mask_greedy,
sampling_params,
logits_greedy,
res_greedy,
filter_list_by_mask(sampling_params, mask_greedy)
)

res_greedy = res_greedy.cpu().numpy()
# Case when there's only greedy sampling
if logits_greedy.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_greedy, logprob_infos_greedy
return res_greedy, logprob_infos

temperatures = []
top_ps = []
Expand Down Expand Up @@ -225,29 +224,26 @@ def _is_safe_to_sample(prob_like):

res_random = torch.multinomial(probs, 1, True)[:, 0]

logprob_infos_random = fetch_raw_logprob_infos(
logprob_infos = get_masked_logprobs(
logprob_infos,
mask_random,
sampling_params,
logits_random,
res_random,
filter_list_by_mask(sampling_params, mask_random),
)

res_random = res_random.cpu().numpy()
# Case when there's only random sampling
if logits_random.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_random, logprob_infos_random
return res_random, logprob_infos

res = np.empty((num_seq,), dtype=np.int32)
res[mask_random] = res_random

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq
logprob_infos = update_masked_list(logprob_infos, mask_random, logprob_infos_random)

if logits_greedy.shape[0] > 0:
res[mask_greedy] = res_greedy

logprob_infos = update_masked_list(logprob_infos, mask_greedy, logprob_infos_greedy)

torch.cuda.nvtx.range_pop()
return res, logprob_infos

Expand Down

0 comments on commit 79ec413

Please sign in to comment.