Skip to content

Commit

Permalink
Merge pull request #1 from InfiniTensor/demo
Browse files Browse the repository at this point in the history
feat (demo): 单机多轮会话demo
  • Loading branch information
YdrMaster authored Mar 13, 2024
2 parents 530545a + bcba7c0 commit ffc603d
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 42 deletions.
4 changes: 2 additions & 2 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[alias]
xtask = "run --package xtask --release --"
debug = "run --package xtask --"
cast = "xtask cast"
generate = "xtask generate"
service = "xtask service"
chat = "xtask chat"
cast = "xtask cast"
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@

> 推荐测试模型:[TinyLlama-1.1B-Chat](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)
### 启动对话服务

```plaintext
cargo chat --model <model>
```

必要参数:

- `model`: 存放模型文件的目录,至少包含以下 3 个文件:
- `config.json`: 模型配置文件;
- `model.safetesnors`: 模型参数文件;
- `tokenizer.model`/`vocab.txt`: 分词器词表;

其他参数参见 `cargo chat --help`

### 启动文本生成

```plaintext
Expand All @@ -17,8 +32,7 @@ cargo generate --model <model> --prompt <prompt>
- `model`: 存放模型文件的目录,至少包含以下 3 个文件:
- `config.json`: 模型配置文件;
- `model.safetesnors`: 模型参数文件;
- `tokenizer.model`: 分词器词表;
> 目前仅支持 32000 词 BPE Tonkenizer。
- `tokenizer.model`/`vocab.txt`: 分词器词表;
- `prompt`: 生成文本的起始文本。

其他参数参见 `cargo generate --help`
Expand All @@ -36,7 +50,7 @@ cargo cast --model <model> --dt <date_type>
- `model`: 存放模型文件的目录,至少包含以下 3 个文件:
- `config.json`: 模型配置文件;
- `model.safetesnors`: 模型参数文件;
- `tokenizer.model`: 分词器词表;
- `tokenizer.model`/`vocab.txt`: 分词器词表;

生成的模型会存放在 `model` 同级目录下,并添加 `_<date_type>` 后缀。

Expand Down
9 changes: 7 additions & 2 deletions service/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ impl From<Arc<SessionComponent>> for Session {

impl Session {
#[inline]
pub fn chat(&self, prompt: &str, f: impl FnMut(&str)) {
pub const fn id(&self) -> usize {
self.id
}

#[inline]
pub fn chat(&mut self, prompt: &str, f: impl FnMut(&str)) {
self.send(&self.component.template.apply_chat(prompt), f)
}

#[inline]
pub fn generate(&self, prompt: &str, f: impl FnMut(&str)) {
pub fn generate(&mut self, prompt: &str, f: impl FnMut(&str)) {
self.send(&self.component.template.normalize(prompt), f)
}

Expand Down
1 change: 1 addition & 0 deletions xtask/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ transformer = { path = "../transformer" }
service = { path = "../service" }
log.workspace = true
simple_logger = "4.3"
colored = "2.1"
clap = { version = "4.5", features = ["derive"] }
119 changes: 119 additions & 0 deletions xtask/src/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use crate::{init_logger, service};
use ::service::{Service, Session};
use colored::Colorize;
use std::{collections::HashMap, io::Write};

#[derive(Args, Default)]
pub(crate) struct ChatArgs {
/// Model directory.
#[clap(short, long)]
model: String,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,

/// Use Nvidia GPU.
#[clap(long)]
nvidia: bool,
}

impl ChatArgs {
pub fn invoke(self) {
init_logger(self.log);
let service = service(&self.model, self.nvidia);
let mut session = service.launch();
let mut sessions = HashMap::new();

println!("{}", WELCOME_MSG);
println!("{}", HELP_MSG);
println!("=====================================");
loop {
println!("{}", format!("会话 {}:", session.id()).yellow());
let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.expect("Unable to read line.");

// 以 / 开头则为用户指令
if input.trim_start().starts_with('/') {
execute_command(&input, &mut session, &mut sessions, &service);
} else {
infer(&input, &mut session);
}
}
}
}

const WELCOME_MSG: &str = r#"
###########################################
# 欢迎使用九源推理框架-大模型单机对话demo #
###########################################
"#;
const HELP_MSG: &str = r#"
/create 新建会话session
/switch [0-9+] 切换至指定会话
/drop [0-9+] 丢弃指定会话
/help 打印帮助信息
使用 /exit 或 Ctrl + C 结束程序
"#;

fn execute_command(
command: &str,
session: &mut Session,
sessions: &mut HashMap<usize, Session>,
service: &Service,
) {
match command.split_whitespace().collect::<Vec<_>>().as_slice() {
["/create"] => {
let old = std::mem::replace(session, service.launch());
sessions.insert(old.id(), old);
}
["/switch", n] => match n.parse() {
Ok(target_id) => {
if target_id == session.id() {
println!("Already in session {}", target_id);
} else if let Some(target) = sessions.remove(&target_id) {
let old = std::mem::replace(session, target);
sessions.insert(old.id(), old);
} else {
println!("Invalid session ID.");
}
}
Err(_) => println!("Invalid drop command"),
},
["/drop", n] => match n.parse() {
Ok(target_id) => {
if target_id == session.id() {
if let Some((&id, _)) = sessions.iter().next() {
let _ = std::mem::replace(session, sessions.remove(&id).unwrap());
} else {
*session = service.launch();
}
println!("Session {target_id} is dropped.")
} else if sessions.remove(&target_id).is_some() {
println!("Session {target_id} is dropped.");
} else {
println!("Invalid session ID.");
}
}
Err(_) => println!("Invalid drop command"),
},
["/help"] => println!("{}", HELP_MSG),
["/exit"] => std::process::exit(0),
_ => println!("Unknown Command"),
}
println!("=====================================");
}

fn infer(text: &str, session: &mut Session) {
println!("{}", "AI:".green());
session.chat(text, |s| match s {
"\\n" => println!(),
_ => {
print!("{s}");
std::io::stdout().flush().unwrap();
}
});
println!();
}
40 changes: 5 additions & 35 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use log::LevelFilter;
use service::{Device, Service};
use simple_logger::SimpleLogger;
use std::{io::Write, path::PathBuf};
use crate::{init_logger, service};
use std::io::Write;

#[derive(Args, Default)]
pub(crate) struct GenerateArgs {
Expand All @@ -11,12 +9,6 @@ pub(crate) struct GenerateArgs {
/// Prompt.
#[clap(short, long)]
prompt: String,
/// Tokenizer file.
#[clap(short, long)]
tokenizer: Option<String>,
/// Max steps.
#[clap(short, long)]
step: Option<usize>,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,
Expand All @@ -28,33 +20,11 @@ pub(crate) struct GenerateArgs {

impl GenerateArgs {
pub fn invoke(self) {
let log = self
.log
.as_ref()
.and_then(|log| match log.to_lowercase().as_str() {
"off" | "none" => Some(LevelFilter::Off),
"trace" => Some(LevelFilter::Trace),
"debug" => Some(LevelFilter::Debug),
"info" => Some(LevelFilter::Info),
"error" => Some(LevelFilter::Error),
_ => None,
})
.unwrap_or(LevelFilter::Warn);
SimpleLogger::new().with_level(log).init().unwrap();
init_logger(self.log);
let service = service(&self.model, self.nvidia);

let model_dir = PathBuf::from(self.model);
print!("{}", self.prompt);
let service = Service::load_model(
model_dir,
if self.nvidia {
Device::NvidiaGpu(0)
} else {
Device::Cpu
},
);

let session = service.launch();
session.generate(&self.prompt, |piece| {
service.launch().generate(&self.prompt, |piece| {
print!("{piece}");
std::io::stdout().flush().unwrap();
});
Expand Down
33 changes: 33 additions & 0 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod cast;
mod chat;
mod generate;

use ::service::{Device, Service};
use clap::Parser;

#[macro_use]
Expand All @@ -11,6 +13,7 @@ fn main() {
match Cli::parse().command {
Cast(cast) => cast.invode(),
Generate(generate) => generate.invoke(),
Chat(chat) => chat.invoke(),
}
}

Expand All @@ -28,4 +31,34 @@ enum Commands {
Cast(cast::CastArgs),
/// Generate following text
Generate(generate::GenerateArgs),
/// Start service
Chat(chat::ChatArgs),
}

fn init_logger(log: Option<String>) {
use log::LevelFilter;
use simple_logger::SimpleLogger;
let log = log
.as_ref()
.and_then(|log| match log.to_lowercase().as_str() {
"off" | "none" => Some(LevelFilter::Off),
"trace" => Some(LevelFilter::Trace),
"debug" => Some(LevelFilter::Debug),
"info" => Some(LevelFilter::Info),
"error" => Some(LevelFilter::Error),
_ => None,
})
.unwrap_or(LevelFilter::Warn);
SimpleLogger::new().with_level(log).init().unwrap();
}

fn service(model_dir: &str, nvidia: bool) -> Service {
Service::load_model(
model_dir,
if nvidia {
Device::NvidiaGpu(0)
} else {
Device::Cpu
},
)
}

0 comments on commit ffc603d

Please sign in to comment.