From 5b4b637dace8f1d55883bb103994c681d1416486 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Dec 2023 16:26:17 +0000 Subject: [PATCH 1/2] add unsloth --- docs/source/sft_trainer.mdx | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index a6ca3e439c..e4115dbbde 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -410,6 +410,61 @@ We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssis Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains. + +### Accelerate fine-tuning using `unsloth` library + +You can further accelerate QLoRA / LoRA and even full-finetuning using [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama and Mistral architectures. +First install `unsloth` according to the official documentation. Once installed you can incorporate unsloth into your workflow with a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows: + +```python +import torch + +from transformers import TrainingArguments +from trl import SFTTrainer +from unsloth import FastLlamaModel, FastMistralModel + +max_seq_length = 2048 # Can change to any number <= 4096 +dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ +load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. + +# Load Llama model +model, tokenizer = FastLlamaModel.from_pretrained( + model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf +) + +# Do model patching and add fast LoRA weights +model = FastLlamaModel.get_peft_model( + model, + r = 16, + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj",], + lora_alpha = 16, + lora_dropout = 0, # Currently only supports dropout = 0 + bias = "none", # Currently only supports bias = "none" + use_gradient_checkpointing = True, + random_state = 3407, + max_seq_length = max_seq_length, +) + +args = TrainingArguments(output_dir="./output") + +trainer = SFTTrainer( + model = model, + args=args, + train_dataset = dataset, + dataset_text_field = "text", + max_seq_length = max_seq_length, +) + +trainer.train() +``` + +The saved model is compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). + ## Best practices Pay attention to the following best practices when training a model with that trainer: From a38e1ad786ba9b4d86fc1c438caf44cad56c2b37 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 21 Dec 2023 16:35:05 +0100 Subject: [PATCH 2/2] Update sft_trainer.mdx (#1124) Co-authored-by: Daniel Han --- docs/source/sft_trainer.mdx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index e4115dbbde..4c0c1abeac 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -411,10 +411,10 @@ We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssis Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains. -### Accelerate fine-tuning using `unsloth` library +### Accelerate fine-tuning 2x using `unsloth` -You can further accelerate QLoRA / LoRA and even full-finetuning using [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama and Mistral architectures. -First install `unsloth` according to the official documentation. Once installed you can incorporate unsloth into your workflow with a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows: +You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures. +First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows: ```python import torch @@ -423,7 +423,7 @@ from transformers import TrainingArguments from trl import SFTTrainer from unsloth import FastLlamaModel, FastMistralModel -max_seq_length = 2048 # Can change to any number <= 4096 +max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number. dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. @@ -454,7 +454,7 @@ args = TrainingArguments(output_dir="./output") trainer = SFTTrainer( model = model, - args=args, + args = args, train_dataset = dataset, dataset_text_field = "text", max_seq_length = max_seq_length, @@ -463,7 +463,7 @@ trainer = SFTTrainer( trainer.train() ``` -The saved model is compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). +The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). ## Best practices