Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

example: add llm-finetune example: baichuan2 #2973

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion example/LLM/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ Current supports datasets:
- For finetune:
- openassistant


## Build Starwhale Runtime

```bash
Expand Down
40 changes: 40 additions & 0 deletions example/llm-finetune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
LLM Finetune
======

LLM finetune is a state-of-art task for large language model.

In these examples, we will use Starwhale to finetune a set of LLM base models, evaluate and release models. The demos are in the [starwhale/llm-finetuning](https://cloud.starwhale.cn/projects/401/overview) project of Starwhale Cloud.

What we learn
------

- use the `@starwhale.finetune` decorator to define a finetune handler for Starwhale Model to finish the LLM finetune.
- use the `@starwhale.evaluation.predict` to define a model evaluation for LLM.
- use the `@starwhale.handler` to define a web handler for LLM online evaluation.
- use one Starwhale Runtime to run all models.
- build Starwhale Dataset by the one-line command from the Huggingface, no code.

Models
------

- [Baichuan2](https://github.com/baichuan-inc/Baichuan2): Baichuan 2 is the new generation of open-source large language models launched by Baichuan Intelligent Technology. It was trained on a high-quality corpus with 2.6 trillion tokens.
- [ChatGLM3](https://github.com/THUDM/ChatGLM3): ChatGLM3 is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series.

Datasets
------

- [Belle multiturn chat](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M): The dataset includes approx. 0.8M Chinese multiturn dialogs between human and assistant from BELLE Group.

```bash
# build the origin dataset from huggingface
swcli dataset build -hf BelleGroup/multiturn_chat_0.8M --name belle-multiturn-chat

# build the random 10k items by baichuan2
swcli dataset build --json https://raw.githubusercontent.com/baichuan-inc/Baichuan2/main/fine-tune/data/belle_chat_ramdon_10k.json --name belle_chat_random_10k
```

- [COIG](https://huggingface.co/datasets/BAAI/COIG): The Chinese Open Instruction Generalist (COIG) project is a harmless, helpful, and diverse set of Chinese instruction corpora.

```bash
swcli dataset build -hf BAAI/COIG --name coig
```
2 changes: 2 additions & 0 deletions example/llm-finetune/models/baichuan2/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pretrain/
.cache/
1 change: 1 addition & 0 deletions example/llm-finetune/models/baichuan2/.swignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.cache/
57 changes: 57 additions & 0 deletions example/llm-finetune/models/baichuan2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
Baichuan2 Finetune with Starwhale
======

- 🍬 Parameters: 7b
- 🔆 Github: https://github.com/baichuan-inc/Baichuan2
- 🥦 Author: Baichuan Inc.
- 📝 License: baichuan
- 🐱 Starwhale Example: https://github.com/star-whale/starwhale/tree/main/example/llm-finetune/models/baichuan2
- 🌽 Introduction: Baichuan 2 is the new generation of large-scale open-source language models launched by Baichuan Intelligence inc..It is trained on a high-quality corpus with 2.6 trillion tokens and has achieved the best performance in authoritative Chinese and English benchmarks of the same size.Baichuan2-7b-chat is chat model of Baichuan 2, which contains 7 billion parameters.

In this example, we will use Baichuan2-7b-chat as the base model to finetune and evaluate.

- Evaluate baichuan2-7b-chat model.
- Provide baichuan2-7b-chat multi-turn chat online evaluation.
- Fine-tune baichuan2-7b-chat model with belle-multiturn-chat dataset.
- Evaluate the fine-tuned model.
- Provide the fine-tuned model multi-turn chat online evaluation.
- Fine-tune fine-tuned baichuan2-7n-chat model.

Because of 4bit quantization technical, the single T4/A10/A100 gpu card is ok for evaluation and finetune.

Build Starwhale Model
------

```bash
python3 build.py
```

Run Online Evaluation in the Standalone instance
------

```bash
# for source code
swcli model run -w . -m evaluation --handler evaluation:chatbot

# for model package with runtime
swcli model run --uri baichuan2-7b-chat --handler evaluation:chatbot --runtime llm-finetune
```

Run Starwhale Model for evaluation in the Standalone instance
------

```bash
swcli dataset cp https://cloud.starwhale.cn/projects/401/datasets/161/versions/223/ .
swcli -vvv model run -w . -m evaluation --handler evaluation:copilot_predict --dataset z-bench-common --dataset-head 3
```

Finetune base model
------

```bash
# build finetune dataset from baichuan2
swcli dataset build --json https://raw.githubusercontent.com/baichuan-inc/Baichuan2/main/fine-tune/data/belle_chat_ramdon_10k.json --name belle_chat_random_10k

