diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 11795be496..3115fab292 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -29,6 +29,8 @@ title: DPO Trainer - local: ddpo_trainer title: Denoising Diffusion Policy Optimization + - local: iterative_sft_trainer + title: Iterative Supervised Fine-Tuning - local: text_environments title: Text Environments title: API diff --git a/docs/source/iterative_sft_trainer.mdx b/docs/source/iterative_sft_trainer.mdx new file mode 100644 index 0000000000..a6eaf5c98f --- /dev/null +++ b/docs/source/iterative_sft_trainer.mdx @@ -0,0 +1,54 @@ +# Iterative Trainer + +Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code. + +## Usage + +To get started quickly, instantiate an instance a model, and a tokenizer. + +```python + +model = AutoModelForCausalLM.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +trainer = IterativeSFTTrainer( + model, + tokenizer +) + +``` + +You have the choice to either provide a list of strings or a list of tensors to the step function. + +#### Using a list of tensors as input: + +```python + +inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask +} + +trainer.step(**inputs) + +``` + +#### Using a list of strings as input: + +```python + +inputs = { + "texts": texts +} + +trainer.step(**inputs) + +``` + +For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels. + +## IterativeTrainer + +[[autodoc]] IterativeSFTTrainer diff --git a/docs/source/trainer.mdx b/docs/source/trainer.mdx index bec27c7970..0d2550a6b1 100644 --- a/docs/source/trainer.mdx +++ b/docs/source/trainer.mdx @@ -36,6 +36,10 @@ We also support a `RewardTrainer` that can be used to train a reward model. [[autodoc]] DDPOTrainer +## IterativeSFTTrainer + +[[autodoc]] IterativeSFTTrainer + ## set_seed [[autodoc]] set_seed diff --git a/tests/test_iterative_sft_trainer.py b/tests/test_iterative_sft_trainer.py new file mode 100644 index 0000000000..70d5640795 --- /dev/null +++ b/tests/test_iterative_sft_trainer.py @@ -0,0 +1,106 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments + +from trl import IterativeSFTTrainer + + +class IterativeTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def _init_tensor_dummy_dataset(self): + dummy_dataset_dict = { + "input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])], + "attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])], + "labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + def _init_textual_dummy_dataset(self): + dummy_dataset_dict = { + "texts": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + "texts_labels": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + def setUp(self): + # initialize trainer + self.model.train() + return super().setUp() + + @parameterized.expand( + [ + ["gpt2", "tensor"], + ["gpt2", "text"], + ["t5", "tensor"], + ["t5", "text"], + ] + ) + def test_iterative_step_from_tensor(self, model_name, input_name): + with tempfile.TemporaryDirectory() as tmp_dir: + # initialize dataset + if input_name == "tensor": + dummy_dataset = self._init_tensor_dummy_dataset() + inputs = { + "input_ids": dummy_dataset["input_ids"], + "attention_mask": dummy_dataset["attention_mask"], + "labels": dummy_dataset["labels"], + } + else: + dummy_dataset = self._init_textual_dummy_dataset() + inputs = { + "texts": dummy_dataset["texts"], + "texts_labels": dummy_dataset["texts_labels"], + } + + if model_name == "gpt2": + model = self.model + tokenizer = self.tokenizer + else: + model = self.t5_model + tokenizer = self.t5_tokenizer + + args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=2, + ) + iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer) + + iterative_trainer.step(**inputs) + + for param in iterative_trainer.model.parameters(): + assert param.grad is not None diff --git a/trl/__init__.py b/trl/__init__.py index 87299c74fd..d28e29e9c8 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -15,6 +15,7 @@ from .trainer import ( DataCollatorForCompletionOnlyLM, DPOTrainer, + IterativeSFTTrainer, PPOConfig, PPOTrainer, RewardConfig, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index b15b7c0ada..e81705fbc2 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -36,6 +36,7 @@ from .ddpo_trainer import DDPOTrainer from .dpo_trainer import DPOTrainer +from .iterative_sft_trainer import IterativeSFTTrainer from .ppo_config import PPOConfig from .ppo_trainer import PPOTrainer from .reward_trainer import RewardTrainer, compute_accuracy diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py new file mode 100644 index 0000000000..006b02ad51 --- /dev/null +++ b/trl/trainer/iterative_sft_trainer.py @@ -0,0 +1,367 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_utils import EvalLoopOutput + +from ..core import PPODecorators +from ..import_utils import is_peft_available + + +if is_peft_available(): + from peft import PeftModel + + +class IterativeSFTTrainer(Trainer): + """ + The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. + + Attributes: + **model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'. + Check the documentation of `PreTrainedModel` for more details. + **args** (`transformers.TrainingArguments`): -- The arguments to use for training. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training. + **data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and + passed along the dataloader. + **eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation. + **max_length** (`int`, defaults to `None`): -- The maximum length of the input. + **truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`. + **preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics. + **compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. + **optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training. + """ + + def __init__( + self, + model: PreTrainedModel = None, + args: TrainingArguments = None, + tokenizer: PreTrainedTokenizerBase = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + data_collator: Optional[DataCollator] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + max_length: Optional[int] = None, + truncation_mode: Optional[str] = "keep_end", + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + optimize_device_cache: Optional[bool] = False, + ): + # Step 0: check positional arguments validity + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, PreTrainedModel): + raise ValueError(f"model must be a PreTrainedModel, got {type(model)}") + if not model.can_generate(): + warnings.warn( + f"The current model class {type(model)} is not compatible with `.generate()`" + "Please make sure that this is intended." + ) + if optimizers[1] is None and args.max_steps == -1: + raise ValueError( + "When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`" + ) + + self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + self.tokenizer = tokenizer + + if data_collator is None: + if self.is_encoder_decoder: + warnings.warn( + "No data collator is provided. Using 'DataCollatorForSeq2Seq' with" + "'labels_pad_token_id' set to '-100' and 'pad_to_multiple_of' set to 8." + ) + self.data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100, pad_to_multiple_of=8) + else: + warnings.warn("No data collator is provided. Using 'DataCollatorForLanguageModeling'") + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + else: + self.data_collator = data_collator + + self.max_length = max_length + self.truncation_mode = truncation_mode + self.optimize_device_cache = optimize_device_cache + + super().__init__( + model=model, + args=args, + data_collator=self.data_collator, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self.create_optimizer_and_scheduler(self.args.max_steps) + + # prepare model, optimizer and lr_scheduler + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + self.tokenizer.truncation_side = "left" if self.truncation_mode == "keep_end" else "right" + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + PPODecorators.optimize_device_cache = self.optimize_device_cache + + def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): + if attention_mask is None: + attention_mask = [torch.ones_like(ids) for ids in input_ids] + + if self.is_encoder_decoder: + input_data = self.data_collator( + [ + {"input_ids": ids, "attention_mask": att, "labels": lab} + for ids, att, lab in zip(input_ids, attention_mask, labels) + ] + ).to(self.model.device) + + input_data.pop("decoder_input_ids", None) # This is directly computed inside the model + + input_data["labels"][input_data["labels"] == self.tokenizer.pad_token_id] = -100 + + else: + input_data = self.data_collator( + [{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)] + ).to(self.model.device) + + # truncate in case the user has provided input_ids, attention_mask and labels + if self.max_length is not None: + if self.truncation_mode == "keep_start": + input_data = {k: v[: self.max_length] for k, v in input_data.items()} + elif self.truncation_mode == "keep_end": + input_data = {k: v[-self.max_length :] for k, v in input_data.items()} + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + return input_data + + @staticmethod + def _step_safety_checker( + input_ids: List[torch.LongTensor], + attention_mask: List[torch.LongTensor], + labels: List[torch.LongTensor], + texts: List[str], + texts_labels: List[str], + ): + """ + Check if the input data is valid for training. + + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids + attention_mask (List[`torch.LongTensor`]): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`]): + List of tensors containing the labels + texts (List[`str`]): + List of string containing the text input. + texts_labels (List[`str`]): + List of string containing the text labels. + Returns: + `tuple`: The input data. + """ + if texts is None: + if attention_mask is None: + for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + for name, tensor_list in zip( + ["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels] + ): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + if not isinstance(texts, list): + raise ValueError(f"'text' must be a list of strings - got {type(texts)}") + if not isinstance(texts[0], str): + raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}") + if texts_labels is not None: + if not isinstance(texts_labels, list): + raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}") + if not isinstance(texts_labels[0], str): + raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}") + + return input_ids, attention_mask, labels, texts, texts_labels + + @PPODecorators.empty_device_cache() + def step( + self, + input_ids: Optional[List[torch.LongTensor]] = None, + attention_mask: Optional[List[torch.LongTensor]] = None, + labels: Optional[List[torch.LongTensor]] = None, + texts: Optional[List[str]] = None, + texts_labels: Optional[List[str]] = None, + ): + """ + Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels. + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids (if not provided, text will be used) + attention_mask (List[`torch.LongTensor`], , *optional*): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`], *optional*): + List of tensors containing the labels (if set to None, will default to input_ids) + texts (List[`str`], *optional*): + List of strings containing the text input (if not provided, input_ids will directly be used) + texts_labels (List[`str`], *optional*): + List of strings containing the text labels (if set to None, will default to text) + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + self.model.train() + + if self.state.global_step == 0: + self.tr_loss = torch.tensor(0.0).to(self.args.device) + self._globalstep_last_logged = self.state.global_step + + if input_ids is None and texts is None: + raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.") + elif input_ids is not None and texts is not None: + warnings.warn( + "Both 'input_ids' and 'texts' are provided. 'input_ids' will be overwritten using inputs provided by the 'texts' keyword argument." + ) + + if labels is None and texts_labels is None and self.is_encoder_decoder: + raise ValueError( + "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed." + ) + + input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker( + input_ids, attention_mask, labels, texts, texts_labels + ) + + if texts is not None: + model_inputs = self.tokenizer( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + + input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"] + + if texts_labels is not None: + labels = self.tokenizer( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + )["input_ids"] + + if labels is None: + warnings.warn("No labels are provided. Setting labels to input_ids") + labels = input_ids + + model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels) + + model_inputs_names = list(model_inputs.keys()) + + batch_dict = {} + batch_dict.update(model_inputs) + + def collator(data): + return_dict = dict() + for key in data[0]: + if key in ["input_ids", "attention_mask", "labels"]: + return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device) + return return_dict + + batch_data = Dataset.from_dict(batch_dict) + batch_data.set_format("torch") + + step_dataloader = DataLoader( + batch_data, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=collator, + ) + + for _, batch in enumerate(step_dataloader): + with self.accelerator.accumulate(self.model): + model_inputs = {k: batch[k] for k in model_inputs_names} + loss = self.compute_loss(self.model, model_inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() + + tr_loss_step = loss.detach() + + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients and self.args.max_grad_norm is not None: + self.accelerator.clip_grad_norm_( + self.model.parameters(), + self.args.max_grad_norm, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + self.state.global_step += 1 + + # update stats etc + self.tr_loss += tr_loss_step + + self._maybe_log_save_evaluate() + + def _maybe_log_save_evaluate(self): + # check if eval is required + if self.args.eval_steps is not None: + if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0: + self.evaluate(self.eval_dataset) + + # check if logging is required + if self.args.logging_steps is not None: + if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0: + logs: Dict[str, float] = {} + + tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item() + + # reset tr_loss to zero + self.tr_loss -= self.tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._globalstep_last_logged = self.state.global_step + + self.log(logs)