From 7f5b0ccd46fa8227a4f303d75eb71e50093bc984 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Tue, 23 Jul 2024 18:15:34 +0200 Subject: [PATCH] Fix instruct models UI issue (#78) * feat(tgi): allow top_k = 0 and top_p = 1 when do_sample = True This might not be the most elegant solution, but it will allow the server to keep working when the web ui gives a request with these parameters for instruct models. * chore: update version to 0.1.4 --- optimum/tpu/generation/logits_process.py | 5 ----- optimum/tpu/version.py | 2 +- .../server/text_generation_server/version.py | 2 +- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/optimum/tpu/generation/logits_process.py b/optimum/tpu/generation/logits_process.py index 2e40c67f..9e8bd088 100644 --- a/optimum/tpu/generation/logits_process.py +++ b/optimum/tpu/generation/logits_process.py @@ -48,8 +48,6 @@ def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper" Returns: a `FusedLogitsWarper` or None if neither top-k nor top-p are configured. """ - if generation_config.do_sample and generation_config.top_k == 0 and generation_config.top_p == 1.0: - raise ValueError("Multinomial sampling requires at least top-k or top-p to be specified.") return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p) def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: @@ -59,9 +57,6 @@ def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch. do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1] do_top_p = self.top_p < 1.0 and self.top_p > 0.0 - if not do_top_k and not do_top_p: - return logits, None - if do_top_k: sorted_logits, sorted_indices = torch.topk(logits, self.top_k) else: diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index 4bd3c7ec..6078244a 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -15,5 +15,5 @@ from pkg_resources import parse_version -__version__ = "0.1.3" +__version__ = "0.1.4" VERSION = parse_version(__version__) diff --git a/text-generation-inference/server/text_generation_server/version.py b/text-generation-inference/server/text_generation_server/version.py index 9ee11539..16913b8c 100644 --- a/text-generation-inference/server/text_generation_server/version.py +++ b/text-generation-inference/server/text_generation_server/version.py @@ -1,5 +1,5 @@ from pkg_resources import parse_version -__version__ = "0.1.3" +__version__ = "0.1.4" VERSION = parse_version(__version__)