Skip to content

Commit

Permalink
Improve text generation quality for bf16 models when sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya committed Dec 20, 2024
1 parent a51475f commit 7b8180b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,7 +2460,7 @@ def _sample(
if token_idx is not None and outputs.logits.shape[-2] > 1:
# case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size]
if self.config.is_encoder_decoder:
next_token_logits = outputs.logits[:, token_idx - 1, :].float()
next_token_logits = outputs.logits[:, token_idx - 1, :]
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
else:
if model_kwargs.get("num_virtual_tokens", 0) > 0:
Expand All @@ -2475,7 +2475,7 @@ def _sample(
next_token_scores = logits_processor(input_ids, next_token_logits)
else:
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].float()
next_token_logits = outputs.logits[:, -1, :]
if token_idx is not None and self.config.is_encoder_decoder:
# case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size]
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
Expand Down

0 comments on commit 7b8180b

Please sign in to comment.