swcli -vvv model run -w . -m finetune --dataset belle_chat_random_10k --handler finetune:lora_finetune
```

32 changes: 32 additions & 0 deletions example/llm-finetune/models/baichuan2/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from huggingface_hub import snapshot_download

import starwhale

try:
from .utils import BASE_MODEL_DIR
from .finetune import lora_finetune
from .evaluation import chatbot, copilot_predict
except ImportError:
from utils import BASE_MODEL_DIR
from finetune import lora_finetune
from evaluation import chatbot, copilot_predict

starwhale.init_logger(3)


def build_starwhale_model() -> None:
BASE_MODEL_DIR.mkdir(parents=True, exist_ok=True)

snapshot_download(
repo_id="baichuan-inc/Baichuan2-7B-Chat",
local_dir=BASE_MODEL_DIR,
)

starwhale.model.build(
name="baichuan2-7b-chat",
modules=[copilot_predict, chatbot, lora_finetune],
)


if __name__ == "__main__":
build_starwhale_model()
145 changes: 145 additions & 0 deletions example/llm-finetune/models/baichuan2/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import os
import typing as t

import torch
import gradio
from peft import PeftModel
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig

from starwhale import handler, evaluation

try:
from .utils import BASE_MODEL_DIR, ADAPTER_MODEL_DIR
except ImportError:
from utils import BASE_MODEL_DIR, ADAPTER_MODEL_DIR

_g_model = None
_g_tokenizer = None


def _load_model_and_tokenizer() -> t.Tuple:
global _g_model, _g_tokenizer

if _g_model is None:
print(f"load model from {BASE_MODEL_DIR} ...")
_g_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_DIR,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
load_in_4bit=True, # for lower gpu memory usage
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
_g_model.generation_config = GenerationConfig.from_pretrained(BASE_MODEL_DIR)

if (ADAPTER_MODEL_DIR / "adapter_config.json").exists():
print(f"load adapter from {ADAPTER_MODEL_DIR} ...")
_g_model = PeftModel.from_pretrained(
_g_model, str(ADAPTER_MODEL_DIR), is_trainable=False
)

if _g_tokenizer is None:
print(f"load tokenizer from {BASE_MODEL_DIR} ...")
_g_tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_DIR, use_fast=False, trust_remote_code=True
)

return _g_model, _g_tokenizer


@evaluation.predict(
resources={"nvidia.com/gpu": 1},
replicas=1,
log_mode="plain",
)
def copilot_predict(data: dict) -> str:
model, tokenizer = _load_model_and_tokenizer()
# support z-bench-common dataset: https://cloud.starwhale.cn/projects/401/datasets/161/versions/223/files
messages = [{"role": "user", "content": data["prompt"]}]

config_dict = model.generation_config.to_dict()
# TODO: use arguments
config_dict.update(
max_new_tokens=int(os.environ.get("MAX_MODEL_LENGTH", 512)),
do_sample=True,
temperature=float(os.environ.get("TEMPERATURE", 0.7)),
top_p=float(os.environ.get("TOP_P", 0.9)),
top_k=int(os.environ.get("TOP_K", 30)),
repetition_penalty=float(os.environ.get("REPETITION_PENALTY", 1.3)),
)
return model.chat(
tokenizer,
messages=messages,
generation_config=GenerationConfig.from_dict(config_dict),
)


@handler(expose=17860)
def chatbot() -> None:
with gradio.Blocks() as server:
chatbot = gradio.Chatbot(height=800)
msg = gradio.Textbox(label="chat", show_label=True)
_max_gen_len = gradio.Slider(
0, 1024, value=256, step=1.0, label="Max Gen Len", interactive=True
)
_top_p = gradio.Slider(
0, 1, value=0.7, step=0.01, label="Top P", interactive=True
)
_temperature = gradio.Slider(
0, 1, value=0.95, step=0.01, label="Temperature", interactive=True
)
gradio.ClearButton([msg, chatbot])

def response(
from_user: str,
chat_history: t.List,
max_gen_len: int,
top_p: float,
temperature: float,
) -> t.Tuple[str, t.List]:
dialog = []
for _user, _assistant in chat_history:
dialog.append({"role": "user", "content": _user})
if _assistant:
dialog.append({"role": "assistant", "content": _assistant})
dialog.append({"role": "user", "content": from_user})

model, tokenizer = _load_model_and_tokenizer()
from_assistant = model.chat(
tokenizer,
messages=dialog,
generation_config=GenerationConfig(
max_new_tokens=max_gen_len,
do_sample=True,
temperature=temperature,
top_p=top_p,
),
)

chat_history.append((from_user, from_assistant))
return "", chat_history

msg.submit(
response,
[msg, chatbot, _max_gen_len, _top_p, _temperature],
[msg, chatbot],
)

server.launch(
server_name="0.0.0.0",
server_port=17860,
share=True,
root_path=os.environ.get(
"SW_ONLINE_SERVING_ROOT_PATH"
), # workaround for the embedded web page in starwhale server
)
Loading