Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different finetune speed in DPO task of peft and ms-swift (600/S iter vs 30/s iter) #2536

Open
7 of 9 tasks
maoulee opened this issue Jan 2, 2025 · 5 comments
Open
7 of 9 tasks
Labels
🏋 DPO Related to DPO 🙋 help from community wanted Open invitation for community members to contribute ⚡ PEFT Related to PEFT

Comments

@maoulee
Copy link

maoulee commented Jan 2, 2025

System Info

  • transformers version: 4.45.0
  • Platform: Linux-5.15.0-60-generic-x86_64-with-glibc2.31
  • Python version: 3.10.14
  • Huggingface_hub version: 0.25.1
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100-SXM4-40GB

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Optimized Problem Description in English:

Swift CLI Configuration:

USE_HF=1
CUDA_VISIBLE_DEVICES=0,1
swift rlhf \
--rlhf_type dpo \
--model_type qwen2_5 \
--model /root/.cache/modelscope/hub/unsloth/Qwen2___5-32B-Instruct-bnb-4bit/ \
--train_type lora \
--tuner_backend peft \
--dataset llamafactory/ultrafeedback_binarized#2000 \
--num_train_epochs 2 \
--learning_rate 5e-6 \
--lora_rank 8 \
--lora_alpha 32 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing_kwargs '{"use_reentrant": false}' \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--lora_dropout 0.05 \
--logging_steps 100 \
--quant_method bnb \
--quant_bit 4 \
--max_new_tokens 1500

Fine-tuning Speed:

Train: 7%|█████▏ | 17/246 [10:46<2:15:35, 35.53s/it]
Train: 28%|█████████████████████ | 69/246 [41:48<1:47:15, 36.36s/it]

PEFT Configuration:

train_dataset = load_dataset("llamafactory/ultrafeedback_binarized", split="train")
train_dataset = train_dataset.shuffle(seed=42)
train_dataset = train_dataset.select(range(2000))
train_dataset = train_dataset.map(map_instruction)

test_dataset = load_dataset("llamafactory/ultrafeedback_binarized", split="test")
test_dataset = test_dataset.shuffle(seed=42)
test_dataset = test_dataset.select(range(200))
test_dataset = test_dataset.map(map_instruction)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="/root/.cache/modelscope/hub/unsloth/Qwen2___5-32B-Instruct-bnb-4bit/",
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    use_cache=False,
    device_map="auto",
)
model = prepare_model_for_kbit_training(model)
model.gradient_checkpointing_enable()

peft_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

tokenizer = AutoTokenizer.from_pretrained("/root/.cache/modelscope/hub/unsloth/Qwen2___5-32B-Instruct-bnb-4bit/")
EOS_TOKEN = tokenizer.eos_token

# Tokenizer settings
if tokenizer.chat_template is None:
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

training_args = DPOConfig(
    output_dir="/llm/checkpoint/",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=16,
    num_train_epochs=4,
    learning_rate=5e-7,
    logging_dir="./logs",
    logging_steps=500,
    save_steps=500,
    eval_strategy="no",
    beta=0.1,
    loss_type="sigmoid",
    optim="adamw_torch",
    max_length=2048,
    max_prompt_length=500
)

model.enable_input_require_grads()
trainer = DPOTrainer(
    model=model,
    args=training_args,
    peft_config=peft_config,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer
)

# Configure generation for evaluation
if training_args.eval_strategy != "no":
    generation_config = GenerationConfig(
        max_new_tokens=2048,
        do_sample=True,
        temperature=1.0
    )
    completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
    trainer.add_callback(completions_callback)

# Train the model
trainer.train()

Fine-tuning Speed:

Could not estimate the number of tokens of the input, floating-point operations will not be computed
0%|▎ | 1/248 [10:40<43:57:17, 640.64s/it]

Expected behavior

Is there something wrong with my peft setup?

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@August-murr August-murr added 🏋 DPO Related to DPO ⚡ PEFT Related to PEFT labels Jan 3, 2025
@qgallouedec
Copy link
Member

qgallouedec commented Jan 8, 2025

It's not very clear what code you're using. Because you seem to be using a command (swift rlhf) that I'm not familiar with and code that you provide doesn't take any arguments.
Plus, the system info that you provide aren't enough (I don't see the trl version among other). Can you copy-paste the output of trl env?
What is map_instruction? What model are you using? Qwen2 doesn't have a 32B version. Is it Qwen2.5?
Currently It's very hard for me to reproduce it.

@maoulee
Copy link
Author

maoulee commented Jan 8, 2025

It's not very clear what code you're using. Because you seem to be using a command (swift rlhf) that I'm not familiar with and code that you provide doesn't take any arguments. Plus, the system info that you provide aren't enough (I don't see the trl version among other). Can you copy-paste the output of trl env? What is map_instruction? What model are you using? Qwen2 doesn't have a 32B version. Is it Qwen2.5? Currently It's very hard for me to reproduce it.

