Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Add unsloth optimizations in TRL's documentation #1119

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,61 @@ We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssis
</div>

Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.

### Accelerate fine-tuning 2x using `unsloth`

You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:

```python
import torch

from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLlamaModel, FastMistralModel

max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# Load Llama model
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

# Do model patching and add fast LoRA weights
model = FastLlamaModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Currently only supports dropout = 0
bias = "none", # Currently only supports bias = "none"
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = max_seq_length,
)

args = TrainingArguments(output_dir="./output")

trainer = SFTTrainer(
model = model,
args = args,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
)

trainer.train()
```

The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).

## Best practices

Pay attention to the following best practices when training a model with that trainer:
Expand Down
Loading