Skip to content

Commit

Permalink
update sample top k
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Mar 9, 2024
1 parent 2856695 commit 0b744d3
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,13 @@ def sample_select_fn(sample_top_k: int) -> IdxSelectFn:
def _sample(scores: torch.Tensor, _: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
assert scores.ndim == 2
k = min(sample_top_k, scores.shape[-1])
sampled_indices = torch.randint(
k,
(len(scores), 1),
device=scores.device
)
top_k = torch.topk(scores, k, dim=-1)
values = top_k.values
values /= top_k.values.sum(dim=-1, keepdim=True)
sampled = torch.multinomial(values, 1)
sampled_indices = torch.gather(top_k.indices, -1, sampled)
indices = torch.gather(top_k.indices, -1, sampled_indices).squeeze(-1)
scores = torch.gather(top_k.values, -1, sampled_indices).squeeze(-1)
scores = torch.gather(top_k.values, -1, indices).squeeze(-1)
return indices, scores

return _sample
Expand Down

0 comments on commit 0b744d3

Please sign in to comment.