-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
df2fb05
commit 783e5cd
Showing
11 changed files
with
490 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ Current supports datasets: | |
- For finetune: | ||
- openassistant | ||
|
||
|
||
## Build Starwhale Runtime | ||
|
||
```bash | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
LLM Finetune | ||
====== | ||
|
||
LLM finetune is a state-of-art task for large language model. | ||
|
||
In these examples, we will use Starwhale to finetune a set of LLM base models, evaluate and release models. The demos are in the [starwhale/llm-finetuning](https://cloud.starwhale.cn/projects/401/overview) project of Starwhale Cloud. | ||
|
||
What we learn | ||
------ | ||
|
||
- use the `@starwhale.finetune` decorator to define a finetune handler for Starwhale Model to finish the LLM finetune. | ||
- use the `@starwhale.evaluation.predict` to define a model evaluation for LLM. | ||
- use the `@starwhale.handler` to define a web handler for LLM online evaluation. | ||
- use one Starwhale Runtime to run all models. | ||
- build Starwhale Dataset by the one-line command from the Huggingface, no code. | ||
|
||
Models | ||
------ | ||
|
||
- [Baichuan2](https://github.com/baichuan-inc/Baichuan2): Baichuan 2 is the new generation of open-source large language models launched by Baichuan Intelligent Technology. It was trained on a high-quality corpus with 2.6 trillion tokens. | ||
- [ChatGLM3](https://github.com/THUDM/ChatGLM3): ChatGLM3 is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series. | ||
|
||
Datasets | ||
------ | ||
|
||
- [Belle multiturn chat](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M): The dataset includes approx. 0.8M Chinese multiturn dialogs between human and assistant from BELLE Group. | ||
|
||
```bash | ||
# build the origin dataset from huggingface | ||
swcli dataset build -hf BelleGroup/multiturn_chat_0.8M --name belle-multiturn-chat | ||
|
||
# build the random 10k items by baichuan2 | ||
swcli dataset build --json https://raw.githubusercontent.com/baichuan-inc/Baichuan2/main/fine-tune/data/belle_chat_ramdon_10k.json --name belle_chat_random_10k | ||
``` | ||
|
||
- [COIG](https://huggingface.co/datasets/BAAI/COIG): The Chinese Open Instruction Generalist (COIG) project is a harmless, helpful, and diverse set of Chinese instruction corpora. | ||
|
||
```bash | ||
swcli dataset build -hf BAAI/COIG --name coig | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pretrain/ | ||
.cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
Baichuan2 Finetune with Starwhale | ||
====== | ||
|
||
- 🍬 Parameters: 7b | ||
- 🔆 Github: https://github.com/baichuan-inc/Baichuan2 | ||
- 🥦 Author: Baichuan Inc. | ||
- 📝 License: baichuan | ||
- 🐱 Starwhale Example: https://github.com/star-whale/starwhale/tree/main/example/llm-finetune/models/baichuan2 | ||
- 🌽 Introduction: Baichuan 2 is the new generation of large-scale open-source language models launched by Baichuan Intelligence inc..It is trained on a high-quality corpus with 2.6 trillion tokens and has achieved the best performance in authoritative Chinese and English benchmarks of the same size.Baichuan2-7b-chat is chat model of Baichuan 2, which contains 7 billion parameters. | ||
|
||
In this example, we will use Baichuan2-7b-chat as the base model to finetune and evaluate. | ||
|
||
- Evaluate baichuan2-7b-chat model. | ||
- Provide baichuan2-7b-chat multi-turn chat online evaluation. | ||
- Fine-tune baichuan2-7b-chat model with belle-multiturn-chat dataset. | ||
- Evaluate the fine-tuned model. | ||
- Provide the fine-tuned model multi-turn chat online evaluation. | ||
- Fine-tune fine-tuned baichuan2-7n-chat model. | ||
|
||
Because of 4bit quantization technical, the single T4/A10/A100 gpu card is ok for evaluation and finetune. | ||
|
||
Build Starwhale Model | ||
------ | ||
|
||
```bash | ||
python3 build.py | ||
``` | ||
|
||
Run Online Evaluation in the Standalone instance | ||
------ | ||
|
||
```bash | ||
# for source code | ||
swcli model run -w . -m evaluation --handler evaluation:chatbot | ||
|
||
# for model package with runtime | ||
swcli model run --uri baichuan2-7b-chat --handler evaluation:chatbot --runtime llm-finetune | ||
``` | ||
|
||
Run Starwhale Model for evaluation in the Standalone instance | ||
------ | ||
|
||
```bash | ||
swcli dataset cp https://cloud.starwhale.cn/projects/401/datasets/161/versions/223/ . | ||
swcli -vvv model run -w . -m evaluation --handler evaluation:copilot_predict --dataset z-bench-common --dataset-head 3 | ||
``` | ||
|
||
Finetune base model | ||
------ | ||
|
||
```bash | ||
# build finetune dataset from baichuan2 | ||
swcli dataset build --json https://raw.githubusercontent.com/baichuan-inc/Baichuan2/main/fine-tune/data/belle_chat_ramdon_10k.json --name belle_chat_random_10k | ||
|
||
swcli -vvv model run -w . -m finetune --dataset belle_chat_random_10k --handler finetune:lora_finetune | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from huggingface_hub import snapshot_download | ||
|
||
import starwhale | ||
|
||
try: | ||
from .utils import BASE_MODEL_DIR | ||
from .finetune import lora_finetune | ||
from .evaluation import chatbot, copilot_predict | ||
except ImportError: | ||
from utils import BASE_MODEL_DIR | ||
from finetune import lora_finetune | ||
from evaluation import chatbot, copilot_predict | ||
|
||
starwhale.init_logger(3) | ||
|
||
|
||
def build_starwhale_model() -> None: | ||
BASE_MODEL_DIR.mkdir(parents=True, exist_ok=True) | ||
|
||
snapshot_download( | ||
repo_id="baichuan-inc/Baichuan2-7B-Chat", | ||
local_dir=BASE_MODEL_DIR, | ||
) | ||
|
||
starwhale.model.build( | ||
name="baichuan2-7b-chat", | ||
modules=[copilot_predict, chatbot, lora_finetune], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
build_starwhale_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import typing as t | ||
|
||
import torch | ||
import gradio | ||
from peft import PeftModel | ||
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM | ||
from transformers.generation.utils import GenerationConfig | ||
|
||
from starwhale import handler, evaluation | ||
|
||
try: | ||
from .utils import BASE_MODEL_DIR, ADAPTER_MODEL_DIR | ||
except ImportError: | ||
from utils import BASE_MODEL_DIR, ADAPTER_MODEL_DIR | ||
|
||
_g_model = None | ||
_g_tokenizer = None | ||
|
||
|
||
def _load_model_and_tokenizer() -> t.Tuple: | ||
global _g_model, _g_tokenizer | ||
|
||
if _g_model is None: | ||
print(f"load model from {BASE_MODEL_DIR} ...") | ||
_g_model = AutoModelForCausalLM.from_pretrained( | ||
BASE_MODEL_DIR, | ||
device_map="auto", | ||
torch_dtype=torch.float16, | ||
trust_remote_code=True, | ||
load_in_4bit=True, # for lower gpu memory usage | ||
quantization_config=BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
llm_int8_threshold=6.0, | ||
llm_int8_has_fp16_weight=False, | ||
bnb_4bit_compute_dtype=torch.float16, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type="nf4", | ||
), | ||
) | ||
_g_model.generation_config = GenerationConfig.from_pretrained(BASE_MODEL_DIR) | ||
|
||
if (ADAPTER_MODEL_DIR / "adapter_config.json").exists(): | ||
print(f"load adapter from {ADAPTER_MODEL_DIR} ...") | ||
_g_model = PeftModel.from_pretrained( | ||
_g_model, str(ADAPTER_MODEL_DIR), is_trainable=False | ||
) | ||
|
||
if _g_tokenizer is None: | ||
print(f"load tokenizer from {BASE_MODEL_DIR} ...") | ||
_g_tokenizer = AutoTokenizer.from_pretrained( | ||
BASE_MODEL_DIR, use_fast=False, trust_remote_code=True | ||
) | ||
|
||
return _g_model, _g_tokenizer | ||
|
||
|
||
@evaluation.predict( | ||
resources={"nvidia.com/gpu": 1}, | ||
replicas=1, | ||
log_mode="plain", | ||
) | ||
def copilot_predict(data: dict) -> str: | ||
model, tokenizer = _load_model_and_tokenizer() | ||
# support z-bench-common dataset: https://cloud.starwhale.cn/projects/401/datasets/161/versions/223/files | ||
messages = [{"role": "user", "content": data["prompt"]}] | ||
|
||
config_dict = model.generation_config.to_dict() | ||
# TODO: use arguments | ||
config_dict.update( | ||
max_new_tokens=int(os.environ.get("MAX_MODEL_LENGTH", 512)), | ||
do_sample=True, | ||
temperature=float(os.environ.get("TEMPERATURE", 0.7)), | ||
top_p=float(os.environ.get("TOP_P", 0.9)), | ||
top_k=int(os.environ.get("TOP_K", 30)), | ||
repetition_penalty=float(os.environ.get("REPETITION_PENALTY", 1.3)), | ||
) | ||
return model.chat( | ||
tokenizer, | ||
messages=messages, | ||
generation_config=GenerationConfig.from_dict(config_dict), | ||
) | ||
|
||
|
||
@handler(expose=17860) | ||
def chatbot() -> None: | ||
with gradio.Blocks() as server: | ||
chatbot = gradio.Chatbot(height=800) | ||
msg = gradio.Textbox(label="chat", show_label=True) | ||
_max_gen_len = gradio.Slider( | ||
0, 1024, value=256, step=1.0, label="Max Gen Len", interactive=True | ||
) | ||
_top_p = gradio.Slider( | ||
0, 1, value=0.7, step=0.01, label="Top P", interactive=True | ||
) | ||
_temperature = gradio.Slider( | ||
0, 1, value=0.95, step=0.01, label="Temperature", interactive=True | ||
) | ||
gradio.ClearButton([msg, chatbot]) | ||
|
||
def response( | ||
from_user: str, | ||
chat_history: t.List, | ||
max_gen_len: int, | ||
top_p: float, | ||
temperature: float, | ||
) -> t.Tuple[str, t.List]: | ||
dialog = [] | ||
for _user, _assistant in chat_history: | ||
dialog.append({"role": "user", "content": _user}) | ||
if _assistant: | ||
dialog.append({"role": "assistant", "content": _assistant}) | ||
dialog.append({"role": "user", "content": from_user}) | ||
|
||
model, tokenizer = _load_model_and_tokenizer() | ||
from_assistant = model.chat( | ||
tokenizer, | ||
messages=dialog, | ||
generation_config=GenerationConfig( | ||
max_new_tokens=max_gen_len, | ||
do_sample=True, | ||
temperature=temperature, | ||
top_p=top_p, | ||
), | ||
) | ||
|
||
chat_history.append((from_user, from_assistant)) | ||
return "", chat_history | ||
|
||
msg.submit( | ||
response, | ||
[msg, chatbot, _max_gen_len, _top_p, _temperature], | ||
[msg, chatbot], | ||
) | ||
|
||
server.launch(server_name="0.0.0.0", server_port=17860, share=True) |
Oops, something went wrong.