From 0b744d349ad47e0c19048689cbd9fa54f5c76313 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Sat, 9 Mar 2024 15:47:39 +0100 Subject: [PATCH] update sample top k --- python/text_utils/inference/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/text_utils/inference/__init__.py b/python/text_utils/inference/__init__.py index dafc4fa..81ad861 100644 --- a/python/text_utils/inference/__init__.py +++ b/python/text_utils/inference/__init__.py @@ -132,14 +132,13 @@ def sample_select_fn(sample_top_k: int) -> IdxSelectFn: def _sample(scores: torch.Tensor, _: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: assert scores.ndim == 2 k = min(sample_top_k, scores.shape[-1]) - sampled_indices = torch.randint( - k, - (len(scores), 1), - device=scores.device - ) top_k = torch.topk(scores, k, dim=-1) + values = top_k.values + values /= top_k.values.sum(dim=-1, keepdim=True) + sampled = torch.multinomial(values, 1) + sampled_indices = torch.gather(top_k.indices, -1, sampled) indices = torch.gather(top_k.indices, -1, sampled_indices).squeeze(-1) - scores = torch.gather(top_k.values, -1, sampled_indices).squeeze(-1) + scores = torch.gather(top_k.values, -1, indices).squeeze(-1) return indices, scores return _sample