Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(demo): 整理应用程序并添加 drop 命令
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
YdrMaster committed Mar 13, 2024
1 parent a38d0b6 commit f80bfbf
Showing 2 changed files with 68 additions and 47 deletions.
106 changes: 61 additions & 45 deletions demo/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use colored::Colorize;
use service::{Device, Service};
use std::{env, io::Write};
use service::{Device, Service, Session};
use std::{collections::HashMap, env, io::Write};

const WELCOME_MSG: &str = r#"
###########################################
# 欢迎使用九源推理框架-大模型单机对话demo #
###########################################
"#;
const HELP_MSG: &str = r#"
/create 新建会话session
/switch [0-9+] 切换至指定会话
/help 打印帮助信息
/create 新建会话session
/switch [0-9+] 切换至指定会话
/drop [0-9+] 丢弃指定会话
/help 打印帮助信息
使用 /exit 或 Ctrl + C 结束程序
"#;
@@ -23,71 +24,86 @@ fn main() {
println!("{}", HELP_MSG);
println!("=====================================");

// 当前会话ID,初始是0
let mut session_id: usize = 0;
// 全部会话
let mut sessions: Vec<service::Session> = Vec::new();
// 启动会话0
sessions.push(infer_service.launch());
let mut sessions = HashMap::new();
let mut session = infer_service.launch();

loop {
println!("{}", format!("{}{}:", "会话", &session_id).yellow());
println!("{}", format!("会话 {}:", session.id()).yellow());
let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.expect("Unable to read line.");
let input = input.trim();

// 以 / 开头则为用户指令
if input.starts_with('/') {
execute_command(input, &mut session_id, &mut sessions, &infer_service);
if input.trim_start().starts_with('/') {
execute_command(&input, &mut session, &mut sessions, &infer_service);
} else {
infer(input, session_id, &mut sessions);
infer(&input, &mut session);
}
}
}

fn execute_command(
command: &str,
session_id: &mut usize,
sessions: &mut Vec<service::Session>,
session: &mut Session,
sessions: &mut HashMap<usize, Session>,
service: &Service,
) {
match command {
"/create" => {
sessions.push(service.launch());
*session_id = sessions.len() - 1;
match command
.trim()
.split_whitespace()
.collect::<Vec<_>>()
.as_slice()
{
["/create"] => {
let old = std::mem::replace(session, service.launch());
sessions.insert(old.id(), old);
}
"/help" => println!("{}", HELP_MSG),
cmd if cmd.starts_with("/switch") => {
let target_id: usize = cmd
.split_whitespace()
.nth(1)
.expect("Invalid switch command")
.parse()
.expect("Invalid switch command");
if target_id < sessions.len() {
*session_id = target_id;
} else {
println!("Invalid session ID.")
["/switch", n] => match n.parse() {
Ok(target_id) => {
if target_id == session.id() {
println!("Already in session {}", target_id);
} else if let Some(target) = sessions.remove(&target_id) {
let old = std::mem::replace(session, target);
sessions.insert(old.id(), old);
} else {
println!("Invalid session ID.");
}
}
}
"/exit" => std::process::exit(0),
Err(_) => println!("Invalid drop command"),
},
["/drop", n] => match n.parse() {
Ok(target_id) => {
if target_id == session.id() {
if let Some((&id, _)) = sessions.iter().next() {
let _ = std::mem::replace(session, sessions.remove(&id).unwrap());
} else {
*session = service.launch();
}
println!("Session {target_id} is dropped.")
} else if sessions.remove(&target_id).is_some() {
println!("Session {target_id} is dropped.");
} else {
println!("Invalid session ID.");
}
}
Err(_) => println!("Invalid drop command"),
},
["/help"] => println!("{}", HELP_MSG),
["/exit"] => std::process::exit(0),
_ => println!("Unknown Command"),
}
println!("=====================================");
}

fn infer(text: &str, session_id: usize, sessions: &mut Vec<service::Session>) {
fn infer(text: &str, session: &mut Session) {
println!("{}", "AI:".green());
sessions[session_id].chat(text, |s| {
match s {
"</s>" => println!(""),
"\\n" => println!(""),
_ => print!("{s}"),
session.chat(text, |s| match s {
"\\n" => println!(),
_ => {
print!("{s}");
std::io::stdout().flush().unwrap();
}

std::io::stdout().flush().unwrap();
});
println!("");
println!();
}
9 changes: 7 additions & 2 deletions service/src/session.rs
Original file line number Diff line number Diff line change
@@ -23,12 +23,17 @@ impl From<Arc<SessionComponent>> for Session {

impl Session {
#[inline]
pub fn chat(&self, prompt: &str, f: impl FnMut(&str)) {
pub const fn id(&self) -> usize {
self.id
}

#[inline]
pub fn chat(&mut self, prompt: &str, f: impl FnMut(&str)) {
self.send(&self.component.template.apply_chat(prompt), f)
}

#[inline]
pub fn generate(&self, prompt: &str, f: impl FnMut(&str)) {
pub fn generate(&mut self, prompt: &str, f: impl FnMut(&str)) {
self.send(&self.component.template.normalize(prompt), f)
}

0 comments on commit f80bfbf

Please sign in to comment.