From 13e43b5061f0160f24f2a5e21e0d9c1a98d06f6e Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Thu, 12 Dec 2024 22:06:09 +0100 Subject: [PATCH] add min p --- python/text_utils/inference/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/text_utils/inference/utils.py b/python/text_utils/inference/utils.py index b0c7b31..cea0e93 100644 --- a/python/text_utils/inference/utils.py +++ b/python/text_utils/inference/utils.py @@ -286,3 +286,19 @@ def _nuc( return sorted_logits.gather(-1, indices.argsort(-1)) return _nuc + + +def min_p_masking(min_p: float) -> LogitFn: + assert 0.0 <= min_p <= 1.0, "min_p must be in [0, 1]" + + def _min_p( + _input_ids: torch.Tensor, logits: torch.Tensor, _: list[Beam] + ) -> torch.Tensor: + masked_logits = torch.full_like(logits, float("-inf")) + probs = torch.softmax(logits, dim=-1) + min_probs = probs.max(dim=-1, keepdim=True)[0] * min_p + mask = probs >= min_probs + masked_logits[mask] = logits[mask] + return masked_logits + + return _min_p