Skip to content

Commit

Permalink
add min p
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 12, 2024
1 parent 4a15d6a commit 13e43b5
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,19 @@ def _nuc(
return sorted_logits.gather(-1, indices.argsort(-1))

return _nuc


def min_p_masking(min_p: float) -> LogitFn:
assert 0.0 <= min_p <= 1.0, "min_p must be in [0, 1]"

def _min_p(
_input_ids: torch.Tensor, logits: torch.Tensor, _: list[Beam]
) -> torch.Tensor:
masked_logits = torch.full_like(logits, float("-inf"))
probs = torch.softmax(logits, dim=-1)
min_probs = probs.max(dim=-1, keepdim=True)[0] * min_p
mask = probs >= min_probs
masked_logits[mask] = logits[mask]
return masked_logits

return _min_p

0 comments on commit 13e43b5

Please sign in to comment.