From 7b8180b20b6eb62b692b99c610dde80314930c61 Mon Sep 17 00:00:00 2001 From: "Kavulya, Soila P" <soila.p.kavulya@intel.com> Date: Thu, 19 Dec 2024 16:04:16 -0800 Subject: [PATCH] Improve text generation quality for bf16 models when sampling --- optimum/habana/transformers/generation/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d81e0d179a..7dbd06a160 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2460,7 +2460,7 @@ def _sample( if token_idx is not None and outputs.logits.shape[-2] > 1: # case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size] if self.config.is_encoder_decoder: - next_token_logits = outputs.logits[:, token_idx - 1, :].float() + next_token_logits = outputs.logits[:, token_idx - 1, :] next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits) else: if model_kwargs.get("num_virtual_tokens", 0) > 0: @@ -2475,7 +2475,7 @@ def _sample( next_token_scores = logits_processor(input_ids, next_token_logits) else: # .float() is needed to retain precision for later logits manipulations - next_token_logits = outputs.logits[:, -1, :].float() + next_token_logits = outputs.logits[:, -1, :] if token_idx is not None and self.config.is_encoder_decoder: # case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size] next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)