From 7b8180b20b6eb62b692b99c610dde80314930c61 Mon Sep 17 00:00:00 2001
From: "Kavulya, Soila P" <soila.p.kavulya@intel.com>
Date: Thu, 19 Dec 2024 16:04:16 -0800
Subject: [PATCH] Improve text generation quality for bf16 models when sampling

---
 optimum/habana/transformers/generation/utils.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index d81e0d179a..7dbd06a160 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -2460,7 +2460,7 @@ def _sample(
             if token_idx is not None and outputs.logits.shape[-2] > 1:
                 # case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size]
                 if self.config.is_encoder_decoder:
-                    next_token_logits = outputs.logits[:, token_idx - 1, :].float()
+                    next_token_logits = outputs.logits[:, token_idx - 1, :]
                     next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
                 else:
                     if model_kwargs.get("num_virtual_tokens", 0) > 0:
@@ -2475,7 +2475,7 @@ def _sample(
                     next_token_scores = logits_processor(input_ids, next_token_logits)
             else:
                 # .float() is needed to retain precision for later logits manipulations
-                next_token_logits = outputs.logits[:, -1, :].float()
+                next_token_logits = outputs.logits[:, -1, :]
                 if token_idx is not None and self.config.is_encoder_decoder:
                     # case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size]
                     next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)