From 2ba520a9f12467bfeb25a8fde4dee7caf27c0067 Mon Sep 17 00:00:00 2001 From: Gustavo Malkomes Date: Wed, 11 Dec 2024 11:12:04 -0600 Subject: [PATCH] Update transformers tests generation util v4.45.2 (#1441) Co-authored-by: Gustavo Co-authored-by: Yaser Afshar Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- conftest.py | 141 + .../habana/transformers/generation/utils.py | 29 +- .../transformers/models/bart/modeling_bart.py | 4 +- pyproject.toml | 9 + .../generation/test_framework_agnostic.py | 43 +- .../tests/generation/test_utils.py | 3975 ++++++++++------- 6 files changed, 2611 insertions(+), 1590 deletions(-) diff --git a/conftest.py b/conftest.py index 71cb6bb7ca..5775644c48 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,88 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# tests directory-specific settings - this file is run automatically +# by pytest before any tests are run +import doctest +import sys +import warnings +from os.path import abspath, dirname, join + +import _pytest +import pytest +from transformers.testing_utils import HfDoctestModule, HfDocTestParser + + +NOT_DEVICE_TESTS = { + "test_tokenization", + "test_processor", + "test_processing", + "test_beam_constraints", + "test_configuration_utils", + "test_data_collator", + "test_trainer_callback", + "test_trainer_utils", + "test_feature_extraction", + "test_image_processing", + "test_image_processor", + "test_image_transforms", + "test_optimization", + "test_retrieval", + "test_config", + "test_from_pretrained_no_checkpoint", + "test_keep_in_fp32_modules", + "test_gradient_checkpointing_backward_compatibility", + "test_gradient_checkpointing_enable_disable", + "test_save_load_fast_init_from_base", + "test_fast_init_context_manager", + "test_fast_init_tied_embeddings", + "test_save_load_fast_init_to_base", + "test_torch_save_load", + "test_initialization", + "test_forward_signature", + "test_model_get_set_embeddings", + "test_model_main_input_name", + "test_correct_missing_keys", + "test_tie_model_weights", + "test_can_use_safetensors", + "test_load_save_without_tied_weights", + "test_tied_weights_keys", + "test_model_weights_reload_no_missing_tied_weights", + "test_pt_tf_model_equivalence", + "test_mismatched_shapes_have_properly_initialized_weights", + "test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist", + "test_model_is_small", + "test_tf_from_pt_safetensors", + "test_flax_from_pt_safetensors", + "ModelTest::test_pipeline_", # None of the pipeline tests from PipelineTesterMixin (of which XxxModelTest inherits from) are running on device + "ModelTester::test_pipeline_", + "/repo_utils/", + "/utils/", + "/agents/", +} + +# allow having multiple repository checkouts and not needing to remember to rerun +# `pip install -e '.[dev]'` when switching between checkouts and running tests. +git_repo_path = abspath(join(dirname(__file__), "src")) +sys.path.insert(1, git_repo_path) + +# silence FutureWarning warnings in tests since often we can't act on them until +# they become normal warnings - i.e. the tests still need to test the current functionality +warnings.simplefilter(action="ignore", category=FutureWarning) + + class Secret: """ Taken from: https://stackoverflow.com/a/67393351 @@ -13,9 +98,47 @@ def __str___(self): return "*******" +def pytest_configure(config): + config.addinivalue_line( + "markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested" + ) + config.addinivalue_line( + "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" + ) + config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested") + config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment") + config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate") + config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule") + config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu") + + +def pytest_collection_modifyitems(items): + for item in items: + if any(test_name in item.nodeid for test_name in NOT_DEVICE_TESTS): + item.add_marker(pytest.mark.not_device_test) + + def pytest_addoption(parser): parser.addoption("--token", action="store", default=None) + from transformers.testing_utils import pytest_addoption_shared + + pytest_addoption_shared(parser) + + +def pytest_terminal_summary(terminalreporter): + from transformers.testing_utils import pytest_terminal_summary_main + + make_reports = terminalreporter.config.getoption("--make-reports") + if make_reports: + pytest_terminal_summary_main(terminalreporter, id=make_reports) + + +def pytest_sessionfinish(session, exitstatus): + # If no tests are collected, pytest exists with code 5, which makes the CI fail. + if exitstatus == 5: + session.exitstatus = 0 + def pytest_generate_tests(metafunc): # This is called for every test. Only get/set command line arguments @@ -23,3 +146,21 @@ def pytest_generate_tests(metafunc): option_value = Secret(metafunc.config.option.token) if "token" in metafunc.fixturenames: metafunc.parametrize("token", [option_value]) + + +# Doctest custom flag to ignore output. +IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT") + +OutputChecker = doctest.OutputChecker + + +class CustomOutputChecker(OutputChecker): + def check_output(self, want, got, optionflags): + if IGNORE_RESULT & optionflags: + return True + return OutputChecker.check_output(self, want, got, optionflags) + + +doctest.OutputChecker = CustomOutputChecker +_pytest.doctest.DoctestModule = HfDoctestModule +doctest.DocTestParser = HfDocTestParser diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 68b445c1b2..d81e0d179a 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -211,19 +211,20 @@ def _prepare_decoder_input_ids_for_generation( # 2. `decoder_start_token_id` must have shape (batch_size, 1) if device is None: device = self.device - if token_idx is None: - if decoder_start_token_id.ndim == 1: - if decoder_start_token_id.shape[0] != batch_size: - raise ValueError( - f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" - ) - decoder_start_token_id = decoder_start_token_id.view(-1, 1) - else: - decoder_start_token_id = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: + raise ValueError( + f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" ) + decoder_start_token_id = decoder_start_token_id.view(-1, 1) else: - # creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx + decoder_start_token_id = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + ) + + if token_idx is not None: + # creating padded decoder_input_ids to achieve static shapes. + # Later new tokens once generated are copied in to decoder_input_ids based on token_idx max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length decoder_start_token_id = ( torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id @@ -3039,7 +3040,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1): if self.generation_config.early_stopping: num_eos_tokens.add_(beam_tokens[0:num_beams].eq(self.config.eos_token_id).sum()) - beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0)) + if self.config.eos_token_id is not None: + beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0)) beam_scores = beam_scores.view(batch_size, -1).unsqueeze(0) _, selected = torch.topk(beam_scores, k=num_beams, dim=-1, largest=True, sorted=True) offset = torch.arange(0, torch.numel(beam_scores), beam_scores.shape[-1]).unsqueeze(-1) @@ -3211,6 +3213,9 @@ def move(obj, device): if not output_scores: sequence_outputs["sequence_scores"] = None + if self.generation_config.static_shapes: + raise NotImplementedError("sequence_scores is not implemented for static_shapes") + if self.config.is_encoder_decoder: return GenerateBeamEncoderDecoderOutput( sequences=sequence_outputs["sequences"], diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py index 3e5f822cb1..08ea48e1a5 100644 --- a/optimum/habana/transformers/models/bart/modeling_bart.py +++ b/optimum/habana/transformers/models/bart/modeling_bart.py @@ -458,7 +458,9 @@ def gaudi_BartDecoder_forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - tensor_past_key_values_length = token_idx - 1 if use_cache else torch.tensor(past_key_values_length) + tensor_past_key_values_length = ( + token_idx - 1 if (use_cache and token_idx is not None) else torch.tensor(past_key_values_length) + ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) diff --git a/pyproject.toml b/pyproject.toml index b7896da5e8..f53b25d1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,3 +41,12 @@ skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. line-ending = "auto" + +[tool.pytest.ini_options] +addopts = "--doctest-glob='**/*.md'" +doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" +markers = [ + "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", + "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", + "generate: marks tests that use the GenerationTesterMixin" +] diff --git a/tests/transformers/tests/generation/test_framework_agnostic.py b/tests/transformers/tests/generation/test_framework_agnostic.py index 7fcc4de752..906a90a95a 100644 --- a/tests/transformers/tests/generation/test_framework_agnostic.py +++ b/tests/transformers/tests/generation/test_framework_agnostic.py @@ -3,8 +3,12 @@ """ import numpy as np +import pytest from transformers import AutoTokenizer -from transformers.testing_utils import slow, torch_device +from transformers.testing_utils import slow + + +torch_device = "hpu" class GenerationIntegrationTestsMixin: @@ -46,6 +50,8 @@ def test_validate_generation_inputs(self): valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))} model.generate(input_ids, **valid_model_kwargs) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_custom_logits_processor(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"] @@ -66,6 +72,8 @@ def test_custom_logits_processor(self): bart_model.config.min_length = None bart_model.generate(input_ids, logits_processor=logits_processor) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_max_new_tokens_encoder_decoder(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -222,6 +230,8 @@ def test_transition_scores_greedy_search_normalized(self): ) self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3)) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_transition_scores_beam_search_encoder_decoder(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -257,6 +267,8 @@ def test_transition_scores_beam_search_encoder_decoder(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_transition_scores_beam_search_encoder_decoder_with_eos(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -291,6 +303,8 @@ def test_transition_scores_beam_search_encoder_decoder_with_eos(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_transition_scores_beam_search_decoder_only(self): model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -328,6 +342,8 @@ def test_transition_scores_beam_search_decoder_only(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_transition_scores_beam_sample_encoder_decoder(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -365,6 +381,7 @@ def test_transition_scores_beam_sample_encoder_decoder(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) @slow + @pytest.mark.skip("Not Implemented: sequence_scores is not implemented for static_shapes") def test_transition_scores_early_stopping(self): # This is an aggressive test that makes sure that `beam_search's` # transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1` @@ -400,6 +417,8 @@ def test_transition_scores_early_stopping(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores)) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_encoder_decoder_generate_attention_mask(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -501,6 +520,8 @@ def test_generate_too_many_encoder_kwargs(self): with self.assertRaises(ValueError): model.generate(input_ids=input_ids, inputs_embeds=input_ids) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_generate_input_features_as_encoder_kwarg(self): model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"] floats_tensor = self.framework_dependent_parameters["floats_tensor"] @@ -542,6 +563,8 @@ def test_generate_pixel_values_as_encoder_kwarg(self): self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs)) self.assertEqual(output_sequences.shape, (2, 5)) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_generate_encoder_outputs_attention_mask(self): model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"] floats_tensor = self.framework_dependent_parameters["floats_tensor"] @@ -576,7 +599,6 @@ def test_eos_token_id_int_and_list_greedy_search(self): "do_sample": False, "num_beams": 1, } - expectation = 13 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" @@ -586,6 +608,7 @@ def test_eos_token_id_int_and_list_greedy_search(self): model = model.to(torch_device) tokens = tokens.to(torch_device) + expectation = model.config.max_length # static shape should give max_length eos_token_id = 873 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) @@ -605,7 +628,6 @@ def test_eos_token_id_int_and_list_contrastive_search(self): "penalty_alpha": 0.6, "top_k": 4, } - expectation = 17 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" @@ -615,6 +637,7 @@ def test_eos_token_id_int_and_list_contrastive_search(self): model = model.to(torch_device) tokens = tokens.to(torch_device) + expectation = model.config.max_length # static shape should give max_length eos_token_id = 225 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) @@ -623,6 +646,8 @@ def test_eos_token_id_int_and_list_contrastive_search(self): generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_eos_token_id_int_and_list_beam_search(self): model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -648,7 +673,10 @@ def test_eos_token_id_int_and_list_beam_search(self): padded_correct_condition = expectation < len(generated_tokens[0]) and all( token == model.config.pad_token_id for token in generated_tokens[0][expectation:] ) - self.assertTrue(unpadded_correct_condition or padded_correct_condition) + static_shape_condition = expectation < len(generated_tokens[0]) and all( + token == eos_token_id for token in generated_tokens[0][expectation:] + ) + self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition) eos_token_id = [873, 198] generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) @@ -656,8 +684,13 @@ def test_eos_token_id_int_and_list_beam_search(self): padded_correct_condition = expectation < len(generated_tokens[0]) and all( token == model.config.pad_token_id for token in generated_tokens[0][expectation:] ) - self.assertTrue(unpadded_correct_condition or padded_correct_condition) + static_shape_condition = expectation < len(generated_tokens[0]) and all( + token in eos_token_id for token in generated_tokens[0][expectation:] + ) + self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_generate_vision2text_conditioning(self): model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"] floats_tensor = self.framework_dependent_parameters["floats_tensor"] diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index 512935e9dd..954bcd14d5 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -14,14 +14,27 @@ # limitations under the License. +import copy import inspect +import tempfile import unittest import warnings import numpy as np import pytest -from transformers import is_torch_available, pipeline -from transformers.testing_utils import require_torch, slow +from parameterized import parameterized +from transformers import is_torch_available, pipeline, set_seed +from transformers.testing_utils import ( + is_flaky, + require_accelerate, + require_auto_gptq, + require_quanto, + require_torch, + require_torch_gpu, + require_torch_multi_accelerator, + require_torch_multi_gpu, + slow, +) from optimum.habana.checkpoint_utils import model_is_optimized from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -32,54 +45,50 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, + AutoProcessor, AutoTokenizer, + BartForCausalLM, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer, ImageGPTForCausalImageModeling, - PreTrainedModel, SpeechEncoderDecoderModel, + T5ForConditionalGeneration, ) + from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, - BeamSearchScorer, - ConstrainedBeamSearchScorer, DisjunctiveConstraint, - ForcedBOSTokenLogitsProcessor, - ForcedEOSTokenLogitsProcessor, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, + GenerationConfig, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, - HammingDiversityLogitsProcessor, LogitsProcessorList, MaxLengthCriteria, MinLengthLogitsProcessor, - NoBadWordsLogitsProcessor, - NoRepeatNGramLogitsProcessor, PhrasalConstraint, - RepetitionPenaltyLogitsProcessor, + PromptLookupCandidateGenerator, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput, StoppingCriteria, StoppingCriteriaList, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, + WatermarkDetector, + WatermarkingConfig, ) - from transformers.generation.candidate_generator import AssistedCandidateGenerator, CandidateGenerator - from transformers.generation.streamers import BaseStreamer + from transformers.generation.utils import _speculative_sampling torch_device = "hpu" adapt_transformers_to_gaudi() @@ -91,116 +100,84 @@ class GenerationTesterMixin: input_name = "input_ids" max_new_tokens = 3 - def _update_default_model_kwargs(self, model_kwargs): - model_kwargs["limit_hpu_graphs"] = False - model_kwargs["reuse_cache"] = False - model_kwargs["bucket_size"] = -1 - def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict[self.input_name] + # TODO: @raushan or @gante, use `model.main_input_name` as the main input instead of relyinn on `input_ids` + input_ids = inputs_dict.pop(self.input_name)[:batch_size, :] + inputs_dict.pop("attention_mask", None) + + # we don't want encoder-decoder models to start from filled decoder ids + inputs_dict.pop("decoder_input_ids", None) + inputs_dict.pop("decoder_attention_mask", None) # cut to half length & take max batch_size 3 sequence_length = input_ids.shape[-1] // 2 input_ids = input_ids[:batch_size, :sequence_length] - # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 + # we'll set cache use in each test differently + inputs_dict.pop("use_cache", None) + + inputs_dict = { + k: v[:batch_size, ...] + for k, v in inputs_dict.items() + if "head_mask" not in k and isinstance(v, torch.Tensor) + } if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` if isinstance(config.eos_token_id, int): config.eos_token_id = [config.eos_token_id] config.pad_token_id = config.eos_token_id[0] - # TransfoXL has no attention mask - if "transfoxl" in config.__class__.__name__.lower(): - attention_mask = None + + if self.has_attentions: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) else: - attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length] - - return config, input_ids, attention_mask, max_length - - @staticmethod - def _get_logits_processor_and_kwargs( - input_length, - eos_token_id, - forced_bos_token_id=None, - forced_eos_token_id=None, - max_length=None, - diversity_penalty=None, - ): - process_kwargs = { - "min_length": input_length + 1 if max_length is None else max_length - 1, + attention_mask = None + + # It is important set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated + config.eos_token_id = None + config.forced_eos_token_id = None + + return config, input_ids, attention_mask, inputs_dict + + def _get_logits_processor_kwargs(self, do_sample=False, config=None): + logits_processor_kwargs = { "bad_words_ids": [[1, 0]], - "no_repeat_ngram_size": 2, "repetition_penalty": 1.2, + "remove_invalid_values": True, } - logits_processor = LogitsProcessorList( - ( - [ - HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2), - ] - if diversity_penalty is not None - else [] - ) - + ( - [ - MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), - ] - if eos_token_id is not None - else [] + if do_sample: + logits_processor_kwargs.update( + { + "top_k": 10, + "top_p": 0.7, + "temperature": 0.7, + } ) - + ( - [ - ForcedBOSTokenLogitsProcessor(forced_bos_token_id), - ] - if forced_bos_token_id is not None - else [] - ) - + ( - [ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)] - if forced_eos_token_id is not None - else [] - ) - + [ - NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), - NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), - RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]), - ] - ) - return process_kwargs, logits_processor - - @staticmethod - def _get_warper_and_kwargs(num_beams): - warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} - logits_warper = LogitsProcessorList( - [ - TemperatureLogitsWarper(warp_kwargs["temperature"]), - TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), - TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), - ] - ) - return warp_kwargs, logits_warper - - @staticmethod - def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): + # TODO (joao, raushan): see this comment for a long-term fix + # https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264) + # This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them + # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens. + if config is not None: + image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None + video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None + if image_token_index is not None and image_token_index < config.get_text_config().vocab_size: + logits_processor_kwargs["bad_words_ids"].append([image_token_index]) + if video_token_index is not None and video_token_index < config.get_text_config().vocab_size: + logits_processor_kwargs["bad_words_ids"].append([video_token_index]) + + return logits_processor_kwargs + + def _get_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, "num_beams": 2, "num_return_sequences": num_return_sequences, } - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=beam_kwargs["num_beams"], - device=torch_device, - length_penalty=beam_kwargs["length_penalty"], - do_early_stopping=beam_kwargs["early_stopping"], - num_beam_hyps_to_keep=num_return_sequences, - ) - return beam_kwargs, beam_scorer + return beam_kwargs - @staticmethod - def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): + def _get_diverse_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -209,93 +186,46 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque "num_beam_groups": 2, # one beam per group "diversity_penalty": 2.0, } - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=beam_kwargs["num_beams"], - device=torch_device, - length_penalty=beam_kwargs["length_penalty"], - do_early_stopping=beam_kwargs["early_stopping"], - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=beam_kwargs["num_beam_groups"], - ) - return beam_kwargs, beam_scorer + return beam_kwargs - @staticmethod - def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1): + def _get_constrained_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, "num_beams": num_return_sequences * 4, "num_return_sequences": num_return_sequences, } - beam_scorer = ConstrainedBeamSearchScorer( - batch_size=batch_size, - constraints=constraints, - num_beams=beam_kwargs["num_beams"], - device=torch_device, - length_penalty=beam_kwargs["length_penalty"], - do_early_stopping=beam_kwargs["early_stopping"], - num_beam_hyps_to_keep=num_return_sequences, - ) - return beam_kwargs, beam_scorer - - @staticmethod - def _get_encoder_outputs( - model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 - ): - encoder = model.get_encoder() - encoder_outputs = encoder( - input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( - num_interleave, dim=0 - ) - input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id() - attention_mask = None - return encoder_outputs, input_ids, attention_mask - - @staticmethod - def _get_static_shapes(): - return False + return beam_kwargs def _greedy_generate( self, model, input_ids, attention_mask, - max_length, + inputs_dict, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - eos_token_id=model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, - ) - + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, + output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, - **logits_process_kwargs, + use_cache=use_cache, + **logits_processor_kwargs, **model_kwargs, + **inputs_dict, ) return output_generate @@ -305,35 +235,33 @@ def _sample_generate( model, input_ids, attention_mask, - max_length, + inputs_dict, num_return_sequences, - logits_processor, - logits_warper, - logits_warper_kwargs, - process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=True, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, num_return_sequences=num_return_sequences, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, - **logits_warper_kwargs, - **process_kwargs, + use_cache=use_cache, + **logits_processor_kwargs, **model_kwargs, + **inputs_dict, ) return output_generate @@ -343,31 +271,31 @@ def _beam_search_generate( model, input_ids, attention_mask, - max_length, - beam_scorer, + inputs_dict, beam_kwargs, - logits_processor, - logits_process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, + use_cache=use_cache, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, + **inputs_dict, ) return output_generate @@ -377,32 +305,34 @@ def _beam_sample_generate( model, input_ids, attention_mask, - max_length, - beam_scorer, + inputs_dict, beam_kwargs, - logits_warper, - logits_warper_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) output_generate = model.generate( input_ids, do_sample=True, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, + use_cache=use_cache, **beam_kwargs, - **logits_warper_kwargs, + **logits_processor_kwargs, **model_kwargs, + **inputs_dict, ) + return output_generate def _group_beam_search_generate( @@ -410,30 +340,31 @@ def _group_beam_search_generate( model, input_ids, attention_mask, - max_length, - beam_scorer, + inputs_dict, beam_kwargs, - logits_processor, - logits_process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) output_generate = model.generate( input_ids, do_sample=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, + use_cache=use_cache, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, + **inputs_dict, ) return output_generate @@ -443,33 +374,33 @@ def _constrained_beam_search_generate( model, input_ids, attention_mask, - max_length, - constrained_beam_scorer, + inputs_dict, constraints, beam_kwargs, - logits_processor, - logits_process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, constraints=constraints, + use_cache=use_cache, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, + **inputs_dict, ) return output_generate @@ -479,76 +410,72 @@ def _contrastive_generate( model, input_ids, attention_mask, - max_length, + inputs_dict, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): contrastive_search_kwargs = { "penalty_alpha": 0.6, "top_k": 5, } - if model.config.is_encoder_decoder: - max_length = 4 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - eos_token_id=model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, - ) - + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, num_beams=1, - max_length=max_length, + max_new_tokens=self.max_new_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, + output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, - **logits_process_kwargs, + use_cache=use_cache, + **logits_processor_kwargs, **model_kwargs, **contrastive_search_kwargs, + **inputs_dict, ) return output_generate + @pytest.mark.generate def test_greedy_generate(self): - # check `generate()` and `greedy_search()` are equal for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - # test old generation output for backwards compatibility + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length + model=model, input_ids=input_ids, attention_mask=attention_mask, inputs_dict=inputs_dict ) + if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + @pytest.mark.generate def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: - # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config.use_cache = False + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, + inputs_dict=inputs_dict, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -564,58 +491,50 @@ def test_greedy_generate_dict_outputs(self): self._check_outputs(output_generate, input_ids, model.config) + @pytest.mark.generate def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: - # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): - # only relevant if model has "use_cache" - return + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): + self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, + inputs_dict=inputs_dict, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self._check_outputs(output_generate, input_ids, model.config, use_cache=True) + @pytest.mark.generate def test_sample_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - model = model_class(config).to(torch_device).eval() - - if model.config.is_encoder_decoder: - max_length = 4 - - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, - ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, + inputs_dict=inputs_dict, num_return_sequences=1, - logits_processor=logits_processor, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, ) if model.config.is_encoder_decoder: @@ -623,38 +542,24 @@ def test_sample_generate(self): else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + @pytest.mark.generate def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - # disable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config.use_cache = False - model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - model.config.eos_token_id, - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - max_length=max_length, - ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, + inputs_dict=inputs_dict, num_return_sequences=2, - logits_processor=logits_processor, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -670,38 +575,20 @@ def test_sample_generate_dict_output(self): self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2) + @pytest.mark.generate def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - config.eos_token_id, - config.forced_bos_token_id, - config.forced_eos_token_id, - max_length, - ) - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, ) if model.config.is_encoder_decoder: @@ -709,72 +596,26 @@ def test_beam_search_generate(self): else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - - output_generate = self._beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, - ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - + @pytest.mark.generate def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - config.eos_token_id, - config.forced_bos_token_id, - config.forced_eos_token_id, - max_length, - ) - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) - if model.config.is_encoder_decoder: - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) - else: - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) @@ -790,148 +631,139 @@ def test_beam_search_generate_dict_output(self): output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) + @pytest.mark.generate def test_beam_search_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): - # only relevant if model has "use_cache" - return + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): + self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - config.eos_token_id, - config.forced_bos_token_id, - config.forced_eos_token_id, - max_length, - ) - - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - logits_processor=logits_processor, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self._check_outputs( - output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams + output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"] ) - @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") - def test_beam_sample_generate(self): + @require_accelerate + @require_torch_multi_accelerator + @pytest.mark.generate + def test_model_parallel_beam_search(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + if "xpu" in torch_device: + return unittest.skip(reason="device_map='auto' does not work with XPU devices") - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + if model_class._no_split_modules is None: + continue - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - model = model_class(config).to(torch_device).eval() + model = model_class(config).eval() + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + new_model = model_class.from_pretrained(tmp_dir, device_map="auto") - # check `generate()` and `beam_search()` are equal - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + new_model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=self.max_new_tokens, + num_beams=2, + **inputs_dict, + ) + + @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") + @pytest.mark.generate + def test_beam_sample_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): - input_embeds = model.get_input_embeddings()(input_ids) - beam_kwargs.update({"inputs_embeds": input_embeds}) - output_generate2 = self._beam_sample_generate( - model=model, - input_ids=None, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, + + # for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly + # no quick fix available, since obtaining image embeddings step is very model-specific + if any(name in model.__class__.__name__.lower() for name in ("blip", "llava", "paligemma")): + prepare_inputs_for_generation_args = set( + inspect.signature(model.prepare_inputs_for_generation).parameters ) + # `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling + # code is up to date with our most recent standards + if ( + "inputs_embeds" in prepare_inputs_for_generation_args + and "cache_positions" in prepare_inputs_for_generation_args + ): + input_embeds = model.get_input_embeddings()(input_ids) + beam_kwargs.update({"inputs_embeds": input_embeds}) + output_generate2 = self._beam_sample_generate( + model=model, + input_ids=None, + attention_mask=attention_mask, + inputs_dict={}, + beam_kwargs=beam_kwargs, + ) - torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) + torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") + @pytest.mark.generate def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) @@ -947,192 +779,131 @@ def test_beam_sample_generate_dict_output(self): output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) + @pytest.mark.generate def test_generate_without_input_ids(self): - config, _, _, max_length = self._get_input_ids_and_config() + config, _, _, _ = self._get_input_ids_and_config() # if no bos token id => cannot generate from None if config.bos_token_id is None: - return + self.skipTest(reason="bos_token_id is None") + + # hack in case they are equal, otherwise the attn mask will be [0] + if config.bos_token_id == config.pad_token_id: + config.pad_token_id = None for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) model.eval() - output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) + output_ids_generate = model.generate( + do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True + ) self.assertIsNotNone(output_ids_generate) @pytest.mark.skip("Group beam search is not supported by optimum-habana") + @pytest.mark.generate def test_group_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - config.eos_token_id, - config.forced_bos_token_id, - config.forced_eos_token_id, - max_length, - diversity_penalty=2.0, - ) - # check `generate()` and `group_beam_search()` are equal - beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_generate, output_group_beam_search = self._group_beam_search_generate( + beam_kwargs = self._get_diverse_beam_kwargs() + output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, - logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - # check `generate()` and `group_beam_search()` are equal for `num_return_sequences` + # check `group_beam_search` for higher than 1 `num_return_sequences` num_return_sequences = 2 - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, num_return_sequences=num_return_sequences - ) - output_generate, output_group_beam_search = self._group_beam_search_generate( + beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences) + output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, - logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) @pytest.mark.skip("Group beam search is not supported by optimum-habana") + @pytest.mark.generate def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config.use_cache = False - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 4 - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - config.eos_token_id, - config.forced_bos_token_id, - config.forced_eos_token_id, - max_length, - diversity_penalty=2.0, - ) - - num_return_sequences = 1 - beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, num_return_sequences=num_return_sequences - ) - output_generate, output_group_beam_search = self._group_beam_search_generate( + beam_kwargs = self._get_diverse_beam_kwargs() + output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - beam_scorer=beam_scorer, + inputs_dict=inputs_dict, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, - logits_process_kwargs=logits_process_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist()) - self.assertTrue( - torch.allclose( - output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3 - ) + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - - for output in (output_group_beam_search, output_generate): - self._check_outputs( - output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams - ) + # TODO: @gante + @is_flaky() + @pytest.mark.generate def test_constrained_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - max_length = 20 - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - config.eos_token_id, - config.forced_bos_token_id, - config.forced_eos_token_id, - max_length, - ) - - # check `generate()` and `constrained_beam_search()` are equal # Sample constraints - if not input_ids.dtype == torch.float32: - min_id = torch.min(input_ids) + 3 - max_id = torch.max(input_ids) - else: - # otherwise this throws an error for Speech2TextModel since its inputs are floating points - min_id = 3 - max_id = 100 + min_id = 3 + max_id = config.get_text_config(decoder=True).vocab_size force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] - beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, constraints, num_return_sequences=1 - ) + beam_kwargs = self._get_constrained_beam_kwargs() output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - constrained_beam_scorer=beam_scorer, + inputs_dict=inputs_dict, constraints=constraints, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, - logits_process_kwargs=logits_process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -1144,86 +915,63 @@ def test_constrained_beam_search_generate(self): PhrasalConstraint(force_tokens), ] - num_return_sequences = 2 - max_length = 20 - - beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences - ) + beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2) output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - constrained_beam_scorer=beam_scorer, + inputs_dict=inputs_dict, constraints=constraints, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, - logits_process_kwargs=logits_process_kwargs, ) - self.assertTrue(output_generate.shape[-1] == max_length) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) + @pytest.mark.generate def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - if model.config.is_encoder_decoder: - max_length = 20 - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], - config.eos_token_id, - config.forced_bos_token_id, - config.forced_eos_token_id, - max_length, - ) # Sample constraints min_id = 3 - max_id = model.config.vocab_size + max_id = model.config.get_text_config(decoder=True).vocab_size force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] - beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, constraints, num_return_sequences=1 - ) + beam_kwargs = self._get_constrained_beam_kwargs() output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, - constrained_beam_scorer=beam_scorer, + inputs_dict=inputs_dict, constraints=constraints, beam_kwargs=beam_kwargs, - logits_processor=logits_processor, - logits_process_kwargs=logits_process_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) - self.assertTrue(output_generate.sequences.shape[-1] == max_length) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) @@ -1232,47 +980,52 @@ def test_constrained_beam_search_generate_dict_output(self): output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - + @pytest.mark.generate def test_contrastive_generate(self): - # check `generate()` and `contrastive_search()` are equal for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support contrastive search generation") + # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return + self.skipTest(reason="Won't fix: old model with different cache format") - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return - config.use_cache = True + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() output_generate = self._contrastive_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_dict=inputs_dict, + use_cache=True, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + @pytest.mark.generate def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support contrastive search generation") + # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return + self.skipTest(reason="Won't fix: old model with different cache format") - # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return - config.use_cache = True + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -1280,36 +1033,40 @@ def test_contrastive_generate_dict_outputs_use_cache(self): model=model, input_ids=input_ids, attention_mask=attention_mask, - max_length=max_length, + inputs_dict=inputs_dict, output_scores=True, + output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self._check_outputs(output_generate, input_ids, model.config, use_cache=True) + @pytest.mark.generate def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: - # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). - if any( - model_name in model_class.__name__.lower() - for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] - ): - return + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support contrastive search generation") - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): + self.skipTest(reason="Won't fix: old model with different cache format") + if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): + self.skipTest(reason="TODO: fix me") + + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - config.use_cache = True config.is_decoder = True # test output equality of low versus high memory @@ -1320,8 +1077,10 @@ def test_contrastive_generate_low_memory(self): top_k=4, penalty_alpha=0.6, low_memory=True, - max_length=max_length, + max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + **inputs_dict, + use_cache=True, ) high_output = model.generate( @@ -1329,8 +1088,10 @@ def test_contrastive_generate_low_memory(self): top_k=4, penalty_alpha=0.6, low_memory=False, - max_length=max_length, + max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + **inputs_dict, + use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1377,89 +1138,75 @@ def test_contrastive_generate_dynamic_shapes(self): ) self.assertListEqual(dynamic_output.tolist(), static_output.tolist()) - # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable - @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana") - @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. - def test_assisted_decoding_matches_greedy_search(self): - # This test ensures that the assisted generation does not introduce output changes over greedy search. - # It breaks the pattern in the tests above, for multiple reasons: - # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to - # prepare the assistant encoder outputs in the main generate body); - # - assisted_decoding does not support `use_cache = False` - # - assisted_decoding does not support `batch_size > 1` - + @pytest.mark.generate + @unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703") + def test_beam_search_low_memory(self): + # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: - # won't fix: FSMT and Reformer have a different cache variable type (and format). + if model_class._is_stateful: + self.skipTest(reason="May fix in the future: need custom cache handling") if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return - # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes + self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in [ + "ctrl", + "gptbigcode", + "transo_xl", + "xlnet", + "cpm", + "jamba", + ] ): - return - - # This for loop is a naive and temporary effort to make the test less flaky. - failed = 0 - for i in range(10): - # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + self.skipTest(reason="May fix in the future: need model-specific fixes") + config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2) + # batch_size=1 is ok, but batch_size>1 will cause non-identical output - # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): - return + config.use_cache = True + config.is_decoder = True - config.use_cache = True - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - output_greedy = model.generate( - input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_beams=1, - do_sample=False, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will - # be correct - output_assisted = model.generate( - input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_beams=1, - do_sample=False, - assistant_model=model, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) + # test output equality of low versus high memory + model = model_class(config).to(torch_device).eval() - try: - self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + low_output = model.generate( + input_ids, + max_new_tokens=8, + num_beams=5, + early_stopping=True, + low_memory=True, + use_cache=True, + ) - for output in (output_greedy, output_assisted): - self._check_outputs(output, input_ids, model.config, use_cache=True) - except AssertionError: - failed += 1 - if failed > 1: - self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + high_output = model.generate( + input_ids, + max_new_tokens=8, + num_beams=5, + early_stopping=True, + low_memory=False, + use_cache=True, + ) + self.assertListEqual(low_output.tolist(), high_output.tolist()) - for output in (output_greedy, output_assisted): - self._check_outputs(output, input_ids, model.config, use_cache=True) + @pytest.mark.generate + @parameterized.expand([("random",), ("same",)]) + @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + # This test ensures that the assisted generation does not introduce output changes over greedy search. + # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul + # shape differences -- and it may result in a different output. The input shape difference happens in the + # main model, that runs the forward pass with several candidates at once (as opposed to generating one token at + # a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. + # NOTE (2): It breaks the pattern in the tests above, for multiple reasons: + # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to + # prepare the assistant encoder outputs in the main generate body); + # - assisted_decoding does not support `use_cache = False` + # - assisted_decoding does not support `batch_size > 1` - # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable - @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana") - def test_assisted_decoding_sample(self): - # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not - # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with - # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support assisted generation") if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest("Won't fix: old model with different cache format") + self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ @@ -1473,16 +1220,15 @@ def test_assisted_decoding_sample(self): "clvp", ] ): - self.skipTest("May fix in the future: need model-specific fixes") + self.skipTest(reason="May fix in the future: need model-specific fixes") # enable cache - config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): - self.skipTest("This model doesn't support caching") + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1491,503 +1237,253 @@ def test_assisted_decoding_sample(self): # the assistant model is correct # c) there are at least two forward passes in the main model, to ensure the input preparation of # the main model is correct - assistant_model = model + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see c) + "num_beams": 1, + "do_sample": False, + "output_scores": True, + "output_logits": True, + "output_hidden_states": True, + "output_attentions": self.has_attentions, + "return_dict_in_generate": True, + "use_cache": True, + } + output_greedy = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict + ) + + # test with the same assistant model or randomly init one + # in the first case all candidate tokens are accepted, in the second none is accepted + # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :) + if assistant_type == "random": + assistant_model = model_class(config).to(torch_device).eval() + else: + assistant_model = model assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) + generation_kwargs.update({"assistant_model": assistant_model}) + output_assisted = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict + ) + + # The two outputs must match and their shape must be as expected + + self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + for output in (output_greedy, output_assisted): + self._check_outputs(output, input_ids, model.config, use_cache=True) + + @is_flaky() + @pytest.mark.generate + def test_prompt_lookup_decoding_matches_greedy_search(self): + # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. + # This test is mostly a copy of test_assisted_decoding_matches_greedy_search + + for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support assisted generation") + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + self.skipTest(reason="Won't fix: old model with different cache format") + if any( + model_name in model_class.__name__.lower() + for model_name in [ + "bigbirdpegasus", + "led", + "mega", + "speech2text", + "git", + "prophetnet", + "seamlessm4t", + "clvp", + ] + ): + self.skipTest(reason="May fix in the future: need model-specific fixes") + + # enable cache + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) + + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + # Sets assisted generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of + # prompt lookup is correct + # c) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct generation_kwargs = { "eos_token_id": -1, # see a) "max_new_tokens": 4, # see c) "num_beams": 1, - "do_sample": True, - "assistant_model": assistant_model, + "do_sample": False, "output_scores": True, + "output_logits": True, "output_hidden_states": True, - "output_attentions": True, + "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } - ####################################################################### - # Monkey patch assisted decoding function till SW issue is resolved - import copy - from types import MethodType - from typing import List, Optional, Union - - from transformers.generation.utils import ( - GenerateDecoderOnlyOutput, - _crop_past_key_values, - _prepare_attention_mask, - _prepare_token_type_ids, - _split_model_outputs, - ) - - def _speculative_sampling( - candidate_input_ids, - candidate_logits, - candidate_length, - new_logits, - last_assistant_token_is_eos, - max_matches, - ): - """ - Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns - the selected tokens, as well as the number of candidate matches. - - NOTE: Unless otherwise stated, the variable names match those in the paper. - """ - new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] - # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens - # selected by the assistant, respectively. - q = candidate_logits.softmax(dim=-1) - q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1) - p = new_logits.softmax(dim=-1) - p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1) - probability_ratio = p_i / q_i - - # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller - # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio - # (= keep with p = probability_ratio). Keep all the tokens until the first rejection - r_i = torch.rand_like(probability_ratio) - is_accepted = r_i <= probability_ratio - n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 - - # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) - if last_assistant_token_is_eos and n_matches == candidate_length: - # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model - # due to acceptance on EOS we fix `n_matches` - n_matches -= 1 - valid_tokens = new_candidate_input_ids[:, : n_matches + 1] - else: - n_matches = min(n_matches, max_matches) - - # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = min(candidate_logits.shape[1], max_matches) - p_n_plus_1 = p[:, n_matches, :] - if n_matches < gamma: - q_n_plus_1 = q[:, n_matches, :] - p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) - p_prime.div_(p_prime.sum()) - else: - p_prime = p_n_plus_1 - t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] - - # The selected tokens include the matches (if any) plus the next sampled tokens - if n_matches > 0: - valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) - else: - valid_tokens = t - - return valid_tokens, n_matches - - def assisted_decoding( - self, - input_ids: torch.LongTensor, - assistant_model: Optional["PreTrainedModel"] = None, - candidate_generator: Optional["CandidateGenerator"] = None, - do_sample: bool = False, - logits_processor: Optional[LogitsProcessorList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, - **model_kwargs, - ): - r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding** or - **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a - candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text - models. - - - - In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use - generate() instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - candidate_generator (`CandidateGenerator`, *optional*): - A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. - assistant_model (`PreTrainedModel`, *optional*): - An assistant model that can be used to accelerate generation. The assistant model must have the exact - same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model - is much faster than running generation with the model you're calling generate from. As such, the - assistant model should be much smaller. - do_sample (`bool`, *optional*, defaults to `False`): - Whether or not to use sampling ; use greedy decoding otherwise. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2") - >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token - >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id - >>> input_prompt = "It might be possible to" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), - ... ] - ... ) - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.assisted_decoding( - ... input_ids, - ... assistant_model=assistant_model, - ... logits_processor=logits_processor, - ... stopping_criteria=stopping_criteria, - ... ) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["It might be possible to get a better understanding of the nature of the problem, but it's not"] - ```""" - # handling deprecated arguments - if (assistant_model is None) == (candidate_generator is None): - raise ValueError( - "One (and only one) of `assistant_model` and `candidate_generator` should be defined." - ) + output_greedy = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict + ) - if assistant_model is not None: - candidate_generator = AssistedCandidateGenerator( - input_ids=input_ids, - assistant_model=assistant_model, - logits_processor=logits_processor, - model_kwargs=model_kwargs, - eos_token_id=eos_token_id, - ) - warnings.warn( - "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " - "Pass the `candidate_generator` argument instead.", - FutureWarning, - ) + generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) + output_prompt_lookup = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict + ) - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if eos_token_id is not None and pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - ) - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) + # The two outputs must match and their shape must be as expected - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist()) + for output in (output_greedy, output_prompt_lookup): + self._check_outputs(output, input_ids, model.config, use_cache=True) - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) + @pytest.mark.generate + def test_dola_decoding_sample(self): + # TODO (joao): investigate skips, try to reduce incompatibilities + for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support DoLa decoding") - # keep track of which sequences are already finished - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - - # other auxiliary variables - max_len = stopping_criteria[0].max_length - - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - torch.dist.all_reduce(this_peer_finished_flag, op=torch.dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - cur_len = input_ids.shape[-1] - - # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - candidate_input_ids = candidate_input_ids.to(self.device) - if candidate_logits is not None: - candidate_logits = candidate_logits.to(self.device) - - candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - last_assistant_token_is_eos = ( - ~candidate_input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - .bool() - ) + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): + self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.") - # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain - # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, - # we use this forward pass to also pick the subsequent logits in the original model. + if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]): + self.skipTest("DoLa is not supported for models that don't return layerwise hidden states") - # 2.1. Prepare the model inputs - candidate_kwargs = copy.copy(model_kwargs) - candidate_kwargs = _prepare_attention_mask( - candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder - ) - candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) + # Encoder-decoder models are not supported + if config.is_encoder_decoder: + self.skipTest("DoLa is not supported for encoder-decoder models") + config.is_decoder = True + model = model_class(config).to(torch_device).eval() - # 2.2. Run a forward pass on the candidate sequence - outputs = self( - **model_inputs, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + if model.get_output_embeddings() is None: + self.skipTest("DoLa is not supported for models that don't have output embeddings") + # Sets dola generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see b) + "num_beams": 1, + "do_sample": True, + "output_scores": True, + "output_logits": True, + "output_hidden_states": True, + "output_attentions": self.has_attentions, + "return_dict_in_generate": True, + "use_cache": hasattr(config, "use_cache"), # Some models don't support the cache + } + generation_kwargs.update({"dola_layers": "low"}) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs, **inputs_dict) + self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache")) - # 2.3. Process the new logits - new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present - if len(logits_processor) > 0: - for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_processor( - candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] - ) - if len(logits_warper) > 0: - for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper( - candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] - ) + @pytest.mark.generate + def test_assisted_decoding_sample(self): + # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not + # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with + # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). + for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support assisted generation") + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + self.skipTest(reason="Won't fix: old model with different cache format") + if any( + model_name in model_class.__name__.lower() + for model_name in [ + "bigbirdpegasus", + "led", + "mega", + "speech2text", + "git", + "prophetnet", + "seamlessm4t", + "clvp", + ] + ): + self.skipTest(reason="May fix in the future: need model-specific fixes") - # 3. Select the accepted tokens. There are two possible cases: - # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) - # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). - max_matches = max_len - cur_len - 1 - if do_sample and candidate_logits is not None: - valid_tokens, n_matches = _speculative_sampling( - candidate_input_ids, - candidate_logits, - candidate_length, - new_logits, - last_assistant_token_is_eos, - max_matches, - ) + # enable cache + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) - # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the - # original model logits with the candidate tokens. We can keep the candidate tokens until the first - # mismatch, or until the max length is reached. - else: - if do_sample: - probs = new_logits.softmax(dim=-1) - selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] - else: - selected_tokens = new_logits.argmax(dim=-1) - - candidate_new_tokens = candidate_input_ids[:, cur_len:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() - - # Ensure we don't generate beyond max_len or an EOS token - if last_assistant_token_is_eos and n_matches == candidate_length: - n_matches -= 1 - n_matches = min(n_matches, max_matches) - valid_tokens = selected_tokens[:, : n_matches + 1] - - # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated - # by the model after the last candidate match is also valid, as it is generated from a correct sequence. - # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there - # is no match. - - # 4.1. Get the valid continuation, after the matching tokens - input_ids = torch.cat((input_ids, valid_tokens), dim=-1) - if streamer is not None: - streamer.put(valid_tokens.cpu()) - new_cur_len = input_ids.shape[-1] - - # 4.2. Discard past key values relative to unused assistant tokens - new_cache_size = new_cur_len - 1 - outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) - - # 5. Update the candidate generation strategy if needed - candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) - - if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need - - # Store scores, attentions and hidden_states when required - # Assistant: modified to append one tuple element per token, as in the other generation methods. - if return_dict_in_generate: - if output_scores: - scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) - - if "past_key_values" not in model_kwargs: - added_len = new_cur_len - else: - added_len = n_matches + 1 - - if output_attentions: - if self.config.is_encoder_decoder: - cross_attentions = _split_model_outputs( - cross_attentions, outputs.cross_attentions, cur_len, added_len - ) - decoder_attentions = _split_model_outputs( - decoder_attentions, - outputs.decoder_attentions, - cur_len, - added_len, - is_decoder_attention=True, - ) - else: - decoder_attentions = _split_model_outputs( - decoder_attentions, - outputs.attentions, - cur_len, - added_len, - is_decoder_attention=True, - ) - if output_hidden_states: - if self.config.is_encoder_decoder: - decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len - ) - else: - decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.hidden_states, cur_len, added_len - ) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + # Sets assisted generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of + # the assistant model is correct + # c) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 # see b) + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see c) + "num_beams": 1, + "do_sample": True, + "assistant_model": assistant_model, + "output_scores": True, + "output_logits": True, + "output_hidden_states": True, + "output_attentions": self.has_attentions, + "return_dict_in_generate": True, + "use_cache": True, + } + output_assisted = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict + ) - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - - # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores): - this_peer_finished = True - - if this_peer_finished and not synced_gpus: - break - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + self._check_outputs(output_assisted, input_ids, config, use_cache=True) + + @pytest.mark.generate + def test_prompt_lookup_decoding_stops_at_eos(self): + # This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens + # (see https://github.com/huggingface/transformers/pull/31301) + + # The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids. + # First time at the very end, so input ends with the unigrams, and second any arbitrary location. + # Also, we need an EOS token which will be injected just after the arbitrary located ngram. + # We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams + # in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model. + # That seems as if the model "generated" and EOS but didn't stop from user's perspective - model.assisted_decoding = MethodType(assisted_decoding, model) + input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50 + arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests + input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs + input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen - ####################################################################### + eos_token_id = torch.tensor([0], device=torch_device) + input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram - output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + # init cand geenerator with max_matching_ngram_size=1 to match per-token + candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 + ) + output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0] - self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) + # PLD shouldn't propose any new tokens based on eos-match + self.assertTrue(output_prompt_lookup.shape[-1] == 10) + @pytest.mark.generate def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() # We want to test only encoder-decoder models if not config.is_encoder_decoder: continue @@ -2013,60 +1509,93 @@ def test_generate_with_head_masking(self): input_ids, attention_mask=attention_mask, num_beams=1, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, remove_invalid_values=True, **{name: mask}, + **inputs_dict, ) # We check the state of decoder_attentions and cross_attentions just from the last step attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + @pytest.mark.generate def test_left_padding_compatibility(self): - # The check done in this test is fairly difficult -- depending on the model architecture, passing the right - # position index for the position embeddings can still result in a different output, due to numerical masking. - # On the other hand, for some types of position embeddings, an incorrect position index can have a minimal - # impact on the output. - # There are two tricks employed to check whether left-padding compatibility is in place: - # 1 - To reduce the negative impact of the numerical attention mask on a correct position index, we set the - # padding size to 1. - # 2 - To reduce the chance of false positives (i.e. passing when it should be failing), we run the check - # multiple times with random inputs, and it has to pass with all of them. - # NOTE: because of 2), there is some chance of false positives in this test. + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # First, filter out models that don't support left padding + # - The model must have generative capabilities + if len(self.all_generative_model_classes) == 0: + self.skipTest(reason="No generative architecture available for this model.") + # - The model must support padding + if not self.has_attentions: + self.skipTest(reason="This model doesn't support padding.") + + # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) + decoder_only_classes = [] for model_class in self.all_generative_model_classes: config, _, _, _ = self._get_input_ids_and_config() if config.is_encoder_decoder: - continue # skip for encoder-decoder models -- they don't need left-padding compatibility + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + # Then, test left-padding + def _prepare_model_kwargs(input_ids, attention_mask, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in decoder_only_classes: + config, input_ids, attention_mask, _ = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() signature = inspect.signature(model.forward).parameters.keys() - no_failures = True - for _ in range(10): # there may be false positives with 10 runs, we rely on the CI to catch the flakiness - _, input_ids, attention_mask, _ = self._get_input_ids_and_config() - model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} - if "position_ids" in signature: - position_ids = torch.cumsum(attention_mask, dim=-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - model_kwargs["position_ids"] = position_ids - next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] - - pad_size = (input_ids.shape[0], 1) - padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id - padded_input_ids = torch.cat((padding, input_ids), dim=1) - padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) - model_kwargs = {"input_ids": padded_input_ids, "attention_mask": padded_attention_mask} - if "position_ids" in signature: - position_ids = torch.cumsum(padded_attention_mask, dim=-1) - 1 - position_ids.masked_fill_(padded_attention_mask == 0, 1) - model_kwargs["position_ids"] = position_ids - next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] - if not torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-7): - no_failures = False - break - - self.assertTrue(no_failures) + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5)) + @pytest.mark.generate def test_past_key_values_format(self): # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a # standard KV cache format is important for a consistent API (and for advanced generation methods). @@ -2075,7 +1604,7 @@ def test_past_key_values_format(self): # If it doesn't support cache, pass the test if not hasattr(config, "use_cache"): - return + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") model = model_class(config).to(torch_device) if "use_cache" not in inputs: @@ -2084,7 +1613,7 @@ def test_past_key_values_format(self): # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) if "past_key_values" not in outputs: - return + self.skipTest(reason="This model doesn't return `past_key_values`") num_hidden_layers = ( getattr(config, "decoder_layers", None) @@ -2138,6 +1667,7 @@ def test_past_key_values_format(self): past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) ) + @pytest.mark.generate def test_generate_from_inputs_embeds_decoder_only(self): # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` # if fails, you should probably update the `prepare_inputs_for_generation` function @@ -2164,106 +1694,587 @@ def test_generate_from_inputs_embeds_decoder_only(self): continue # Traditional way of generating text - outputs_from_ids = model.generate(input_ids) - self.assertEqual(outputs_from_ids.shape, (2, 20)) + outputs_from_ids = model.generate( + input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + ) + self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) inputs_embeds = model.get_input_embeddings()(input_ids) - outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds) - self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) + outputs_from_embeds = model.generate( + input_ids, + inputs_embeds=inputs_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) - # But if we pass different inputs_embeds, we should get different outputs - torch.manual_seed(0) + # But if we pass different inputs_embeds, we should get different outputs (the output text may be the + # same, but the logits will almost surely be different) random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist()) + outputs_from_rand_embeds = model.generate( + input_ids, + inputs_embeds=random_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + for i in range(len(outputs_from_rand_embeds.scores)): + self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same outputs_from_embeds_wo_ids = model.generate( - inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1] + inputs_embeds=inputs_embeds, max_new_tokens=5, return_dict_in_generate=True, output_scores=True ) self.assertListEqual( - outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), - outputs_from_embeds_wo_ids.tolist(), + outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), + outputs_from_embeds_wo_ids.sequences.tolist(), ) - def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): - batch_size, seq_length = input_ids.shape - num_sequences_in_output = batch_size * num_return_sequences - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length - ) + @pytest.mark.generate + def test_generate_from_inputs_embeds_with_static_cache(self): + """ + Test that StaticCache can generate from inputs_embeds and calculates max_cache_length + correctly in `generate()`. We force the model to not stop generation until max-length is reached + to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache. + """ + for model_class in self.all_generative_model_classes: + if not model_class._supports_static_cache: + self.skipTest(reason="This model does not support the static cache format") - # scores - self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + if config.is_encoder_decoder: + self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") - # Attentions - if config.is_encoder_decoder: - # encoder - self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) - # decoder - self._check_attentions_for_generate( - num_sequences_in_output, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, + model = model_class(config).to(torch_device).eval() + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + self.skipTest(reason="This model does not support `inputs_embeds` in generation") + + model.config.use_cache = True + model.config.is_decoder = True + batch_size, seq_length = input_ids.shape + max_cache_len = 30 + + # here we force to not stop at eos and go until max-length + model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 + generation_kwargs = { + "max_length": max_cache_len, + "cache_implementation": "static", + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + text_config = model.config.get_text_config() + head_dim = ( + text_config.head_dim + if hasattr(text_config, "head_dim") + else text_config.hidden_size // text_config.num_attention_heads ) - else: - # if use_cache first input is equal to no use_cache, so skip here - attentions = output.attentions if not use_cache else output.attentions[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_attentions_for_generate( - num_sequences_in_output, - attentions=attentions, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, + num_key_value_heads = ( + text_config.num_attention_heads + if getattr(text_config, "num_key_value_heads", None) is None + else text_config.num_key_value_heads ) + num_hidden_layers = text_config.num_hidden_layers - # Hidden States - if config.is_encoder_decoder: - # encoder - self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, batch_size, config, seq_length + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs = model.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, **inputs_dict ) - # decoder - self._check_hidden_states_for_generate( - num_sequences_in_output, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_hidden_states_for_generate( - num_sequences_in_output, - hidden_states, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) + # we should get `max_length` in shape, not `max_length - embeds_length` + cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) + self.assertTrue(isinstance(outputs.past_key_values, StaticCache)) + self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers) + self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape) - def _check_scores(self, batch_size, scores, length, config): - expected_shape = (batch_size, config.vocab_size) - self.assertIsInstance(scores, tuple) - self.assertEqual(len(scores), length) - self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) + @pytest.mark.generate + def test_generate_continue_from_past_key_values(self): + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): + self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") - def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 - ): - self.assertIsInstance(attentions, tuple) - self.assertListEqual( + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + if not hasattr(config, "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When + # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents + # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls + # with cache, what is considered a prompt is different in the two cases. + + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + + model = model_class(config).to(torch_device) + model.eval() + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None + model.generation_config.encoder_no_repeat_ngram_size = 0 + model.generation_config.use_cache = True + + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs) + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[-1] + if config.is_encoder_decoder: + inputs["decoder_input_ids"] = outputs_cached.sequences + if "decoder_attention_mask" in inputs: + inputs["decoder_attention_mask"] = torch.nn.functional.pad( + inputs["decoder_attention_mask"], + (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]), + mode="constant", + value=1, + ) + else: + inputs["input_ids"] = outputs_cached.sequences + if "attention_mask" in inputs: + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], + (0, new_attention_len - inputs["attention_mask"].shape[1]), + mode="constant", + value=1, + ) + outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) + + # The two sets of generated text and past kv should be equal to each other + self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) + for layer_idx in range(len(outputs_cached.past_key_values)): + for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + outputs_cached.past_key_values[layer_idx][kv_idx], + ) + ) + + @parameterized.expand([(1, False), (1, True), (4, False)]) + @pytest.mark.generate + def test_new_cache_format(self, num_beams, do_sample): + # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). + # 👉 tests with and without beam search so that we can test with and without cache reordering. + # 👉 tests with and without sampling so we can cover the most common use cases. + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + self.skipTest(reason="This model does not support the new cache format") + + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "do_sample": do_sample, + "num_beams": num_beams, + "num_return_sequences": num_beams, + "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, + } + + # Sets seed before calling `generate` for the case with do_sample=True + seed = torch.randint(0, 1000000, (1,)).item() + set_seed(seed) + legacy_results = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict + ) + set_seed(seed) + if config.is_encoder_decoder: + cache_cls = EncoderDecoderCache + past_key_values = cache_cls(DynamicCache(), DynamicCache()) + past_key_values = cache_cls(DynamicCache(), DynamicCache()) + else: + cache_cls = DynamicCache + past_key_values = cache_cls() + + new_results = model.generate( + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + **generation_kwargs, + **inputs_dict, + ) + + # The two sets of generated sequences must match, despite the cache format between forward passes being + # different + self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) + self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) + self.assertTrue(isinstance(new_results.past_key_values, cache_cls)) + + # The contents of the two caches, when converted to the same format (in both directions!), must match + legacy_cache = legacy_results.past_key_values + new_cache_converted = new_results.past_key_values.to_legacy_cache() + for layer_idx in range(len(legacy_cache)): + for kv_idx in range(len(legacy_cache[layer_idx])): + # TODO: @raushan, please look into this for new cache format + if legacy_cache[layer_idx][kv_idx] != []: + self.assertTrue( + torch.allclose( + legacy_cache[layer_idx][kv_idx], + new_cache_converted[layer_idx][kv_idx], + ) + ) + + new_cache = new_results.past_key_values + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) + for layer_idx in range(len(new_cache)): + for kv_idx in range(len(new_cache[layer_idx])): + # TODO: @raushan, please look into this for new cache format + if new_cache[layer_idx][kv_idx] != []: + self.assertTrue( + torch.allclose( + new_cache[layer_idx][kv_idx], + legacy_cache_converted[layer_idx][kv_idx], + ) + ) + + @pytest.mark.generate + def test_generate_with_static_cache(self): + """ + Tests if StaticCache works if we set attn_implementation=static when generation. + This doesn't test if generation quality is good, but tests that models with + self._supports_static_cache don't throw an error when generating and return + a StaticCache object at the end. + """ + for model_class in self.all_generative_model_classes: + if not model_class._supports_static_cache: + self.skipTest(reason="This model does not support the static cache format") + + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + if config.is_encoder_decoder: + self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") + + config.is_decoder = True + batch_size, seq_length = input_ids.shape + max_new_tokens = 20 + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_length": None, + "max_new_tokens": max_new_tokens, + "cache_implementation": "static", + "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, + } + + max_cache_len = seq_length + max_new_tokens + config = config.text_config if hasattr(config, "text_config") else config + head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + num_hidden_layers = config.num_hidden_layers + results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict) + + cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) + self.assertTrue(isinstance(results.past_key_values, StaticCache)) + self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers) + self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape) + + @require_quanto + @pytest.mark.generate + def test_generate_with_quant_cache(self): + for model_class in self.all_generative_model_classes: + if not model_class._supports_quantized_cache: + self.skipTest(reason="This model does not support the quantized cache format") + + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "cache_implementation": "quantized", + # careful with group size, should be divisor of model's hidden size + "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, + "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, + } + + results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict) + self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache)) + + # passing past key values of different type should raise Error + with self.assertRaises(ValueError): + num_hidden_layers = config.get_text_config().num_hidden_layers + model.generate( + input_ids, + attention_mask=attention_mask, + past_key_valyes=DynamicCache(num_hidden_layers), + **generation_kwargs, + ) + + # setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense + generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128} + with self.assertRaises(ValueError): + model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + @pytest.mark.generate + @require_torch_gpu + @slow + @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky + def test_generate_compile_fullgraph(self): + """ + Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. + ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ + """ + for model_class in self.all_generative_model_classes: + if not model_class._supports_static_cache: + self.skipTest("This model doesn't support static cache") + # TODO (joao) -- fix and enable me :) + if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]): + self.skipTest("whisper model end-to-end generate compile not yet supported") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # TODO (joao) -- fix and enable me :) + if config.is_encoder_decoder: + self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported") + + model = model_class(config).to(torch_device) + model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time + + input_ids = inputs_dict["input_ids"].to(torch_device) + # creates two sets of *different* inputs with the same shape + half_batch_size = input_ids.shape[0] // 2 + input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]] + self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape) + + generation_kwargs = { + "do_sample": False, + "max_new_tokens": 10, + } + + max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"] + config = config.get_text_config() + past_key_values = StaticCache( + config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device + ) + + for model_inputs in input_ids_sets: + # eager dynamic cache + output_dynamic = model.generate(model_inputs, **generation_kwargs) + + # end-to-end compiled dynamic cache + torch.compiler.reset() + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + generation_config = copy.deepcopy(model.generation_config) + generation_config.update(**generation_kwargs) + output_compiled = compiled_generate( + model_inputs, generation_config=generation_config, past_key_values=past_key_values + ) + self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + + @pytest.mark.generate + def test_generate_methods_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + # All generation methods (except assisted decoding) rely on always extracting the last token logits of the + # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, + # other methods will work as well) + generation_kwargs = { + "max_new_tokens": 10, + "do_sample": False, + } + + # Setting num_logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0 + ) + # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + without_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs + ) + self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + + @pytest.mark.generate + @is_flaky() # assisted generation tests are flaky (minor fp ops differences) + def test_assisted_decoding_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support assisted generation") + + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + assistant_model = model + # All generation methods (except assisted decoding) rely on always extracting the last token logits of the + # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, + # other methods will work as well) + generation_kwargs = { + "max_new_tokens": 10, + "do_sample": False, + "assistant_model": assistant_model, + } + + assistant_model.generation_config.assistant_confidence_threshold = None + # Setting num_logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0 + ) + # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + without_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs + ) + self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + + @pytest.mark.generate + def test_inherits_generation_mixin(self): + """ + Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel` + to inherit it. + """ + for model_class in self.all_generative_model_classes: + self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): + batch_size, seq_length = input_ids.shape + config = config.text_config if hasattr(config, "text_config") else config + num_sequences_in_output = batch_size * num_return_sequences + + gen_len = ( + output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length + ) + + # scores + self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) + + # unprocessed logits + self._check_logits(num_sequences_in_output, output.logits, config=config) + + # Attentions + if self.has_attentions: + if config.is_encoder_decoder: + # encoder + self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) + # decoder + self._check_attentions_for_generate( + num_sequences_in_output, + output.decoder_attentions, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + else: + # if use_cache first input is equal to no use_cache, so skip here + attentions = output.attentions if not use_cache else output.attentions[1:] + min_length = seq_length if not use_cache else seq_length + 1 + self._check_attentions_for_generate( + num_sequences_in_output, + attentions=attentions, + min_length=min_length, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + + # Hidden States + if config.is_encoder_decoder: + # encoder + self._check_encoder_hidden_states_for_generate( + output.encoder_hidden_states, batch_size, config, seq_length + ) + + # decoder + self._check_hidden_states_for_generate( + num_sequences_in_output, + output.decoder_hidden_states, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + else: + # if use_cache first input is equal to no use_cache, so skip here + hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] + min_length = seq_length if not use_cache else seq_length + 1 + self._check_hidden_states_for_generate( + num_sequences_in_output, + hidden_states, + min_length=min_length, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + + # Past Key Value States -- a few notes here: + # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" + # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the + # standard cache format (e.g.gptbigcode ) + models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet") + has_standard_cache = not any( + model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache + ) + if has_standard_cache: + if use_cache: + past_key_values = output.past_key_values + past_sequence_length = output.sequences.shape[-1] - 1 + self._check_past_key_values_for_generate( + num_sequences_in_output, + past_key_values, + seq_length=past_sequence_length, + config=config, + ) + elif use_cache is False: + self.assertTrue(output.past_key_values is None) + + def _check_scores(self, batch_size, scores, length, config): + vocab_size = config.get_text_config(decoder=True).vocab_size + expected_shape = (batch_size, vocab_size) + self.assertIsInstance(scores, tuple) + self.assertEqual(len(scores), length) + self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) + + def _check_logits(self, batch_size, scores, config): + vocab_size = config.get_text_config(decoder=True).vocab_size + self.assertIsInstance(scores, tuple) + self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores)) + # vocabulary difference equal to one (imagegptmodel?) or zero (all other models) + vocab_diff = vocab_size - scores[0].shape[-1] + self.assertTrue(vocab_diff in [0, 1]) + self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores)) + + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) ) self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) @@ -2318,6 +2329,30 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c [encoder_expected_shape] * len(hidden_states), ) + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): + self.assertIsInstance(past_key_values, tuple) + self.assertListEqual( + [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values], + [True] * len(past_key_values), + ) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size * num_beam_groups, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + seq_length, + config.hidden_size // config.num_attention_heads, + ) + # check shape key, value + self.assertListEqual( + [layer_past_key_values[0].shape for layer_past_key_values in past_key_values], + [expected_shape] * len(past_key_values), + ) + self.assertListEqual( + [layer_past_key_values[1].shape for layer_past_key_values in past_key_values], + [expected_shape] * len(past_key_values), + ) + def _check_sequence_inside_sequence(self, tensor_1, tensor_2): # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. # set to same device. we don't care what device. @@ -2342,6 +2377,45 @@ def _check_sequence_inside_sequence(self, tensor_1, tensor_2): self.assertTrue(flag) +@require_torch +class UtilsFunctionsTest(unittest.TestCase): + def test_speculative_sampling(self): + # assume vocab size 10, input length 5 + 3 generated candidates + candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens + candidate_logits = torch.tensor( + [ + [ + [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1 + [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4 + [-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5 + ] + ] + ) + candidate_length = 3 + inf = float("inf") + new_logits = torch.tensor( + [ + [ + [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1 + [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4 + [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8 + [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A + ] + ] + ) + last_assistant_token_is_eos = False + validated_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + ) + self.assertTrue(n_matches.item() == 2) + self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) + + +@pytest.mark.generate @require_torch class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin): # setting framework_dependent_parameters needs to be gated, just like its contents' imports @@ -2359,6 +2433,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi } @slow + @pytest.mark.skip("Group beam search is not supported by optimum-habana") def test_diverse_beam_search(self): # PT-only test: TF doesn't have a diverse beam search implementation article = """Justin Timberlake and Jessica Biel, welcome to parenthood. @@ -2393,257 +2468,80 @@ def test_diverse_beam_search(self): ], ) - def test_max_length_backward_compat_greedy(self): + def test_max_length_if_input_embeds(self): # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + article = "Today a dragon flew over Paris." + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + inputs_embeds = model.get_input_embeddings()(input_ids) - max_length = 20 - input_ids = input_ids.expand(2, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) + # Controlling max_length via the configuration is deprecated in favor of max_new_tokens + max_new_tokens = 20 + input_len = input_ids.shape[-1] + out_gen = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens) + out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens) + self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) - with self.assertWarns(UserWarning): - bart_model.greedy_search( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) + def test_min_length_if_input_embeds(self): + # PT-only test: TF doesn't have StoppingCriteria + article = "Today a dragon flew over Paris." + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + inputs_embeds = model.get_input_embeddings()(input_ids) + + # Controlling max_length via the configuration is deprecated in favor of max_new_tokens + min_length = 10 + input_len = input_ids.shape[-1] + out_gen = model.generate(input_ids=input_ids, min_length=min_length, max_new_tokens=20) + out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length, max_new_tokens=20) + self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) - def test_max_length_backward_compat_sample(self): + def test_custom_stopping_criteria_overload_error(self): # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) - max_length = 20 - input_ids = input_ids.expand(2, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - with torch.no_grad(): - with self.assertWarns(UserWarning): - bart_model.sample( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + stopping_criteria = StoppingCriteriaList() + stopping_criteria.append(MaxLengthCriteria(max_length=42)) + with self.assertRaises(ValueError): + bart_model.generate(input_ids, stopping_criteria=stopping_criteria) + with self.assertRaises(ValueError): + bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32) - def test_max_length_backward_compat_beam_search(self): + def test_custom_stopping_criteria(self): # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - batch_size = 1 - max_length = 20 - num_beams = 2 - - input_ids = input_ids.expand(2, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) + class DummyCriteria(StoppingCriteria): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids.shape[-1] >= 20 - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - ) - with self.assertWarns(UserWarning): - _ = bart_model.beam_search( - input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs - ) + stopping_criteria = StoppingCriteriaList() + stopping_criteria.append(DummyCriteria()) - def test_max_length_backward_compat_group_beam_search(self): - # PT-only test: TF doesn't have StoppingCriteria & group beam search - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device + output = bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22) + self.assertEqual( + list(output.shape), + [1, 22], # still produces the max_length ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + # make sure final tokens are padding + self.assertEqual(output[:, 20:].tolist(), [[bart_model.config.pad_token_id, bart_model.config.pad_token_id]]) - batch_size = 1 - max_length = 20 - num_beams = 6 - num_beam_groups = 3 - num_return_sequences = num_beams * batch_size - - input_ids = input_ids.expand(6, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - - diverse_beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, - ) - with self.assertWarns(UserWarning): - bart_model.group_beam_search( - input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs - ) - - def test_max_length_warning_if_different(self): - # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - batch_size = 1 - - max_length = 20 - num_beams = 6 - num_beam_groups = 3 - num_return_sequences = num_beams * batch_size - stopping_criteria_max_length = 18 - stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) - - # Greedy - input_ids = input_ids.expand(6, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( - batch_size=input_ids.shape[0], - model_input_name=bart_model.main_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=bart_model.config.decoder_start_token_id, - bos_token_id=bart_model.config.bos_token_id, - ) - - with self.assertWarns(UserWarning): - bart_model.greedy_search( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - stopping_criteria=stopping_criteria, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) - - # Sample - with self.assertWarns(UserWarning): - with torch.no_grad(): - bart_model.sample( - input_ids, - max_length=max_length, - stopping_criteria=stopping_criteria, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) - - # Beam - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - ) - with self.assertWarns(UserWarning): - with torch.no_grad(): - bart_model.beam_search( - input_ids, - num_beams=num_beams, - stopping_criteria=stopping_criteria, - max_length=max_length, - beam_scorer=beam_scorer, - **model_kwargs, - ) - - # Grouped beam search - diverse_beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, - ) - with self.assertWarns(UserWarning): - bart_model.group_beam_search( - input_ids, - diverse_beam_scorer, - stopping_criteria=stopping_criteria, - num_beams=num_beams, - max_length=max_length, - **model_kwargs, - ) - - def test_custom_stopping_criteria_overload_error(self): - # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) - - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - stopping_criteria = StoppingCriteriaList() - stopping_criteria.append(MaxLengthCriteria(max_length=42)) - with self.assertRaises(ValueError): - bart_model.generate(input_ids, stopping_criteria=stopping_criteria) - with self.assertRaises(ValueError): - bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32) - - def test_custom_stopping_criteria(self): - # PT-only test: TF doesn't have StoppingCriteria - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - class DummyCriteria(StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return input_ids.shape[-1] >= 20 - - stopping_criteria = StoppingCriteriaList() - stopping_criteria.append(DummyCriteria()) - - self.assertEqual( - list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22).shape), - [1, 20], - ) - self.assertEqual( - list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape), - [1, 18], + self.assertEqual( + list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape), + [1, 18], ) + # TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail def test_stop_sequence_stopping_criteria(self): # PT-only test: TF doesn't have StoppingCriteria prompt = """Hello I believe in""" @@ -2651,17 +2549,11 @@ def test_stop_sequence_stopping_criteria(self): output = generator(prompt) self.assertEqual( output, - [ - { - "generated_text": ( - "Hello I believe in in in number number number number number number number number number" - ) - } - ], + [{"generated_text": ("Hello I believe in we we we we we we we we we")}], ) - output = generator(prompt, stop_sequence=" number") - self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) + output = generator(prompt, stop_sequence=" we") + self.assertEqual(output, [{"generated_text": "Hello I believe in we"}]) def test_generate_non_nlp_input_ids_as_kwarg(self): # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input @@ -2687,6 +2579,7 @@ def test_generate_input_values_as_encoder_kwarg(self): self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) self.assertEqual(output_sequences.shape, (2, 5)) + @pytest.mark.skip("Group beam search is not supported by optimum-habana") def test_transition_scores_group_beam_search_encoder_decoder(self): # PT-only test: TF doesn't have group beam search articles = [ @@ -2716,13 +2609,61 @@ def test_transition_scores_group_beam_search_encoder_decoder(self): self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_beam_search_low_memory(self): + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + model_inputs = tokenizer("I", return_tensors="pt")["input_ids"] + + low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True) + + high_output = model.generate( + model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False + ) + self.assertListEqual(low_output.tolist(), high_output.tolist()) + + @slow + @pytest.mark.skip("Watermarking is not supported by optimum-habana yet") + def test_watermark_generation(self): + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer.pad_token_id = tokenizer.eos_token_id + model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device) + input_len = model_inputs["input_ids"].shape[-1] + + # generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :) + watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") + _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15) + + # We will not check watermarked text, since we check it in `logits_processors` tests + # Checking if generated ids are as expected fails on different hardware + args = { + "bias": 2.0, + "context_width": 1, + "seeding_scheme": "selfhash", + "greenlist_ratio": 0.25, + "hashing_key": 15485863, + } + output = model.generate(**model_inputs, do_sample=False, max_length=15) + output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15) + + # Check that the detector is detecting watermarked text + detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args) + detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True) + detection_out = detector(output[:, input_len:], return_dict=True) + + self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) + self.assertListEqual(detection_out.prediction.tolist(), [False]) + @slow def test_beam_search_example_integration(self): # PT-only test: TF doesn't have a BeamSearchScorer # exactly the example provided in the docstrings of beam search, which previously # failed after directly copying from it. Refer to PR #15555 - tokenizer = AutoTokenizer.from_pretrained("t5-base") - model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") encoder_input_str = "translate English to German: How old are you?" encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids @@ -2730,31 +2671,15 @@ def test_beam_search_example_integration(self): # lets run beam search using 3 beams num_beams = 3 # define decoder start token ids - input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) input_ids = input_ids * model.config.decoder_start_token_id # add encoder_outputs to model keyword arguments - model_kwargs = { - "encoder_outputs": model.get_encoder()( - encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ) - } - - # instantiate beam scorer - beam_scorer = BeamSearchScorer( - batch_size=1, - num_beams=num_beams, - device=model.device, - ) + model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} - # instantiate logits processors - logits_processor = LogitsProcessorList( - [ - MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ] + outputs = model.generate( + input_ids, num_beams=num_beams, min_length=5, eos_token_id=model.config.eos_token_id, **model_kwargs ) - - outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual(outputs, ["Wie alt bist du?"]) @@ -2762,8 +2687,8 @@ def test_beam_search_example_integration(self): @slow def test_constrained_beam_search(self): # PT-only test: TF doesn't have constrained beam search - model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids @@ -2800,8 +2725,8 @@ def test_constrained_beam_search(self): @slow def test_constrained_beam_search_mixed(self): # PT-only test: TF doesn't have constrained beam search - model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids flexible_phrases = tokenizer( @@ -2841,8 +2766,8 @@ def test_constrained_beam_search_mixed(self): @slow def test_constrained_beam_search_mixed_mixin(self): # PT-only test: TF doesn't have constrained beam search - model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") force_word = "scared" force_flexible = ["scream", "screams", "screaming", "screamed"] @@ -2877,9 +2802,15 @@ def test_constrained_beam_search_mixed_mixin(self): ) @slow + @pytest.mark.xfail def test_cfg_mixin(self): - model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + + # add pad_token_id for static shape + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model.generation_config.pad_token_id = model.generation_config.eos_token_id input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True) input["input_ids"] = input["input_ids"].to(torch_device) @@ -2919,8 +2850,8 @@ def test_cfg_mixin(self): @slow def test_constrained_beam_search_example_translation_mixin(self): # PT-only test: TF doesn't have constrained beam search - tokenizer = AutoTokenizer.from_pretrained("t5-base") - model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") encoder_input_str = "translate English to German: How old are you?" force_words = ["sind"] @@ -2944,8 +2875,8 @@ def test_constrained_beam_search_example_translation_mixin(self): @slow def test_constrained_beam_search_example_integration(self): # PT-only test: TF doesn't have constrained beam search - tokenizer = AutoTokenizer.from_pretrained("t5-base") - model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") encoder_input_str = "translate English to German: How old are you?" encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids @@ -2953,38 +2884,65 @@ def test_constrained_beam_search_example_integration(self): # lets run beam search using 5 beams num_beams = 5 # define decoder start token ids - input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) input_ids = input_ids * model.config.decoder_start_token_id # add encoder_outputs to model keyword arguments - model_kwargs = { - "encoder_outputs": model.get_encoder()( - encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ) - } + model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} constraint_str = "sind" constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token - constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] - # instantiate beam scorer - beam_scorer = ConstrainedBeamSearchScorer( - batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints + outputs = model.generate( + input_ids, + num_beams=num_beams, + force_words_ids=[constraint_token_ids], + min_length=5, + eos_token_id=model.config.eos_token_id, + **model_kwargs, ) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) - # instantiate logits processors - logits_processor = LogitsProcessorList( - [ - MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ] - ) + self.assertListEqual(outputs, ["Wie alt sind Sie?"]) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + @slow + def test_per_row_stopping_criteria(self): + text = [ + "They completed the challenging puzzle, revealing the hidden", + "Today a dragon flew over France", + "The aroma of freshly baked pizza filled the kitchen", + ] + stop_strings = ["secrets"] - outputs = model.constrained_beam_search( - input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs + model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.padding_side = "left" + tokenizer.pad_token_id = tokenizer.eos_token_id + input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to( + torch_device ) - outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) - self.assertListEqual(outputs, ["Wie alt sind Sie?"]) + # normal generation with one stopping criteria + out = model.generate(input_ids, max_length=15) + out_text = tokenizer.batch_decode(out) + expected_out = [ + "They completed the challenging puzzle, revealing the hidden secrets of the world.\n", + "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", + "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", + ] + self.assertListEqual(out_text, expected_out) + + # generation should stop at "secrets" for first batch only, filling the rest with eos tokens + out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer) + out_text = tokenizer.batch_decode(out) + expected_out = [ + "They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", + "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", + "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", + ] + self.assertListEqual(out_text, expected_out) def test_constrained_beam_search_mixin_type_checks(self): # PT-only test: TF doesn't have constrained beam search @@ -3027,6 +2985,55 @@ def test_constrained_beam_search_mixin_type_checks(self): with self.assertRaises(ValueError): model.generate(input_ids, force_words_ids=[[[-1]]]) + def test_batched_decoder_start_id(self): + # PT-only test: TF doesn't support batched_decoder_start_id + articles = [ + "Justin Timberlake and Jessica Biel, welcome to parenthood.", + "Michael Phelps is arguably the most decorated Olympian of all time.", + ] + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) + decoder_start_token_id = bart_model.generation_config.decoder_start_token_id + decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0] + + outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id) + + outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch) + + self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist()) + + def test_decoder_start_id_from_config(self): + # Refer to: (#30899) + articles = [ + "Justin Timberlake and Jessica Biel, welcome to parenthood.", + "Michael Phelps is arguably the most decorated Olympian of all time.", + ] + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) + decoder_start_token_id = bart_model.generation_config.decoder_start_token_id + + # we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type + outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) + + # If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config + bart_model.generation_config.decoder_start_token_id = None + bart_model.generation_config.bos_token_id = None + outputs_with_user_id = bart_model.generate( + input_ids, + generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id), + ) + + self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist()) + + with self.assertRaises(ValueError): + outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) + def test_contrastive_search_batched(self): # PT-only test: TF doesn't have constrained beam search # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs) @@ -3053,6 +3060,27 @@ def test_contrastive_search_batched(self): max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max() self.assertTrue(max_score_diff < 1e-5) + def test_logits_processor_not_inplace(self): + # PT-only test: TF fixes were not made + article = "Today a dragon flew over Paris." + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + out = model.generate(input_ids, output_logits=True, output_scores=True, return_dict_in_generate=True) + out_with_temp = model.generate( + input_ids, + temperature=0.5, + do_sample=True, + output_logits=True, + output_scores=True, + return_dict_in_generate=True, + ) + + # if no logits processor is used, scores == logits. Otherwise, the processor has to modify the scores + self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist()) + self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist()) + def test_eos_token_id_int_and_list_top_k_top_sampling(self): # Has TF equivalent: this test relies on random sampling generation_kwargs = { @@ -3107,6 +3135,10 @@ def forward(self, input_ids, foo=None, **kwargs): # because it doesn't do signature filtering. class FakeEncoder(bart_model.model.encoder.__class__): def forward(self, input_ids, **kwargs): + # We remove these to pass gaudi_BartEncoder_forward TypeError + kwargs.pop("bucket_size", None) + kwargs.pop("bucket_internal", None) + kwargs.pop("reduce_recompile", None) return super().forward(input_ids, **kwargs) fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device) @@ -3121,15 +3153,16 @@ def forward(self, input_ids, **kwargs): def test_default_max_length_warning(self): model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model.config.pad_token_id = tokenizer.eos_token_id + model.generation_config.pad_token_id = tokenizer.eos_token_id text = "Hello world" tokenized_inputs = tokenizer([text], return_tensors="pt") input_ids = tokenized_inputs.input_ids.to(torch_device) # Default generation config value of 20 -> emits warning - with self.assertWarns(UserWarning): - model.generate(input_ids) + # NOTE: in OH we do not have this warning + # with self.assertWarns(UserWarning): + # model.generate(input_ids) # Explicitly setting max_length to 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: @@ -3138,7 +3171,805 @@ def test_default_max_length_warning(self): # Generation config max_length != 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: + # generation_config is modified -> legacy mode is disabled = generation_config takes precedence model.generation_config.max_length = 10 - model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence model.generate(input_ids) self.assertEqual(len(warning_list), 0) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_length_warning_assisted_generation(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.generation_config.pad_token_id = tokenizer.eos_token_id + assistant.generation_config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # This should not raise any warning that min length is not feasible in candidate generation + with warnings.catch_warnings(record=True) as warning_list: + model.generate( + input_ids, + assistant_model=assistant, + min_new_tokens=10, + max_length=20, + ) + self.assertEqual(len(warning_list), 0) + + def test_default_assisted_generation(self): + # Initialize the GenerationConfig object + config = GenerationConfig() + + # Check the default values + self.assertEqual(config.num_assistant_tokens, 20) + self.assertEqual(config.num_assistant_tokens_schedule, "constant") + self.assertEqual(config.assistant_confidence_threshold, 0.4) + self.assertEqual(config.is_assistant, False) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_generated_length_assisted_generation(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.generation_config.pad_token_id = tokenizer.eos_token_id + assistant.generation_config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + min_new_tokens=10, + max_new_tokens=20, + ) + self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length)) + + out = model.generate( + input_ids, + assistant_model=assistant, + min_new_tokens=10, + ) + self.assertTrue((input_length + 10) <= out.shape[-1] <= 20) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_model_kwarg_assisted_decoding_decoder_only(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.generation_config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with token_type_ids + outputs_tti = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + ) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant.config.pad_token_id = tokenizer.eos_token_id + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_model_kwarg_assisted_decoding_encoder_decoder(self): + """ + Tests that the following scenario is compatible with assisted generation: + 1. encoder-decoder main model + 2. encoder-decoder assistant model + 3. both have a custom input + (e.g. Whisper) + """ + + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg that distorts the output + class FakeBart(BartForConditionalGeneration): + def forward(self, input_ids, past_key_values, foo=False, **kwargs): + outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs) + if foo: + outs["logits"][:, :, :] = 0.0 + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + inputs["foo"] = foo + return inputs + + model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate(input_ids, foo=True) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # Check that passing encoder_outputs directly also works as expected + encoder_outputs = assistant.get_encoder()(input_ids) + + outputs_assisted = model.generate( + foo=True, + assistant_model=assistant, + encoder_outputs=encoder_outputs, + assistant_encoder_outputs=encoder_outputs, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_assisted_decoding_encoder_decoder_shared_encoder(self): + """ + Tests that the following scenario is compatible with assisted generation: + 1. encoder-decoder main model + 2. decoder-only assistant model + 3. both have a custom input + (e.g. DistilWhisper) + """ + + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg called foo that distorts the output + class FakeBartSeq2Seq(BartForConditionalGeneration): + def forward(self, input_ids, foo=False, **kwargs): + outs = super().forward(input_ids, **kwargs) + if foo: + outs["logits"][:, :, :] = 0.0 + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + inputs["foo"] = foo + return inputs + + class FakeBartCausalLM(BartForCausalLM): + def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs): + outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs) + if foo: + outs["logits"][:, :, :] = 0.0 + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + inputs["foo"] = foo + return inputs + + model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate(input_ids, foo=True) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = FakeBartCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-BartForConditionalGeneration" + ).to(torch_device) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # Check that passing encoder_outputs directly also works as expected + encoder_outputs = model.get_encoder()(input_ids) + + outputs_assisted = model.generate( + foo=True, + assistant_model=assistant, + encoder_outputs=encoder_outputs, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self): + # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly. + + prompt = "Alice and Bob" + checkpoint = "EleutherAI/pythia-160m-deduped" + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + inputs = tokenizer(prompt, return_tensors="pt") + + model = AutoModelForCausalLM.from_pretrained(checkpoint) + + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 5 + assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic" + generation_kwargs = { + "eos_token_id": -1, + "max_new_tokens": 5, + "do_sample": False, + "assistant_model": assistant_model, + } + model.generate(**inputs, **generation_kwargs) + # update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7 + self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7)) + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self): + # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly. + + prompt = "Alice and Bob" + checkpoint = "EleutherAI/pythia-160m-deduped" + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + inputs = tokenizer(prompt, return_tensors="pt") + + model = AutoModelForCausalLM.from_pretrained(checkpoint) + + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 5 + assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient" + generation_kwargs = { + "eos_token_id": -1, + "max_new_tokens": 5, + "do_sample": False, + "assistant_model": assistant_model, + } + model.generate(**inputs, **generation_kwargs) + # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5 + self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5) + + # TODO [gustavo] Enable this test to Optimum-habana + @slow + @pytest.mark.xfail + def test_validate_assistant(self): + # Generate a random sample: + inputs = np.random.rand(160000) + + # Load a main encoder-decoder model: + model_id = "openai/whisper-large-v2" + processor = AutoProcessor.from_pretrained(model_id) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + ) + model.to(torch_device) + + # process the input: + features = processor(inputs, return_tensors="pt").to(torch_device) + + # Load an encoder-decoder assistant with same encoder as the main model: + assistant_distil_model_id = "distil-whisper/distil-large-v2" + assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( + assistant_distil_model_id, + use_safetensors=True, + ).to(torch_device) + self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) + + # Load its decoder only version: + assistant_causal_lm = AutoModelForCausalLM.from_pretrained( + assistant_distil_model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + ).to(torch_device) + self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum()) + + # Load an encoder-decoder assistant with a different encoder than the main model: + assistant_distil_model_id = "openai/whisper-tiny" + assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( + assistant_distil_model_id, + use_safetensors=True, + ).to(torch_device) + self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) + + # Load its decoder only version: + assistant_causal_lm = AutoModelForCausalLM.from_pretrained( + assistant_distil_model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + ).to(torch_device) + # It will raise an error as the encoder of the main and assistant model are not compatible: + with self.assertRaises(ValueError): + model.generate(**features, assistant_model=assistant_causal_lm) + + # Load an encoder-decoder model with a different tokenizer than the main model: + assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText" + assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( + assistant_distil_model_id, + ).to(torch_device) + # This should raise an error as the main and assistant model don't use the same tokenizer: + with self.assertRaises(ValueError): + model.generate(**features, assistant_model=assistant_seq_to_seq) + + def test_compare_unprocessed_logit_scores(self): + # Get unprocessed logit scores back from model generate function. + # Assert that unprocessed logits from generate() are same as those from modal eval() + + # tell model to generate text and return unprocessed/unwarped logit scores + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = "generate yes or no: " + input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + with torch.no_grad(): + # Get logits for the next token from fwd pass + logits_fwd = model(input_ids).logits[:, -1, :][0] + + # Get logits for the next token from generate function + outputs = model.generate( + input_ids=input_ids, + return_dict_in_generate=True, + output_logits=True, + max_new_tokens=1, + do_sample=True, + ) + logits_gen = outputs.logits[0][0] + + # assert that unprocessed logits from generate() are same as those from modal eval() + self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist()) + + def test_return_unprocessed_logit_scores(self): + # tell model to generate text and return unprocessed/unwarped logit scores + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = "generate yes or no: " + input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + outputs = model.generate( + input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3 + ) + + # perform dummy check if unpreprocessed logits make sense. + # do preselection on high probabilities; find scores of y and n tokens + probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1) + indices = torch.argwhere(probs_all > 0.001) + indices = indices[:, -1] + tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True) + probs_max = probs_all[probs_all > 0.001] + + self.assertTrue(len(indices) >= 2) + next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)} + self.assertTrue("n" in next_token_dict) + self.assertTrue("y" in next_token_dict) + y_prob = next_token_dict["y"] + n_prob = next_token_dict["n"] + + self.assertTrue(y_prob > 0.001 and n_prob > 0.001) + self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + + @slow + @require_torch_multi_gpu + def test_assisted_decoding_in_different_gpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cuda:1" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) + + @slow + @require_torch_gpu + def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cpu" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) + + def test_special_tokens_fall_back_to_model_default(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + torch_device + ) + test_bos_id = 50 + + # Sanity-check: the model has a BOS token set, and the first generated token is a BOS token + gen_output = model.generate() + self.assertTrue(model.generation_config.bos_token_id is not None) + self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) + + # If we pass a generation config **with** a BOS token, `generate` will use it + generation_config = GenerationConfig(bos_token_id=test_bos_id) + gen_output = model.generate(generation_config=generation_config) + self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0]) + self.assertTrue(generation_config.bos_token_id == gen_output[0, 0]) + self.assertTrue(test_bos_id == gen_output[0, 0]) + + # If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from + # `model.generation_config` + generation_config = GenerationConfig(bos_token_id=None) + gen_output = model.generate(generation_config=generation_config) + self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) + self.assertFalse(test_bos_id == gen_output[0, 0]) + self.assertTrue(generation_config.bos_token_id is None) + + # Changing `model.generation_config` will affect fallback behavior + model.generation_config.bos_token_id = test_bos_id + gen_output = model.generate(generation_config=generation_config) + self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) + self.assertTrue(test_bos_id == gen_output[0, 0]) + self.assertTrue(generation_config.bos_token_id is None) + + @pytest.mark.generate + @require_torch_multi_gpu + def test_generate_with_static_cache_multi_gpu(self): + """ + Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus. + """ + # need to split manually as auto doesn't work well with unbalanced model + device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + generation_kwargs = { + "max_new_tokens": 20, + "cache_implementation": "static", + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + results = model.generate(input_ids, **generation_kwargs) + self.assertTrue(isinstance(results.past_key_values, StaticCache)) + + # check device of each layer + key_cache_0 = results.past_key_values.key_cache[0] + value_cache_0 = results.past_key_values.value_cache[0] + self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + + key_cache_1 = results.past_key_values.key_cache[1] + value_cache_1 = results.past_key_values.value_cache[1] + self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + + @pytest.mark.generate + @require_torch_multi_gpu + def test_init_static_cache_multi_gpu(self): + """ + Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup. + """ + # need to split manually as auto doesn't work well with unbalanced model + device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + generation_kwargs = { + "max_new_tokens": 20, + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + # TODO: We need to raise a warning in case the cache is not set correctly + # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"): + # past_key_values = StaticCache( + # config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype + # ) + # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) + + # deduced from the device_map : layer 0 on device 0 and layer 1 on device 1 + layer_device_map = {0: 0, 1: 1} + past_key_values = StaticCache( + config=model.config, + batch_size=1, + max_cache_len=30, + device=torch_device, + dtype=model.dtype, + layer_device_map=layer_device_map, + ) + results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) + + # check device of each layer + key_cache_0 = results.past_key_values.key_cache[0] + value_cache_0 = results.past_key_values.value_cache[0] + self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + + key_cache_1 = results.past_key_values.key_cache[1] + value_cache_1 = results.past_key_values.value_cache[1] + self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + + @slow + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_padding_input_contrastive_search_gpt2(self): + # Load the pre-trained GPT-2 model and tokenizer + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") + model.to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True) + + # Set the tokenizer to left-pad the sequences + tokenizer.padding_side = "left" + + # Define the PAD token as the EOS token + tokenizer.pad_token = tokenizer.eos_token + model.generation_config.pad_token_id = model.generation_config.eos_token_id + + # Define the input prompt + prompt_text = "The whispered legends of the haunted mansion spoke" + + # Tokenize the input prompt + encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True) + input_ids = encoded_prompt.input_ids.to(torch_device) + attention_mask = encoded_prompt.attention_mask.to(torch_device) + + # Define the contrastive search params + penalty_alpha = 0.6 + top_k = 4 + + # Define the padding length to add to the input IDs and attention mask + padding_length = 10 + + # Generate text without padding + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Pad the input IDs and attention mask on the left + padded_input_ids = F.pad( + input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id + ) + padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0) + + # Generate text with padded inputs + outputs_with_padding = model.generate( + input_ids=padded_input_ids, + attention_mask=padded_attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) + + # Assert that the generated texts are identical for padded and non-padded inputs + self.assertEqual(generated_text_no_padding, generated_text_with_padding) + self.assertEqual( + generated_text_with_padding, + 'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling ' + 'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been ' + 'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea', + ) + + @slow + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_padding_input_contrastive_search_t5(self): + # Load the pre-trained T5 model and tokenizer + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + model.to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True) + + # Define the input prompt + prompt_text = "translate English to German: I need to finish this task before the end of the day." + + # Tokenize the input prompt + encoded_prompt = tokenizer(prompt_text, return_tensors="pt") + input_ids = encoded_prompt.input_ids.to(torch_device) + attention_mask = encoded_prompt.attention_mask.to(torch_device) + + # Define the decoder prompt + decoder_prompt_text = "Ich muss diese Aufgabe" + encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt") + decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device) + decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device) + + # Define the contrastive search params + penalty_alpha = 0.6 + top_k = 4 + + # Generate text without padding + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Define the padding length to add to the input IDs and attention mask + padding_length = 10 + + # Pad the decoder input IDs and attention mask on the left + padded_decoder_input_ids = F.pad( + decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id + ) + padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0) + # Since the decoder_start_token_id is the same as the pad_token_id, + # the last padded token represents the decoder start token. + # Set the attention mask for the decoder_start_token_id to True (1). + padded_decoder_attention_mask[:, padding_length - 1] = 1 + # Generate text with padded inputs + outputs_with_padding = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=padded_decoder_input_ids, + decoder_attention_mask=padded_decoder_attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) + + # Assert that the generated texts are identical for padded and non-padded inputs + self.assertEqual(generated_text_no_padding, generated_text_with_padding) + self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") + + # TODO [gustavo] Enable this test to Optimum-habana + @pytest.mark.xfail + def test_generate_compile_fullgraph_tiny(self): + """ + Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash) + NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the + non-slow tests to prevent regressions! + """ + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + # compile generate + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + + # compiled generate does NOT accept parameterization except a) model inputs b) a generation config + generation_config = copy.deepcopy(model.generation_config) + generation_config.pad_token_id = model.config.eos_token_id + + model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt") + model_inputs = model_inputs.to(model.device) + gen_out = compiled_generate(**model_inputs, generation_config=generation_config) + self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated + + +@require_torch +class TokenHealingTestCase(unittest.TestCase): + @parameterized.expand( + [ + ( + "square_bracket", + 'An example ["like this"] and another example [', + 'An example ["like this"] and another example ["', + ), + ("url", 'The link is