Skip to content

Commit

Permalink
Documentation & error checking for AdaLoRA timing (#2341)
Browse files Browse the repository at this point in the history
The documentation about how the AdaLoRA works was a bit unclear.
Especially that `tfinal` is not a point in time but a duration.

It was also possible to build schedules that never budget and
therefore lead to an exception because the code does not expect
this case (which is OK). We prevent such a scenario now by treating
this configuration as invalid. (Issue #2337)

We also check for `total_step` != None since this is also a guaranteed error in the code.
  • Loading branch information
githubnemo authored Jan 24, 2025
1 parent 6538e56 commit 9c25d94
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 91 deletions.
2 changes: 2 additions & 0 deletions docs/source/conceptual_guides/adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ OFT preserves the hyperspherical energy by learning an orthogonal transformation

[AdaLoRA](https://hf.co/papers/2303.10512) manages the parameter budget introduced from LoRA by allocating more parameters - in other words, a higher rank `r` - for important weight matrices that are better adapted for a task and pruning less important ones. The rank is controlled by a method similar to singular value decomposition (SVD). The ∆W is parameterized with two orthogonal matrices and a diagonal matrix which contains singular values. This parametrization method avoids iteratively applying SVD which is computationally expensive. Based on this method, the rank of ∆W is adjusted according to an importance score. ∆W is divided into triplets and each triplet is scored according to its contribution to model performance. Triplets with low importance scores are pruned and triplets with high importance scores are kept for finetuning.

Training with AdaLoRA has three phases: the init phase, the budgeting phase and the final phase. In the initial phase, no budgeting is applied, therefore the ranks are not touched. During the budgeting phase the process described above is applied and the rank is redistributed according to a budget, aiming to give more important adapters more rank and less important layers less. When reaching the final phase, budgeting has ended, the ranks are redistributed but we may continue training for a while with the redistributed ranks to further improve performance.

## Llama-Adapter

[Llama-Adapter](https://hf.co/papers/2303.16199) is a method for adapting Llama into a instruction-following model. To help adapt the model for instruction-following, the adapter is trained with a 52K instruction-output dataset.
Expand Down
29 changes: 28 additions & 1 deletion src/peft/tuners/adalora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,29 @@ class AdaLoraConfig(LoraConfig):
"""
This is the configuration class to store the configuration of a [`~peft.AdaLora`].
AdaLoRA has three phases defined by `tinit`, `tfinal` and `total_step`.
The initial phase can be understood as a step for pre-training the adapters so that when reducing their rank, there
is already some information encoded that can be reduced instead of random matrices. This phase is defined by
supplying `tinit`.
After the initial phase is over (`tinit` steps have passed) and the final phase has not begun, AdaLoRA reduces the
budget of how much rank each layer is allowed to have with each step. This is where the reduction of rank is
happening. This goes on until `total_step - tfinal` steps are reached.
The last phase, beginning once `total_step - tfinal` steps are reached, does not change the layer ranks anymore but
fine-tunes the reduced-rank layers that resulted from the previous phase.
A practical example: `tinit` is 10, `tfinal` is 20, `total_step` is 100. We spend 10 steps doing pre-training
without rank reduction because our budget is constant (init phase), then we spend 80 (100-20) steps in the
reduction phase where our budget decreases step-wise and, finally, 20 steps in the final fine-tuning stage without
reduction.
Args:
target_r (`int`): The target average rank of incremental matrix.
init_r (`int`): The initial rank for each incremental matrix.
tinit (`int`): The steps of initial fine-tuning warmup.
tfinal (`int`): The step of final fine-tuning.
tfinal (`int`): The number of steps of final fine-tuning.
deltaT (`int`): The time internval between two budget allocations.
beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing.
beta2 (`float`): The hyperparameter of EMA for undertainty quantification.
Expand Down Expand Up @@ -79,3 +97,12 @@ def __post_init__(self):
"Note that `r` is not used in AdaLora and will be ignored."
"If you intended to set the initial rank, use `init_r` instead."
)

if self.total_step is None or self.total_step <= 0:
raise ValueError("AdaLoRA does not work when `total_step` is None, supply a value > 0.")

if self.tinit >= (self.total_step - self.tfinal):
raise ValueError(
"The supplied schedule values don't allow for a budgeting phase. Decrease `tfinal`/`tinit` or "
"increase `total_step`."
)
4 changes: 4 additions & 0 deletions tests/regression/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def test_adalora(self):
r=8,
init_lora_weights=False,
target_modules=["lin0"],
total_step=1,
)
model = get_peft_model(base_model, config)
self.assert_results_equal_or_store(model, "adalora_mlp")
Expand Down Expand Up @@ -567,6 +568,7 @@ def test_adalora(self):
config = AdaLoraConfig(
r=8,
init_lora_weights=False,
total_step=1,
)
model = get_peft_model(base_model, config)
self.assert_results_equal_or_store(model, "adalora_opt-350m")
Expand Down Expand Up @@ -621,6 +623,7 @@ def test_adalora(self):
target_r=4,
tinit=50,
tfinal=100,
total_step=200,
deltaT=5,
beta1=0.3,
beta2=0.3,
Expand Down Expand Up @@ -681,6 +684,7 @@ def test_adalora(self):
target_r=4,
tinit=50,
tfinal=100,
total_step=200,
deltaT=5,
beta1=0.3,
beta2=0.3,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_adalora_bnb_quantization_from_pretrained_safetensors(self, quantization
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
config = AdaLoraConfig(task_type=TaskType.CAUSAL_LM)
config = AdaLoraConfig(task_type=TaskType.CAUSAL_LM, total_step=1)
peft_model = get_peft_model(model, config)
peft_model = prepare_model_for_kbit_training(peft_model)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
Expand Down Expand Up @@ -1624,7 +1624,7 @@ def test_lora_dora_add_new_adapter_does_not_change_device(self, mlp):
def test_adalora_add_new_adapter_does_not_change_device(self, mlp):
# same as first test, but using AdaLORA
# AdaLora does not like multiple trainable adapters, hence inference_mode=True
config = AdaLoraConfig(target_modules=["lin0"], inference_mode=True)
config = AdaLoraConfig(target_modules=["lin0"], inference_mode=True, total_step=1)
model = get_peft_model(mlp, config)
model = model.to(self.device)
model.lin0.lora_A.cpu()
Expand Down
Loading

0 comments on commit 9c25d94

Please sign in to comment.