Skip to content

Commit

Permalink
Update KTO example to use better model and ChatML support (#1485)
Browse files Browse the repository at this point in the history
* Update KTO example

* Tweak params

* Fix values

* Fix LoRA params
  • Loading branch information
lewtun authored Mar 27, 2024
1 parent 7ff6206 commit 0ee349d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
32 changes: 18 additions & 14 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,38 @@
# Full training:
python examples/scripts/kto.py \
--model_name_or_path=stabilityai/stablelm-2-zephyr-1_6b \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 16 \
--num_train_epochs 1 \
--learning_rate 2e-5 \
--learning_rate 1e-5 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="kto-aligned-model" \
--warmup_steps 150 \
--output_dir=kto-aligned-model \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step
# LoRA:
# QLoRA:
python examples/scripts/kto.py \
--model_name_or_path=stabilityai/stablelm-2-zephyr-1_6b \
--per_device_train_batch_size 16 \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--learning_rate 2e-4 \
--learning_rate 1e-4 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="kto-aligned-model-lora" \
--warmup_steps 150 \
--output_dir=kto-aligned-model-lora \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step \
--use_peft \
--load_in_4bit \
--lora_target_modules=all-linear \
--lora_r=16 \
--lora_alpha=16
"""
Expand All @@ -54,7 +58,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format


# Define and parse arguments.
Expand All @@ -78,10 +82,10 @@ class ScriptArguments:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
raise ValueError(
"Tokenizer must have a chat template in order to format the examples. Alternatively, adjust this script to format the examples differently."
)
model, tokenizer = setup_chat_format(model, tokenizer)

# Load the dataset
dataset = load_dataset(script_args.dataset_name)
Expand Down
7 changes: 6 additions & 1 deletion trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ def setup_chat_format(
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
)
# Make sure to update the generation config to use the new eos & bos token
# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# Update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
Expand Down

0 comments on commit 0ee349d

Please sign in to comment.