diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index cfaace83c5..0fe8d1627a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -189,7 +189,7 @@ def test_sft_trainer_uncorrect_data(self): args=training_args, train_dataset=self.dummy_dataset, formatting_func=formatting_prompts_func, - max_seq_length=1024, # make sure there is at least 1 packed sequence + max_seq_length=1024, # make sure there is NOT at least 1 packed sequence packing=True, ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 2ccf9ddffb..d344cc1fcf 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -21,6 +21,7 @@ import torch.nn as nn from datasets import Dataset from datasets.arrow_writer import SchemaInferenceError +from datasets.builder import DatasetGenerationError from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -404,7 +405,7 @@ def data_generator(constant_length_iterator): packed_dataset = Dataset.from_generator( data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator} ) - except SchemaInferenceError: + except (DatasetGenerationError, SchemaInferenceError): raise ValueError( "Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence." )