Skip to content

Commit

Permalink
fix stackllama2 sft gradient checkpointing (#906)
Browse files Browse the repository at this point in the history
* fix stackllama2 sft gradient checkpointing

* stackllama2 sft use tyro as arg parser
  • Loading branch information
nrailg authored Oct 25, 2023
1 parent 7de7db6 commit 02f5c1d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 58 deletions.
4 changes: 2 additions & 2 deletions examples/research_projects/stack_llama_2/scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ $ accelerate config

There were two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
- `accelerate launch examples/stack_llama_2/scripts/sft_llama2.py --output_dir="sft"`
- `accelerate launch examples/stack_llama_2/scripts/sft_llama2.py --training_args.output_dir="sft"`
1. Run the DPO trainer using the model saved by the previous step:
- `accelerate launch examples/stack_llama_2/scripts/dpo_llama2.py --model_name_or_path="sft/final_checkpoint" --output_dir="dpo"`

Expand Down Expand Up @@ -48,4 +48,4 @@ model = AutoPeftModelForCausalLM.from_pretrained(
)

model.generate(...)
```
```
102 changes: 46 additions & 56 deletions examples/research_projects/stack_llama_2/scripts/sft_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from typing import Optional

import torch
import tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments

from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset
Expand All @@ -17,7 +18,6 @@
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})

dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
Expand All @@ -28,38 +28,53 @@ class ScriptArguments:
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})

max_steps: Optional[int] = field(default=500, metadata={"help": "the maximum number of sgd steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=10, metadata={"help": "the saving frequency"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the per device train batch size"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "the per device eval batch size"})
gradient_accumulation_steps: Optional[int] = field(default=2, metadata={"help": "the gradient accumulation steps"})
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "whether to use gradient checkpointing"}
training_args: TrainingArguments = field(
default_factory=lambda: TrainingArguments(
output_dir="./results",
max_steps=500,
logging_steps=10,
save_steps=10,
per_device_train_batch_size=4,
per_device_eval_batch_size=1,
gradient_accumulation_steps=2,
gradient_checkpointing=False,
group_by_length=False,
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.05,
optim="paged_adamw_32bit",
bf16=True,
remove_unused_columns=False,
run_name="sft_llama2",
report_to="wandb",
)
)
group_by_length: Optional[bool] = field(default=False, metadata={"help": "whether to group by length"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})

lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})

learning_rate: Optional[float] = field(default=1e-4, metadata={"help": "the learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
num_warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})

output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
peft_config: LoraConfig = field(
default_factory=lambda: LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
)


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
script_args = tyro.cli(ScriptArguments)

if script_args.group_by_length and script_args.packing:
if script_args.training_args.group_by_length and script_args.packing:
raise ValueError("Cannot use both packing and group by length")

# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`.
if script_args.training_args.gradient_checkpointing:
raise ValueError("gradient_checkpointing not supported")


def chars_token_ratio(dataset, tokenizer, nb_examples=400):
"""
Expand Down Expand Up @@ -155,38 +170,13 @@ def create_datasets(tokenizer, args):
)
base_model.config.use_cache = False

peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
peft_config = script_args.peft_config

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training


training_args = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
learning_rate=script_args.learning_rate,
logging_steps=script_args.logging_steps,
max_steps=script_args.max_steps,
report_to=script_args.log_with,
save_steps=script_args.save_steps,
group_by_length=script_args.group_by_length,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_steps=script_args.num_warmup_steps,
optim=script_args.optimizer_type,
bf16=True,
remove_unused_columns=False,
run_name="sft_llama2",
)
training_args = script_args.training_args

train_dataset, eval_dataset = create_datasets(tokenizer, script_args)

Expand All @@ -201,9 +191,9 @@ def create_datasets(tokenizer, args):
args=training_args,
)
trainer.train()
trainer.save_model(script_args.output_dir)
trainer.save_model(script_args.training_args.output_dir)

output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
output_dir = os.path.join(script_args.training_args.output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)

# Free memory for merging weights
Expand All @@ -213,5 +203,5 @@ def create_datasets(tokenizer, args):
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

output_merged_dir = os.path.join(script_args.output_dir, "final_merged_checkpoint")
output_merged_dir = os.path.join(script_args.training_args.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)

5 comments on commit 02f5c1d

@Elfsong
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nrailg Why import the weird library "tyro" into this project? It doesn't work now.

@younesbelkada
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Elfsong can you elaborate more? why it is not working ?

@Elfsong
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Elfsong can you elaborate more? why it is not working ?

huggingface/transformers#27276 It has been fixed.

@nrailg
Copy link
Contributor Author

@nrailg nrailg commented on 02f5c1d Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nrailg Why import the weird library "tyro" into this project? It doesn't work now.

I did as suggested.

#906

@Elfsong
Copy link

@Elfsong Elfsong commented on 02f5c1d Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nrailg Why import the weird library "tyro" into this project? It doesn't work now.

I did as suggested.

#906

This library does lead a lot of errors...

Please sign in to comment.