Skip to content

Commit

Permalink
Merge pull request #114 from whn09/main
Browse files Browse the repository at this point in the history
fix readme bug, add multi gpu finetune description
  • Loading branch information
LlamaFamily authored Aug 16, 2023
2 parents a531146 + 2d097dd commit e5e35f9
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ doker-compose up -d --build

#### Step3: 微调脚本

我们提供了用于微调的脚本[train/sft/finetune.sh](https://github.com/FlagAlpha/Llama2-Chinese/blob/main/train/sft/finetune.sh),通过修改脚本的部分参数实现模型的微调,关于微调的具体代码见[train/sft/finetune_clm_lora.py](https://github.com/FlagAlpha/Llama2-Chinese/blob/main/train/sft/finetune_clm_lora.py)
我们提供了用于微调的脚本[train/sft/finetune.sh](https://github.com/FlagAlpha/Llama2-Chinese/blob/main/train/sft/finetune.sh),通过修改脚本的部分参数实现模型的微调,关于微调的具体代码见[train/sft/finetune_clm_lora.py](https://github.com/FlagAlpha/Llama2-Chinese/blob/main/train/sft/finetune_clm_lora.py),单机多卡的微调可以通过修改脚本中的`--include localhost:0`来实现

### 中文微调参数
我们基于中文指令数据集对Llama2-Chat模型进行了微调,使得Llama2模型有着更强的中文对话能力。LoRA参数以及与基础模型合并的参数均已上传至[Hugging Face](https://huggingface.co/FlagAlpha),目前包含7B和13B的模型。
Expand All @@ -308,19 +308,20 @@ doker-compose up -d --build
| Llama2-Chinese-13b-Chat | FlagAlpha/Llama2-Chinese-13b-Chat | meta-llama/Llama-2-13b-chat-hf | [模型下载](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) | 中文指令微调的LoRA参数与基础模型参数合并版本 |

### 加载微调模型
通过[PEFT](https://github.com/huggingface/peft)加载预训练模型参数和微调模型参数,以下示例代码中,base_model_name_or_path为预训练模型参数保存路径,fintune_model_path为微调模型参数保存路径
通过[PEFT](https://github.com/huggingface/peft)加载预训练模型参数和微调模型参数,以下示例代码中,base_model_name_or_path为预训练模型参数保存路径,finetune_model_path为微调模型参数保存路径

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel,PeftConfig
# 例如: fintune_model_path='FlagAlpha/Llama2-Chinese-7b-Chat-LoRA'
fintune_model_path=''
config = PeftConfig.from_pretrained(fintune_model_path)
# 例如: finetune_model_path='FlagAlpha/Llama2-Chinese-7b-Chat-LoRA'
finetune_model_path=''
config = PeftConfig.from_pretrained(finetune_model_path)
# 例如: base_model_name_or_path='meta-llama/Llama-2-7b-chat'
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path,use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path,device_map='auto',torch_dtype=torch.float16,load_in_8bit=True)
model = PeftModel.from_pretrained(model, fintune_model_path, device_map={"": 0})
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,device_map='auto',torch_dtype=torch.float16,load_in_8bit=True)
model = PeftModel.from_pretrained(model, finetune_model_path, device_map={"": 0})
model =model.eval()
input_ids = tokenizer(['<s>Human: 介绍一下北京\n</s><s>Assistant: '], return_tensors="pt",add_special_tokens=False).input_ids.to('cuda')
generate_input = {
Expand Down

0 comments on commit e5e35f9

Please sign in to comment.