Skip to content

Commit

Permalink
More convenient way to initialize LoftQ (#1543)
Browse files Browse the repository at this point in the history
Related to #1532

At the moment, using LoftQ is quite cumbersome, as shown in this
example:

https://github.com/huggingface/peft/tree/7e84dec20b3106bdd0a90ba8e80187f0aec835b7/examples/loftq_finetuning

Essentially, users have to:

1. Load the non-quantized model with LoftQ (which can be quite huge)
2. Modify the PEFT config
3. Save the adapter
4. Unwrap the base model
5. Save the base model with modified weights (i.e. a whole copy of the
   base model)
6. Load the base model from step 5 with bnb quantization
7. Load the adapter from step 3

Yes, there is a helper script to do this, but this still has the
advantage that we need to load the non-quantized model and that we have
to create a completely new model checkpoint with the modified weights.

This PR aims to make this process more convenient by adding a single
function replace_lora_weights_loftq. This function takes the
bnb-quantized LoRA model as input. Then it goes through each module with
LoRA weights, lazily loads the corresponding non-quantized weights one
at a time using safetensors, computes the quantization error, and
replaces the LoRA weights with LoftQ-initialized LoRA weights.

This is much more convenient because we only require very little extra
memory thanks to lazy loading, and we don't have to keep an extra copy
of the weights.

While working on this, I still found that LoftQ initialization often did
not seem to help a lot, as mentioned in #1532. I measured this by
creating (1) logits with the base model, (2) with the quantized+LoRA
model, and (3) with the quantized+LoRA+LoftQ model. The expectation is
that (1) should be closer to (3) than to (2). This was often not the
case.

I therefore added the possibility to run a check each time that we
replace a LoRA weight with the LoftQ weights. If this check returns
True, we proceed to the next weight, otherwise we discard the change.
That way, we only make the replacement with LoftQ weights if we see a
real improvement. Of course, this is only a form of greedy optimization,
but it seems to work in practice. And since it's optional, users can
choose not to use it.

This doesn't support 8bit quantization and the num_iter arguments of LoftQ.
However, the replace_lora_weights_loftq function can be called multiple
times in a row for slightly improved results.

---------

Co-authored-by: Steven Liu <[email protected]>
  • Loading branch information
BenjaminBossan and stevhliu authored Mar 20, 2024
1 parent a86b29a commit 8e979fc
Show file tree
Hide file tree
Showing 8 changed files with 1,086 additions and 2 deletions.
27 changes: 27 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,37 @@ config = LoraConfig(init_lora_weights=False, ...)

### LoftQ

#### Standard approach

When quantizing the base model for QLoRA training, consider using the [LoftQ initialization](https://arxiv.org/abs/2310.08659), which has been shown to improve performance when training quantized models. The idea is that the LoRA weights are initialized such that the quantization error is minimized. To use LoftQ, follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/loftq_finetuning).

In general, for LoftQ to work best, it is recommended to target as many layers with LoRA as possible, since those not targeted cannot have LoftQ applied. This means that passing `LoraConfig(..., target_modules="all-linear")` will most likely give the best results. Also, you should use `nf4` as quant type in your quantization config when using 4bit quantization, i.e. `BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")`.

#### A more convienient way

An easier but more limited way to apply LoftQ initialization is to use the convenience function `replace_lora_weights_loftq`. This takes the quantized PEFT model as input and replaces the LoRA weights in-place with their LoftQ-initialized counterparts.

```python
from peft import replace_lora_weights_loftq
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(load_in_4bit, ...)
base_model = AutoModelForCausalLM.from_pretrained(..., quantization_config=bnb_config)
# note: don't pass init_lora_weights="loftq" or loftq_config!
lora_config = LoraConfig(task_type="CAUSAL_LM")
peft_model = get_peft_model(base_model, lora_config)
replace_lora_weights_loft(peft_model)
```

`replace_lora_weights_loftq` also allows you to pass a `callback` argument to give you more control over which layers should be modified or not, which empirically can improve the results quite a lot. To see a more elaborate example of this, check out [this notebook](https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/LoftQ_weight_replacement.ipynb).

`replace_lora_weights_loftq` implements only one iteration step of LoftQ. This means that only the LoRA weights are updated, instead of iteratevily updating LoRA weights and quantized base model weights. This may lead to lower performance but has the advantage that we can use the original quantized weights derived from the base model, instead of having to keep an extra copy of modified quantized weights. Whether this tradeoff is worthwhile depends on the use case.

At the moment, `replace_lora_weights_loftq` has these additional limitations:

- Model files must be stored as a `safetensors` file.
- Only bitsandbytes 4bit quantization is supported.

<Tip>

Learn more about how PEFT works with quantization in the [Quantization](quantization) guide.
Expand Down
6 changes: 5 additions & 1 deletion docs/source/package_reference/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,8 @@ The abstract from the paper is:

## LoraModel

[[autodoc]] tuners.lora.model.LoraModel
[[autodoc]] tuners.lora.model.LoraModel

## Utility

[[autodoc]] utils.loftq_utils.replace_lora_weights_loftq
Loading

0 comments on commit 8e979fc

Please sign in to comment.