Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversational dataset support for DPOTrainer #2131

Merged
merged 39 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ad6b9a8
conversational dataset support for dpo
qgallouedec Sep 26, 2024
94a29ef
support standard dataset for extract prompt
qgallouedec Sep 26, 2024
69a6933
test standard dataset for extract prompt
qgallouedec Sep 26, 2024
e479c85
fix maybe
qgallouedec Sep 26, 2024
4301724
fix maybe apply prompt
qgallouedec Sep 26, 2024
2fbca62
Merge branch 'main' into dpo-conversational-dataset
qgallouedec Sep 26, 2024
ead4114
style
qgallouedec Sep 26, 2024
6f058df
Merge branch 'main' into dpo-conversational-dataset
qgallouedec Sep 27, 2024
3428813
overwrite default learning rate of DPO
qgallouedec Sep 27, 2024
61c589f
style
qgallouedec Sep 27, 2024
9c6769b
rlaif script
qgallouedec Sep 27, 2024
c656c99
`writer_batch_size` in `train_test_split`
qgallouedec Sep 27, 2024
8ddf39e
initial dpo doc refactoring
qgallouedec Sep 27, 2024
d461963
vision data section in doc
qgallouedec Sep 27, 2024
e513cfe
lil format modif
qgallouedec Sep 27, 2024
dbf003e
Merge branch 'main' into dpo-conversational-dataset
qgallouedec Sep 27, 2024
b22bb82
refine Vision datasets
qgallouedec Sep 28, 2024
5b8e75f
refine doc
qgallouedec Sep 28, 2024
93f87b8
test new loss type format
qgallouedec Sep 28, 2024
0671ab5
restrcture loss function
qgallouedec Sep 28, 2024
840db37
table loss type
qgallouedec Sep 28, 2024
08b21b1
simplify `unsloth`
qgallouedec Sep 30, 2024
083aeb5
improve doc
qgallouedec Sep 30, 2024
92bed88
looged metrics up
qgallouedec Sep 30, 2024
985227e
refine loss section
qgallouedec Sep 30, 2024
9ba55e8
Fix label_smoothing parameter in DPOConfig
qgallouedec Sep 30, 2024
208f34e
Merge branch 'main' into dpo-conversational-dataset
qgallouedec Sep 30, 2024
c2d1836
dataset for test
qgallouedec Sep 30, 2024
9869467
update readme
qgallouedec Sep 30, 2024
063628d
Merge branch 'dpo-conversational-dataset' of https://github.com/huggi…
qgallouedec Sep 30, 2024
f50a4bb
Update docs/source/dpo_trainer.mdx
qgallouedec Oct 1, 2024
df7cb6a
try colorized code block
qgallouedec Oct 1, 2024
8749c70
Merge branch 'main' into dpo-conversational-dataset
qgallouedec Oct 1, 2024
bb2b368
refine doc style
qgallouedec Oct 1, 2024
3d8e0b6
further refine doc
qgallouedec Oct 1, 2024
2f591f5
Update docs/source/dpo_trainer.mdx
qgallouedec Oct 1, 2024
4ac091b
Merge branch 'main' into dpo-conversational-dataset
qgallouedec Oct 1, 2024
a55c8ec
re add pali gemma test
qgallouedec Oct 2, 2024
94a31e4
Add missing period
qgallouedec Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": "
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
```

Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved

### Unpaired preference

An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
Expand Down Expand Up @@ -710,3 +712,35 @@ dataset = dataset.remove_columns(["completion", "label"])
>>> dataset[0]
{'prompt': 'The sky is'}
```

## Vision datasets

Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.

A conversational vision dataset differs from a standard conversational dataset in two key ways:

1. The dataset must contain the key `images` with the image data.
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.

Example:

```python
# Textual dataset format:
"content": "What color is the sky?"

# Vision dataset format:
"content": [
{"type": "image"},
{"type": "text", "text": "What color is the sky in the image?"}
]
```

An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:

<iframe
src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>

300 changes: 136 additions & 164 deletions docs/source/dpo_trainer.mdx

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = OnlineDPOConfig(output_dir="online-dpo-qwen2", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model,
reward_model=reward_model,
args=training_args,
tokenizer=tokenizer,
train_dataset=train_dataset,
model=model, reward_model=reward_model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Expand Down
73 changes: 73 additions & 0 deletions examples/datasets/rlaif-v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional

from datasets import features, load_dataset
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.

Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/rlaif-v"
dataset_num_proc: Optional[int] = None


def to_conversational(example):
"""
Convert prompt from "xxx" to [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "xxx"}]}]
and chosen and rejected from "xxx" to [{"role": "assistant", "content": [{"type": "text", "text": "xxx"}]}].
Images are wrapped into a list.
"""
prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}]
chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}]
rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}]
return {"prompt": prompt, "images": [example["image"]], "chosen": chosen, "rejected": rejected}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train")
dataset = dataset.map(
to_conversational,
num_proc=script_args.dataset_num_proc,
remove_columns=dataset.column_names,
writer_batch_size=128,
)

# Cast the images to Sequence[Image] to avoid bytes format
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True))
dataset = dataset.cast(f)

dataset = dataset.train_test_split(test_size=0.01, writer_batch_size=128)

