diff --git a/README.md b/README.md index 3ac69fafff..04af6a195e 100644 --- a/README.md +++ b/README.md @@ -103,15 +103,12 @@ from datasets import load_dataset dataset = load_dataset("trl-lib/Capybara", split="train") -# configure trainer training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT") trainer = SFTTrainer( args=training_args, model="Qwen/Qwen2.5-0.5B", train_dataset=dataset, ) - -# train trainer.train() ``` @@ -121,7 +118,6 @@ Here is a basic example on how to use the `RewardTrainer`: ```python from trl import RewardConfig, RewardTrainer -from trl.extras.dataset_formatting import conversations_formatting_function from datasets import load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer @@ -131,48 +127,15 @@ model = AutoModelForSequenceClassification.from_pretrained( ) model.config.pad_token_id = tokenizer.pad_token_id -dataset = load_dataset("trl-lib/Capybara-Preferences", split="train") - -def preprocess_function(examples): - new_examples = { - "input_ids_chosen": [], - "attention_mask_chosen": [], - "input_ids_rejected": [], - "attention_mask_rejected": [], - } - for chosen, rejected in zip(examples["chosen"], examples["rejected"]): - tokenized_chosen = tokenizer(chosen) - tokenized_rejected = tokenizer(rejected) - new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) - new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) - new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) - new_examples["attention_mask_rejected"].append( - tokenized_rejected["attention_mask"] - ) - - return new_examples - -chosen_fn = conversations_formatting_function(tokenizer, "chosen") -rejected_fn = conversations_formatting_function(tokenizer, "rejected") -dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}) -dataset = dataset.map( - preprocess_function, - batched=True, -) +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -training_args = RewardConfig( - per_device_train_batch_size=2, - remove_unused_columns=False, - output_dir="Qwen2.5-0.5B-Reward", -) +training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2) trainer = RewardTrainer( args=training_args, model=model, tokenizer=tokenizer, train_dataset=dataset, ) - -# train trainer.train() ``` @@ -210,7 +173,6 @@ trainer = RLOOTrainer( train_dataset=dataset["train"], eval_dataset=dataset["test"], ) -# train trainer.train() ``` @@ -219,21 +181,17 @@ trainer.train() `DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the `DPOTrainer`: ```python -# imports from trl import DPOConfig, DPOTrainer, maybe_extract_prompt, maybe_apply_chat_template from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -# load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -# load preference dataset - needs to be in a specific format dataset = load_dataset("trl-lib/Capybara-Preferences", split="train") dataset = dataset.map(maybe_extract_prompt) dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) -# load trainer training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") trainer = DPOTrainer( args=training_args, @@ -241,8 +199,6 @@ trainer = DPOTrainer( tokenizer=tokenizer, train_dataset=dataset, ) - -# train trainer.train() ``` diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.mdx index 5a73217ead..42bd487bac 100644 --- a/docs/source/reward_trainer.mdx +++ b/docs/source/reward_trainer.mdx @@ -6,18 +6,10 @@ Check out a complete flexible example at [`examples/scripts/reward_modeling.py`] ## Expected dataset format -The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: +The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `chosen` and `rejected` (and not `prompt`). +The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. -
- -
- -Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named: - -- `input_ids_chosen` -- `attention_mask_chosen` -- `input_ids_rejected` -- `attention_mask_rejected` +You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`. ## Using the `RewardTrainer` diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index bbb2e23459..44b7c00729 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -19,8 +19,6 @@ --output_dir Qwen2-0.5B-Reward \ --per_device_train_batch_size 8 \ --num_train_epochs 1 \ - --gradient_accumulation_steps 1 \ - --remove_unused_columns False \ --gradient_checkpointing True \ --learning_rate 1.0e-5 \ --logging_steps 25 \ @@ -32,17 +30,15 @@ python examples/scripts/reward_modeling.py \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --dataset_name trl-lib/ultrafeedback_binarized \ - --output_dir Qwen2-0.5B-Reward \ + --output_dir Qwen2-0.5B-Reward-LoRA \ --per_device_train_batch_size 8 \ --num_train_epochs 1 \ - --gradient_accumulation_steps 1 \ - --remove_unused_columns False \ --gradient_checkpointing True \ - --learning_rate 1.0e-5 \ + --learning_rate 1.0e-4 \ --logging_steps 25 \ --eval_strategy steps \ --eval_steps 50 \ - --max_length 2048 / + --max_length 2048 \ --use_peft \ --lora_r 32 \ --lora_alpha 16 @@ -51,9 +47,7 @@ import warnings import torch -from accelerate import PartialState from datasets import load_dataset -from tqdm import tqdm from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser from trl import ( @@ -66,10 +60,6 @@ setup_chat_format, ) from trl.commands.cli_utils import RewardScriptArguments -from trl.extras.dataset_formatting import conversations_formatting_function - - -tqdm.pandas() if __name__ == "__main__": @@ -90,6 +80,7 @@ revision=model_config.model_revision, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, + use_cache=False if training_args.gradient_checkpointing else True, ) tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True @@ -110,49 +101,11 @@ " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT." ) - ############################# - # Load and preprocess dataset - ############################# + ############## + # Load dataset + ############## dataset = load_dataset(args.dataset_name) - def preprocess_function(examples): - new_examples = { - "input_ids_chosen": [], - "attention_mask_chosen": [], - "input_ids_rejected": [], - "attention_mask_rejected": [], - } - for chosen, rejected in zip(examples["chosen"], examples["rejected"]): - tokenized_chosen = tokenizer(chosen) - tokenized_rejected = tokenizer(rejected) - new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) - new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) - new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) - new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) - - return new_examples - - with PartialState().local_main_process_first(): - # Wrap inputs with chat template. - # This assumes the chosen/rejected columns are in the OpenAI messages format. - chosen_fn = conversations_formatting_function(tokenizer, "chosen") - rejected_fn = conversations_formatting_function(tokenizer, "rejected") - dataset = dataset.map( - lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=training_args.dataset_num_proc - ) - # Tokenize inputs - dataset = dataset.map( - preprocess_function, - batched=True, - num_proc=training_args.dataset_num_proc, - ) - # Filter out examples that are too long - dataset = dataset.filter( - lambda x: len(x["input_ids_chosen"]) <= training_args.max_length - and len(x["input_ids_rejected"]) <= training_args.max_length, - num_proc=training_args.dataset_num_proc, - ) - ########## # Training ########## diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 9e37502ce7..2a9ce8cea8 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -14,103 +14,96 @@ import tempfile import unittest -import pytest import torch -from datasets import Dataset +from datasets import Dataset, load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available -from trl import RewardConfig, RewardTrainer +from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template from trl.trainer import compute_accuracy +from trl.trainer.reward_trainer import _tokenize + + +if is_peft_available(): + from peft import LoraConfig, TaskType class RewardTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) + def test_accuracy_metrics(self): dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0])) accuracy = compute_accuracy(dummy_eval_predictions) assert accuracy["accuracy"] == 0.5 - def test_reward_trainer(self): - model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" - model = AutoModelForSequenceClassification.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token - + def test_preprocessing_conversational(self): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = RewardConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - remove_unused_columns=False, - gradient_accumulation_steps=4, - learning_rate=9e-1, - eval_strategy="steps", - report_to="none", + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset ) + dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) + self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) - # fmt: off - dummy_dataset_dict = { - "input_ids_chosen": [ - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - ], - "attention_mask_chosen": [ - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - ], - "input_ids_rejected": [ - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - ], - "attention_mask_rejected": [ - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 0]), - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 1]), - ], - } - # fmt: on - dummy_dataset = Dataset.from_dict(dummy_dataset_dict) - + def test_preprocessing_standard(self): + # No chat template, so we load a fresh tokenizer + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") trainer = RewardTrainer( - model=model, - args=training_args, - tokenizer=tokenizer, - train_dataset=dummy_dataset, - eval_dataset=dummy_dataset, + model=self.model, args=training_args, tokenizer=tokenizer, train_dataset=dummy_dataset ) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer}) + self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) + def test_train_full(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset + ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - trainer.train() - assert trainer.state.log_history[(-1)]["train_loss"] is not None - + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) # check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) # check the params have changed - ignore 0 biases if param.sum() != 0: - assert not torch.equal(param, new_param) - - preds = trainer.predict(dummy_dataset) - assert preds.predictions.shape == (4, 2) + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) - @require_peft - def test_reward_trainer_peft(self): - from peft import LoraConfig, TaskType + def test_train_full_pretokenized(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() - model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" - model = AutoModelForSequenceClassification.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_peft + def test_train_lora(self): peft_config = LoraConfig( task_type=TaskType.SEQ_CLS, inference_mode=False, @@ -118,55 +111,14 @@ def test_reward_trainer_peft(self): lora_alpha=32, lora_dropout=0.1, ) - with tempfile.TemporaryDirectory() as tmp_dir: - training_args = RewardConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=6, - remove_unused_columns=False, - gradient_accumulation_steps=2, - learning_rate=9e-1, - eval_strategy="steps", - report_to="none", - ) - - # fmt: off - dummy_dataset_dict = { - "input_ids_chosen": [ - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - ], - "attention_mask_chosen": [ - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - ], - "input_ids_rejected": [ - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - ], - "attention_mask_rejected": [ - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 0]), - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 1]), - ], - } - # fmt: on - dummy_dataset = Dataset.from_dict(dummy_dataset_dict) - + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") trainer = RewardTrainer( - model=model, + model=self.model, args=training_args, - tokenizer=tokenizer, + tokenizer=self.tokenizer, train_dataset=dummy_dataset, - eval_dataset=dummy_dataset, peft_config=peft_config, ) previous_trainable_params = {} @@ -184,111 +136,68 @@ def test_reward_trainer_peft(self): trainer.train() - assert trainer.state.log_history[(-1)]["train_loss"] is not None + self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) # check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) # check the non trainable params have not changed for n, param in previous_non_trainable_params.items(): new_param = trainer.model.get_parameter(n) - assert torch.allclose(param, new_param, atol=1e-12, rtol=1e-12) - - preds = trainer.predict(dummy_dataset) - assert preds.predictions.shape == (4, 2) - - def test_reward_trainer_assert_value_error(self): - model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" - model = AutoModelForSequenceClassification.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + @require_peft + def test_train_lora_pretokenized(self): + peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = RewardConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=1, - remove_unused_columns=False, - report_to="none", - ) - - # fmt: off - dummy_dataset_dict = { - "input_ids_b": [ - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - ], - "attention_mask_c": [ - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - ], - "input_ids_f": [ - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - ], - "attention_mask_g": [ - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 0]), - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 1]), - ], - } - # fmt: on - dummy_dataset = Dataset.from_dict(dummy_dataset_dict) - + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") trainer = RewardTrainer( - model=model, + model=self.model, args=training_args, - tokenizer=tokenizer, + tokenizer=self.tokenizer, train_dataset=dummy_dataset, + peft_config=peft_config, ) + previous_trainable_params = {} + previous_non_trainable_params = {} - with pytest.raises(ValueError): - trainer.train() + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "modules_to_save"] - training_args = RewardConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=1, - remove_unused_columns=True, - report_to="none", - ) + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any(t in n for t in trainable_params_name): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() - with self.assertWarns(UserWarning): - trainer = RewardTrainer( - model=model, - args=training_args, - tokenizer=tokenizer, - train_dataset=dummy_dataset, - ) + trainer.train() - def test_reward_trainer_margin(self): - model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" - model = AutoModelForSequenceClassification.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token + self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) - with tempfile.TemporaryDirectory() as tmp_dir: - training_args = RewardConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - remove_unused_columns=False, - gradient_accumulation_steps=4, - learning_rate=9e-1, - eval_strategy="steps", - report_to="none", - ) + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + # check the non trainable params have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) - # fmt: off + def test_margin(self): + with tempfile.TemporaryDirectory() as tmp_dir: dummy_dataset_dict = { "input_ids_chosen": [ torch.LongTensor([0, 1, 2]), @@ -304,17 +213,12 @@ def test_reward_trainer_margin(self): ], "margin": [ torch.FloatTensor([1.0]), - ] + ], } - # fmt: on dummy_dataset = Dataset.from_dict(dummy_dataset_dict) - + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") trainer = RewardTrainer( - model=model, - args=training_args, - tokenizer=tokenizer, - train_dataset=dummy_dataset, - eval_dataset=dummy_dataset, + model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset ) batch = [dummy_dataset[0]] @@ -326,62 +230,13 @@ def test_reward_trainer_margin(self): outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"] ).mean() - assert abs(loss - l_val) < 1e-6 - - def test_reward_trainer_tags(self): - model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" - model = AutoModelForSequenceClassification.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token + self.assertLess(abs(loss - l_val), 1e-6) + def test_tags(self): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = RewardConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - remove_unused_columns=False, - gradient_accumulation_steps=4, - learning_rate=9e-1, - eval_strategy="steps", - report_to="none", - ) - - # fmt: off - dummy_dataset_dict = { - "input_ids_chosen": [ - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - torch.LongTensor([0, 1, 2]), - torch.LongTensor([1, 2]), - ], - "attention_mask_chosen": [ - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - torch.LongTensor([1, 1, 1]), - torch.LongTensor([1, 0]), - ], - "input_ids_rejected": [ - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - torch.LongTensor([0, 2]), - torch.LongTensor([1, 2, 0]), - ], - "attention_mask_rejected": [ - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 0]), - torch.LongTensor([1, 1]), - torch.LongTensor([1, 1, 1]), - ], - } - # fmt: on - dummy_dataset = Dataset.from_dict(dummy_dataset_dict) - + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") trainer = RewardTrainer( - model=model, - args=training_args, - tokenizer=tokenizer, - train_dataset=dummy_dataset, - eval_dataset=dummy_dataset, + model=self.model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset ) - - assert trainer.model.model_tags == trainer._tag_names + self.assertEqual(trainer.model.model_tags, trainer._tag_names) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 8eaa0bdcba..6e3eeab372 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -36,8 +36,12 @@ class RewardConfig(TrainingArguments): center_rewards_coefficient (`float`, *optional*, defaults to `None`): Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether or not to remove the columns that are not used by the model's forward pass. Can be `True` only if + the dataset is pretokenized. """ max_length: Optional[int] = None dataset_num_proc: Optional[int] = None center_rewards_coefficient: Optional[float] = None + remove_unused_columns: bool = False diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 68f0f177d8..9dd8472235 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -21,6 +21,7 @@ import pandas as pd import torch import torch.nn as nn +from accelerate import PartialState from accelerate.utils import gather_object from datasets import Dataset from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments @@ -29,6 +30,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available +from ..data_utils import maybe_apply_chat_template from .reward_config import RewardConfig from .utils import ( RewardDataCollatorWithPadding, @@ -43,26 +45,26 @@ from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training -class RewardTrainer(Trainer): - r""" - The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the - `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use - an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset - of paired examples, where each example is a tuple of two sequences. The reward model should be trained to - predict which example in the pair is more relevant to the task at hand. - - The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least - if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named - - `input_ids_chosen` - - `attention_mask_chosen` - - `input_ids_rejected` - - `attention_mask_rejected` - - Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the - loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. - If you don't pass a margin, no margin will be used. - """ +def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") -> Dict[str, List[Any]]: + """Tokenize a batch from a reward modelling dataset.""" + new_examples = { + "input_ids_chosen": [], + "attention_mask_chosen": [], + "input_ids_rejected": [], + "attention_mask_rejected": [], + } + for chosen, rejected in zip(batch["chosen"], batch["rejected"]): + tokenized_chosen = tokenizer(chosen) + tokenized_rejected = tokenizer(rejected) + new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) + new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) + new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) + new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) + + return new_examples + +class RewardTrainer(Trainer): _tag_names = ["trl", "reward-trainer"] def __init__( @@ -111,8 +113,6 @@ def __init__( The optimizer and scheduler to use for training. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): The function to use to preprocess the logits before computing the metrics. - max_length (`int`, defaults to `None`): - The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. peft_config (`Dict`, defaults to `None`): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. """ @@ -205,6 +205,41 @@ def __init__( self.use_reward_data_collator = True else: self.use_reward_data_collator = False + + if "input_ids_chosen" not in train_dataset.column_names: + with PartialState().local_main_process_first(): + fn_kwargs = {"tokenizer": tokenizer} + train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) + train_dataset = train_dataset.map( + _tokenize, + batched=True, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + ) + # This filter is important because otherwise you get samples that exceed the model's context length and + # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the + # user might get surprised if N samples are missing from training. + train_dataset = train_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length, + num_proc=args.dataset_num_proc, + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) + eval_dataset = eval_dataset.map( + _tokenize, + fn_kwargs=fn_kwargs, + batched=True, + num_proc=args.dataset_num_proc, + ) + # This filter is important because otherwise you get samples that exceed the model's context length and + # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the + # user might get surprised if N samples are missing from training. + eval_dataset = eval_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= max_length + and len(x["input_ids_rejected"]) <= max_length, + num_proc=args.dataset_num_proc, + ) + super().__init__( model=model, args=args,