From aa3f41f7529ed078e9225b2fc1edbb8c71f58f99 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 17 Jan 2025 18:17:48 +0100 Subject: [PATCH] FIX: Generating with mixed adapter batches and with beam search enabled (#2287) See #2283 Right now, using mixed adapter batches with beam search generations does not work. This is because users need to pass the adapter names associated with each sample, i.e. the number of adapter names should be identical to the number of samples in the input. When applying beam search, transformers internally repeats the samples once per beam (or so it looks like). Therefore, we have more samples during generation than samples in the input. Consequently, the adapter names have to be extended accordingly. This is now taken care of. --- src/peft/tuners/lora/model.py | 23 ++++++++ tests/test_decoder_models.py | 12 +++++ tests/test_encoder_decoder_models.py | 12 +++++ tests/testing_common.py | 81 ++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 32631647b2..2967b8da9c 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -444,6 +444,18 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): if unexpected_adapters: raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}") + # deal with beam search + num_beams = kwargs.get("num_beams", None) + uses_beam_search = isinstance(num_beams, int) and (num_beams > 1) + original_adapter_names = adapter_names[:] + if uses_beam_search: + if not isinstance(adapter_names, (list, tuple)): + raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.") + # When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and + # then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the + # encoder part. Further below, the original argument is thus restored for the encoder. + adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), []) + hook_handles = [] for module in self.modules(): if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper): @@ -451,6 +463,17 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) hook_handles.append(handle) + if uses_beam_search and hasattr(self.model, "get_encoder"): + # For encoder-decoder models, even when applying beam search, the encoder part of the model should not use + # the extended adapter_names. This is because the encoder still uses the original, non-extended samples. + for module in self.model.get_encoder().modules(): + if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper): + # Add another hook to overwrite the kwargs with the original adapter names -- this is easier than + # trying to exclude the encoder. + pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + yield for handle in hook_handles: diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index a1aca79b12..78b4947db6 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -303,6 +303,18 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + ) + ) + def test_generate_with_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): + self._test_generate_with_mixed_adapter_batches_and_beam_search(model_id, config_cls, config_kwargs) + @parameterized.expand( PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_oft_or_hra_and_gpt2) ) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index e22f010089..8f8eb9c0dd 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -118,6 +118,18 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "task_type": "SEQ_2_SEQ_LM", + }, + ) + ) + def test_generate_with_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): + self._test_generate_with_mixed_adapter_batches_and_beam_search(model_id, config_cls, config_kwargs) + # skip non lora models - generate does not work for prefix tuning, prompt tuning @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_generate(self, test_name, model_id, config_cls, config_kwargs): diff --git a/tests/testing_common.py b/tests/testing_common.py index fec265812b..a553b24747 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -941,6 +941,87 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_adapter0[1::3], logits_mixed[1::3], atol=atol, rtol=rtol) assert torch.allclose(logits_adapter1[2::3], logits_mixed[2::3], atol=atol, rtol=rtol) + def _test_generate_with_mixed_adapter_batches_and_beam_search(self, model_id, config_cls, config_kwargs): + # Test generating with beam search and with mixing different adapters in a single batch by passing the + # adapter_names argument. See #2283. + if config_cls not in (LoraConfig,): + return pytest.skip(f"Mixed adapter batches not supported for {config_cls}") + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + + torch.manual_seed(0) + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config, adapter_name="adapter0").eval() + model.add_adapter("adapter1", config) + + # In contrast to forward, for generate, it can sometimes happen that we get the same results as the base model + # even with LoRA applied because the impact of LoRA is not big enough. Therefore, use this "trick" to make LoRA + # stronger. + for name, param in model.named_parameters(): + if model.base_model.prefix in name: + param.data.mul_(10.0) + + model = model.to(self.torch_device).eval() + + dummy_input = self.prepare_inputs_for_testing() + # ensure that we have at least 3 samples for this test + dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()} + + gen_kwargs = {**dummy_input, "max_length": 20, "num_beams": 10, "early_stopping": True} + with torch.inference_mode(): + with model.disable_adapter(): + gen_base = model.generate(**gen_kwargs) + + model.set_adapter("adapter0") + with torch.inference_mode(): + gen_adapter0 = model.generate(**gen_kwargs) + + model.set_adapter("adapter1") + with torch.inference_mode(): + gen_adapter1 = model.generate(**gen_kwargs) + + def remove_padding(seq, pad_value): + lst = list(seq) + while lst and (lst[-1] == pad_value): + lst.pop() + return lst + + def gens_are_same(gen0, gen1): + # Special function to compare generations. We cannot use torch.allclose it will raise an error when sequence + # lengths differ. Morevoer, we need to remove the padding from the sequences. This is because, even though + # normally identical sequences should have the same length, when we do mixed adapter batches, each sample + # will be padded to the longest sequence in that mixed batch, which can be different from the longest + # sequence without mixed adapter batches. + pad_value = model.config.eos_token_id + for sample0, sample1 in zip(gen0, gen1): + sample0 = remove_padding(sample0, pad_value) + sample1 = remove_padding(sample1, pad_value) + if (len(sample0) != len(sample1)) or (sample0 != sample1): + # at least one sample differs, the generations are not identical + return False + return True + + # sanity check that there are enough outputs and that they are different + assert len(gen_base) == len(gen_adapter0) == len(gen_adapter1) + assert len(gen_adapter1) >= 3 + assert not gens_are_same(gen_base, gen_adapter0) + assert not gens_are_same(gen_base, gen_adapter1) + assert not gens_are_same(gen_adapter0, gen_adapter1) + + # alternate between base model, adapter0, and adapter1 + adapters = ["__base__", "adapter0", "adapter1"] + gen_kwargs["adapter_names"] = [adapters[i % 3] for i in (range(len(dummy_input["input_ids"])))] + + with torch.inference_mode(): + gen_mixed = model.generate(**gen_kwargs) + + assert gens_are_same(gen_base[::3], gen_mixed[::3]) + assert gens_are_same(gen_adapter0[1::3], gen_mixed[1::3]) + assert gens_are_same(gen_adapter1[2::3], gen_mixed[2::3]) + def _test_generate(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) config = config_cls(