Skip to content

Commit

Permalink
[SFT Trainer] precompute packed iterable into a dataset (#979)
Browse files Browse the repository at this point in the history
* precompute packed iterable into a dataset

* add generator function

* fix typo

* fix style

* fix test

* fix style

* add test

* minor refactor

* fix test

* Apply suggestions from code review

Co-authored-by: lewtun <[email protected]>

* style

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
  • Loading branch information
4 people authored Dec 4, 2023
1 parent 4cdc03a commit f06f357
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 25 deletions.
14 changes: 13 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,21 @@ def test_sft_trainer_uncorrect_data(self):
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
max_seq_length=32, # make sure there is at least 1 packed sequence
packing=True,
)

with self.assertRaises(ValueError):
# This should not work because not enough data for one sample
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
max_seq_length=1024, # make sure there is NOT at least 1 packed sequence
packing=True,
)

# This should not work as well
with self.assertRaises(ValueError):
_ = SFTTrainer(
Expand All @@ -191,7 +203,7 @@ def test_sft_trainer_uncorrect_data(self):
packing=False,
)

# but this shpuld work
# but this should work
_ = SFTTrainer(
model=self.model,
args=training_args,
Expand Down
85 changes: 61 additions & 24 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
import torch.nn as nn
from datasets import Dataset
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Expand Down Expand Up @@ -128,7 +130,7 @@ def __init__(
packing: Optional[bool] = False,
formatting_func: Optional[Callable] = None,
max_seq_length: Optional[int] = None,
infinite: Optional[bool] = False,
infinite: Optional[bool] = None,
num_of_sequences: Optional[int] = 1024,
chars_per_token: Optional[float] = 3.6,
dataset_num_proc: Optional[int] = None,
Expand All @@ -141,6 +143,11 @@ def __init__(
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.")

if infinite is not None:
warnings.warn(
"The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length."
)

if isinstance(model, str):
warnings.warn(
"You passed a model_id to the SFTTrainer. This will automatically create an "
Expand Down Expand Up @@ -236,7 +243,6 @@ def make_inputs_require_grad(module, input, output):
dataset_text_field,
max_seq_length,
formatting_func,
infinite,
num_of_sequences,
chars_per_token,
)
Expand All @@ -248,7 +254,6 @@ def make_inputs_require_grad(module, input, output):
dataset_text_field,
max_seq_length,
formatting_func,
infinite,
num_of_sequences,
chars_per_token,
)
Expand Down Expand Up @@ -311,7 +316,6 @@ def _prepare_dataset(
dataset_text_field,
max_seq_length,
formatting_func,
infinite,
num_of_sequences,
chars_per_token,
):
Expand All @@ -327,30 +331,19 @@ def _prepare_dataset(
tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func
)

if dataset_text_field is not None or formatting_func is not None:
if tokenizer is None:
raise ValueError(
"You need to pass a tokenizer when using the SFT Trainer when passing a `dataset_text_field`."
)

return ConstantLengthDataset(
else:
return self._prepare_packed_dataloader(
tokenizer,
dataset,
dataset_text_field=dataset_text_field,
formatting_func=formatting_func,
seq_length=max_seq_length,
infinite=infinite,
num_of_sequences=num_of_sequences,
chars_per_token=chars_per_token,
eos_token_id=tokenizer.eos_token_id,
dataset_text_field,
max_seq_length,
num_of_sequences,
chars_per_token,
formatting_func,
)

raise ValueError(
"You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`."
)

def _prepare_non_packed_dataloader(
self, tokenizer, dataset, dataset_text_field, max_seq_len, formatting_func=None
self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False
Expand All @@ -361,7 +354,7 @@ def tokenize(element):
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
truncation=True,
padding=False,
max_length=max_seq_len,
max_length=max_seq_length,
return_overflowing_tokens=False,
return_length=False,
)
Expand All @@ -386,6 +379,50 @@ def tokenize(element):

return tokenized_dataset

def _prepare_packed_dataloader(
self,
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
num_of_sequences,
chars_per_token,
formatting_func=None,
):
if dataset_text_field is not None or formatting_func is not None:
if tokenizer is None:
raise ValueError("You need to pass a tokenizer when using `dataset_text_field` with `SFTTrainer`.")

constant_length_iterator = ConstantLengthDataset(
tokenizer,
dataset,
dataset_text_field=dataset_text_field,
formatting_func=formatting_func,
seq_length=max_seq_length,
infinite=False,
num_of_sequences=num_of_sequences,
chars_per_token=chars_per_token,
eos_token_id=tokenizer.eos_token_id,
)

def data_generator(constant_length_iterator):
for i in constant_length_iterator:
yield i

try:
packed_dataset = Dataset.from_generator(
data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator}
)
except (DatasetGenerationError, SchemaInferenceError):
raise ValueError(
"Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence."
)
return packed_dataset
else:
raise ValueError(
"You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`."
)

def _trl_activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
Expand Down

0 comments on commit f06f357

Please sign in to comment.