Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔬 SFT simplification #2405

Merged
merged 75 commits into from
Feb 7, 2025
Merged
Changes from 1 commit
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
6c413ae
initial commit
qgallouedec Oct 2, 2024
8bc9c7b
Merge branch 'main' into sft-refactor
qgallouedec Nov 28, 2024
0db6f6a
update
qgallouedec Nov 30, 2024
95245cd
Refactor SFTTrainer and SFTConfig
qgallouedec Nov 30, 2024
61ddc84
Update SFTConfig class in sft_config.py
qgallouedec Dec 1, 2024
ddfaf47
Fix SFTConfig torch_dtype validation and dataset preprocessing flag
qgallouedec Dec 1, 2024
6810572
Refactor dataset mapping and conversion
qgallouedec Dec 1, 2024
ec8a2df
Refactor dataset mapping in SFTTrainer
qgallouedec Dec 1, 2024
a7ef4ab
Fix SFTTrainerTester unit test by removing unnecessary code
qgallouedec Dec 1, 2024
e951d6d
Remove unused variables and update tokenization logic
qgallouedec Dec 1, 2024
81f303b
Remove pack_dataset function
qgallouedec Dec 1, 2024
284a9b0
Merge branch 'main' into sft-refactor
qgallouedec Dec 1, 2024
21ddf4c
Add deprecation warning for tokenizer in SFTTrainer constructor
qgallouedec Dec 1, 2024
ef06dfe
add docstring back
qgallouedec Dec 1, 2024
1c41eb3
Update model parameter type annotation
qgallouedec Dec 1, 2024
ae15a92
Merge branch 'sft-refactor' of https://github.com/huggingface/trl int…
qgallouedec Dec 1, 2024
86fbbfc
Update SFTTrainer class definition [ci skip]
qgallouedec Dec 1, 2024
34a1234
style [ci skip]
qgallouedec Dec 1, 2024
73702f7
preprocess_dataset -> _prepare_dataset
qgallouedec Dec 1, 2024
c835031
Retro compat
qgallouedec Dec 1, 2024
39a5e6c
Update formatting_func type hint in SFTTrainer constructor
qgallouedec Dec 1, 2024
56ab848
typo [ci skip]
qgallouedec Dec 1, 2024
0b3a54b
better comment [skip ci]
qgallouedec Dec 1, 2024
6024e34
simplify tokenize row
qgallouedec Dec 1, 2024
a6b8fd5
Fix type hint for peft_config
qgallouedec Dec 1, 2024
770950a
fix doc [ci skip]
qgallouedec Dec 1, 2024
906f5a7
Add pack_examples function to `test_data_utils.py`
qgallouedec Dec 1, 2024
f89d2b2
promote pack_examples and document
qgallouedec Dec 1, 2024
a464c27
improve doc
qgallouedec Dec 1, 2024
d55c84b
Merge branch 'main' into sft-refactor
qgallouedec Dec 3, 2024
812827a
Add new SFTTrainerTester2 class for testing
qgallouedec Dec 4, 2024
c28823b
Merge branch 'main' into sft-refactor
qgallouedec Dec 5, 2024
5dc0f27
test was reversed [ci skip]
qgallouedec Dec 9, 2024
be36cc1
Merge branch 'sft-refactor' of https://github.com/huggingface/trl int…
qgallouedec Dec 9, 2024
49877a4
Merge branch 'main' into sft-refactor
qgallouedec Dec 9, 2024
97c3ec1
©️ Copyrights update (#2454)
qgallouedec Dec 10, 2024
0682ccd
💬 Fix chat for windows (#2443)
qgallouedec Dec 10, 2024
9cbca0c
🆔 Add `datast_config` to `ScriptArguments` (#2440)
qgallouedec Dec 10, 2024
f09971c
🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417)
qgallouedec Dec 10, 2024
59865e8
👯 Standardize `model_args` (#2442)
qgallouedec Dec 10, 2024
0b5ccf9
Merge branch 'main' into sft-refactor
qgallouedec Dec 10, 2024
0b5360a
Merge branch 'main' into sft-refactor
qgallouedec Dec 13, 2024
faffd0e
Merge branch 'main' into sft-refactor
qgallouedec Dec 21, 2024
f732aa8
refactor config
qgallouedec Dec 21, 2024
d2ee070
drop skip prepare dataset
qgallouedec Dec 21, 2024
dc84d08
add sep to packing
qgallouedec Dec 21, 2024
16ef195
drop prompt-completion for now
qgallouedec Dec 21, 2024
26c2a20
Revert "drop prompt-completion for now"
qgallouedec Jan 8, 2025
a4e186d
Revert "add sep to packing"
qgallouedec Jan 8, 2025
ce0320e
Revert "drop skip prepare dataset"
qgallouedec Jan 8, 2025
9515d83
Revert "refactor config"
qgallouedec Jan 8, 2025
f98a45a
Merge branch 'main' into sft-refactor
qgallouedec Jan 8, 2025
fea7094
Format
qgallouedec Jan 8, 2025
8414e1d
Merge branch 'main' into sft-refactor
qgallouedec Jan 19, 2025
b654dec
Update doc-builder workflow to use specific commit sha
qgallouedec Jan 20, 2025
0ddcf31
Merge branch 'main' into sft-refactor
qgallouedec Jan 20, 2025
d6f5188
Merge branch 'main' into sft-refactor
qgallouedec Feb 7, 2025
205f39d
Merge branch 'main' into sft-refactor
qgallouedec Feb 7, 2025
b4c9df8
add peft edge cases
kashif Feb 7, 2025
fab6ba9
no logits when using liger
kashif Feb 7, 2025
af73d13
remove unused columns
kashif Feb 7, 2025
85f1c5a
proper handle of prompt-completion
qgallouedec Feb 7, 2025
4e5363e
trick to keep messages
qgallouedec Feb 7, 2025
e6abc24
fix messages missing
qgallouedec Feb 7, 2025
da3d7a6
for Liger kernel, ensure only input_ids is present
kashif Feb 7, 2025
7781ef2
packing and liger are compatible
kashif Feb 7, 2025
7e47b12
shinny doc and final nits
qgallouedec Feb 7, 2025
2f2dba1
Merge branch 'sft-refactor' of https://github.com/huggingface/trl int…
qgallouedec Feb 7, 2025
813977e
another nit
qgallouedec Feb 7, 2025
3c430f8
refactor config and doc
qgallouedec Feb 7, 2025
efba9be
re add truncation
qgallouedec Feb 7, 2025
cf10d1a
fix ci
qgallouedec Feb 7, 2025
b968003
drop deprecated params in tests
qgallouedec Feb 7, 2025
07f6a1c
fix link [ci skip]
qgallouedec Feb 7, 2025
84e01e3
fix config docstring [ci skip]
qgallouedec Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
preprocess_dataset -> _prepare_dataset
qgallouedec committed Dec 1, 2024
commit 73702f7d441c2c88a78fd2b49d8292ce175431a8
13 changes: 6 additions & 7 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,6 @@
import warnings
from typing import Any, Callable, Optional, Type, Union

import datasets
import torch
import torch.nn as nn
from accelerate import PartialState
@@ -107,8 +106,8 @@ def __init__(
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[Union[SFTConfig, TrainingArguments]] = None,
data_collator: Optional[DataCollator] = None, # type: ignore
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
@@ -182,18 +181,18 @@ def __init__(
# 4. Handle the dataset
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
train_dataset = self.preprocess_dataset(
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self.preprocess_dataset(dataset, processing_class, args, packing, formatting_func, key)
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self.preprocess_dataset(
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)

@@ -221,7 +220,7 @@ def __init__(
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)

def preprocess_dataset(
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],