diff --git a/README.md b/README.md
index 3ac69fafff..04af6a195e 100644
--- a/README.md
+++ b/README.md
@@ -103,15 +103,12 @@ from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
-# configure trainer
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
-
-# train
trainer.train()
```
@@ -121,7 +118,6 @@ Here is a basic example on how to use the `RewardTrainer`:
```python
from trl import RewardConfig, RewardTrainer
-from trl.extras.dataset_formatting import conversations_formatting_function
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@@ -131,48 +127,15 @@ model = AutoModelForSequenceClassification.from_pretrained(
)
model.config.pad_token_id = tokenizer.pad_token_id
-dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
-
-def preprocess_function(examples):
- new_examples = {
- "input_ids_chosen": [],
- "attention_mask_chosen": [],
- "input_ids_rejected": [],
- "attention_mask_rejected": [],
- }
- for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
- tokenized_chosen = tokenizer(chosen)
- tokenized_rejected = tokenizer(rejected)
- new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
- new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
- new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
- new_examples["attention_mask_rejected"].append(
- tokenized_rejected["attention_mask"]
- )
-
- return new_examples
-
-chosen_fn = conversations_formatting_function(tokenizer, "chosen")
-rejected_fn = conversations_formatting_function(tokenizer, "rejected")
-dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)})
-dataset = dataset.map(
- preprocess_function,
- batched=True,
-)
+dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
-training_args = RewardConfig(
- per_device_train_batch_size=2,
- remove_unused_columns=False,
- output_dir="Qwen2.5-0.5B-Reward",
-)
+training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
-
-# train
trainer.train()
```
@@ -210,7 +173,6 @@ trainer = RLOOTrainer(
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
-# train
trainer.train()
```
@@ -219,21 +181,17 @@ trainer.train()
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the `DPOTrainer`:
```python
-# imports
from trl import DPOConfig, DPOTrainer, maybe_extract_prompt, maybe_apply_chat_template
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
-# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
-# load preference dataset - needs to be in a specific format
dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
dataset = dataset.map(maybe_extract_prompt)
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
-# load trainer
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
args=training_args,
@@ -241,8 +199,6 @@ trainer = DPOTrainer(
tokenizer=tokenizer,
train_dataset=dataset,
)
-
-# train
trainer.train()
```
diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.mdx
index 5a73217ead..42bd487bac 100644
--- a/docs/source/reward_trainer.mdx
+++ b/docs/source/reward_trainer.mdx
@@ -6,18 +6,10 @@ Check out a complete flexible example at [`examples/scripts/reward_modeling.py`]
## Expected dataset format
-The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
+The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `chosen` and `rejected` (and not `prompt`).
+The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
-
-
-
-
-Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named:
-
-- `input_ids_chosen`
-- `attention_mask_chosen`
-- `input_ids_rejected`
-- `attention_mask_rejected`
+You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.
## Using the `RewardTrainer`
diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py
index bbb2e23459..44b7c00729 100644
--- a/examples/scripts/reward_modeling.py
+++ b/examples/scripts/reward_modeling.py
@@ -19,8 +19,6 @@
--output_dir Qwen2-0.5B-Reward \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
- --gradient_accumulation_steps 1 \
- --remove_unused_columns False \
--gradient_checkpointing True \
--learning_rate 1.0e-5 \
--logging_steps 25 \
@@ -32,17 +30,15 @@
python examples/scripts/reward_modeling.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
- --output_dir Qwen2-0.5B-Reward \
+ --output_dir Qwen2-0.5B-Reward-LoRA \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
- --gradient_accumulation_steps 1 \
- --remove_unused_columns False \
--gradient_checkpointing True \
- --learning_rate 1.0e-5 \
+ --learning_rate 1.0e-4 \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 50 \
- --max_length 2048 /
+ --max_length 2048 \
--use_peft \
--lora_r 32 \
--lora_alpha 16
@@ -51,9 +47,7 @@
import warnings
import torch
-from accelerate import PartialState
from datasets import load_dataset
-from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
from trl import (
@@ -66,10 +60,6 @@
setup_chat_format,
)
from trl.commands.cli_utils import RewardScriptArguments
-from trl.extras.dataset_formatting import conversations_formatting_function
-
-
-tqdm.pandas()
if __name__ == "__main__":
@@ -90,6 +80,7 @@
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
+ use_cache=False if training_args.gradient_checkpointing else True,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
@@ -110,49 +101,11 @@
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
)
- #############################
- # Load and preprocess dataset
- #############################
+ ##############
+ # Load dataset
+ ##############
dataset = load_dataset(args.dataset_name)
- def preprocess_function(examples):
- new_examples = {
- "input_ids_chosen": [],
- "attention_mask_chosen": [],
- "input_ids_rejected": [],
- "attention_mask_rejected": [],
- }
- for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
- tokenized_chosen = tokenizer(chosen)
- tokenized_rejected = tokenizer(rejected)
- new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
- new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
- new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
- new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
-
- return new_examples
-
- with PartialState().local_main_process_first():
- # Wrap inputs with chat template.
- # This assumes the chosen/rejected columns are in the OpenAI messages format.
- chosen_fn = conversations_formatting_function(tokenizer, "chosen")
- rejected_fn = conversations_formatting_function(tokenizer, "rejected")
- dataset = dataset.map(
- lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=training_args.dataset_num_proc
- )
- # Tokenize inputs
- dataset = dataset.map(
- preprocess_function,
- batched=True,
- num_proc=training_args.dataset_num_proc,
- )
- # Filter out examples that are too long
- dataset = dataset.filter(
- lambda x: len(x["input_ids_chosen"]) <= training_args.max_length
- and len(x["input_ids_rejected"]) <= training_args.max_length,
- num_proc=training_args.dataset_num_proc,
- )
-
##########
# Training
##########
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index 9e37502ce7..2a9ce8cea8 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -14,103 +14,96 @@
import tempfile
import unittest
-import pytest
import torch
-from datasets import Dataset
+from datasets import Dataset, load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
from transformers.testing_utils import require_peft
+from transformers.utils import is_peft_available
-from trl import RewardConfig, RewardTrainer
+from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
from trl.trainer import compute_accuracy
+from trl.trainer.reward_trainer import _tokenize
+
+
+if is_peft_available():
+ from peft import LoraConfig, TaskType
class RewardTrainerTester(unittest.TestCase):
+ def setUp(self):
+ self.model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
+ self.tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
+
def test_accuracy_metrics(self):
dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0]))
accuracy = compute_accuracy(dummy_eval_predictions)
assert accuracy["accuracy"] == 0.5
- def test_reward_trainer(self):
- model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- tokenizer.pad_token = tokenizer.eos_token
-
+ def test_preprocessing_conversational(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = RewardConfig(
- output_dir=tmp_dir,
- per_device_train_batch_size=2,
- max_steps=3,
- remove_unused_columns=False,
- gradient_accumulation_steps=4,
- learning_rate=9e-1,
- eval_strategy="steps",
- report_to="none",
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+ training_args = RewardConfig(output_dir=tmp_dir, report_to="none")
+ trainer = RewardTrainer(
+ model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset
)
+ dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
+ dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
+ self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
- # fmt: off
- dummy_dataset_dict = {
- "input_ids_chosen": [
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- ],
- "attention_mask_chosen": [
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- ],
- "input_ids_rejected": [
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- ],
- "attention_mask_rejected": [
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 0]),
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 1]),
- ],
- }
- # fmt: on
- dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
-
+ def test_preprocessing_standard(self):
+ # No chat template, so we load a fresh tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.model_id)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
+ training_args = RewardConfig(output_dir=tmp_dir, report_to="none")
trainer = RewardTrainer(
- model=model,
- args=training_args,
- tokenizer=tokenizer,
- train_dataset=dummy_dataset,
- eval_dataset=dummy_dataset,
+ model=self.model, args=training_args, tokenizer=tokenizer, train_dataset=dummy_dataset
)
+ dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer})
+ self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
+ def test_train_full(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+ training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none")
+ trainer = RewardTrainer(
+ model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset
+ )
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
-
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
-
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
-
- preds = trainer.predict(dummy_dataset)
- assert preds.predictions.shape == (4, 2)
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
- @require_peft
- def test_reward_trainer_peft(self):
- from peft import LoraConfig, TaskType
+ def test_train_full_pretokenized(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+ dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
+ dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
+ training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none")
+ trainer = RewardTrainer(
+ model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset
+ )
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+ trainer.train()
- model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- tokenizer.pad_token = tokenizer.eos_token
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ # check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ # check the params have changed - ignore 0 biases
+ if param.sum() != 0:
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+ @require_peft
+ def test_train_lora(self):
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
@@ -118,55 +111,14 @@ def test_reward_trainer_peft(self):
lora_alpha=32,
lora_dropout=0.1,
)
-
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = RewardConfig(
- output_dir=tmp_dir,
- per_device_train_batch_size=2,
- max_steps=6,
- remove_unused_columns=False,
- gradient_accumulation_steps=2,
- learning_rate=9e-1,
- eval_strategy="steps",
- report_to="none",
- )
-
- # fmt: off
- dummy_dataset_dict = {
- "input_ids_chosen": [
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- ],
- "attention_mask_chosen": [
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- ],
- "input_ids_rejected": [
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- ],
- "attention_mask_rejected": [
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 0]),
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 1]),
- ],
- }
- # fmt: on
- dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
-
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+ training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none")
trainer = RewardTrainer(
- model=model,
+ model=self.model,
args=training_args,
- tokenizer=tokenizer,
+ tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
- eval_dataset=dummy_dataset,
peft_config=peft_config,
)
previous_trainable_params = {}
@@ -184,111 +136,68 @@ def test_reward_trainer_peft(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
+ self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
# check the non trainable params have not changed
for n, param in previous_non_trainable_params.items():
new_param = trainer.model.get_parameter(n)
- assert torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
-
- preds = trainer.predict(dummy_dataset)
- assert preds.predictions.shape == (4, 2)
-
- def test_reward_trainer_assert_value_error(self):
- model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- tokenizer.pad_token = tokenizer.eos_token
+ self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+ @require_peft
+ def test_train_lora_pretokenized(self):
+ peft_config = LoraConfig(
+ task_type=TaskType.SEQ_CLS,
+ inference_mode=False,
+ r=8,
+ lora_alpha=32,
+ lora_dropout=0.1,
+ )
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = RewardConfig(
- output_dir=tmp_dir,
- per_device_train_batch_size=2,
- max_steps=1,
- remove_unused_columns=False,
- report_to="none",
- )
-
- # fmt: off
- dummy_dataset_dict = {
- "input_ids_b": [
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- ],
- "attention_mask_c": [
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- ],
- "input_ids_f": [
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- ],
- "attention_mask_g": [
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 0]),
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 1]),
- ],
- }
- # fmt: on
- dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
-
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+ dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
+ dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
+ training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none")
trainer = RewardTrainer(
- model=model,
+ model=self.model,
args=training_args,
- tokenizer=tokenizer,
+ tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
+ peft_config=peft_config,
)
+ previous_trainable_params = {}
+ previous_non_trainable_params = {}
- with pytest.raises(ValueError):
- trainer.train()
+ # due to a change in the way the modules to save are dealt in PEFT.
+ trainable_params_name = ["lora", "modules_to_save"]
- training_args = RewardConfig(
- output_dir=tmp_dir,
- per_device_train_batch_size=2,
- max_steps=1,
- remove_unused_columns=True,
- report_to="none",
- )
+ # check gradients are not None
+ for n, param in trainer.model.named_parameters():
+ if any(t in n for t in trainable_params_name):
+ previous_trainable_params[n] = param.clone()
+ else:
+ previous_non_trainable_params[n] = param.clone()
- with self.assertWarns(UserWarning):
- trainer = RewardTrainer(
- model=model,
- args=training_args,
- tokenizer=tokenizer,
- train_dataset=dummy_dataset,
- )
+ trainer.train()
- def test_reward_trainer_margin(self):
- model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- tokenizer.pad_token = tokenizer.eos_token
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = RewardConfig(
- output_dir=tmp_dir,
- per_device_train_batch_size=2,
- max_steps=3,
- remove_unused_columns=False,
- gradient_accumulation_steps=4,
- learning_rate=9e-1,
- eval_strategy="steps",
- report_to="none",
- )
+ # check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+
+ # check the non trainable params have not changed
+ for n, param in previous_non_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
- # fmt: off
+ def test_margin(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
dummy_dataset_dict = {
"input_ids_chosen": [
torch.LongTensor([0, 1, 2]),
@@ -304,17 +213,12 @@ def test_reward_trainer_margin(self):
],
"margin": [
torch.FloatTensor([1.0]),
- ]
+ ],
}
- # fmt: on
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
-
+ training_args = RewardConfig(output_dir=tmp_dir, report_to="none")
trainer = RewardTrainer(
- model=model,
- args=training_args,
- tokenizer=tokenizer,
- train_dataset=dummy_dataset,
- eval_dataset=dummy_dataset,
+ model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset
)
batch = [dummy_dataset[0]]
@@ -326,62 +230,13 @@ def test_reward_trainer_margin(self):
outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"]
).mean()
- assert abs(loss - l_val) < 1e-6
-
- def test_reward_trainer_tags(self):
- model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- tokenizer.pad_token = tokenizer.eos_token
+ self.assertLess(abs(loss - l_val), 1e-6)
+ def test_tags(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = RewardConfig(
- output_dir=tmp_dir,
- per_device_train_batch_size=2,
- max_steps=3,
- remove_unused_columns=False,
- gradient_accumulation_steps=4,
- learning_rate=9e-1,
- eval_strategy="steps",
- report_to="none",
- )
-
- # fmt: off
- dummy_dataset_dict = {
- "input_ids_chosen": [
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- torch.LongTensor([0, 1, 2]),
- torch.LongTensor([1, 2]),
- ],
- "attention_mask_chosen": [
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- torch.LongTensor([1, 1, 1]),
- torch.LongTensor([1, 0]),
- ],
- "input_ids_rejected": [
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- torch.LongTensor([0, 2]),
- torch.LongTensor([1, 2, 0]),
- ],
- "attention_mask_rejected": [
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 0]),
- torch.LongTensor([1, 1]),
- torch.LongTensor([1, 1, 1]),
- ],
- }
- # fmt: on
- dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
-
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+ training_args = RewardConfig(output_dir=tmp_dir, report_to="none")
trainer = RewardTrainer(
- model=model,
- args=training_args,
- tokenizer=tokenizer,
- train_dataset=dummy_dataset,
- eval_dataset=dummy_dataset,
+ model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset
)
-
- assert trainer.model.model_tags == trainer._tag_names
+ self.assertEqual(trainer.model.model_tags, trainer._tag_names)
diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py
index 8eaa0bdcba..6e3eeab372 100644
--- a/trl/trainer/reward_config.py
+++ b/trl/trainer/reward_config.py
@@ -36,8 +36,12 @@ class RewardConfig(TrainingArguments):
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
+ Whether or not to remove the columns that are not used by the model's forward pass. Can be `True` only if
+ the dataset is pretokenized.
"""
max_length: Optional[int] = None
dataset_num_proc: Optional[int] = None
center_rewards_coefficient: Optional[float] = None
+ remove_unused_columns: bool = False
diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py
index 68f0f177d8..9dd8472235 100644
--- a/trl/trainer/reward_trainer.py
+++ b/trl/trainer/reward_trainer.py
@@ -21,6 +21,7 @@
import pandas as pd
import torch
import torch.nn as nn
+from accelerate import PartialState
from accelerate.utils import gather_object
from datasets import Dataset
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
@@ -29,6 +30,7 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
+from ..data_utils import maybe_apply_chat_template
from .reward_config import RewardConfig
from .utils import (
RewardDataCollatorWithPadding,
@@ -43,26 +45,26 @@
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
-class RewardTrainer(Trainer):
- r"""
- The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the
- `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use
- an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset
- of paired examples, where each example is a tuple of two sequences. The reward model should be trained to
- predict which example in the pair is more relevant to the task at hand.
-
- The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least
- if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named
- - `input_ids_chosen`
- - `attention_mask_chosen`
- - `input_ids_rejected`
- - `attention_mask_rejected`
-
- Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the
- loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/.
- If you don't pass a margin, no margin will be used.
- """
+def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") -> Dict[str, List[Any]]:
+ """Tokenize a batch from a reward modelling dataset."""
+ new_examples = {
+ "input_ids_chosen": [],
+ "attention_mask_chosen": [],
+ "input_ids_rejected": [],
+ "attention_mask_rejected": [],
+ }
+ for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
+ tokenized_chosen = tokenizer(chosen)
+ tokenized_rejected = tokenizer(rejected)
+ new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
+ new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
+ new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
+ new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
+
+ return new_examples
+
+class RewardTrainer(Trainer):
_tag_names = ["trl", "reward-trainer"]
def __init__(
@@ -111,8 +113,6 @@ def __init__(
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
- max_length (`int`, defaults to `None`):
- The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
@@ -205,6 +205,41 @@ def __init__(
self.use_reward_data_collator = True
else:
self.use_reward_data_collator = False
+
+ if "input_ids_chosen" not in train_dataset.column_names:
+ with PartialState().local_main_process_first():
+ fn_kwargs = {"tokenizer": tokenizer}
+ train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
+ train_dataset = train_dataset.map(
+ _tokenize,
+ batched=True,
+ fn_kwargs=fn_kwargs,
+ num_proc=args.dataset_num_proc,
+ )
+ # This filter is important because otherwise you get samples that exceed the model's context length and
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
+ # user might get surprised if N samples are missing from training.
+ train_dataset = train_dataset.filter(
+ lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
+ num_proc=args.dataset_num_proc,
+ )
+ if eval_dataset is not None:
+ eval_dataset = eval_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
+ eval_dataset = eval_dataset.map(
+ _tokenize,
+ fn_kwargs=fn_kwargs,
+ batched=True,
+ num_proc=args.dataset_num_proc,
+ )
+ # This filter is important because otherwise you get samples that exceed the model's context length and
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
+ # user might get surprised if N samples are missing from training.
+ eval_dataset = eval_dataset.filter(
+ lambda x: len(x["input_ids_chosen"]) <= max_length
+ and len(x["input_ids_rejected"]) <= max_length,
+ num_proc=args.dataset_num_proc,
+ )
+
super().__init__(
model=model,
args=args,