Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing the Iterative Trainer #737

Merged
merged 39 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7827003
initial skeleton
gaetanlop Sep 4, 2023
0442b6b
iterative trainer for decoder only
gaetanlop Sep 4, 2023
4255b0b
iterative trainer unittest
gaetanlop Sep 4, 2023
5648602
encoder_decoder support
gaetanlop Sep 4, 2023
faaae0a
fix typo in unittest
gaetanlop Sep 4, 2023
2836807
init
gaetanlop Sep 4, 2023
3063a2f
fix typo
gaetanlop Sep 4, 2023
23888e1
fix init typo
gaetanlop Sep 5, 2023
ddc573a
Merge branch 'main' into iterativetrainer
gaetanlop Sep 6, 2023
3705578
adding loggings and safety checker
gaetanlop Sep 15, 2023
aa842c3
fixed minor issues
gaetanlop Sep 15, 2023
f2673e7
doc
gaetanlop Sep 15, 2023
2de2388
table of contents update
gaetanlop Sep 15, 2023
b61cc9a
add test for seq2seq2 models
gaetanlop Sep 18, 2023
a8eee4b
change year
gaetanlop Sep 18, 2023
917c2eb
adding text as step input
gaetanlop Oct 8, 2023
bae3ca3
precommit
gaetanlop Oct 8, 2023
173a34b
fixing typo
gaetanlop Oct 8, 2023
f6e188e
run precommit
gaetanlop Oct 8, 2023
cfd835f
fixing typo in safety checker
gaetanlop Oct 8, 2023
ede0ccc
fix text tokenization issue
gaetanlop Oct 8, 2023
11c69e2
add truncate and inherit from trainer
gaetanlop Nov 1, 2023
57622f8
remove iterative config from tests
gaetanlop Nov 1, 2023
c21091b
remove iterative config from init
gaetanlop Nov 1, 2023
b4600fc
fix peft model
gaetanlop Nov 1, 2023
4121cbe
change truncation side based on truncation_mode
gaetanlop Nov 1, 2023
441b05c
removed iterativeconfig autodoc
gaetanlop Nov 1, 2023
3b28163
fixed typo in trainer.mdx
gaetanlop Nov 1, 2023
3907102
remove mention of iterative config in docs
gaetanlop Nov 1, 2023
e31c3a5
make sure optimizer and scheduler are created
gaetanlop Nov 1, 2023
c4cd798
adding max_steps to test
gaetanlop Nov 1, 2023
78d5879
remove log_stats fn
gaetanlop Nov 1, 2023
5b2da4e
remove compute loss
gaetanlop Nov 1, 2023
5de3ddf
fixing encoder decoder detection
gaetanlop Nov 1, 2023
4e68b76
fix PPODecorator
gaetanlop Nov 1, 2023
122494e
run precommit
gaetanlop Nov 1, 2023
71c28fd
fix testing
gaetanlop Nov 2, 2023
00a89ca
fix small typos in iterative trainer
gaetanlop Nov 2, 2023
aa70ca3
adapted function log and eval
gaetanlop Nov 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
title: DPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: text_environments
title: Text Environments
title: API
Expand Down
54 changes: 54 additions & 0 deletions docs/source/iterative_sft_trainer.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Iterative Trainer

Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.

## Usage

To get started quickly, instantiate an instance a model, and a tokenizer.

```python

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

trainer = IterativeSFTTrainer(
model,
tokenizer
)

```

You have the choice to either provide a list of strings or a list of tensors to the step function.

#### Using a list of tensors as input:

```python

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}

trainer.step(**inputs)

```

#### Using a list of strings as input:

```python

inputs = {
"texts": texts
}

trainer.step(**inputs)

```

For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.

## IterativeTrainer

[[autodoc]] IterativeSFTTrainer
4 changes: 4 additions & 0 deletions docs/source/trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ We also support a `RewardTrainer` that can be used to train a reward model.

[[autodoc]] DDPOTrainer

## IterativeSFTTrainer

[[autodoc]] IterativeSFTTrainer

## set_seed

[[autodoc]] set_seed
106 changes: 106 additions & 0 deletions tests/test_iterative_sft_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest

import torch
from datasets import Dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments

from trl import IterativeSFTTrainer


class IterativeTrainerTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token

# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

def _init_tensor_dummy_dataset(self):
dummy_dataset_dict = {
"input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
"attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])],
"labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
}

dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
dummy_dataset.set_format("torch")
return dummy_dataset

def _init_textual_dummy_dataset(self):
dummy_dataset_dict = {
"texts": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"],
"texts_labels": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"],
}

dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
dummy_dataset.set_format("torch")
return dummy_dataset

def setUp(self):
# initialize trainer
self.model.train()
return super().setUp()

@parameterized.expand(
[
["gpt2", "tensor"],
["gpt2", "text"],
["t5", "tensor"],
["t5", "text"],
]
)
def test_iterative_step_from_tensor(self, model_name, input_name):
with tempfile.TemporaryDirectory() as tmp_dir:
# initialize dataset
if input_name == "tensor":
dummy_dataset = self._init_tensor_dummy_dataset()
inputs = {
"input_ids": dummy_dataset["input_ids"],
"attention_mask": dummy_dataset["attention_mask"],
"labels": dummy_dataset["labels"],
}
else:
dummy_dataset = self._init_textual_dummy_dataset()
inputs = {
"texts": dummy_dataset["texts"],
"texts_labels": dummy_dataset["texts_labels"],
}

if model_name == "gpt2":
model = self.model
tokenizer = self.tokenizer
else:
model = self.t5_model
tokenizer = self.t5_tokenizer

args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
)
iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer)

iterative_trainer.step(**inputs)

for param in iterative_trainer.model.parameters():
assert param.grad is not None
1 change: 1 addition & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .trainer import (
DataCollatorForCompletionOnlyLM,
DPOTrainer,
IterativeSFTTrainer,
PPOConfig,
PPOTrainer,
RewardConfig,
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .ddpo_trainer import DDPOTrainer

from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .reward_trainer import RewardTrainer, compute_accuracy
Expand Down
Loading