-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: YdrMaster <ydrml@hotmail.com>
Showing
11 changed files
with
326 additions
and
265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,112 +1,171 @@ | ||
use crate::InferenceArgs; | ||
use ::service::{Service, Session}; | ||
use colored::Colorize; | ||
use service::{Service, Session}; | ||
use std::{collections::HashMap, io::Write}; | ||
use transformer::SampleArgs; | ||
|
||
#[derive(Args, Default)] | ||
pub(crate) struct ChatArgs { | ||
#[clap(flatten)] | ||
inference: InferenceArgs, | ||
} | ||
|
||
impl ChatArgs { | ||
pub fn invoke(self) { | ||
let service: Service = self.inference.into(); | ||
let mut session = service.launch(); | ||
let mut sessions = HashMap::new(); | ||
impl InferenceArgs { | ||
pub fn chat(self) { | ||
let mut chating = Chating::from(self); | ||
|
||
println!( | ||
"\ | ||
########################################### | ||
# 欢迎使用九源推理框架-大模型单机对话demo # | ||
########################################### | ||
PID = {} | ||
", | ||
std::process::id() | ||
###########################################" | ||
); | ||
println!("{}", HELP_MSG); | ||
chating.print_args(); | ||
println!(); | ||
Chating::print_help(); | ||
println!("====================================="); | ||
|
||
loop { | ||
println!("{}", format!("会话 {}:", session.id()).yellow()); | ||
chating.print_session(); | ||
let mut input = String::new(); | ||
std::io::stdin() | ||
.read_line(&mut input) | ||
.expect("Unable to read line."); | ||
|
||
// 以 / 开头则为用户指令 | ||
if input.trim_start().starts_with('/') { | ||
execute_command(&input, &mut session, &mut sessions, &service); | ||
chating.execute_command(&input); | ||
} else { | ||
infer(&input, &mut session); | ||
chating.infer(&input); | ||
} | ||
} | ||
} | ||
} | ||
|
||
const HELP_MSG: &str = "\ | ||
/create 新建会话session | ||
/switch [0-9+] 切换至指定会话 | ||
/drop [0-9+] 丢弃指定会话 | ||
/help 打印帮助信息 | ||
使用 /exit 或 Ctrl + C 结束程序"; | ||
struct Chating { | ||
service: Service, | ||
sample: SampleArgs, | ||
session: Session, | ||
sessions: HashMap<usize, Session>, | ||
} | ||
|
||
fn execute_command( | ||
command: &str, | ||
session: &mut Session, | ||
sessions: &mut HashMap<usize, Session>, | ||
service: &Service, | ||
) { | ||
match command.split_whitespace().collect::<Vec<_>>().as_slice() { | ||
["/create"] => { | ||
let old = std::mem::replace(session, service.launch()); | ||
sessions.insert(old.id(), old); | ||
impl From<InferenceArgs> for Chating { | ||
fn from(args: InferenceArgs) -> Self { | ||
let service: Service = args.into(); | ||
let session = service.launch(); | ||
let sample = service.sample_args(); | ||
Self { | ||
service, | ||
sample, | ||
session, | ||
sessions: HashMap::new(), | ||
} | ||
["/switch", n] => match n.parse() { | ||
Ok(target_id) => { | ||
if target_id == session.id() { | ||
println!("Already in session {}", target_id); | ||
} else if let Some(target) = sessions.remove(&target_id) { | ||
let old = std::mem::replace(session, target); | ||
sessions.insert(old.id(), old); | ||
} else { | ||
println!("Invalid session ID."); | ||
} | ||
} | ||
} | ||
|
||
impl Chating { | ||
fn print_help() { | ||
println!( | ||
"\ | ||
/create 新建会话session | ||
/switch [0-9+] 切换至指定会话 | ||
/drop [0-9+] 丢弃指定会话 | ||
/args 打印当前参数 | ||
/args key value 设置指定参数 | ||
/help 打印帮助信息 | ||
使用 /exit 或 Ctrl + C 结束程序" | ||
); | ||
} | ||
|
||
fn print_args(&self) { | ||
println!( | ||
"PID = {}, temperature = {}, top-k = {}, top-p = {}", | ||
std::process::id(), | ||
self.sample.temperature, | ||
self.sample.top_k, | ||
self.sample.top_p, | ||
); | ||
} | ||
|
||
fn print_session(&mut self) { | ||
println!("{}", format!("会话 {}:", self.session.id()).yellow()); | ||
} | ||
|
||
fn execute_command(&mut self, command: &str) { | ||
match command.split_whitespace().collect::<Vec<_>>().as_slice() { | ||
["/create"] => { | ||
let old = std::mem::replace(&mut self.session, self.service.launch()); | ||
self.sessions.insert(old.id(), old); | ||
} | ||
Err(_) => println!("Invalid drop command"), | ||
}, | ||
["/drop", n] => match n.parse() { | ||
Ok(target_id) => { | ||
if target_id == session.id() { | ||
if let Some((&id, _)) = sessions.iter().next() { | ||
let _ = std::mem::replace(session, sessions.remove(&id).unwrap()); | ||
["/switch", n] => match n.parse() { | ||
Ok(target_id) => { | ||
if target_id == self.session.id() { | ||
println!("Already in session {}", target_id); | ||
} else if let Some(target) = self.sessions.remove(&target_id) { | ||
let old = std::mem::replace(&mut self.session, target); | ||
self.sessions.insert(old.id(), old); | ||
} else { | ||
println!("Invalid session ID."); | ||
} | ||
} | ||
Err(_) => println!("Invalid drop command"), | ||
}, | ||
["/drop", n] => match n.parse() { | ||
Ok(target_id) => { | ||
if target_id == self.session.id() { | ||
if let Some((&id, _)) = self.sessions.iter().next() { | ||
let _ = std::mem::replace( | ||
&mut self.session, | ||
self.sessions.remove(&id).unwrap(), | ||
); | ||
} else { | ||
self.session = self.service.launch(); | ||
} | ||
println!("Session {target_id} is dropped.") | ||
} else if self.sessions.remove(&target_id).is_some() { | ||
println!("Session {target_id} is dropped."); | ||
} else { | ||
*session = service.launch(); | ||
println!("Invalid session ID."); | ||
} | ||
println!("Session {target_id} is dropped.") | ||
} else if sessions.remove(&target_id).is_some() { | ||
println!("Session {target_id} is dropped."); | ||
} | ||
Err(_) => println!("Invalid drop command"), | ||
}, | ||
["/args"] => self.print_args(), | ||
["/args", "temperature", t] => { | ||
if let Ok(t) = t.parse() { | ||
self.sample.temperature = t; | ||
self.service.set_sample_args(self.sample.clone()); | ||
} else { | ||
println!("Invalid temperature"); | ||
} | ||
} | ||
["/args", "top-k", k] => { | ||
if let Ok(k) = k.parse() { | ||
self.sample.top_k = k; | ||
self.service.set_sample_args(self.sample.clone()); | ||
} else { | ||
println!("Invalid session ID."); | ||
println!("Invalid top-k"); | ||
} | ||
} | ||
Err(_) => println!("Invalid drop command"), | ||
}, | ||
["/help"] => println!("{}", HELP_MSG), | ||
["/exit"] => std::process::exit(0), | ||
_ => println!("Unknown Command"), | ||
["/args", "top-p", p] => { | ||
if let Ok(p) = p.parse() { | ||
self.sample.top_p = p; | ||
self.service.set_sample_args(self.sample.clone()); | ||
} else { | ||
println!("Invalid top-p"); | ||
} | ||
} | ||
["/help"] => Self::print_help(), | ||
["/exit"] => std::process::exit(0), | ||
_ => println!("Unknown Command"), | ||
} | ||
println!("====================================="); | ||
} | ||
println!("====================================="); | ||
} | ||
|
||
fn infer(text: &str, session: &mut Session) { | ||
println!("{}", "AI:".green()); | ||
session.chat(text, |s| match s { | ||
"\\n" => println!(), | ||
_ => { | ||
print!("{s}"); | ||
std::io::stdout().flush().unwrap(); | ||
} | ||
}); | ||
println!(); | ||
fn infer(&mut self, text: &str) { | ||
println!("{}", "AI:".green()); | ||
self.session.chat(text, |s| match s { | ||
"\\n" => println!(), | ||
_ => { | ||
print!("{s}"); | ||
std::io::stdout().flush().unwrap(); | ||
} | ||
}); | ||
println!(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters