Skip to content

Commit

Permalink
Removing tyro in sft_llama2.py (#1081)
Browse files Browse the repository at this point in the history
* refactor

* precommit
  • Loading branch information
vwxyzjn authored Dec 11, 2023
1 parent 94fa4b0 commit 393dbf6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 50 deletions.
29 changes: 27 additions & 2 deletions examples/research_projects/stack_llama_2/scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,34 @@ $ accelerate config

There were two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
- `accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py --training_args.output_dir="sft"`

```
accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \
--output_dir="./sft" \
--max_steps=500 \
--logging_steps=10 \
--save_steps=10 \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=False \
--group_by_length=False \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--bf16=True \
--remove_unused_columns=False \
--run_name="sft_llama2" \
--report_to="wandb"
```
1. Run the DPO trainer using the model saved by the previous step:
- `accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py --model_name_or_path="sft/final_checkpoint" --output_dir="dpo"`
```
accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \
--model_name_or_path="sft/final_checkpoint" \
--output_dir="dpo"
```
## Merging the adaptors
Expand Down
70 changes: 22 additions & 48 deletions examples/research_projects/stack_llama_2/scripts/sft_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from typing import Optional

import torch
import tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

from trl import SFTTrainer
from trl.import_utils import is_xpu_available
Expand All @@ -19,7 +18,6 @@
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})

dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
Expand All @@ -28,52 +26,31 @@ class ScriptArguments:
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})

training_args: TrainingArguments = field(
default_factory=lambda: TrainingArguments(
output_dir="./results",
max_steps=500,
logging_steps=10,
save_steps=10,
per_device_train_batch_size=4,
per_device_eval_batch_size=1,
gradient_accumulation_steps=2,
gradient_checkpointing=False,
group_by_length=False,
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.05,
optim="paged_adamw_32bit",
bf16=True,
remove_unused_columns=False,
run_name="sft_llama2",
report_to="wandb",
)
)

packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})

peft_config: LoraConfig = field(
default_factory=lambda: LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
)


script_args = tyro.cli(ScriptArguments)
# LoraConfig
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})


parser = HfArgumentParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_into_dataclasses()
peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)

if script_args.training_args.group_by_length and script_args.packing:
if training_args.group_by_length and script_args.packing:
raise ValueError("Cannot use both packing and group by length")

# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`.
if script_args.training_args.gradient_checkpointing:
if training_args.gradient_checkpointing:
raise ValueError("gradient_checkpointing not supported")


Expand Down Expand Up @@ -171,14 +148,11 @@ def create_datasets(tokenizer, args):
)
base_model.config.use_cache = False

peft_config = script_args.peft_config

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

training_args = script_args.training_args

train_dataset, eval_dataset = create_datasets(tokenizer, script_args)

trainer = SFTTrainer(
Expand All @@ -192,9 +166,9 @@ def create_datasets(tokenizer, args):
args=training_args,
)
trainer.train()
trainer.save_model(script_args.training_args.output_dir)
trainer.save_model(training_args.output_dir)

output_dir = os.path.join(script_args.training_args.output_dir, "final_checkpoint")
output_dir = os.path.join(training_args.output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)

# Free memory for merging weights
Expand All @@ -207,5 +181,5 @@ def create_datasets(tokenizer, args):
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

output_merged_dir = os.path.join(script_args.training_args.output_dir, "final_merged_checkpoint")
output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)

0 comments on commit 393dbf6

Please sign in to comment.