Skip to content

Commit

Permalink
Fixing tests by making static_shapes False (huggingface#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhargaveede authored Mar 10, 2024
1 parent 2f55de3 commit 639c21b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
5 changes: 0 additions & 5 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,6 @@ def _update_model_kwargs_for_generation(
model_kwargs["attention_mask"] = attention_mask
else:
# update decoder attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
if token_idx is not None:
attention_mask.index_fill_(1, token_idx, 1)
model_kwargs["attention_mask"] = attention_mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
if token_idx is not None:
Expand Down
10 changes: 9 additions & 1 deletion tests/transformers/tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ def _get_encoder_outputs(
attention_mask = None
return encoder_outputs, input_ids, attention_mask

@staticmethod
def _get_static_shapes():
return False

def _greedy_generate(
self,
model,
Expand All @@ -277,7 +281,7 @@ def _greedy_generate(

kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}

model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down Expand Up @@ -337,6 +341,7 @@ def _sample_generate(
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=True,
Expand Down Expand Up @@ -406,6 +411,7 @@ def _beam_search_generate(
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down Expand Up @@ -603,6 +609,7 @@ def _constrained_beam_search_generate(
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down Expand Up @@ -679,6 +686,7 @@ def _contrastive_generate(
kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down

0 comments on commit 639c21b

Please sign in to comment.