Skip to content

Commit

Permalink
first modifications in the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Sep 16, 2024
1 parent 07f0e68 commit 70f4019
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ Choosing the right dataset format depends on the task you are working on and the

<Tip>

TRL trainers only support standard dataset formats. If you have a conversational dataset, you must first convert it into a standard format.
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.

</Tip>
Expand Down
89 changes: 52 additions & 37 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,75 @@ This post-training method was contributed by [Michael Noukhovitch](https://huggi
> [!WARNING]
> Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.
## Expected dataset format

Online DPO only requires a [prompt-only dataset](dataset_format#preference) (unlike offline DPO, that expects [preference dataset](dataset_format#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_format#conversational-dataset-format) and [standard](dataset_format#standard-dataset-format) dataset format.

## Quick start

This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and the [Qwen 0.5B reward model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) as the reward model. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback):

<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
The basic API is as follows:

```python
from datasets import Dataset
# train_online_dpo.py
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
NUM_DUMMY_SAMPLES = 100
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The reference model to calculate the KL divergence against
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to score completions with.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

train_dataset = Dataset.from_dict(
{"prompt": ["Q: Hi how are you? A:"] * NUM_DUMMY_SAMPLES})
eval_dataset = Dataset.from_dict(
{"prompt": ["Q: What do you like to eat A:"] * NUM_DUMMY_SAMPLES})

args = OnlineDPOConfig(output_dir="online-dpo-model")
args = OnlineDPOConfig(output_dir="online-dpo-qwen2", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model,
ref_model=ref_model,
reward_model=reward_model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
```

We run this script with the following command:

```bash
accelerate launch train_online_dpo.py
```

After approximately 1 hour of training, the model is trained and we can demonstrate the completions:

````python
>>> from transformers import pipeline
>>> generator = pipeline("text-generation", model="online-dpo-qwen2/checkpoint-500", device="cuda")
>>> question = "Can you tell me which shell command can be used to display the CPU usage of a specific process in Linux? And what is the syntax for this command?"
>>> output = generator([{"role": "user", "content": question}], max_new_tokens=200, return_full_text=False)[0]
>>> print(output["generated_text"])
Yes, you can use the `top` command in Linux to display the CPU usage of a specific process.
The syntax for the `top` command depends on your version of Linux. Here's an example of how to run `top` with the `-b` option:
```
top -b
```

This will display the CPU usage of all processes in the system at the top level. You can also specify additional options by adding them after the `-b` option using square brackets (`[]`). For example, to display only the running processes and their CPU usage, you would add the following options to the command:
```
top -b --running=1
```

Note that some versions of Linux may require you to set up a user account or enable logging before running the `top` command.
````

### Example script

To test the online DPO script with 1B parameter models, run:

```bash
Expand All @@ -78,18 +109,7 @@ Tips:
* `objective/rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
* We recommend using the "EOS trick" via the `--missing_eos_penalty` argument, which subtracts from the rewards a fixed scalar penalty for completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.

### Expected dataset format

Unlike offline DPO, where one provides a dataset with chosen and rejected columns, online DPO only requires a dataset of prompts to generate the completions from. The [`OnlineDPOTrainer`] assumes that the dataset is preprocessed for model inference, so typically you will need to wrap your prompts in the messages format and then apply the chat template as follows:

```python
def prepare_dataset(row):
"""Apply chat template to messages"""
row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False, add_generation_prompt=True)
return row

dataset = prepare_dataset(dataset)
```

### Explanation of the logged metrics

Expand All @@ -112,12 +132,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an

## What is my model doing exactly?

To help you understand what your model is doing, we periodically log some sample completions from the model via [`LogCompletionsCallback`]. You can find an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/hlzevfro?nw=nwuserlewtun), which allows you to see the model's response at different stages of training. By default we generate during training, but you can customize the number of prompts to generate completions for in [`LogCompletionsCallback`].


## Implementation details

Many online implementation details are borrowed from the [`PPOv2Trainer`], which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
To help you understand what your model is doing, we periodically log some sample completions from the model via [`LogCompletionsCallback`]. You can find an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/hlzevfro), which allows you to see the model's response at different stages of training. By default we generate during training, but you can customize the number of prompts to generate completions for in [`LogCompletionsCallback`].


## Benchmark experiments
Expand Down

0 comments on commit 70f4019

Please sign in to comment.