Here is the trl env info:

  • Platform: Linux-5.15.0-60-generic-x86_64-with-glibc2.31
  • Python version: 3.10.14
  • PyTorch version: 2.4.1
  • CUDA device(s): NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB
  • Transformers version: 4.47.1
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • Datasets version: 3.0.1
  • HF Hub version: 0.25.1
  • TRL version: 0.13.0
  • bitsandbytes version: 0.45.0
  • DeepSpeed version: 0.15.4
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.50.2
  • PEFT version: 0.13.2

To print trl env:
I upgrade:
TRL 0.11.4-> 0.13.0 for
transformers: 4.45.2->4.47.1
tokenizers: 0.20.3->0.20.1

The map_instruction function is used to map the dataset.
Here is the used model and dataset:
model: unsloth/Qwen2.5-32B-Instruct-bnb-4bit
dataset: llamafactory/ultrafeedback_binarized

Here is the complete code:
import torch
import os
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
BitsAndBytesConfig
)
from trl import (
LogCompletionsCallback,
ModelConfig,
DPOConfig,
DPOTrainer,
TrlParser,
get_peft_config,
)
from peft import LoraConfig,get_peft_model,prepare_model_for_kbit_training
import data
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from datasets import load_dataset
os.environ["WANDB_DISABLED"]="true"

def map_instruction(example):
instruction = example['instruction']
return {'prompt': instruction}

def main():
train_dataset = load_dataset("/root/.cache/modelscope/hub/datasets/llamafactory___ultrafeedback_binarized/", split="train")
train_dataset = train_dataset.shuffle(seed=42)
train_dataset = train_dataset.select(range(2000))
train_dataset=train_dataset.map(map_instruction)

test_dataset=load_dataset("/root/.cache/modelscope/hub/datasets/llamafactory___ultrafeedback_binarized/", split="test")
test_dataset = test_dataset.shuffle(seed=42)
test_dataset = test_dataset.select(range(200))
test_dataset=test_dataset.map(map_instruction)

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",

)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="/root/.cache/modelscope/hub/unsloth/Qwen2___5-32B-Instruct-bnb-4bit/",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
use_cache= True,
device_map="auto",
)
model=prepare_model_for_kbit_training(model)
#model.gradient_checkpointing_enable()
peft_config = LoraConfig(
r=4,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
tokenizer=AutoTokenizer.from_pretrained("/root/.cache/modelscope/hub/unsloth/Qwen2___5-32B-Instruct-bnb-4bit/")
EOS_TOKEN = tokenizer.eos_token

# Tokenizer settings
if tokenizer.chat_template is None:
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"


training_args = DPOConfig(
    output_dir="/llm/checkpoint/",  # Output directory for checkpoints and final model
    per_device_train_batch_size=2,  # Batch size per device during training
    gradient_accumulation_steps=8,  # Number of gradient accumulation steps
    num_train_epochs=4,  # Total number of training epochs
    learning_rate=5e-7,  # Learning rate
    logging_dir="./logs",  # Directory for storing logs
    logging_steps=500,  # Log every X updates steps
    save_steps=500,  # Save checkpoint every X updates steps
    eval_strategy="no",  # Evaluation is done (and logged) every `eval_steps`
    beta=0.1,  # The beta parameter for DPO loss
    loss_type="sigmoid",
    optim = "adamw_torch",
    lr_scheduler_type="cosine",
    max_prompt_length=500,
    max_target_length=1500,
)
model.enable_input_require_grads()
trainer=DPOTrainer(
    model=model,
    args=training_args,
    peft_config=peft_config,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer
)


# Configure generation for evaluation
if training_args.eval_strategy != "no":
    generation_config = GenerationConfig(
        max_new_tokens=2048,
        do_sample=True,
        temperature=1.0
    )
    completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
    trainer.add_callback(completions_callback)

# Train the model
trainer.train()

# Save the final model
final_model_path = training_args.output_dir
trainer.save_model(final_model_path)
print(f"Final model saved to {final_model_path}")

if name == "main":
main()

@qgallouedec
Copy link
Member

qgallouedec commented Jan 8, 2025

I was able to reproduce the speed. I don't know how swift is different form trl (it's built upon trl as far as I understand). You should probably ask swift community here

@qgallouedec qgallouedec added the 🙋 help from community wanted Open invitation for community members to contribute label Jan 8, 2025
@maoulee
Copy link
Author

maoulee commented Jan 8, 2025

I was able to reproduce the speed. I don't know how swift is different form trl (it's built upon trl as far as I understand). You should probably ask swift community here

Thank you for your response. I have identified the key issue:

When I load the model and pass the peft_config directly into DPOTrainer, the fine-tuning speed is 600 seconds per iteration.

However, when I use model = get_peft_model(model, peft_config) before passing it to the trainer, the fine-tuning speed improves significantly to 30.2 seconds per iteration.

The logic of the two seems to be the same, but the speed difference is large.

@qgallouedec
Copy link
Member

qgallouedec commented Jan 8, 2025

It's probably because when you pass a peft model, it gets merged and unload (merge_and_unload). Those two settings should be equivalent though. It's probably an issue with the DPOTrainer. If you manage to fix it, feel free to open a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO 🙋 help from community wanted Open invitation for community members to contribute ⚡ PEFT Related to PEFT
Projects
None yet
Development

No branches or pull requests

3 participants