diff --git a/commands/run_sft.sh b/commands/run_sft.sh index ab67222618..a418bb058c 100644 --- a/commands/run_sft.sh +++ b/commands/run_sft.sh @@ -41,7 +41,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \ --dataset_name $DATASET_NAME \ --output_dir $OUTPUT_DIR \ --max_steps $MAX_STEPS \ - --batch_size $BATCH_SIZE \ + --per_device_train_batch_size $BATCH_SIZE \ --seq_length $SEQ_LEN \ $EXTRA_TRAINING_ARGS """ diff --git a/docs/source/lora_tuning_peft.mdx b/docs/source/lora_tuning_peft.mdx index 4b4345bc5f..582f1ac9d8 100644 --- a/docs/source/lora_tuning_peft.mdx +++ b/docs/source/lora_tuning_peft.mdx @@ -71,7 +71,7 @@ The `trl` library is powered by `accelerate`. As such it is best to configure an ```bash accelerate config # will prompt you to define the training configuration -accelerate launch scripts/gpt2-sentiment_peft.py # launches training +accelerate launch examples/scripts/ppo.py --use_peft # launch`es training ``` ## Using `trl` + `peft` and Data Parallelism @@ -140,5 +140,5 @@ python PATH_TO_SCRIPT You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB): ```bash -python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2 +python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2 ``` diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index d60af083e1..5072920cb4 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -88,6 +88,7 @@ class ScriptArguments: quantization_config=quantization_config, ) tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True) + tokenizer.pad_token = tokenizer.eos_token ################ # Dataset