diff --git a/docs/source/conceptual_guides/adapter.md b/docs/source/conceptual_guides/adapter.md index 2ea9556c65..c30044bbde 100644 --- a/docs/source/conceptual_guides/adapter.md +++ b/docs/source/conceptual_guides/adapter.md @@ -93,6 +93,8 @@ OFT preserves the hyperspherical energy by learning an orthogonal transformation [AdaLoRA](https://hf.co/papers/2303.10512) manages the parameter budget introduced from LoRA by allocating more parameters - in other words, a higher rank `r` - for important weight matrices that are better adapted for a task and pruning less important ones. The rank is controlled by a method similar to singular value decomposition (SVD). The ∆W is parameterized with two orthogonal matrices and a diagonal matrix which contains singular values. This parametrization method avoids iteratively applying SVD which is computationally expensive. Based on this method, the rank of ∆W is adjusted according to an importance score. ∆W is divided into triplets and each triplet is scored according to its contribution to model performance. Triplets with low importance scores are pruned and triplets with high importance scores are kept for finetuning. +Training with AdaLoRA has three phases: the init phase, the budgeting phase and the final phase. In the initial phase, no budgeting is applied, therefore the ranks are not touched. During the budgeting phase the process described above is applied and the rank is redistributed according to a budget, aiming to give more important adapters more rank and less important layers less. When reaching the final phase, budgeting has ended, the ranks are redistributed but we may continue training for a while with the redistributed ranks to further improve performance. + ## Llama-Adapter [Llama-Adapter](https://hf.co/papers/2303.16199) is a method for adapting Llama into a instruction-following model. To help adapt the model for instruction-following, the adapter is trained with a 52K instruction-output dataset. diff --git a/src/peft/tuners/adalora/config.py b/src/peft/tuners/adalora/config.py index 4096e9ca78..004c26b0fe 100644 --- a/src/peft/tuners/adalora/config.py +++ b/src/peft/tuners/adalora/config.py @@ -25,11 +25,29 @@ class AdaLoraConfig(LoraConfig): """ This is the configuration class to store the configuration of a [`~peft.AdaLora`]. + AdaLoRA has three phases defined by `tinit`, `tfinal` and `total_step`. + + The initial phase can be understood as a step for pre-training the adapters so that when reducing their rank, there + is already some information encoded that can be reduced instead of random matrices. This phase is defined by + supplying `tinit`. + + After the initial phase is over (`tinit` steps have passed) and the final phase has not begun, AdaLoRA reduces the + budget of how much rank each layer is allowed to have with each step. This is where the reduction of rank is + happening. This goes on until `total_step - tfinal` steps are reached. + + The last phase, beginning once `total_step - tfinal` steps are reached, does not change the layer ranks anymore but + fine-tunes the reduced-rank layers that resulted from the previous phase. + + A practical example: `tinit` is 10, `tfinal` is 20, `total_step` is 100. We spend 10 steps doing pre-training + without rank reduction because our budget is constant (init phase), then we spend 80 (100-20) steps in the + reduction phase where our budget decreases step-wise and, finally, 20 steps in the final fine-tuning stage without + reduction. + Args: target_r (`int`): The target average rank of incremental matrix. init_r (`int`): The initial rank for each incremental matrix. tinit (`int`): The steps of initial fine-tuning warmup. - tfinal (`int`): The step of final fine-tuning. + tfinal (`int`): The number of steps of final fine-tuning. deltaT (`int`): The time internval between two budget allocations. beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing. beta2 (`float`): The hyperparameter of EMA for undertainty quantification. @@ -79,3 +97,12 @@ def __post_init__(self): "Note that `r` is not used in AdaLora and will be ignored." "If you intended to set the initial rank, use `init_r` instead." ) + + if self.total_step is None or self.total_step <= 0: + raise ValueError("AdaLoRA does not work when `total_step` is None, supply a value > 0.") + + if self.tinit >= (self.total_step - self.tfinal): + raise ValueError( + "The supplied schedule values don't allow for a budgeting phase. Decrease `tfinal`/`tinit` or " + "increase `total_step`." + ) diff --git a/tests/regression/test_regression.py b/tests/regression/test_regression.py index b588642be3..c9fb54d4fd 100644 --- a/tests/regression/test_regression.py +++ b/tests/regression/test_regression.py @@ -328,6 +328,7 @@ def test_adalora(self): r=8, init_lora_weights=False, target_modules=["lin0"], + total_step=1, ) model = get_peft_model(base_model, config) self.assert_results_equal_or_store(model, "adalora_mlp") @@ -567,6 +568,7 @@ def test_adalora(self): config = AdaLoraConfig( r=8, init_lora_weights=False, + total_step=1, ) model = get_peft_model(base_model, config) self.assert_results_equal_or_store(model, "adalora_opt-350m") @@ -621,6 +623,7 @@ def test_adalora(self): target_r=4, tinit=50, tfinal=100, + total_step=200, deltaT=5, beta1=0.3, beta2=0.3, @@ -681,6 +684,7 @@ def test_adalora(self): target_r=4, tinit=50, tfinal=100, + total_step=200, deltaT=5, beta1=0.3, beta2=0.3, diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 2d1f8cba1e..b366c5e331 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -291,7 +291,7 @@ def test_adalora_bnb_quantization_from_pretrained_safetensors(self, quantization kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) - config = AdaLoraConfig(task_type=TaskType.CAUSAL_LM) + config = AdaLoraConfig(task_type=TaskType.CAUSAL_LM, total_step=1) peft_model = get_peft_model(model, config) peft_model = prepare_model_for_kbit_training(peft_model) peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) @@ -1624,7 +1624,7 @@ def test_lora_dora_add_new_adapter_does_not_change_device(self, mlp): def test_adalora_add_new_adapter_does_not_change_device(self, mlp): # same as first test, but using AdaLORA # AdaLora does not like multiple trainable adapters, hence inference_mode=True - config = AdaLoraConfig(target_modules=["lin0"], inference_mode=True) + config = AdaLoraConfig(target_modules=["lin0"], inference_mode=True, total_step=1) model = get_peft_model(mlp, config) model = model.to(self.device) model.lin0.lora_A.cpu() diff --git a/tests/test_config.py b/tests/test_config.py index bc329b1784..b29f0b71a6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -48,30 +48,31 @@ PEFT_MODELS_TO_TEST = [("lewtun/tiny-random-OPTForCausalLM-delta", "v1")] +# Config classes and their mandatory parameters ALL_CONFIG_CLASSES = ( - AdaLoraConfig, - AdaptionPromptConfig, - BOFTConfig, - FourierFTConfig, - HRAConfig, - IA3Config, - LNTuningConfig, - LoHaConfig, - LoKrConfig, - LoraConfig, - MultitaskPromptTuningConfig, - PolyConfig, - PrefixTuningConfig, - PromptEncoderConfig, - PromptTuningConfig, - VeraConfig, - VBLoRAConfig, + (AdaLoraConfig, {"total_step": 1}), + (AdaptionPromptConfig, {}), + (BOFTConfig, {}), + (FourierFTConfig, {}), + (HRAConfig, {}), + (IA3Config, {}), + (LNTuningConfig, {}), + (LoHaConfig, {}), + (LoKrConfig, {}), + (LoraConfig, {}), + (MultitaskPromptTuningConfig, {}), + (PolyConfig, {}), + (PrefixTuningConfig, {}), + (PromptEncoderConfig, {}), + (PromptTuningConfig, {}), + (VeraConfig, {}), + (VBLoRAConfig, {}), ) class TestPeftConfig: - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_methods(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_methods(self, config_class, mandatory_kwargs): r""" Test if all configs have the expected methods. Here we test - to_dict @@ -80,22 +81,22 @@ def test_methods(self, config_class): - from_json_file """ # test if all configs have the expected methods - config = config_class() + config = config_class(**mandatory_kwargs) assert hasattr(config, "to_dict") assert hasattr(config, "save_pretrained") assert hasattr(config, "from_pretrained") assert hasattr(config, "from_json_file") - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) @pytest.mark.parametrize("valid_task_type", list(TaskType) + [None]) - def test_valid_task_type(self, config_class, valid_task_type): + def test_valid_task_type(self, config_class, mandatory_kwargs, valid_task_type): r""" Test if all configs work correctly for all valid task types """ - config_class(task_type=valid_task_type) + config_class(task_type=valid_task_type, **mandatory_kwargs) - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_invalid_task_type(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_invalid_task_type(self, config_class, mandatory_kwargs): r""" Test if all configs correctly raise the defined error message for invalid task types. """ @@ -104,7 +105,7 @@ def test_invalid_task_type(self, config_class): ValueError, match=f"Invalid task type: '{invalid_task_type}'. Must be one of the following task types: {', '.join(TaskType)}.", ): - config_class(task_type=invalid_task_type) + config_class(task_type=invalid_task_type, **mandatory_kwargs) def test_from_peft_type(self): r""" @@ -115,11 +116,16 @@ def test_from_peft_type(self): for peft_type in PeftType: expected_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] - config = PeftConfig.from_peft_type(peft_type=peft_type) + mandatory_config_kwargs = {} + + if expected_cls == AdaLoraConfig: + mandatory_config_kwargs = {"total_step": 1} + + config = PeftConfig.from_peft_type(peft_type=peft_type, **mandatory_config_kwargs) assert type(config) is expected_cls - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_from_pretrained(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_from_pretrained(self, config_class, mandatory_kwargs): r""" Test if the config is correctly loaded using: - from_pretrained @@ -128,22 +134,22 @@ def test_from_pretrained(self, config_class): # Test we can load config from delta config_class.from_pretrained(model_name, revision=revision) - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_save_pretrained(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_save_pretrained(self, config_class, mandatory_kwargs): r""" Test if the config is correctly saved and loaded using - save_pretrained """ - config = config_class() + config = config_class(**mandatory_kwargs) with tempfile.TemporaryDirectory() as tmp_dirname: config.save_pretrained(tmp_dirname) config_from_pretrained = config_class.from_pretrained(tmp_dirname) assert config.to_dict() == config_from_pretrained.to_dict() - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_from_json_file(self, config_class): - config = config_class() + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_from_json_file(self, config_class, mandatory_kwargs): + config = config_class(**mandatory_kwargs) with tempfile.TemporaryDirectory() as tmp_dirname: config.save_pretrained(tmp_dirname) @@ -159,17 +165,17 @@ def test_from_json_file(self, config_class): config_from_json = config_class.from_json_file(config_path) assert config.to_dict() == config_from_json - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_to_dict(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_to_dict(self, config_class, mandatory_kwargs): r""" Test if the config can be correctly converted to a dict using: - to_dict """ - config = config_class() + config = config_class(**mandatory_kwargs) assert isinstance(config.to_dict(), dict) - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_from_pretrained_cache_dir(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_from_pretrained_cache_dir(self, config_class, mandatory_kwargs): r""" Test if the config is correctly loaded with extra kwargs """ @@ -186,8 +192,8 @@ def test_from_pretrained_cache_dir_remote(self): PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname) assert "models--ybelkada--test-st-lora" in os.listdir(tmp_dirname) - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_save_pretrained_with_runtime_config(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_save_pretrained_with_runtime_config(self, config_class, mandatory_kwargs): r""" Test if the config correctly removes runtime config when saving """ @@ -201,10 +207,10 @@ def test_save_pretrained_with_runtime_config(self, config_class): cfg = config_class.from_pretrained(tmp_dirname) assert not cfg.runtime_config.ephemeral_gpu_offload - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_set_attributes(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_set_attributes(self, config_class, mandatory_kwargs): # manually set attributes and check if they are correctly written - config = config_class(peft_type="test") + config = config_class(peft_type="test", **mandatory_kwargs) # save pretrained with tempfile.TemporaryDirectory() as tmp_dirname: @@ -213,24 +219,24 @@ def test_set_attributes(self, config_class): config_from_pretrained = config_class.from_pretrained(tmp_dirname) assert config.to_dict() == config_from_pretrained.to_dict() - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_config_copy(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_config_copy(self, config_class, mandatory_kwargs): # see https://github.com/huggingface/peft/issues/424 - config = config_class() + config = config_class(**mandatory_kwargs) copied = copy.copy(config) assert config.to_dict() == copied.to_dict() - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_config_deepcopy(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_config_deepcopy(self, config_class, mandatory_kwargs): # see https://github.com/huggingface/peft/issues/424 - config = config_class() + config = config_class(**mandatory_kwargs) copied = copy.deepcopy(config) assert config.to_dict() == copied.to_dict() - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_config_pickle_roundtrip(self, config_class): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_config_pickle_roundtrip(self, config_class, mandatory_kwargs): # see https://github.com/huggingface/peft/issues/424 - config = config_class() + config = config_class(**mandatory_kwargs) copied = pickle.loads(pickle.dumps(config)) assert config.to_dict() == copied.to_dict() @@ -317,7 +323,7 @@ def test_ia3_is_feedforward_subset_valid_config(self): def test_adalora_config_r_warning(self): # This test checks that a warning is raised when r is set other than default in AdaLoraConfig # No warning should be raised when initializing AdaLoraConfig with default values. - kwargs = {"peft_type": "ADALORA", "task_type": "SEQ_2_SEQ_LM", "init_r": 12, "lora_alpha": 32} + kwargs = {"peft_type": "ADALORA", "task_type": "SEQ_2_SEQ_LM", "init_r": 12, "lora_alpha": 32, "total_step": 1} # Test that no warning is raised with default initialization with warnings.catch_warnings(): warnings.simplefilter("error") @@ -327,16 +333,58 @@ def test_adalora_config_r_warning(self): pytest.fail("AdaLoraConfig raised a warning with default initialization.") # Test that a warning is raised when r != 8 in AdaLoraConfig with pytest.warns(UserWarning, match="Note that `r` is not used in AdaLora and will be ignored."): - AdaLoraConfig(r=10) + AdaLoraConfig(r=10, total_step=1) + + def test_adalora_config_correct_timing_still_works(self): + pass + + @pytest.mark.parametrize( + "timing_kwargs", + [ + {"total_step": 100, "tinit": 0, "tfinal": 0}, + {"total_step": 100, "tinit": 10, "tfinal": 10}, + {"total_step": 100, "tinit": 79, "tfinal": 20}, + {"total_step": 100, "tinit": 80, "tfinal": 19}, + ], + ) + def test_adalora_config_valid_timing_works(self, timing_kwargs): + # Make sure that passing correct timing values is not prevented by faulty config checks. + AdaLoraConfig(**timing_kwargs) # does not raise + + def test_adalora_config_invalid_total_step_raises(self): + with pytest.raises(ValueError) as e: + AdaLoraConfig(total_step=None) + assert "AdaLoRA does not work when `total_step` is None, supply a value > 0." in str(e) + + @pytest.mark.parametrize( + "timing_kwargs", + [ + {"total_step": 100, "tinit": 20, "tfinal": 80}, + {"total_step": 100, "tinit": 80, "tfinal": 20}, + {"total_step": 10, "tinit": 20, "tfinal": 0}, + {"total_step": 10, "tinit": 0, "tfinal": 10}, + {"total_step": 10, "tinit": 10, "tfinal": 0}, + {"total_step": 10, "tinit": 20, "tfinal": 0}, + {"total_step": 10, "tinit": 20, "tfinal": 20}, + {"total_step": 10, "tinit": 0, "tfinal": 20}, + ], + ) + def test_adalora_config_timing_bounds_error(self, timing_kwargs): + # Check if the user supplied timing values that will certainly fail because it breaks + # AdaLoRA assumptions. + with pytest.raises(ValueError) as e: + AdaLoraConfig(**timing_kwargs) + + assert "The supplied schedule values don't allow for a budgeting phase" in str(e) - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_from_pretrained_forward_compatible(self, config_class, tmp_path, recwarn): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_from_pretrained_forward_compatible(self, config_class, mandatory_kwargs, tmp_path, recwarn): """ Make it possible to load configs that contain unknown keys by ignoring them. The idea is to make PEFT configs forward-compatible with future versions of the library. """ - config = config_class() + config = config_class(**mandatory_kwargs) config.save_pretrained(tmp_path) # add a spurious key to the config with open(tmp_path / "adapter_config.json") as f: @@ -356,8 +404,8 @@ def test_from_pretrained_forward_compatible(self, config_class, tmp_path, recwar assert config.to_dict() == config_from_pretrained.to_dict() assert isinstance(config_from_pretrained, config_class) - @pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES) - def test_from_pretrained_sanity_check(self, config_class, tmp_path): + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_from_pretrained_sanity_check(self, config_class, mandatory_kwargs, tmp_path): """Following up on the previous test about forward compatibility, we *don't* want any random json to be accepted as a PEFT config. There should be a minimum set of required keys. """ diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 1b0c17d488..f86302c486 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -538,15 +538,15 @@ "AdaLora Same", "adalora", AdaLoraConfig, - {"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True}, - {"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True}, + {"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True, "total_step": 1}, + {"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True, "total_step": 1}, ), ( "AdaLora Different", "adalora", AdaLoraConfig, - {"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True}, - {"target_modules": ["lin1"], "init_lora_weights": False, "inference_mode": True}, + {"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True, "total_step": 1}, + {"target_modules": ["lin1"], "init_lora_weights": False, "inference_mode": True, "total_step": 1}, ), ( "FourierFT Same", @@ -1745,7 +1745,7 @@ def test_load_resized_embedding_ignore_mismatched_sizes(self): LoraConfig(target_modules=["lin0"], init_lora_weights=False), LoKrConfig(target_modules=["lin0"], init_weights=False), LoHaConfig(target_modules=["lin0"], init_weights=False), - AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False, total_step=1), IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False), OFTConfig(target_modules=["lin0"], init_weights=False, r=2), BOFTConfig(target_modules=["lin0"], init_weights=False, boft_block_size=2), @@ -2425,10 +2425,10 @@ def test_requires_grad_ia3_same_targets(self): def test_requires_grad_adalora_different_targets(self): # test two different AdaLora adapters that target different modules - config0 = AdaLoraConfig(target_modules=["lin0"]) + config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1) peft_model = get_peft_model(MLP(), config0) - config1 = AdaLoraConfig(target_modules=["lin1"], inference_mode=True) + config1 = AdaLoraConfig(target_modules=["lin1"], total_step=1, inference_mode=True) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -2471,10 +2471,10 @@ def test_requires_grad_adalora_different_targets(self): def test_requires_grad_adalora_same_targets(self): # same as previous test, except that AdaLora adapters target the same layer - config0 = AdaLoraConfig(target_modules=["lin0"]) + config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1) peft_model = get_peft_model(MLP(), config0) - config1 = AdaLoraConfig(target_modules=["lin0"], inference_mode=True) + config1 = AdaLoraConfig(target_modules=["lin0"], total_step=1, inference_mode=True) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 2a7c643f88..8b58daa99d 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -453,6 +453,7 @@ def test_generate_adalora_no_dropout(self): "target_modules": None, "task_type": "CAUSAL_LM", "lora_dropout": 0.0, + "total_step": 1, } self._test_generate(model_id, AdaLoraConfig, config_kwargs) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 3b0cde1444..83e492ccf2 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -44,6 +44,7 @@ Seq2SeqTrainer, Seq2SeqTrainingArguments, Trainer, + TrainerCallback, TrainingArguments, WhisperFeatureExtractor, WhisperForConditionalGeneration, @@ -371,6 +372,7 @@ def test_4bit_adalora_causalLM(self): target_r=4, tinit=50, tfinal=100, + total_step=200, deltaT=5, beta1=0.3, beta2=0.3, @@ -388,6 +390,12 @@ def test_4bit_adalora_causalLM(self): batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True) self._check_inference_finite(model, batch) + class OptimizerStepCallback(TrainerCallback): + def on_optimizer_step(self, args, state, control, **kwargs): + model.update_and_allocate(state.global_step) + + step_callback = OptimizerStepCallback() + with tempfile.TemporaryDirectory() as tmp_dir: trainer = Trainer( model=model, @@ -405,6 +413,7 @@ def test_4bit_adalora_causalLM(self): data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) model.config.use_cache = False + trainer.add_callback(step_callback) trainer.train() model.cpu().save_pretrained(tmp_dir) @@ -436,6 +445,7 @@ def test_8bit_adalora_causalLM(self): target_r=4, tinit=50, tfinal=100, + total_step=200, deltaT=5, beta1=0.3, beta2=0.3, @@ -453,6 +463,12 @@ def test_8bit_adalora_causalLM(self): batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True) self._check_inference_finite(model, batch) + class OptimizerStepCallback(TrainerCallback): + def on_optimizer_step(self, args, state, control, **kwargs): + model.update_and_allocate(state.global_step) + + step_callback = OptimizerStepCallback() + with tempfile.TemporaryDirectory() as tmp_dir: trainer = Trainer( model=model, @@ -470,6 +486,7 @@ def test_8bit_adalora_causalLM(self): data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) model.config.use_cache = False + trainer.add_callback(step_callback) trainer.train() model.cpu().save_pretrained(tmp_dir) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 510f12892e..ea2dca09b7 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1352,16 +1352,16 @@ class TestAdaLoraInitialization: torch_device = infer_device() def test_adalora_target_modules_set(self): - config = AdaLoraConfig(target_modules=["linear", "embed", "conv2d"]) + config = AdaLoraConfig(target_modules=["linear", "embed", "conv2d"], total_step=1) assert config.target_modules == {"linear", "embed", "conv2d"} def test_adalora_use_dora_raises(self): with pytest.raises(ValueError, match="ADALORA does not support DoRA"): - AdaLoraConfig(use_dora=True) + AdaLoraConfig(use_dora=True, total_step=1) def test_adalora_loftq_config_raises(self): with pytest.raises(ValueError, match="ADALORA does not support LOFTQ"): - AdaLoraConfig(init_lora_weights="loftq", loftq_config={"loftq": "config"}) + AdaLoraConfig(init_lora_weights="loftq", loftq_config={"loftq": "config"}, total_step=1) def get_model(self): class MyModule(nn.Module): @@ -1385,7 +1385,7 @@ def test_adalora_default_init_identity(self, data): model = self.get_model() output_before = model(data) - config = AdaLoraConfig(target_modules=["linear"]) + config = AdaLoraConfig(target_modules=["linear"], total_step=1) model = get_peft_model(model, config) output_after = model(data) assert torch.allclose(output_before, output_after) diff --git a/tests/test_mixed.py b/tests/test_mixed.py index 3845046b4e..773df1b49c 100644 --- a/tests/test_mixed.py +++ b/tests/test_mixed.py @@ -395,7 +395,7 @@ def _check_loading(self, model_cls, config0, config1, input, *, is_commutative): LoraConfig(target_modules=["lin0"], init_lora_weights=False), LoHaConfig(target_modules=["lin0"], init_weights=False), LoKrConfig(target_modules=["lin0"], init_weights=False), - AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False, total_step=1), ], r=2, ), @@ -415,7 +415,7 @@ def test_target_first_layer(self, config0, config1): LoraConfig(target_modules=["lin1"], init_lora_weights=False), LoHaConfig(target_modules=["lin1"], init_weights=False), LoKrConfig(target_modules=["lin1"], init_weights=False), - AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False, total_step=1), ], r=2, ), @@ -439,7 +439,7 @@ def test_target_last_layer(self, config0, config1): LoraConfig(init_lora_weights=False), LoHaConfig(init_weights=False), LoKrConfig(init_weights=False), - AdaLoraConfig(init_lora_weights=False), + AdaLoraConfig(init_lora_weights=False, total_step=1), ], r=2, ), @@ -480,8 +480,8 @@ def test_target_different_layers(self, config0, config1): LoKrConfig(target_modules=["lin1"], init_weights=False), ), ( - AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), - AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False, total_step=1), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False, total_step=1), ), ], name_func=_param_name_func, @@ -509,8 +509,8 @@ def test_target_last_layer_same_type(self, config0, config1): LoKrConfig(target_modules=["lin0"], init_weights=False), ), ( - AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), - AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False, total_step=1), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False, total_step=1), ), ], name_func=_param_name_func, @@ -542,7 +542,7 @@ def test_deeply_nested(self): config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0"], init_weights=False) peft_model.add_adapter("adapter1", config1) - config2 = AdaLoraConfig(r=4, lora_alpha=4, target_modules=["lin1"], init_lora_weights=False) + config2 = AdaLoraConfig(r=4, lora_alpha=4, target_modules=["lin1"], init_lora_weights=False, total_step=1) peft_model.add_adapter("adapter2", config2) config3 = LoKrConfig(r=4, alpha=4, target_modules=["lin0", "lin1"], init_weights=False) @@ -683,7 +683,7 @@ def test_get_nb_trainable_parameters(self): assert trainable_params1 == (params_lora + params_loha) assert all_param1 == ((params_base + params_lora) + params_loha) - config2 = AdaLoraConfig(target_modules=["lin0", "lin1"]) + config2 = AdaLoraConfig(target_modules=["lin0", "lin1"], total_step=1) peft_model.add_adapter("adapter2", config2) peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) params_adalora = sum(p.numel() for n, p in model.named_parameters() if "adapter2" in n) @@ -732,7 +732,7 @@ def test_decoder_model(self): assert not torch.allclose(output0, output1) torch.manual_seed(2) - config2 = AdaLoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) + config2 = AdaLoraConfig(task_type="CAUSAL_LM", init_lora_weights=False, total_step=1) peft_model.add_adapter("adapter2", config2) peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) output2 = peft_model.generate(**input_dict) diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index 23aec8ee1a..2aa33386d2 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -31,6 +31,7 @@ BitsAndBytesConfig, DataCollatorForLanguageModeling, Trainer, + TrainerCallback, TrainingArguments, ) @@ -62,7 +63,7 @@ # Mapping: name of the setting -> (Peft config instance, torch.compile kwargs) SETTINGS = { - "adalora": (AdaLoraConfig(task_type=TaskType.CAUSAL_LM), {}), + "adalora": (AdaLoraConfig(task_type=TaskType.CAUSAL_LM, total_step=5), {}), "boft": (BOFTConfig(task_type=TaskType.CAUSAL_LM), {}), "dora": (LoraConfig(task_type=TaskType.CAUSAL_LM, use_dora=True), {}), "ia3": (IA3Config(task_type=TaskType.CAUSAL_LM), {}), @@ -93,7 +94,7 @@ class TestTorchCompileCausalLM: """ Tests for using torch.compile with causal LM. - Tip: When adding a new test, set `fake_compile = False` below. With this setting, torch.compile is being skipped. + Tip: When adding a new test, set `fake_compile = True` below. With this setting, torch.compile is being skipped. This is useful for two reasons: - compile is slow, so to quickly iterate on the test, it's best to disable it and only enable it at the very end @@ -156,8 +157,6 @@ def test_causal_lm_training_trainer_compile(self, settings, tokenizer, data, tmp r"""Train a PEFT model with torch.compile using Trainer""" tmp_dir = tmp_path / "model" config, compile_kwargs = settings - if isinstance(config, AdaLoraConfig): - pytest.skip(reason="AdaLora does not work correctly with Trainer") torch.manual_seed(0) model = AutoModelForCausalLM.from_pretrained( @@ -181,6 +180,10 @@ def test_causal_lm_training_trainer_compile(self, settings, tokenizer, data, tmp "output_dir": tmp_dir, "seed": 0, } + + if isinstance(config, AdaLoraConfig): + train_kwargs["learning_rate"] = 1e-2 + training_args = TrainingArguments( torch_compile=not self.fake_compile, torch_compile_backend=compile_kwargs.get("torch_compile_backend", None), @@ -194,6 +197,15 @@ def test_causal_lm_training_trainer_compile(self, settings, tokenizer, data, tmp data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) model.config.use_cache = False + + if isinstance(config, AdaLoraConfig): + + class OptimizerStepCallback(TrainerCallback): + def on_optimizer_step(self, args, state, control, **kwargs): + model.update_and_allocate(state.global_step) + + trainer.add_callback(OptimizerStepCallback()) + trainer.train() model.eval() diff --git a/tests/testing_common.py b/tests/testing_common.py index a553b24747..9e3fbdc667 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -92,6 +92,7 @@ # AdaLoRA { "target_modules": None, + "total_step": 1, }, # BOFT {