Skip to content

Commit

Permalink
add dpo example tests
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Jan 12, 2024
1 parent d82983e commit a0552d5
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/slow-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,6 @@ jobs:
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
rm *.txt
56 changes: 56 additions & 0 deletions commands/run_dpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128

# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""

# This is a hack to get the number of available GPUs
mapfile -t num_gpus < <(nvidia-smi --format=csv --query-gpu=index | tail -n+2 | wc -l)
NUM_GPUS=${num_gpus[0]}

if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi


CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
`pwd`/examples/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""

echo "Starting program..."

{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0
46 changes: 40 additions & 6 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from typing import Dict, Optional

import torch
from accelerate import PartialState
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

from trl import DPOTrainer
from trl import DPOTrainer, is_xpu_available


# Define and parse arguments.
Expand All @@ -45,6 +46,9 @@ class ScriptArguments:
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
fp16: Optional[bool] = field(default=False, metadata={"help": "Whether to activate fp16 mixed precision"})
bf16: Optional[bool] = field(default=False, metadata={"help": "Whether to activate bf16 mixed precision"})
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: Optional[int] = field(
Expand Down Expand Up @@ -83,6 +87,8 @@ class ScriptArguments:
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})


def extract_anthropic_prompt(prompt_and_response):
Expand Down Expand Up @@ -126,16 +132,43 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# Copy the model to each device
device_map = (
{"": f"xpu:{PartialState().local_process_index}"}
if is_xpu_available()
else {"": PartialState().local_process_index}
)
torch_dtype = torch.bfloat16
else:
device_map = None
quantization_config = None
torch_dtype = None

# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=torch_dtype,
)

if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
if not script_args.use_peft:
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
else:
# If one uses PEFT, there is no need to load a reference model
model_ref = None

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
if tokenizer.pad_token is None:
Expand All @@ -158,11 +191,12 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
logging_first_step=True,
logging_steps=10, # match results in blog post
eval_steps=500,
output_dir="./test",
output_dir=script_args.output_dir,
optim="rmsprop",
warmup_steps=150,
report_to=script_args.report_to,
bf16=True,
bf16=script_args.bf16,
fp16=script_args.fp16,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next transformers release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
Expand Down
4 changes: 3 additions & 1 deletion scripts/log_example_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,14 @@ def main(text_file_name, slack_channel_name=None):
}
payload.append(action_button)

test_type_name = text_file_name.replace(".txt", "").replace("temp_results_", "")

date_report = {
"type": "context",
"elements": [
{
"type": "plain_text",
"text": f"Nightly {os.environ.get('TEST_TYPE')} test results for {date.today()}",
"text": f"Nightly {os.environ.get('TEST_TYPE') + test_type_name} test results for {date.today()}",
},
],
}
Expand Down

0 comments on commit a0552d5

Please sign in to comment.