Skip to content

Commit

Permalink
feat(xtask): 支持对话过程中修改参数
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
YdrMaster committed Mar 18, 2024
1 parent ff1ae69 commit 1784077
Showing 11 changed files with 326 additions and 265 deletions.
35 changes: 21 additions & 14 deletions service/src/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use crate::Command;
use crate::{session, Command};
use common::utok;
use std::{collections::HashMap, path::Path, time::Instant};
use std::{
collections::HashMap,
path::Path,
sync::{Arc, Mutex},
time::Instant,
};
use transformer_cpu::{LayerCache, Memory, Request, SampleArgs, Transformer};

pub struct CpuTask {
transformer: Transformer,
sessions: HashMap<usize, SessionContext>,
sample: SampleArgs,
sample: Arc<Mutex<SampleArgs>>,
}

impl CpuTask {
pub fn new(model_dir: impl AsRef<Path>, sample: SampleArgs) -> Self {
pub fn new(model_dir: impl AsRef<Path>, sample: Arc<Mutex<SampleArgs>>) -> Self {
let time = Instant::now();
let model = Box::new(Memory::load_safetensors_from_dir(model_dir).unwrap());
info!("load model ... {:?}", time.elapsed());
@@ -43,18 +48,20 @@ impl CpuTask {
let eos = self.transformer.eos_token_id();

let time = Instant::now();
let mut token = self
.transformer
.decode(vec![ctx.request(&prompt, max_seq_len)], &self.sample)[0]
.1;
let mut token = self.transformer.decode(
vec![ctx.request(&prompt, max_seq_len)],
&*self.sample.lock().unwrap(),
)[0]
.1;
info!("prefill transformer ... {:?}", time.elapsed());

while token != eos {
responsing.send(token).unwrap();
token = self
.transformer
.decode(vec![ctx.request(&[token], max_seq_len)], &self.sample)[0]
.1;
token = self.transformer.decode(
vec![ctx.request(&[token], max_seq_len)],
&*self.sample.lock().unwrap(),
)[0]
.1;
}
}
Command::Drop { id } => {
@@ -64,12 +71,12 @@ impl CpuTask {
}
}

struct SessionContext(super::SessionContext<LayerCache>);
struct SessionContext(session::SessionContext<LayerCache>);

impl SessionContext {
#[inline]
fn new(transformer: &Transformer, id: usize) -> Self {
Self(super::SessionContext::new(transformer.new_cache(), id))
Self(session::SessionContext::new(transformer.new_cache(), id))
}

#[inline]
55 changes: 14 additions & 41 deletions service/src/lib.rs
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ use std::{
path::Path,
sync::{
mpsc::{channel, Sender},
Arc,
Arc, Mutex,
},
thread::{self, JoinHandle},
};
@@ -26,6 +26,7 @@ extern crate log;

pub struct Service {
session_component: Arc<SessionComponent>,
sample: Arc<Mutex<SampleArgs>>,
_manager: JoinHandle<()>,
}

@@ -39,6 +40,7 @@ pub enum Device {
impl Service {
pub fn load_model(path: impl AsRef<Path>, sample: SampleArgs, device: Device) -> Self {
let model_dir = path.as_ref().to_owned();
let sample = Arc::new(Mutex::new(sample));
let (sender, receiver) = channel();
Service {
session_component: Arc::new(SessionComponent {
@@ -47,6 +49,7 @@ impl Service {
tokenizer: tokenizer(&model_dir),
sender,
}),
sample: sample.clone(),
_manager: thread::spawn(move || match device {
Device::Cpu => {
let mut task = CpuTask::new(model_dir, sample);
@@ -75,6 +78,16 @@ impl Service {
pub fn launch(&self) -> Session {
self.session_component.clone().into()
}

#[inline]
pub fn sample_args(&self) -> SampleArgs {
self.sample.lock().unwrap().clone()
}

#[inline]
pub fn set_sample_args(&self, sample: SampleArgs) {
*self.sample.lock().unwrap() = sample;
}
}

enum Command {
@@ -127,43 +140,3 @@ fn tokenizer(model_dir: impl AsRef<Path>) -> Box<dyn Tokenizer + Send + Sync> {
}
panic!("Tokenizer file not found");
}

struct SessionContext<Cache> {
id: usize,
tokens: Vec<utok>,
cache: Vec<Cache>,
}

impl<Cache> SessionContext<Cache> {
#[inline]
fn new(cache: Vec<Cache>, id: usize) -> Self {
Self {
id,
tokens: Vec::new(),
cache,
}
}

#[inline]
fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> usize {
if self.tokens.len() + tokens.len() > max_seq_len {
let pos = self.tokens.len().min(16);
if tokens.len() > max_seq_len / 2 {
let tokens = &tokens[tokens.len() - max_seq_len / 2..];
self.tokens.truncate(pos);
self.tokens.extend_from_slice(tokens);
} else {
let tail_len = (self.tokens.len() - pos).min(64);
let tail = self.tokens.len() - tail_len;
self.tokens.copy_within(tail.., pos);
self.tokens.truncate(pos + tail_len);
self.tokens.extend_from_slice(tokens);
}
pos
} else {
let pos = self.tokens.len();
self.tokens.extend_from_slice(tokens);
pos
}
}
}
19 changes: 12 additions & 7 deletions service/src/nvidia.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use crate::Command;
use crate::{session, Command};
use common::utok;
use std::{
collections::HashMap, fs::File, io::Read, path::Path, sync::mpsc::Receiver, time::Instant,
collections::HashMap,
fs::File,
io::Read,
path::Path,
sync::{mpsc::Receiver, Arc, Mutex},
time::Instant,
};
use transformer_cpu::{Llama2, Memory, SampleArgs};
use transformer_nvidia::{
@@ -11,7 +16,7 @@ use transformer_nvidia::{

pub fn task(
model_dir: impl AsRef<Path>,
sample: SampleArgs,
sample: Arc<Mutex<SampleArgs>>,
receiver: Receiver<Command>,
ctx: &ContextGuard,
) {
@@ -54,7 +59,7 @@ pub fn task(
let time = Instant::now();
let mut token = transformer.decode(
vec![ctx.request(&prompt, max_seq_len)],
&sample,
&*sample.lock().unwrap(),
&compute,
&transfer,
)[0]
@@ -65,7 +70,7 @@ pub fn task(
responsing.send(token).unwrap();
token = transformer.decode(
vec![ctx.request(&[token], max_seq_len)],
&sample,
&*sample.lock().unwrap(),
&compute,
&transfer,
)[0]
@@ -79,12 +84,12 @@ pub fn task(
}
}

struct SessionContext<'a>(super::SessionContext<LayerCache<'a>>);
struct SessionContext<'a>(session::SessionContext<LayerCache<'a>>);

impl<'a> SessionContext<'a> {
#[inline]
fn new(transformer: &Transformer, id: usize, stream: &'a Stream) -> Self {
Self(super::SessionContext::new(
Self(session::SessionContext::new(
transformer.new_cache(stream),
id,
))
41 changes: 41 additions & 0 deletions service/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{template::Template, Command};
use common::utok;
use std::sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
mpsc::{channel, Sender},
@@ -72,3 +73,43 @@ pub(crate) struct SessionComponent {
pub tokenizer: Box<dyn Tokenizer + Send + Sync>,
pub sender: Sender<Command>,
}

pub(crate) struct SessionContext<Cache> {
pub id: usize,
pub tokens: Vec<utok>,
pub cache: Vec<Cache>,
}

impl<Cache> SessionContext<Cache> {
#[inline]
pub fn new(cache: Vec<Cache>, id: usize) -> Self {
Self {
id,
tokens: Vec::new(),
cache,
}
}

#[inline]
pub fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> usize {
if self.tokens.len() + tokens.len() > max_seq_len {
let pos = self.tokens.len().min(16);
if tokens.len() > max_seq_len / 2 {
let tokens = &tokens[tokens.len() - max_seq_len / 2..];
self.tokens.truncate(pos);
self.tokens.extend_from_slice(tokens);
} else {
let tail_len = (self.tokens.len() - pos).min(64);
let tail = self.tokens.len() - tail_len;
self.tokens.copy_within(tail.., pos);
self.tokens.truncate(pos + tail_len);
self.tokens.extend_from_slice(tokens);
}
pos
} else {
let pos = self.tokens.len();
self.tokens.extend_from_slice(tokens);
pos
}
}
}
16 changes: 1 addition & 15 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@ use tensor::{reslice, slice, udim, DataType, Tensor};

pub type Request<'a, Id> = transformer::Request<'a, Id, Storage>;
pub type LayerCache = transformer::LayerCache<Storage>;
use transformer::{argmax, random};
pub use transformer::{save, Llama2, Memory, SampleArgs};

pub struct Transformer(Box<dyn Llama2>);
@@ -236,20 +235,7 @@ impl Transformer {
requests
.into_iter()
.enumerate()
.map(|(i, r)| {
let logits = &kernel::slice!(logits; voc; [i]);
(
r.id,
match sample {
SampleArgs::Top => argmax(logits),
SampleArgs::Random {
temperature,
top_k,
top_p,
} => random(logits, *temperature, *top_k, *top_p),
},
)
})
.map(|(i, r)| (r.id, sample.random(&kernel::slice!(logits; voc; [i]))))
.collect()
}
}
12 changes: 2 additions & 10 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ use kernel::{gather, mat_mul, FusedSoftmax, Reform, RmsNormalization, RotaryEmbe
use parameters::{LayersParameters, ModelParameters};
use storage::Storage;
use tensor::{reslice, slice, udim, DataType, Tensor};
use transformer::{argmax, random, SampleArgs};
use transformer::SampleArgs;

pub type Request<'a, 'b, Id> = transformer::Request<'a, Id, Storage<'b>>;
pub type LayerCache<'a> = transformer::LayerCache<Storage<'a>>;
@@ -295,17 +295,9 @@ impl<'ctx> Transformer<'ctx> {
.into_iter()
.enumerate()
.map(|(i, r)| {
let logits = &logits[i * voc as usize..][..voc as usize];
(
r.id,
match sample {
SampleArgs::Top => argmax(logits),
SampleArgs::Random {
temperature,
top_k,
top_p,
} => random(logits, *temperature, *top_k, *top_p),
},
sample.random(&logits[i * voc as usize..][..voc as usize]),
)
})
.collect()
2 changes: 1 addition & 1 deletion transformer/src/lib.rs
Original file line number Diff line number Diff line change
@@ -12,4 +12,4 @@ pub use cache::LayerCache;
pub use host_memory::HostMemory;
pub use parameters::{save, Llama2, Memory, SafeTensorError};
pub use request::Request;
pub use sample::{argmax, random, SampleArgs};
pub use sample::SampleArgs;
166 changes: 85 additions & 81 deletions transformer/src/sample.rs
Original file line number Diff line number Diff line change
@@ -4,95 +4,99 @@ use common::utok;
use half::f16;

#[derive(Clone, PartialEq, Debug)]
pub enum SampleArgs {
Top,
Random {
temperature: f32,
top_k: usize,
top_p: f32,
},
pub struct SampleArgs {
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
}

pub fn argmax<T: PartialOrd>(logits: &[T]) -> utok {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as _
}

pub fn random(logits: &[f16], temperature: f32, top_k: usize, top_p: f32) -> utok {
#[derive(Clone, Copy, PartialEq, Debug)]
struct Probability {
val: f32,
tok: utok,
impl SampleArgs {
#[inline]
fn is_argmax(&self) -> bool {
self.temperature <= 0. || self.top_k < 2 || self.top_p <= 0.
}
impl Eq for Probability {}
impl PartialOrd for Probability {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))

pub fn random(&self, logits: &[f16]) -> utok {
if self.is_argmax() {
return logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as _;
}
}
impl Ord for Probability {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
match self.val.partial_cmp(&other.val).unwrap() {
Ordering::Equal => self.tok.cmp(&other.tok),
ord => ord.reverse(),
}

#[derive(Clone, Copy, PartialEq, Debug)]
struct Probability {
val: f32,
tok: utok,
}
}
// top-k & max
let (logits, max) = {
let mut buf = BinaryHeap::with_capacity(top_k + 1);
let mut max = f32::NEG_INFINITY;
for (i, p) in logits.iter().enumerate() {
let val = p.to_f32();
max = max.max(val);
buf.push(Probability { val, tok: i as _ });
if buf.len() > top_k {
buf.pop();
impl Eq for Probability {}
impl PartialOrd for Probability {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
(buf.into_vec(), max)
};
// temperature & sum
let (logits, sum) = {
let mut logits = logits;
let mut sum = 0.;
for pi in logits.iter_mut() {
pi.val = ((pi.val - max) / temperature).exp();
sum += pi.val;
impl Ord for Probability {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
match self.val.partial_cmp(&other.val).unwrap() {
Ordering::Equal => self.tok.cmp(&other.tok),
ord => ord.reverse(),
}
}
}
(logits, sum)
};
// top p
let logits = if (0. ..1.).contains(&top_p) {
let i = logits
.iter()
.scan(top_p * sum, |top_p, pi| {
if *top_p > 0. {
*top_p -= pi.val;
Some(())
} else {
None
// top-k & max
let (logits, max) = {
let mut buf = BinaryHeap::with_capacity(self.top_k + 1);
let mut max = f32::NEG_INFINITY;
for (i, p) in logits.iter().enumerate() {
let val = p.to_f32();
max = max.max(val);
buf.push(Probability { val, tok: i as _ });
if buf.len() > self.top_k {
buf.pop();
}
}
(buf.into_vec(), max)
};
// temperature & sum
let (logits, sum) = {
let mut logits = logits;
let mut sum = 0.;
for pi in logits.iter_mut() {
pi.val = ((pi.val - max) / self.temperature).exp();
sum += pi.val;
}
(logits, sum)
};
// top p
let logits = if self.top_p < 1. {
let i = logits
.iter()
.scan(self.top_p * sum, |top_p, pi| {
if *top_p > 0. {
*top_p -= pi.val;
Some(())
} else {
None
}
})
.count();
&logits[..i]
} else {
&logits[..]
};
// random
let mut rand = rand::random::<f32>() * sum;
logits
.iter()
.find(|pi| {
rand -= pi.val;
rand <= 0.
})
.count();
&logits[..i]
} else {
&logits[..]
};
// random
let mut rand = rand::random::<f32>() * sum;
logits
.iter()
.find(|pi| {
rand -= pi.val;
rand <= 0.
})
.unwrap_or(logits.last().unwrap())
.tok
.unwrap_or(logits.last().unwrap())
.tok
}
}
211 changes: 135 additions & 76 deletions xtask/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,112 +1,171 @@
use crate::InferenceArgs;
use ::service::{Service, Session};
use colored::Colorize;
use service::{Service, Session};
use std::{collections::HashMap, io::Write};
use transformer::SampleArgs;

#[derive(Args, Default)]
pub(crate) struct ChatArgs {
#[clap(flatten)]
inference: InferenceArgs,
}

impl ChatArgs {
pub fn invoke(self) {
let service: Service = self.inference.into();
let mut session = service.launch();
let mut sessions = HashMap::new();
impl InferenceArgs {
pub fn chat(self) {
let mut chating = Chating::from(self);

println!(
"\
###########################################
# 欢迎使用九源推理框架-大模型单机对话demo #
###########################################
PID = {}
",
std::process::id()
###########################################"
);
println!("{}", HELP_MSG);
chating.print_args();
println!();
Chating::print_help();
println!("=====================================");

loop {
println!("{}", format!("会话 {}:", session.id()).yellow());
chating.print_session();
let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.expect("Unable to read line.");

// 以 / 开头则为用户指令
if input.trim_start().starts_with('/') {
execute_command(&input, &mut session, &mut sessions, &service);
chating.execute_command(&input);
} else {
infer(&input, &mut session);
chating.infer(&input);
}
}
}
}

const HELP_MSG: &str = "\
/create 新建会话session
/switch [0-9+] 切换至指定会话
/drop [0-9+] 丢弃指定会话
/help 打印帮助信息
使用 /exit 或 Ctrl + C 结束程序";
struct Chating {
service: Service,
sample: SampleArgs,
session: Session,
sessions: HashMap<usize, Session>,
}

fn execute_command(
command: &str,
session: &mut Session,
sessions: &mut HashMap<usize, Session>,
service: &Service,
) {
match command.split_whitespace().collect::<Vec<_>>().as_slice() {
["/create"] => {
let old = std::mem::replace(session, service.launch());
sessions.insert(old.id(), old);
impl From<InferenceArgs> for Chating {
fn from(args: InferenceArgs) -> Self {
let service: Service = args.into();
let session = service.launch();
let sample = service.sample_args();
Self {
service,
sample,
session,
sessions: HashMap::new(),
}
["/switch", n] => match n.parse() {
Ok(target_id) => {
if target_id == session.id() {
println!("Already in session {}", target_id);
} else if let Some(target) = sessions.remove(&target_id) {
let old = std::mem::replace(session, target);
sessions.insert(old.id(), old);
} else {
println!("Invalid session ID.");
}
}
}

impl Chating {
fn print_help() {
println!(
"\
/create 新建会话session
/switch [0-9+] 切换至指定会话
/drop [0-9+] 丢弃指定会话
/args 打印当前参数
/args key value 设置指定参数
/help 打印帮助信息
使用 /exit 或 Ctrl + C 结束程序"
);
}

fn print_args(&self) {
println!(
"PID = {}, temperature = {}, top-k = {}, top-p = {}",
std::process::id(),
self.sample.temperature,
self.sample.top_k,
self.sample.top_p,
);
}

fn print_session(&mut self) {
println!("{}", format!("会话 {}:", self.session.id()).yellow());
}

fn execute_command(&mut self, command: &str) {
match command.split_whitespace().collect::<Vec<_>>().as_slice() {
["/create"] => {
let old = std::mem::replace(&mut self.session, self.service.launch());
self.sessions.insert(old.id(), old);
}
Err(_) => println!("Invalid drop command"),
},
["/drop", n] => match n.parse() {
Ok(target_id) => {
if target_id == session.id() {
if let Some((&id, _)) = sessions.iter().next() {
let _ = std::mem::replace(session, sessions.remove(&id).unwrap());
["/switch", n] => match n.parse() {
Ok(target_id) => {
if target_id == self.session.id() {
println!("Already in session {}", target_id);
} else if let Some(target) = self.sessions.remove(&target_id) {
let old = std::mem::replace(&mut self.session, target);
self.sessions.insert(old.id(), old);
} else {
println!("Invalid session ID.");
}
}
Err(_) => println!("Invalid drop command"),
},
["/drop", n] => match n.parse() {
Ok(target_id) => {
if target_id == self.session.id() {
if let Some((&id, _)) = self.sessions.iter().next() {
let _ = std::mem::replace(
&mut self.session,
self.sessions.remove(&id).unwrap(),
);
} else {
self.session = self.service.launch();
}
println!("Session {target_id} is dropped.")
} else if self.sessions.remove(&target_id).is_some() {
println!("Session {target_id} is dropped.");
} else {
*session = service.launch();
println!("Invalid session ID.");
}
println!("Session {target_id} is dropped.")
} else if sessions.remove(&target_id).is_some() {
println!("Session {target_id} is dropped.");
}
Err(_) => println!("Invalid drop command"),
},
["/args"] => self.print_args(),
["/args", "temperature", t] => {
if let Ok(t) = t.parse() {
self.sample.temperature = t;
self.service.set_sample_args(self.sample.clone());
} else {
println!("Invalid temperature");
}
}
["/args", "top-k", k] => {
if let Ok(k) = k.parse() {
self.sample.top_k = k;
self.service.set_sample_args(self.sample.clone());
} else {
println!("Invalid session ID.");
println!("Invalid top-k");
}
}
Err(_) => println!("Invalid drop command"),
},
["/help"] => println!("{}", HELP_MSG),
["/exit"] => std::process::exit(0),
_ => println!("Unknown Command"),
["/args", "top-p", p] => {
if let Ok(p) = p.parse() {
self.sample.top_p = p;
self.service.set_sample_args(self.sample.clone());
} else {
println!("Invalid top-p");
}
}
["/help"] => Self::print_help(),
["/exit"] => std::process::exit(0),
_ => println!("Unknown Command"),
}
println!("=====================================");
}
println!("=====================================");
}

fn infer(text: &str, session: &mut Session) {
println!("{}", "AI:".green());
session.chat(text, |s| match s {
"\\n" => println!(),
_ => {
print!("{s}");
std::io::stdout().flush().unwrap();
}
});
println!();
fn infer(&mut self, text: &str) {
println!("{}", "AI:".green());
self.session.chat(text, |s| match s {
"\\n" => println!(),
_ => {
print!("{s}");
std::io::stdout().flush().unwrap();
}
});
println!();
}
}
14 changes: 7 additions & 7 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
@@ -5,18 +5,18 @@ use std::io::Write;
#[derive(Args, Default)]
pub(crate) struct GenerateArgs {
#[clap(flatten)]
inference: InferenceArgs,
pub inference: InferenceArgs,
/// Prompt.
#[clap(short, long)]
prompt: String,
pub prompt: String,
}

impl GenerateArgs {
pub fn invoke(self) {
let service: Service = self.inference.into();
impl InferenceArgs {
pub fn generate(self, prompt: &str) {
let service: Service = self.into();

print!("{}", self.prompt);
service.launch().generate(&self.prompt, |piece| {
print!("{prompt}");
service.launch().generate(prompt, |piece| {
print!("{piece}");
std::io::stdout().flush().unwrap();
});
20 changes: 7 additions & 13 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
@@ -13,8 +13,8 @@ fn main() {
use Commands::*;
match Cli::parse().command {
Cast(cast) => cast.invode(),
Generate(generate) => generate.invoke(),
Chat(chat) => chat.invoke(),
Generate(args) => args.inference.generate(&args.prompt),
Chat(chat) => chat.chat(),
}
}

@@ -33,7 +33,7 @@ enum Commands {
/// Generate following text
Generate(generate::GenerateArgs),
/// Start service
Chat(chat::ChatArgs),
Chat(InferenceArgs),
}

#[derive(Args, Default)]
@@ -85,19 +85,13 @@ impl From<InferenceArgs> for Service {
.unwrap_or(LevelFilter::Warn);
SimpleLogger::new().with_level(log).init().unwrap();

let sample = if temperature <= 0. || top_k < 2 || top_p <= 0. {
SampleArgs::Top
} else {
SampleArgs::Random {
Service::load_model(
model,
SampleArgs {
temperature,
top_k,
top_p,
}
};

Service::load_model(
model,
sample,
},
if nvidia {
Device::NvidiaGpu(0)
} else {

0 comments on commit 1784077

Please sign in to comment.