Skip to content

Commit

Permalink
Update TRL README.md to clean up models (#1706)
Browse files Browse the repository at this point in the history
  • Loading branch information
shepark authored Jan 24, 2025
1 parent fd1878f commit 929dcff
Showing 1 changed file with 2 additions and 95 deletions.
97 changes: 2 additions & 95 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,103 +79,10 @@ $ pip install -U -r requirements.txt
### Training
#### For meta-llama/Llama-2-7b-hf
The following example is for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model.
There are two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi sft.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset_name "lvwerra/stack-exchange-paired" \
--output_dir="./sft" \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
--do_train \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--lora_target_modules "q_proj" "v_proj" \
--bf16 \
--remove_unused_columns=False \
--run_name="sft_llama2" \
--report_to=none \
--use_habana \
--use_lazy_mode
```
To merge the adaptors to get the final sft merged checkpoint, we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="sft" --output_name="sft/final_merged_checkpoint"
```
2. Run the DPO trainer using the model saved by the previous step:
```
python ../gaudi_spawn.py --world_size 8 --use_mpi dpo.py \
--model_name_or_path="sft/final_merged_checkpoint" \
--tokenizer_name_or_path=meta-llama/Llama-2-7b-hf \
--lora_target_modules "q_proj" "v_proj" "k_proj" "out_proj" "fc_in" "fc_out" "wte" \
--output_dir="dpo" \
--report_to=none
```
#### mistralai/Mistral-7B-v0.1
1. Supervised fine-tuning of the base Mistral-7B-v0.1 model:
```
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 8 --use_deepspeed sft.py \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--dataset_name "lvwerra/stack-exchange-paired" \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--output_dir="./sft" \
--do_train \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
--per_device_train_batch_size=1 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--lora_target_modules "q_proj" "v_proj" \
--bf16 \
--remove_unused_columns=False \
--run_name="sft_mistral" \
--report_to=none \
--use_habana \
--use_lazy_mode
```
To merge the adaptors to get the final sft merged checkpoint, we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python merge_peft_adapter.py --base_model_name="mistralai/Mistral-7B-v0.1" --adapter_model_name="sft" --output_name="sft/final_merged_checkpoint"
```
2. Run the DPO trainer using the model saved by the previous step:
```
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 8 --use_deepspeed dpo.py \
--model_name_or_path="sft/final_merged_checkpoint" \
--tokenizer_name_or_path=mistralai/Mistral-7B-v0.1 \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--lora_target_modules "q_proj" "v_proj" "k_proj" "out_proj" "fc_in" "fc_out" "wte" \
--output_dir="dpo" \
--max_prompt_length=256 \
--max_length=512 \
--report_to=none
```
#### For meta-llama/Llama-2-70b-hf
The following example is for the creation of StackLlaMa 2: a Stack exchange llama-v2-70b model. There are two main steps to the DPO training process.
For large model like Llama2-70B, we could use DeepSpeed Zero-3 to enable DPO training in multi-card.
steps like:
1. Supervised fine-tuning of the base llama-v2-70b model to create llama-v2-70b-se:
Expand Down

0 comments on commit 929dcff

Please sign in to comment.