Skip to content

Commit

Permalink
use SFTConfig instead of SFTTrainer keyword args
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Oct 14, 2024
1 parent 749b924 commit a0c6277
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 36 deletions.
9 changes: 1 addition & 8 deletions docs/source/accelerate/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Notice that we are using LoRA with rank=8, alpha=16 and targeting all linear la
Let's dive a little deeper into the script so you can see what's going on, and understand how it works.
The first thing to know is that the script uses DeepSpeed for distributed training as the DeepSpeed config has been passed. The `SFTTrainer` class handles all the heavy lifting of creating the PEFT model using the peft config that is passed. After that, when you call `trainer.train()`, `SFTTrainer` internally uses 🤗 Accelerate to prepare the model, optimizer and trainer using the DeepSpeed config to create DeepSpeed engine which is then trained. The main code snippet is below:
The first thing to know is that the script uses DeepSpeed for distributed training as the DeepSpeed config has been passed. The [`~trl.SFTTrainer`] class handles all the heavy lifting of creating the PEFT model using the peft config that is passed. After that, when you call `trainer.train()`, [`~trl.SFTTrainer`] internally uses 🤗 Accelerate to prepare the model, optimizer and trainer using the DeepSpeed config to create DeepSpeed engine which is then trained. The main code snippet is below:
```python
# trainer
Expand All @@ -139,13 +139,6 @@ trainer = SFTTrainer(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=data_args.packing,
dataset_kwargs={
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
},
dataset_text_field=data_args.dataset_text_field,
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
Expand Down
9 changes: 1 addition & 8 deletions docs/source/accelerate/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Notice that we are using LoRA with rank=8, alpha=16 and targeting all linear la

Let's dive a little deeper into the script so you can see what's going on, and understand how it works.

The first thing to know is that the script uses FSDP for distributed training as the FSDP config has been passed. The `SFTTrainer` class handles all the heavy lifting of creating PEFT model using the peft config that is passed. After that when you call `trainer.train()`, Trainer internally uses 🤗 Accelerate to prepare model, optimizer and trainer using the FSDP config to create FSDP wrapped model which is then trained. The main code snippet is below:
The first thing to know is that the script uses FSDP for distributed training as the FSDP config has been passed. The [`~trl.SFTTrainer`] class handles all the heavy lifting of creating PEFT model using the peft config that is passed. After that when you call `trainer.train()`, Trainer internally uses 🤗 Accelerate to prepare model, optimizer and trainer using the FSDP config to create FSDP wrapped model which is then trained. The main code snippet is below:

```python
# trainer
Expand All @@ -119,13 +119,6 @@ trainer = SFTTrainer(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=data_args.packing,
dataset_kwargs={
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
},
dataset_text_field=data_args.dataset_text_field,
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
if model_args.use_peft_lora:
Expand Down
5 changes: 2 additions & 3 deletions examples/olora_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
Expand All @@ -18,11 +18,10 @@ lora_config = LoraConfig(
init_lora_weights="olora"
)
peft_model = get_peft_model(model, lora_config)
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
tokenizer=tokenizer,
)
trainer.train()
Expand Down
7 changes: 3 additions & 4 deletions examples/pissa_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ PiSSA represents a matrix $W\in\mathbb{R}^{m\times n}$ within the model by the p
```python
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from transformers import AutoTokenizer, AutoModelForCausalLMfrom trl import SFTConfig, SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
Expand All @@ -23,11 +22,11 @@ peft_model.print_trainable_parameters()

dataset = load_dataset("imdb", split="train[:1%]")

training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=128,
tokenizer=tokenizer,
)
trainer.train()
Expand Down
6 changes: 2 additions & 4 deletions examples/pissa_finetuning/pissa_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer

from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

Expand Down Expand Up @@ -53,7 +53,7 @@ class TrainingArguments(TrainingArguments):
)


parser = HfArgumentParser(TrainingArguments)
parser = HfArgumentParser(SFTConfig)
script_args = parser.parse_args_into_dataclasses()[0]
print(script_args)

Expand Down Expand Up @@ -133,8 +133,6 @@ class TrainingArguments(TrainingArguments):
model=peft_model,
args=script_args,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
)
trainer.train()
Expand Down
16 changes: 7 additions & 9 deletions examples/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional

from transformers import HfArgumentParser, TrainingArguments, set_seed
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from utils import create_and_prepare_model, create_datasets


Expand Down Expand Up @@ -112,6 +112,11 @@ def main(model_args, data_args, training_args):
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}

training_args.dataset_kwargs = {
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
}

# datasets
train_dataset, eval_dataset = create_datasets(
tokenizer,
Expand All @@ -128,13 +133,6 @@ def main(model_args, data_args, training_args):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=data_args.packing,
dataset_kwargs={
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
},
dataset_text_field=data_args.dataset_text_field,
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
if hasattr(trainer.model, "print_trainable_parameters"):
Expand All @@ -153,7 +151,7 @@ def main(model_args, data_args, training_args):


if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down

0 comments on commit a0c6277

Please sign in to comment.