Skip to content

Commit

Permalink
Allow swapping PEFT adapters for target/ref model. (#1193)
Browse files Browse the repository at this point in the history
* Allow swapping PEFT adapters for target/ref model.

* Update DPOTrainer docs.

* python format

* isort

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <[email protected]>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <[email protected]>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <[email protected]>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <[email protected]>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <[email protected]>

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
jondurbin and kashif authored Jan 8, 2024
1 parent dbcb2f0 commit 3267be0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 8 deletions.
64 changes: 62 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ While training and evaluating we record the following reward metrics:
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards

### Accelerate DPO fine-tuning using `unsloth`
## Accelerate DPO fine-tuning using `unsloth`

You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `DPOTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:
Expand Down Expand Up @@ -156,6 +156,66 @@ dpo_trainer.train()

The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).

## Reference model considerations with PEFT

You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.

1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
2. Merge the adapter into the base model, create another adapter on top, then leave the `model_ref` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.

### Downsides to merging QLoRA before DPO (approach 2)

As suggested by [Tim Dettmers](https://twitter.com/Tim_Dettmers/status/1694654191325573456), the best option for merging QLoRA adapters is to first quantize the base model, merge the adapter, then convert back to bf16. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py)

You can also just merge the adapters the standard way without quantizing the base model, but then you have 1-2% reduced performance (and evidently, more issues with empty responses).

If you use the recommended approach, which quantizes the model, you're now in a situation where to use QLoRA for DPO, you will need to re-quantize the merged model again or use an unquantized merge with lower overall performance.

### Using option 3 - load the adapter twice

To avoid the downsides with option 2, at the expense of slightly increased VRAM, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer.

For example:
```python
# Load the base model.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False

# Load the adapter.
model = PeftModel.from_pretrained(
model,
"/path/to/peft",
is_trainable=True,
adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter("/path/to/peft", adapter_name="reference")

# Initialize the trainer, without a ref_model param.
dpo_trainer = DPOTrainer(
model,
...
model_adapter_name="train",
ref_adapter_name="reference",
)
```

## DPOTrainer

[[autodoc]] DPOTrainer
[[autodoc]] DPOTrainer
30 changes: 24 additions & 6 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random
import warnings
from collections import defaultdict
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -126,6 +126,10 @@ class DPOTrainer(Trainer):
Dict of Optional kwargs to pass when instantiating the model from a string
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string
model_adapter_name (`str`, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`str`, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
"""

_tag_names = ["trl", "dpo"]
Expand Down Expand Up @@ -160,6 +164,8 @@ def __init__(
precompute_ref_log_probs: bool = False,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
model_adapter_name: str = None,
ref_adapter_name: str = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
Expand Down Expand Up @@ -253,6 +259,8 @@ def make_inputs_require_grad(module, input, output):
self.is_encoder_decoder = is_encoder_decoder

self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name

if ref_model:
self.ref_model = ref_model
Expand Down Expand Up @@ -704,14 +712,24 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)

return batch

@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")

def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
"""Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
# compute reference logps
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model else nullcontext():
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
Expand Down Expand Up @@ -976,7 +994,7 @@ def get_batch_loss_metrics(
else:
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
Expand Down Expand Up @@ -1048,7 +1066,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[
reference_output = batch["reference_output"]
else:
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
with self.null_ref_context():
reference_output = self.model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
Expand Down

0 comments on commit 3267be0

Please sign in to comment.