if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)
9 changes: 0 additions & 9 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
"""

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand All @@ -60,8 +59,6 @@
get_kbit_device_map,
get_peft_config,
get_quantization_config,
maybe_apply_chat_template,
maybe_extract_prompt,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

Expand Down Expand Up @@ -115,12 +112,6 @@
################
dataset = load_dataset(script_args.dataset_name)

with PartialState().local_main_process_first():
dataset = dataset.map(maybe_extract_prompt, num_proc=training_args.dataset_num_proc)
dataset = dataset.map(
maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer}
)

##########
# Training
################
Expand Down
60 changes: 49 additions & 11 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_maybe_unpair_preference_dataset_dict_already_paired(self):


class ExtractPromptTester(unittest.TestCase):
example_implicit_prompt = {
example_implicit_prompt_conversational = {
"chosen": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
Expand All @@ -279,7 +279,7 @@ class ExtractPromptTester(unittest.TestCase):
],
}

example_explicit_prompt = {
example_explicit_prompt_conversational = {
"prompt": [
{"role": "user", "content": "What color is the sky?"},
],
Expand All @@ -291,30 +291,68 @@ class ExtractPromptTester(unittest.TestCase):
],
}

def test_extract_prompt(self):
example_implicit_prompt_standard = {
"chosen": "The sky is blue.",
"rejected": "The sky is green.",
}

example_explicit_prompt_standard = {
"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green.",
}

def test_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_conversational_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt should remain unchanged.",
)

def test_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt)
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt(self):
def test_maybe_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt)
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_already_explicit(self):
def test_maybe_extract_prompt_standard_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt)
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt should remain unchanged.",
)

Expand Down
22 changes: 13 additions & 9 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, TypeVar
from typing import Any, Dict, List, Optional, Sequence, TypeVar

from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -280,15 +280,17 @@ def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int
return dataset


def extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
r"""
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
the chosen and rejected completions.

For more details, see [`maybe_extract_prompt`].
"""
for idx in range(min(len(example["chosen"]), len(example["rejected"]))):
if example["chosen"][idx]["content"] != example["rejected"][idx]["content"]:
if example["chosen"][idx] != example["rejected"][idx]:
if example["chosen"][idx - 1] == " ": # remove space before the prompt
idx -= 1
Comment on lines +291 to +293
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str1 = "I am Quentin"
str2 = "I am in Lyon"
# What we want:
prompt = "I am"
# What we don't want:
prompt = "I am "

That's why, when the prompt ends with a space, we take idx-1 instead.

break
return {
"prompt": example["chosen"][:idx],
Expand All @@ -303,15 +305,14 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
the chosen and rejected completions.

If the example already contains a `"prompt"` key, the function returns the example as is. Else, the function

identifies the longest common sequence (prefix) of conversation turns between the "chosen" and "rejected"
completions and extracts this as the prompt. It then removes this prompt from the respective "chosen" and
"rejected" completions.

Args:
example (`Dict[str, List]`):
A dictionary representing a single data entry in the preference dataset. It must contain the keys
`"chosen"` and `"rejected"`, where each value is a list.
`"chosen"` and `"rejected"`, where each value is either conversational or standard (`str`).

Returns:
`Dict[str, List]`: A dictionary containing:
Expand Down Expand Up @@ -379,7 +380,10 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
# "chosen": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
# "rejected": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}]}
# That's why we check if the prompt is also conversational before deciding not to extract it.
if "prompt" in example and is_conversational({"prompt": example["prompt"]}):
return example
else:
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})
if "prompt" in example:
# Both conversational or both non-conversational
chosen_conv = is_conversational({"chosen": example["chosen"]})
prompt_conv = is_conversational({"prompt": example["prompt"]})
if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv):
return example
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})
6 changes: 5 additions & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class DPOConfig(TrainingArguments):
command line.

Parameters:
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
Expand Down Expand Up @@ -110,7 +113,7 @@ class DPOConfig(TrainingArguments):
f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
Type of f-divergence regularization function to compute divergence between policy and reference model.
f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
α coefficient in the α-divergence \\(u^{-\\alpha}\\) regularization function for DPO loss.
α coefficient in the α-divergence u^-α regularization function for DPO loss.
sync_ref_model (`bool`, *optional*, defaults to `False`):
When set to `True`, the reference model is synchronized with the active model every `ref_model_sync_steps`
steps, using the `ref_model_mixup_alpha` parameter. This synchronization originites from the
Expand All @@ -130,6 +133,7 @@ class DPOConfig(TrainingArguments):
DPO loss. The paper recommends `rpo_alpha=1.0`.
"""

learning_rate: float = 1e-6
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
beta: float = 0.1
label_smoothing: float = 0.0
loss_type: Literal[
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper, create_reference_model
from .callbacks import SyncRefModelCallback
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
Expand Down Expand Up @@ -815,6 +816,17 @@ def make_inputs_require_grad(module, input, output):
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Extract the prompt if needed, and apply the chat template if needed
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
train_dataset = train_dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}, num_proc=args.dataset_num_proc
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}, num_proc=args.dataset_num_proc
)

# tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models)
fn_kwargs = {
"tokenizer": self.tokenizer,
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class KTOConfig(TrainingArguments):

Parameters:
learning_rate (`float`, *optional*, defaults to `5e-7`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`].
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
Expand Down
Loading