Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
leandro committed Nov 24, 2023
1 parent 3f3fd56 commit e7618ec
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_sft_trainer_uncorrect_data(self):
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
max_seq_length=1024, # make sure there is at least 1 packed sequence
max_seq_length=1024, # make sure there is NOT at least 1 packed sequence
packing=True,
)

Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn as nn
from datasets import Dataset
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Expand Down Expand Up @@ -404,7 +405,7 @@ def data_generator(constant_length_iterator):
packed_dataset = Dataset.from_generator(
data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator}
)
except SchemaInferenceError:
except (DatasetGenerationError, SchemaInferenceError):
raise ValueError(
"Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence."
)
Expand Down

0 comments on commit e7618ec

Please sign in to comment.