Skip to content

Commit

Permalink
Generation utils update (minor) (#1468)
Browse files Browse the repository at this point in the history
  • Loading branch information
yafshar authored Dec 8, 2024
1 parent 899b364 commit e3b32d8
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers.streamers import BaseStreamer
from transformers.generation.streamers import BaseStreamer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .candidate_generator import GaudiCandidateGenerator
Expand Down Expand Up @@ -178,12 +178,12 @@ def _prepare_decoder_attention_mask(
self,
max_steps: int, # current stopping criteria
batch_size: int,
pad_token_id: int,
device: str,
dtype: str = bool,
device: Union[str, torch.device],
dtype: torch.dtype = torch.bool,
) -> torch.Tensor:
x = torch.zeros((batch_size, max_steps), device=device, dtype=dtype)
return x.index_fill(1, torch.tensor(0), 1) # First the position with pad_token_id
decoder_attention_mask = torch.zeros((batch_size, max_steps), device=device, dtype=dtype)
index = torch.tensor(0, device=device)
return decoder_attention_mask.index_fill(1, index, 1) # First position with 1

def _prepare_decoder_input_ids_for_generation(
self,
Expand Down Expand Up @@ -337,32 +337,37 @@ def _expand_dict_for_generation(dict_to_expand):
return input_ids, model_kwargs

def _pad_past_key_values(self, model_kwargs):
# Early return if no past key values to pad
past_key_values = model_kwargs.get("past_key_values")
if not past_key_values:
return

# Determine if the model is MQA or not
is_mqa_model = model_kwargs.get("mqa_model", False)
lazy_mode = model_kwargs.get("lazy_mode", False)
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
kv_cache_len = model_kwargs.get("kv_cache_len", 0)
if model_kwargs["past_key_values"]:
if model_kwargs.get("mqa_model", False):
for i in range(len(model_kwargs["past_key_values"])): # layer
if (
torch.is_tensor(model_kwargs["past_key_values"][i])
and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
model_kwargs["past_key_values"][i] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount)
)
if model_kwargs.get("lazy_mode", False):
kv_cache_len_pad_amount = kv_cache_len - pad_amount

# For MQA models, past_key_values is a tensor
if is_mqa_model:
for i, layer in enumerate(past_key_values): # Iterate over layers
if torch.is_tensor(layer) and layer.shape[-2] == kv_cache_len_pad_amount:
# tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
past_key_values[i] = torch.nn.functional.pad(layer, (0, 0, 0, pad_amount))
# Mark step if lazy mode is enabled
if lazy_mode:
self.htcore_generation.mark_step()
# For Non-MQA models, the past_key_values is a list of lists (k and v)
else:
for i, layer in enumerate(past_key_values): # Iterate over layers
for j, k_or_v in enumerate(layer): # Iterate over k and v
if torch.is_tensor(k_or_v) and k_or_v.shape[-2] == kv_cache_len_pad_amount:
# tensor(batch_size, n_heads, kv_cache_len, head_dim)
past_key_values[i][j] = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount))
# Mark step if lazy mode is enabled
if lazy_mode:
self.htcore_generation.mark_step()
else:
for i in range(len(model_kwargs["past_key_values"])): # layer
for j in range(len(model_kwargs["past_key_values"][i])): # k or v
if (
torch.is_tensor(model_kwargs["past_key_values"][i][j])
and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, n_heads, kv_cache_len, head_dim)
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)
)
if model_kwargs.get("lazy_mode", False):
self.htcore_generation.mark_step()

def _remove_past_key_values(self, model_kwargs):
if model_kwargs["past_key_values"]:
Expand Down Expand Up @@ -1164,7 +1169,6 @@ def generate(
model_kwargs["decoder_attention_mask"] = self._prepare_decoder_attention_mask(
max_length,
inputs_tensor.shape[0],
generation_config.pad_token_id,
inputs_tensor.device,
)

Expand Down

0 comments on commit e3b32d8

Please sign in to comment.