diff --git a/Cargo.lock b/Cargo.lock index 5a35e9f3..e44053d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2157,16 +2157,11 @@ dependencies = [ "causal-lm", "colored", "common", - "distributed", - "half", "log", - "search-cuda-tools", "tensor", "tokenizer", "tokio", - "transformer", "transformer-cpu", - "transformer-nv", ] [[package]] @@ -2522,7 +2517,7 @@ dependencies = [ [[package]] name = "transformer-cpu" -version = "0.0.0" +version = "0.0.1" dependencies = [ "causal-lm", "common", @@ -2704,13 +2699,13 @@ name = "web-api" version = "0.0.1" dependencies = [ "actix-web", + "causal-lm", "futures", "log", "serde", "serde_json", "service", "tokio", - "transformer", ] [[package]] @@ -2746,9 +2741,9 @@ dependencies = [ [[package]] name = "winapi-util" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134306a13c5647ad6453e8deaec55d3a44d6021970129e6188735e74bf546697" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" dependencies = [ "windows-sys 0.52.0", ] @@ -2925,6 +2920,7 @@ dependencies = [ name = "xtask" version = "0.0.0" dependencies = [ + "causal-lm", "clap", "colored", "common", @@ -2934,6 +2930,7 @@ dependencies = [ "tensor", "tokio", "transformer", + "transformer-cpu", "web-api", ] @@ -2959,9 +2956,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "63381fa6624bf92130a6b87c0d07380116f80b565c42cf0d754136f0238359ef" [[package]] name = "zstd" diff --git a/nvidia/common/src/lib.rs b/nvidia/common/src/lib.rs index a4dd269d..169b59d7 100644 --- a/nvidia/common/src/lib.rs +++ b/nvidia/common/src/lib.rs @@ -247,3 +247,12 @@ where }) } } + +pub fn synchronize() { + cuda::init(); + for i in 0..cuda::Device::count() { + cuda::Device::new(i as _) + .retain_primary() + .apply(|ctx| ctx.synchronize()); + } +} diff --git a/service/Cargo.toml b/service/Cargo.toml index ee1cba37..68d710bd 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -11,20 +11,9 @@ common = { path = "../common" } tensor = { path = "../tensor" } tokenizer = { path = "../tokenizer" } causal-lm = { path = "../causal-lm" } -transformer = { path = "../transformer" } -transformer-cpu = { path = "../transformer-cpu" } -transformer-nv = { path = "../nvidia/transformer", optional = true } -distributed = { path = "../nvidia/distributed", optional = true } -half.workspace = true log.workspace = true tokio.workspace = true -[build-dependencies] -search-cuda-tools.workspace = true - [dev-dependencies] colored = "2.1" -tokio = { workspace = true, features = ["time"] } - -[features] -nvidia = ["transformer-nv", "distributed"] +transformer-cpu = { path = "../transformer-cpu" } diff --git a/service/build.rs b/service/build.rs deleted file mode 100644 index 68c62630..00000000 --- a/service/build.rs +++ /dev/null @@ -1,12 +0,0 @@ -fn main() { - use search_cuda_tools::*; - if !cfg!(feature = "nvidia") { - return; - } - if find_cuda_root().is_some() { - detect_cuda(); - } - if find_nccl_root().is_some() { - detect_nccl(); - } -} diff --git a/service/src/batcher.rs b/service/src/batcher.rs index 4c1558b1..de636294 100644 --- a/service/src/batcher.rs +++ b/service/src/batcher.rs @@ -1,18 +1,10 @@ -use crate::session::SessionContext; -use common::utok; -use std::sync::{ +use std::sync::{ atomic::{ AtomicBool, Ordering::{Acquire, Release}, }, Condvar, Mutex, }; -use tokio::sync::mpsc::UnboundedSender; - -pub struct Task { - pub ctx: SessionContext, - pub responsing: UnboundedSender, -} pub struct Batcher { queue: Mutex>, diff --git a/service/src/cpu.rs b/service/src/cpu.rs deleted file mode 100644 index ce9ef63e..00000000 --- a/service/src/cpu.rs +++ /dev/null @@ -1,15 +0,0 @@ -use std::{path::Path, time::Instant}; -use transformer::Memory; -use transformer_cpu::Transformer; - -pub fn transformer(model_dir: impl AsRef) -> Transformer { - let time = Instant::now(); - let model = Memory::load_safetensors(model_dir).unwrap(); - info!("load model ... {:?}", time.elapsed()); - - let time = Instant::now(); - let transformer = Transformer::new(model); - info!("build transformer ... {:?}", time.elapsed()); - - transformer -} diff --git a/service/src/dispatch.rs b/service/src/dispatch.rs deleted file mode 100644 index c67aab8b..00000000 --- a/service/src/dispatch.rs +++ /dev/null @@ -1,154 +0,0 @@ -use crate::{ - batcher::{Batcher, Task}, - session::{Command, SessionContext}, -}; -use std::{ - collections::{hash_map::Entry, HashMap, HashSet}, - sync::{Arc, Mutex}, -}; -use tokio::{ - sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - task::{spawn_blocking, JoinSet}, -}; -use transformer::{SampleArgs, Transformer}; - -pub fn run( - transformer: T, - sample: Arc>, - commands: UnboundedReceiver, -) -> JoinSet<()> -where - T: Transformer + Send + Sync + 'static, - T::Cache: Send + 'static, -{ - let mut dispatcher = Dispatcher::new(transformer, Batcher::new()); - let messages = dispatcher.manage(); - dispatcher.forward(messages.clone(), commands); - dispatcher.decode(sample, messages); - - dispatcher.set -} - -struct Dispatcher { - transformer: Arc, - batcher: Arc>>, - set: JoinSet<()>, -} - -enum Message { - Cmd(Command), - Ctx(SessionContext), -} - -impl Dispatcher -where - T: Transformer + Send + Sync + 'static, - T::Cache: Send + 'static, -{ - #[inline] - pub fn new(transformer: T, batcher: Batcher>) -> Self { - Dispatcher { - transformer: Arc::new(transformer), - batcher: Arc::new(batcher), - set: JoinSet::new(), - } - } - - pub fn forward( - &mut self, - messages: UnboundedSender>, - mut commands: UnboundedReceiver, - ) { - self.set.spawn(async move { - while let Some(msg) = commands.recv().await { - messages.send(Message::Cmd(msg)).unwrap(); - } - }); - } - - pub fn manage(&mut self) -> UnboundedSender> { - let (sender, mut receiver) = unbounded_channel(); - let max_seq_len = self.transformer.max_position_embeddings(); - let transformer = self.transformer.clone(); - let batcher = self.batcher.clone(); - self.set.spawn(async move { - let mut sessions = HashMap::new(); - let mut removing = HashSet::new(); - while let Some(msg) = receiver.recv().await { - match msg { - Message::Cmd(Command::Infer(id, infer)) => { - let mut ctx = match sessions.entry(id) { - Entry::Occupied(ctx) => ctx.remove(), - Entry::Vacant(_) => SessionContext::new(transformer.new_cache(), id), - }; - ctx.push(&infer.prompt, max_seq_len); - batcher.enq(Task { - ctx, - responsing: infer.responsing, - }); - } - Message::Cmd(Command::Fork(id, new_id)) => { - warn!("TODO: fork session {id} to {new_id}"); - } - Message::Cmd(Command::Drop(id)) => { - if sessions.remove(&id).is_none() { - removing.insert(id); - } - } - Message::Ctx(ctx) => { - if !removing.remove(&ctx.id) { - sessions.insert(ctx.id, ctx); - } - } - } - } - }); - sender - } - - pub fn decode( - &mut self, - sample: Arc>, - sender: UnboundedSender>, - ) { - let max_seq_len = self.transformer.max_position_embeddings(); - let eos = self.transformer.eos_token(); - let transformer = self.transformer.clone(); - let batcher = self.batcher.clone(); - self.set.spawn_blocking(move || loop { - let mut tasks = batcher.deq(); - - let requests = tasks - .iter_mut() - .map(|task| task.ctx.request(usize::MAX)) - .collect::>(); - - let (requests, logits) = transformer.decode(requests); - let transformer = transformer.clone(); - let sender = sender.clone(); - let batcher = batcher.clone(); - let sample = sample.clone(); - spawn_blocking(move || { - let tokens = transformer - .sample(&sample.lock().unwrap(), requests, logits) - .into_iter() - .collect::>(); - - for mut task in tasks { - match tokens.get(&task.ctx.id) { - Some(&token) => { - if token != eos && task.responsing.send(token).is_ok() { - task.ctx.push(&[token], max_seq_len); - batcher.enq(task); - } else { - task.ctx.push(&[eos], max_seq_len); - sender.send(Message::Ctx(task.ctx)).unwrap(); - } - } - None => batcher.enq(task), - }; - } - }); - }); - } -} diff --git a/service/src/lib.rs b/service/src/lib.rs index 147d70e9..898c66dc 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -1,138 +1,364 @@ mod batcher; -mod cpu; -mod dispatch; -mod new; -#[cfg(detected_cuda)] -mod nvidia; -mod session; mod template; -use session::SessionComponent; +use batcher::Batcher; +use causal_lm::{CausalLM, DecodingMeta, QueryContext, SampleArgs, SampleMeta}; +use common::{upos, utok}; +use core::fmt; use std::{ + borrow::Cow, + error, + iter::zip, + mem::{replace, take}, + ops::Range, path::Path, sync::{Arc, Mutex}, }; use template::Template; +use tensor::Tensor; use tokenizer::{BPECommonNormalizer, Normalizer, Tokenizer, VocabTxt, BPE}; -use tokio::{sync::mpsc::unbounded_channel, task::JoinSet}; -use transformer::SampleArgs; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -pub use session::Session; +/// 对话服务。 +#[repr(transparent)] +pub struct Service(Arc>); +/// 会话。 +pub struct Session { + component: Arc>, + pub sample: SampleArgs, + cache: Option>, + dialog: Vec>, + tail: Vec, +} +/// 忙会话,表示会话正在处理推理任务,并可接收推理结果。 +pub struct BusySession<'a, M: CausalLM> { + session: &'a mut Session, + receiver: Option>, + cache: Arc>>>, +} -#[macro_use] -extern crate log; +/// 对话错误类型。 +/// +/// 目前唯一可能的对话错误是增量对话中句子位置异常。 +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct ChatError; -pub struct Service { - session_component: Arc, - sample: Arc>, - _workers: JoinSet<()>, +impl error::Error for ChatError {} +impl fmt::Display for ChatError { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "chat error") + } } -#[derive(Debug)] -#[non_exhaustive] -pub enum Device { - Cpu, - NvidiaGpu(Vec), +/// 服务中不变的组件,将在所有会话之间共享。 +/// +/// 推理线程的生命周期与这个组件绑定。 +struct ServiceComponent { + handle: Arc>, + tokenizer: Box, + normalizer: Box, + template: Box, } -impl Service { - pub fn load_model(path: impl AsRef, sample: SampleArgs, device: Device) -> Self { - let model_dir = path.as_ref().to_owned(); - let sample = Arc::new(Mutex::new(sample)); - let (sender, receiver) = unbounded_channel(); - Service { - session_component: Arc::new(SessionComponent { - template: template(&model_dir), - normalizer: normalizer(&model_dir), - tokenizer: tokenizer(&model_dir), - sender, - }), - sample: sample.clone(), - _workers: match device { - Device::Cpu => dispatch::run(cpu::transformer(model_dir), sample, receiver), - #[cfg(detected_cuda)] - Device::NvidiaGpu(devices) => match devices.as_slice() { - &[] => dispatch::run(nvidia::transformer(model_dir, 0), sample, receiver), - &[i] => dispatch::run(nvidia::transformer(model_dir, i as _), sample, receiver), - #[cfg(detected_nccl)] - dev => dispatch::run( - nvidia::distributed(model_dir, dev.iter().map(|&d| d as _)), - sample, - receiver, - ), - #[cfg(not(detected_nccl))] - _ => panic!("No NCCL detected"), - }, - #[cfg(not(detected_cuda))] - _ => panic!("Unsupported device"), - }, - } +impl Drop for ServiceComponent { + #[inline] + fn drop(&mut self) { + self.handle.batcher.shutdown(); } +} + +struct HandleComponent { + model: M, + batcher: Batcher>, +} +struct Task { + tokens: Vec, + pos: upos, + sample: SampleArgs, + cache: Arc>>>, + sender: UnboundedSender, +} + +impl Task { #[inline] - pub fn launch(&self) -> Session { - Session::new(self.session_component.clone()) + fn range(&self) -> Range { + self.pos..self.pos + self.tokens.len() as upos } +} - #[inline] - pub fn fork(&self, session: &Session) -> Session { - let new = Session::new(self.session_component.clone()); - self.session_component - .sender - .send(session::Command::Fork(session.id(), new.id())) - .unwrap(); - new +impl Service +where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send + Sync + 'static, +{ + pub fn new(model_dir: impl AsRef) -> Self { + let handle = Arc::new(HandleComponent { + model: M::load(&model_dir), + batcher: Batcher::new(), + }); + { + let handle = handle.clone(); + tokio::task::spawn_blocking(move || { + // 这个线程的生命周期不小于服务的生命周期,不占用线程池 + while let Some(tasks) = Some(handle.batcher.deq()).filter(|t| !t.is_empty()) { + let token_embedded = { + let queries = tasks.iter().flat_map(|t| &t.tokens).copied(); + handle.model.token_embed(queries) + }; + // 锁定所有请求的 cache + let hidden_state = { + let mut queries = tasks + .iter() + .map(|t| (t, t.cache.lock().unwrap())) + .collect::>(); + let queries = queries.iter_mut().map(|(task, lock)| QueryContext { + cache: lock.as_mut(), + range: task.range(), + }); + handle.model.forward(queries, token_embedded) + }; + // 为每次推理启动一个任务执行解码工作 + let handle = handle.clone(); + tokio::task::spawn_blocking(move || { + let num_decode = tasks + .iter() + .map(|t| if !t.sender.is_closed() { 1 } else { 0 }) + .collect::>(); + + let decoding = + zip(&tasks, &num_decode).map(|(t, num_decode)| DecodingMeta { + num_query: t.tokens.len(), + num_decode: *num_decode, + }); + let logits = handle.model.decode(decoding, hidden_state); + + let args = zip(&tasks, &num_decode).map(|(t, num_decode)| SampleMeta { + num_decode: *num_decode, + args: t.sample.clone(), + }); + let tokens = handle.model.sample(args, logits); + + let eos = handle.model.eos_token(); + zip(tasks, num_decode) + .filter(|(_, n)| *n > 0) + .map(|(t, _)| t) + .zip(tokens) + .filter(|(task, token)| { + *token != eos && task.sender.send(*token).is_ok() + }) + .for_each(|(mut task, token)| { + task.pos += replace(&mut task.tokens, vec![token]).len() as upos; + handle.batcher.enq(task); + }); + }); + } + }); + } + Self(Arc::new(ServiceComponent { + handle, + tokenizer: tokenizer(&model_dir), + normalizer: normalizer(&model_dir), + template: template(model_dir), + })) } +} +impl Service { + /// 从对话服务启动一个会话。 + pub fn launch(&self) -> Session { + Session { + component: self.0.clone(), + sample: SampleArgs::default(), + cache: None, + dialog: vec![], + tail: vec![], + } + } +} + +impl Session { #[inline] - pub fn sample_args(&self) -> SampleArgs { - self.sample.lock().unwrap().clone() + pub fn dialog_pos(&self) -> usize { + self.dialog.len() + } + + /// 复制当前会话。 + pub fn fork(&self) -> Self { + Self { + component: self.component.clone(), + sample: Default::default(), + cache: self.cache.as_ref().map(|cache| { + self.component + .handle + .model + .duplicate_cache(cache, self.pos()) + }), + dialog: self.dialog.clone(), + tail: self.tail.clone(), + } + } + + /// 用 dialog 重置会话,启动推理并返回忙会话。 + pub fn reset<'s, 'a>( + &'s mut self, + dialog: impl IntoIterator, + ) -> BusySession<'s, M> { + // 重置会话状态 + self.dialog.clear(); + self.tail.clear(); + // 填充对话 + let eos = self.component.handle.model.eos_token(); + let mut prompt = true; + let mut prefill = vec![]; + for s in dialog { + let s = if prompt { + self.component.template.apply_chat(s) + } else { + s.into() + }; + + let s = self.component.normalizer.encode(&s); + let s = self.component.tokenizer.encode(&s); + prefill.extend_from_slice(self.push_sentence(s)); + + if !prompt { + self.tail = vec![eos]; + } + prompt = !prompt; + } + self.infer(prefill, 0) + } + + /// 向对话的 `dialog_pos` 处填充 `prompt`,启动推理并返回忙会话。 + /// + /// 如果 `dialog_pos` 位置之前有未知的句子,返回 `ChatError`。 + pub fn chat(&mut self, dialog_pos: usize, prompt: &str) -> Result, ChatError> { + if dialog_pos > self.dialog.len() { + Err(ChatError) + } else { + // tokenize and normalize the prompt + let prompt = self.component.template.apply_chat(prompt); + let prompt = self.component.normalizer.encode(&prompt); + let prompt = self.component.tokenizer.encode(&prompt); + // dialog_pos 是历经的位置,需要回滚对话 + if let Some(sentence) = self.dialog.get(dialog_pos) { + let tail = sentence.head(); + self.tail = Vec::with_capacity(tail.len() + prompt.len()); + self.tail.extend(tail); + self.dialog.truncate(dialog_pos); + } + let pos = self.pos(); + let prompt = self.push_sentence(prompt).to_vec(); + Ok(self.infer(prompt, pos)) + } + } + + fn infer(&mut self, tokens: Vec, pos: upos) -> BusySession { + // 生成推理任务与会话的交互管道 + let (sender, receiver) = unbounded_channel(); + let cache = Arc::new(Mutex::new(Some( + self.cache + .take() + .unwrap_or_else(|| self.component.handle.model.new_cache()), + ))); + self.component.handle.batcher.enq(Task { + tokens, + pos, + sample: self.sample.clone(), + cache: cache.clone(), + sender, + }); + BusySession { + session: self, + receiver: Some(receiver), + cache, + } } #[inline] - pub fn set_sample_args(&self, sample: SampleArgs) { - *self.sample.lock().unwrap() = sample; + fn pos(&self) -> upos { + self.dialog + .last() + .map_or(0, |s| s.pos + s.tokens.len() as upos) } -} -fn template(model_dir: impl AsRef) -> Box { - let path: String = model_dir.as_ref().display().to_string(); - let path = path.to_ascii_lowercase(); - if path.contains("tinyllama") { - Box::new(template::ChatTinyLlama) - } else { - Box::new(template::ChatCPM) + /// 连接上一个句子的后续并构造新句子。 + fn push_sentence(&mut self, s: Vec) -> &[utok] { + let pos = self.pos(); + let head_len = self.tail.len(); + self.tail.extend(s); + self.dialog + .push(Sentence::take(&mut self.tail, pos, head_len)); + &self.dialog.last().unwrap().tokens } } -fn normalizer(model_dir: impl AsRef) -> Box { - use std::io::ErrorKind::NotFound; - match BPE::from_model_file(model_dir.as_ref().join("tokenizer.model")) { - Ok(_) => return Box::new(BPECommonNormalizer {}), - Err(e) if e.kind() == NotFound => {} - Err(e) => panic!("{e:?}"), +impl BusySession<'_, M> { + /// 接收模型解码产生的文本。 + pub async fn decode(&mut self) -> Option> { + self.receiver.as_mut().unwrap().recv().await.map(|token| { + // 记录 token + self.session.tail.push(token); + // detokenize and denormalize the token + let ServiceComponent { + normalizer, + tokenizer, + .. + } = &*self.session.component; + normalizer.decode(tokenizer.decode(token)) + }) } - match VocabTxt::from_txt_file(model_dir.as_ref().join("vocabs.txt")) { - Ok(_) => return Box::new(()), - Err(e) if e.kind() == NotFound => {} - Err(e) => panic!("{e:?}"), +} + +impl Drop for BusySession<'_, M> { + fn drop(&mut self) { + let s = &mut *self.session; + // 停止响应接收 + let _ = self.receiver.take(); + // 回收 cache + s.cache = self.cache.lock().unwrap().take(); + if !s.tail.is_empty() { + // 只要忙会话收集到任何 token,就生成一个新的句子 + let answer = take(&mut s.tail); + s.push_sentence(answer); + // 无论忙会话为何丢弃,只要生成了新句子,就补充一个结束符 + s.tail = vec![s.component.handle.model.eos_token()]; + } else if let Some(last) = s.dialog.pop() { + // 否则回滚句子 + s.tail = last.head().to_vec(); + } } - panic!("Tokenizer file not found"); } -fn tokenizer(model_dir: impl AsRef) -> Box { - use std::io::ErrorKind::NotFound; - match BPE::from_model_file(model_dir.as_ref().join("tokenizer.model")) { - Ok(bpe) => return Box::new(bpe), - Err(e) if e.kind() == NotFound => {} - Err(e) => panic!("{e:?}"), +/// 对话中的一个片段。 +struct Sentence { + /// 按 token 计数,句子在对话中的位置。 + pos: upos, + /// 句子中来自上一个句子的后续 token 的数量。 + head_len: usize, + /// 句子的 token 序列。 + tokens: Vec, +} + +impl Sentence { + /// 取走 `tokens` 以构造一个位于 `pos` 处的句子, + /// 其中 `tokens` 的前 `head_len` token 是前一个句子的后续,回滚时需要重新连接。 + #[inline] + pub fn take(tokens: &mut Vec, pos: upos, head_len: usize) -> Arc { + Arc::new(Self { + pos, + head_len, + tokens: take(tokens), + }) } - match VocabTxt::from_txt_file(model_dir.as_ref().join("vocabs.txt")) { - Ok(voc) => return Box::new(voc), - Err(e) if e.kind() == NotFound => {} - Err(e) => panic!("{e:?}"), + + /// 句子中来自前一句的后续部分。 + #[inline] + pub fn head(&self) -> &[utok] { + &self.tokens[..self.head_len] } - panic!("Tokenizer file not found"); } #[test] @@ -146,21 +372,10 @@ fn test() { }; println!("model_dir: {}", model_dir.display()); - let runtime = Builder::new_current_thread().enable_time().build().unwrap(); + let runtime = Builder::new_current_thread().build().unwrap(); let _rt = runtime.enter(); - let service = Service::load_model( - model_dir, - SampleArgs { - temperature: 0., - top_k: usize::MAX, - top_p: 1., - }, - #[cfg(not(detected_cuda))] - Device::Cpu, - #[cfg(detected_cuda)] - Device::NvidiaGpu(vec![0]), - ); + let service = Service::::new(model_dir); let mut set = JoinSet::new(); let tasks = vec![ @@ -173,7 +388,7 @@ fn test() { for ((prompt, color), mut session) in zip(tasks, sessions) { set.spawn(async move { - let mut busy = session.chat(prompt); + let mut busy = session.chat(0, prompt).unwrap(); while let Some(s) = busy.decode().await { print!("{}", s.color(color)); std::io::stdout().flush().unwrap(); @@ -185,16 +400,42 @@ fn test() { runtime.shutdown_background(); } -#[cfg(feature = "nvidia")] -pub fn synchronize() { - #[cfg(detected_cuda)] - { - use transformer_nv::cuda; - cuda::init(); - for i in 0..cuda::Device::count() { - cuda::Device::new(i as _) - .retain_primary() - .apply(|ctx| ctx.synchronize()); - } +fn template(model_dir: impl AsRef) -> Box { + let path: String = model_dir.as_ref().display().to_string(); + let path = path.to_ascii_lowercase(); + if path.contains("tinyllama") { + Box::new(template::ChatTinyLlama) + } else { + Box::new(template::ChatCPM) } } + +fn normalizer(model_dir: impl AsRef) -> Box { + use std::io::ErrorKind::NotFound; + match BPE::from_model_file(model_dir.as_ref().join("tokenizer.model")) { + Ok(_) => return Box::new(BPECommonNormalizer {}), + Err(e) if e.kind() == NotFound => {} + Err(e) => panic!("{e:?}"), + } + match VocabTxt::from_txt_file(model_dir.as_ref().join("vocabs.txt")) { + Ok(_) => return Box::new(()), + Err(e) if e.kind() == NotFound => {} + Err(e) => panic!("{e:?}"), + } + panic!("Tokenizer file not found"); +} + +fn tokenizer(model_dir: impl AsRef) -> Box { + use std::io::ErrorKind::NotFound; + match BPE::from_model_file(model_dir.as_ref().join("tokenizer.model")) { + Ok(bpe) => return Box::new(bpe), + Err(e) if e.kind() == NotFound => {} + Err(e) => panic!("{e:?}"), + } + match VocabTxt::from_txt_file(model_dir.as_ref().join("vocabs.txt")) { + Ok(voc) => return Box::new(voc), + Err(e) if e.kind() == NotFound => {} + Err(e) => panic!("{e:?}"), + } + panic!("Tokenizer file not found"); +} diff --git a/service/src/new.rs b/service/src/new.rs deleted file mode 100644 index 0ba142bf..00000000 --- a/service/src/new.rs +++ /dev/null @@ -1,392 +0,0 @@ -use crate::{batcher::Batcher, normalizer, template, tokenizer}; -use causal_lm::{CausalLM, DecodingMeta, QueryContext, SampleArgs, SampleMeta}; -use common::{upos, utok}; -use core::fmt; -use std::{ - borrow::Cow, - error, - iter::zip, - mem::{replace, take}, - ops::Range, - path::Path, - sync::{Arc, Mutex}, -}; -use tensor::Tensor; -use tokenizer::{Normalizer, Tokenizer}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; - -/// 对话服务。 -#[repr(transparent)] -pub struct Service(Arc>); -/// 会话。 -pub struct Session { - component: Arc>, - sample: SampleArgs, - cache: Option>, - dialog: Vec>, - tail: Vec, -} -/// 忙会话,表示会话正在处理推理任务,并可接收推理结果。 -pub struct BusySession<'a, M: CausalLM> { - session: &'a mut Session, - receiver: Option>, - cache: Arc>>>, -} - -/// 对话错误类型。 -/// -/// 目前唯一可能的对话错误是增量对话中句子位置异常。 -#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub struct ChatError; - -impl error::Error for ChatError {} -impl fmt::Display for ChatError { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "chat error") - } -} - -/// 服务中不变的组件,将在所有会话之间共享。 -/// -/// 推理线程的生命周期与这个组件绑定。 -struct ServiceComponent { - handle: Arc>, - tokenizer: Box, - normalizer: Box, - template: Box, -} - -impl Drop for ServiceComponent { - #[inline] - fn drop(&mut self) { - self.handle.batcher.shutdown(); - } -} - -struct HandleComponent { - model: M, - batcher: Batcher>, -} - -struct Task { - tokens: Vec, - pos: upos, - sample: SampleArgs, - cache: Arc>>>, - sender: UnboundedSender, -} - -impl Task { - #[inline] - fn range(&self) -> Range { - self.pos..self.pos + self.tokens.len() as upos - } -} - -impl Service -where - M: CausalLM + Send + Sync + 'static, - M::Storage: Send + Sync + 'static, -{ - pub fn new(model_dir: impl AsRef) -> Self { - let handle = Arc::new(HandleComponent { - model: M::load(&model_dir), - batcher: Batcher::new(), - }); - { - let handle = handle.clone(); - tokio::task::spawn_blocking(move || { - // 这个线程的生命周期不小于服务的生命周期,不占用线程池 - while let Some(tasks) = Some(handle.batcher.deq()).filter(|t| !t.is_empty()) { - let token_embedded = { - let queries = tasks.iter().flat_map(|t| &t.tokens).copied(); - handle.model.token_embed(queries) - }; - // 锁定所有请求的 cache - let hidden_state = { - let mut queries = tasks - .iter() - .map(|t| (t, t.cache.lock().unwrap())) - .collect::>(); - let queries = queries.iter_mut().map(|(task, lock)| QueryContext { - cache: lock.as_mut(), - range: task.range(), - }); - handle.model.forward(queries, token_embedded) - }; - // 为每次推理启动一个任务执行解码工作 - let handle = handle.clone(); - tokio::task::spawn_blocking(move || { - let num_decode = tasks - .iter() - .map(|t| if !t.sender.is_closed() { 1 } else { 0 }) - .collect::>(); - - let decoding = - zip(&tasks, &num_decode).map(|(t, num_decode)| DecodingMeta { - num_query: t.tokens.len(), - num_decode: *num_decode, - }); - let logits = handle.model.decode(decoding, hidden_state); - - let args = zip(&tasks, &num_decode).map(|(t, num_decode)| SampleMeta { - num_decode: *num_decode, - args: t.sample.clone(), - }); - let tokens = handle.model.sample(args, logits); - - let eos = handle.model.eos_token(); - zip(tasks, num_decode) - .filter(|(_, n)| *n > 0) - .map(|(t, _)| t) - .zip(tokens) - .filter(|(task, token)| { - *token != eos && task.sender.send(*token).is_ok() - }) - .for_each(|(mut task, token)| { - task.pos += replace(&mut task.tokens, vec![token]).len() as upos; - handle.batcher.enq(task); - }); - }); - } - }); - } - Self(Arc::new(ServiceComponent { - handle, - tokenizer: tokenizer(&model_dir), - normalizer: normalizer(&model_dir), - template: template(model_dir), - })) - } -} - -impl Service { - /// 从对话服务启动一个会话。 - pub fn launch(&self) -> Session { - Session { - component: self.0.clone(), - sample: SampleArgs::default(), - cache: None, - dialog: vec![], - tail: vec![], - } - } -} - -impl Session { - /// 复制当前会话。 - pub fn fork(&self) -> Self { - Self { - component: self.component.clone(), - sample: Default::default(), - cache: self.cache.as_ref().map(|cache| { - self.component - .handle - .model - .duplicate_cache(cache, self.pos()) - }), - dialog: self.dialog.clone(), - tail: self.tail.clone(), - } - } - - /// 用 dialog 重置会话,启动推理并返回忙会话。 - pub fn reset<'s, 'a>( - &'s mut self, - dialog: impl IntoIterator, - ) -> BusySession<'s, M> { - // 重置会话状态 - self.dialog.clear(); - self.tail.clear(); - // 填充对话 - let eos = self.component.handle.model.eos_token(); - let mut prompt = true; - let mut prefill = vec![]; - for s in dialog { - let s = if prompt { - self.component.template.apply_chat(s) - } else { - s.into() - }; - - let s = self.component.normalizer.encode(&s); - let s = self.component.tokenizer.encode(&s); - prefill.extend_from_slice(self.push_sentence(s)); - - if !prompt { - self.tail = vec![eos]; - } - prompt = !prompt; - } - self.infer(prefill, 0) - } - - /// 向对话的 `dialog_pos` 处填充 `prompt`,启动推理并返回忙会话。 - /// - /// 如果 `dialog_pos` 位置之前有未知的句子,返回 `ChatError`。 - pub fn chat(&mut self, dialog_pos: usize, prompt: &str) -> Result, ChatError> { - if dialog_pos > self.dialog.len() { - Err(ChatError) - } else { - // tokenize and normalize the prompt - let prompt = self.component.template.apply_chat(prompt); - let prompt = self.component.normalizer.encode(&prompt); - let prompt = self.component.tokenizer.encode(&prompt); - // dialog_pos 是历经的位置,需要回滚对话 - if let Some(sentence) = self.dialog.get(dialog_pos) { - let tail = sentence.head(); - self.tail = Vec::with_capacity(tail.len() + prompt.len()); - self.tail.extend(tail); - self.dialog.truncate(dialog_pos); - } - let pos = self.pos(); - let prompt = self.push_sentence(prompt).to_vec(); - Ok(self.infer(prompt, pos)) - } - } - - fn infer(&mut self, tokens: Vec, pos: upos) -> BusySession { - // 生成推理任务与会话的交互管道 - let (sender, receiver) = unbounded_channel(); - let cache = Arc::new(Mutex::new(Some( - self.cache - .take() - .unwrap_or_else(|| self.component.handle.model.new_cache()), - ))); - self.component.handle.batcher.enq(Task { - tokens, - pos, - sample: self.sample.clone(), - cache: cache.clone(), - sender, - }); - BusySession { - session: self, - receiver: Some(receiver), - cache, - } - } - - #[inline] - fn pos(&self) -> upos { - self.dialog - .last() - .map_or(0, |s| s.pos + s.tokens.len() as upos) - } - - /// 连接上一个句子的后续并构造新句子。 - fn push_sentence(&mut self, s: Vec) -> &[utok] { - let pos = self.pos(); - let head_len = self.tail.len(); - self.tail.extend(s); - self.dialog - .push(Sentence::take(&mut self.tail, pos, head_len)); - &self.dialog.last().unwrap().tokens - } -} - -impl BusySession<'_, M> { - /// 接收模型解码产生的文本。 - pub async fn decode(&mut self) -> Option> { - self.receiver.as_mut().unwrap().recv().await.map(|token| { - // 记录 token - self.session.tail.push(token); - // detokenize and denormalize the token - let ServiceComponent { - normalizer, - tokenizer, - .. - } = &*self.session.component; - normalizer.decode(tokenizer.decode(token)) - }) - } -} - -impl Drop for BusySession<'_, M> { - fn drop(&mut self) { - let s = &mut *self.session; - // 停止响应接收 - let _ = self.receiver.take(); - // 回收 cache - s.cache = self.cache.lock().unwrap().take(); - if !s.tail.is_empty() { - // 只要忙会话收集到任何 token,就生成一个新的句子 - let answer = take(&mut s.tail); - s.push_sentence(answer); - // 无论忙会话为何丢弃,只要生成了新句子,就补充一个结束符 - s.tail = vec![s.component.handle.model.eos_token()]; - } else if let Some(last) = s.dialog.pop() { - // 否则回滚句子 - s.tail = last.head().to_vec(); - } - } -} - -/// 对话中的一个片段。 -struct Sentence { - /// 按 token 计数,句子在对话中的位置。 - pos: upos, - /// 句子中来自上一个句子的后续 token 的数量。 - head_len: usize, - /// 句子的 token 序列。 - tokens: Vec, -} - -impl Sentence { - /// 取走 `tokens` 以构造一个位于 `pos` 处的句子, - /// 其中 `tokens` 的前 `head_len` token 是前一个句子的后续,回滚时需要重新连接。 - #[inline] - pub fn take(tokens: &mut Vec, pos: upos, head_len: usize) -> Arc { - Arc::new(Self { - pos, - head_len, - tokens: take(tokens), - }) - } - - /// 句子中来自前一句的后续部分。 - #[inline] - pub fn head(&self) -> &[utok] { - &self.tokens[..self.head_len] - } -} - -#[test] -fn test() { - use colored::{Color, Colorize}; - use std::{io::Write, iter::zip}; - use tokio::{runtime::Builder, task::JoinSet}; - - let Some(model_dir) = common::test_model::find() else { - return; - }; - println!("model_dir: {}", model_dir.display()); - - let runtime = Builder::new_current_thread().enable_time().build().unwrap(); - let _rt = runtime.enter(); - - let service = Service::::new(model_dir); - - let mut set = JoinSet::new(); - let tasks = vec![ - ("Say \"Hi\" to me.", Color::Yellow), - ("Hi", Color::Red), - ("Where is the capital of France?", Color::Green), - ]; - - let sessions = tasks.iter().map(|_| service.launch()).collect::>(); - - for ((prompt, color), mut session) in zip(tasks, sessions) { - set.spawn(async move { - let mut busy = session.chat(0, prompt).unwrap(); - while let Some(s) = busy.decode().await { - print!("{}", s.color(color)); - std::io::stdout().flush().unwrap(); - } - }); - } - - runtime.block_on(async { while set.join_next().await.is_some() {} }); - runtime.shutdown_background(); -} diff --git a/service/src/nvidia.rs b/service/src/nvidia.rs deleted file mode 100644 index a977a35d..00000000 --- a/service/src/nvidia.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::{path::Path, time::Instant}; - -pub fn transformer(model_dir: impl AsRef, device: i32) -> transformer_nv::Transformer { - use transformer_nv::{cuda, Transformer}; - - let time = Instant::now(); - cuda::init(); - let dev = cuda::Device::new(device); - dev.set_mempool_threshold(u64::MAX); - let transformer = Transformer::new(model_dir, usize::MAX, dev); - info!("build transformer ... {:?}", time.elapsed()); - - transformer -} - -#[cfg(detected_nccl)] -pub fn distributed( - model_dir: impl AsRef, - devices: impl IntoIterator, -) -> distributed::Transformer { - use distributed::{cuda, Transformer}; - cuda::init(); - - let time = Instant::now(); - let dev = devices - .into_iter() - .map(cuda::Device::new) - .collect::>(); - let transformer = Transformer::new(model_dir, &dev); - info!("load {:?}", time.elapsed()); - - transformer -} diff --git a/service/src/session.rs b/service/src/session.rs deleted file mode 100644 index e5b00750..00000000 --- a/service/src/session.rs +++ /dev/null @@ -1,169 +0,0 @@ -use crate::template::Template; -use common::utok; -use std::{ - borrow::Cow, - sync::{ - atomic::{AtomicUsize, Ordering::Relaxed}, - Arc, - }, - time::Instant, -}; -use tokenizer::{Normalizer, Tokenizer}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -use transformer::{LayerCache, Request}; - -pub struct Session { - id: usize, - component: Arc, -} - -impl Session { - #[inline] - pub(crate) fn new(component: Arc) -> Self { - static ID_ROOT: AtomicUsize = AtomicUsize::new(0); - Self { - id: ID_ROOT.fetch_add(1, Relaxed), - component, - } - } - - #[inline] - pub const fn id(&self) -> usize { - self.id - } - - #[inline] - pub fn chat(&mut self, prompt: &str) -> BusySession { - self.send(&self.component.template.apply_chat(prompt)) - } - - #[inline] - pub fn generate(&mut self, prompt: &str) -> BusySession { - self.send(&self.component.template.normalize(prompt)) - } - - fn send(&mut self, prompt: &str) -> BusySession { - let _stamp = Instant::now(); - - let prompt = self.component.normalizer.encode(prompt); - let prompt = self.component.tokenizer.encode(&prompt); - - let (responsing, receiver) = unbounded_channel(); - let chat = Command::Infer( - self.id, - Box::new(Infer { - _stamp, - prompt, - responsing, - }), - ); - - self.component.sender.send(chat).unwrap(); - BusySession { - session: self, - receiver, - } - } -} - -impl Drop for Session { - #[inline] - fn drop(&mut self) { - self.component.sender.send(Command::Drop(self.id)).unwrap(); - } -} - -pub struct BusySession<'a> { - session: &'a mut Session, - receiver: UnboundedReceiver, -} - -impl BusySession<'_> { - pub async fn decode(&mut self) -> Option> { - if let Some(token) = self.receiver.recv().await { - let SessionComponent { - normalizer, - tokenizer, - .. - } = &*self.session.component; - Some(normalizer.decode(tokenizer.decode(token))) - } else { - None - } - } -} - -pub(crate) enum Command { - Infer(usize, Box), - Fork(usize, usize), - Drop(usize), -} - -pub(crate) struct Infer { - pub _stamp: Instant, - pub prompt: Vec, - pub responsing: UnboundedSender, -} - -pub(crate) struct SessionComponent { - pub template: Box, - pub normalizer: Box, - pub tokenizer: Box, - pub sender: UnboundedSender, -} - -pub(crate) struct SessionContext { - /// 会话标识符。 - pub id: usize, - /// 上文缓存。 - pub cache: Vec>, - /// 上文缓存对应的上文 token。 - pub cache_map: Vec, - /// 当前已经计算过上下文缓存的 token 数量。 - pub progress: usize, -} - -impl SessionContext { - #[inline] - pub fn new(cache: Vec>, id: usize) -> Self { - Self { - id, - cache, - cache_map: Vec::new(), - progress: 0, - } - } - - pub fn push(&mut self, tokens: &[utok], max_seq_len: usize) { - if self.cache_map.len() + tokens.len() > max_seq_len { - self.progress = self.progress.min(16); - if tokens.len() > max_seq_len / 2 { - let tokens = &tokens[tokens.len() - max_seq_len / 2..]; - self.cache_map.truncate(self.progress); - self.cache_map.extend_from_slice(tokens); - } else { - let tail_len = (self.cache_map.len() - self.progress).min(64); - let tail = self.cache_map.len() - tail_len; - self.cache_map.copy_within(tail.., self.progress); - self.cache_map.truncate(self.progress + tail_len); - self.cache_map.extend_from_slice(tokens); - } - } else { - self.cache_map.extend_from_slice(tokens); - } - } - - #[inline] - pub fn request(&mut self, max_tokens: usize) -> Request { - let mut tokens = &self.cache_map[self.progress..]; - let decode = tokens.len() <= max_tokens; - if !decode { - tokens = &tokens[..max_tokens]; - } - - let pos = self.progress; - self.progress += tokens.len(); - - Request::new(self.id, tokens, &mut self.cache, pos as _, decode) - } -} diff --git a/transformer-cpu/Cargo.toml b/transformer-cpu/Cargo.toml index 587bd192..d085d77a 100644 --- a/transformer-cpu/Cargo.toml +++ b/transformer-cpu/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "transformer-cpu" -version = "0.0.0" +version = "0.0.1" edition = "2021" authors = ["YdrMaster "] diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index a280203f..b2a7eb40 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -7,7 +7,7 @@ use itertools::izip; use kernel::CpuKernels; use std::{iter::repeat, path::Path, slice::from_raw_parts}; use tensor::{reslice, slice, split, udim, DataType, LocalSplitable, Tensor}; -use transformer::{pos, Kernels, LayerBuffer, LayerCache, Llama2, Memory, Request, SampleArgs}; +use transformer::{Kernels, Llama2, Memory}; pub struct Transformer(Memory); @@ -288,320 +288,6 @@ impl CausalLM for Transformer { } } -impl transformer::Transformer for Transformer { - type Cache = Blob; - - #[inline] - fn max_position_embeddings(&self) -> usize { - self.0.max_position_embeddings() - } - - #[inline] - fn eos_token(&self) -> utok { - self.0.eos_token_id() - } - - #[inline] - fn new_cache(&self) -> Vec> { - LayerCache::new_layers(&self.0, |dt, shape| Tensor::alloc(dt, shape, Blob::new)) - } - - fn decode( - &self, - mut requests: Vec>, - ) -> (Vec, Tensor) { - // 归拢所有纯解码的请求到前面,减少批量解码的拷贝开销 - requests.sort_unstable_by_key(Request::purely_decode); - // 生成词嵌入并预分配空间 - let mut x0 = self.token_embed(&requests); - let mut x1 = Tensor::alloc(x0.data_type(), x0.shape(), Blob::new); - let mut buf = LayerBuffer::alloc(&self.0, &requests, Blob::new); - // 生成位置张量 - let nt = x0.shape()[0]; // `nt` for number of tokens - let pos = pos(&requests, nt); - let pos = Tensor::new(DataType::U32, &[nt], reslice(&pos)); - // 推理 - for layer in 0..self.0.num_hidden_layers() { - let (q, k, v) = self.before_att(layer, &x0, &mut x1, &mut buf.qkv, &pos); - let o = &mut x1; - self.attention( - layer, - &mut requests, - q, - k, - v, - o, - &mut buf.q_buf, - &mut buf.att_buf, - ); - self.after_att(layer, &mut x0, &mut x1, &mut buf.gate_up); - } - // 解码 - if requests[0].decode() { - let x = self.move_decode(&requests, x0); - let requests = requests.into_iter().map(Request::id).collect(); - (requests, self.logits(x)) - } else { - todo!() - } - } - - fn sample( - &self, - args: &SampleArgs, - requests: Vec, - logits: Tensor, - ) -> Vec<(Id, utok)> { - let &[_, voc] = logits.shape() else { panic!() }; - let dt = logits.data_type(); - - macro_rules! sample { - ($ty:ty) => {{ - let logits: &[$ty] = reslice(logits.as_slice()); - requests - .into_iter() - .enumerate() - .map(|(i, id)| (id, args.random(&kernel::slice!(logits; voc; [i])))) - .collect() - }}; - } - - match dt { - DataType::F16 => sample!(f16), - DataType::F32 => sample!(f32), - _ => unreachable!(), - } - } -} - -type Splitable = LocalSplitable; - -impl Transformer { - #[inline] - pub fn new(model: Memory) -> Self { - assert!(model.data_type() == DataType::F16 || model.data_type() == DataType::F32); - Self(model) - } - - fn token_embed(&self, requests: &[Request]) -> Tensor { - let dt = self.0.data_type(); - let nt = requests.iter().map(Request::seq_len).sum::(); - let d = self.0.hidden_size() as udim; - let kernels = CpuKernels::new(&self.0); - - let mut x0 = Tensor::alloc(dt, &[nt, d], Blob::new); - let tokens = requests.iter().flat_map(Request::tokens).copied(); - kernels.gather(&mut x0, &self.0.embed_tokens(), tokens); - - x0 - } - - fn before_att( - &self, - layer: usize, - x0: &Tensor, - x1: &mut Tensor, - qkv: &mut Tensor, - pos: &Tensor<&[u8]>, - ) -> (Tensor, Tensor, Tensor) { - let nt = x0.shape()[0]; - let d = self.0.hidden_size() as udim; - let nh = self.0.num_attention_heads() as udim; - let nkvh = self.0.num_key_value_heads() as udim; - let dh = d / nh; - let dkv = nkvh * dh; - let kernels = CpuKernels::new(&self.0); - - let input_layernorm = self.0.input_layernorm(layer); - kernels.rms_norm(x1, x0, &input_layernorm); - - let w_qkv = self.0.w_qkv(layer).transpose(&[1, 0]); - kernels.mat_mul(qkv, 0., x1, &w_qkv, 1.); - - let (q, k, v) = split!(qkv; [1]: d, dkv, dkv); - let mut q = q.reshape(&[nt, nh, dh]); - let mut k = k.reshape(&[nt, nkvh, dh]); - let v = v.reshape(&[nt, nkvh, dh]); - - kernels.rotary_embedding(&mut q, pos); - kernels.rotary_embedding(&mut k, pos); - - (q, k, v) - } - - fn attention( - &self, - layer: usize, - requests: &mut [Request], - q: Tensor, - k: Tensor, - v: Tensor, - o: &mut Tensor, - q_buf: &mut Blob, - att_buf: &mut Blob, - ) { - let dt = self.0.data_type(); - let nt = o.shape()[0]; - let d = self.0.hidden_size() as udim; - let nh = self.0.num_attention_heads() as udim; - let nkvh = self.0.num_key_value_heads() as udim; - let dh = d / nh; - let head_group = nh / nkvh; - let head_div = (dh as f32).sqrt().recip(); - let kernels = CpuKernels::new(&self.0); - - let q = q.as_ref().transpose(&[1, 0, 2]).map_physical(|u| &**u); - let k = k.as_ref().transpose(&[1, 0, 2]).map_physical(|u| &**u); - let v = v.as_ref().transpose(&[1, 0, 2]).map_physical(|u| &**u); - let mut o = o.as_mut().reshape(&[nt, nh, dh]).transpose(&[1, 0, 2]); - - let mut req = 0; - for r in requests.iter_mut() { - let pos = r.pos(); - let seq_len = r.seq_len(); - let att_len = r.att_len(); - - let req_slice = &[slice![=>], slice![req =>=> seq_len], slice![=>]]; - let cat_slice = &[slice![=>], slice![pos =>=> seq_len], slice![=>]]; - let att_slice = &[slice![=>], slice![ => att_len], slice![=>]]; - req += seq_len; - - let q = q.clone().slice(req_slice); - let k = k.clone().slice(req_slice); - let v = v.clone().slice(req_slice); - let mut o = o.as_mut().slice(req_slice).map_physical(|u| &mut ***u); - - let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], &mut q_buf[..]); - let (k_cache, v_cache) = r.cache(layer); - let mut k_cat = k_cache.as_mut().slice(cat_slice).map_physical(|u| &mut **u); - let mut v_cat = v_cache.as_mut().slice(cat_slice).map_physical(|u| &mut **u); - kernels.reform(&mut q_att, &q); - kernels.reform(&mut k_cat, &k); - kernels.reform(&mut v_cat, &v); - - let q_att = q_att.reshape(&[nkvh, head_group * seq_len, dh]); - let k_att = k_cache - .as_ref() - .slice(att_slice) - .transpose(&[0, 2, 1]) - .map_physical(|u| &**u); - let v_att = v_cache.as_ref().slice(att_slice).map_physical(|u| &**u); - - let shape_att0 = &[nkvh, head_group * seq_len, att_len]; - let shape_att1 = &[nkvh * head_group, seq_len, att_len]; - - let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]); - kernels.mat_mul(&mut att, 0., &q_att, &k_att, head_div); - let mut att = att.reshape(shape_att1); - kernels.softmax(&mut att); - let mut x2 = q_att; - kernels.mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.); - - kernels.reform(&mut o, &x2.reshape(&[nh, seq_len, dh])); - } - } - - fn after_att( - &self, - layer: usize, - x0: &mut Tensor, - x1: &mut Tensor, - gate_up: &mut Tensor, - ) { - let di = self.0.intermediate_size() as udim; - let kernels = CpuKernels::new(&self.0); - - let wo = self.0.self_attn_o_proj(layer).transpose(&[1, 0]); - kernels.mat_mul(x0, 1., x1, &wo, 1.); - - let post_layernorm = self.0.post_attention_layernorm(layer); - kernels.rms_norm(x1, x0, &post_layernorm); - - let w_gate_up = self.0.mlp_gate_up(layer).transpose(&[1, 0]); - kernels.mat_mul(gate_up, 0., x1, &w_gate_up, 1.); - - let (mut gate, up) = split!(gate_up; [1]: di, di); - kernels.swiglu(&mut gate, &up); - - let mlp_down = self.0.mlp_down(layer).transpose(&[1, 0]); - kernels.mat_mul(x0, 1., &gate, &mlp_down, 1.); - } - - fn move_decode( - &self, - requests: &[Request], - mut x0: Tensor, - ) -> Tensor { - let buf = x0.as_mut_slice(); - let len = self.0.hidden_size() * self.0.data_type().size(); - - let (head, others) = requests.split_first().unwrap(); - let begin = head.seq_len() as usize - 1; - - let mut src = begin; - let mut dst = begin; - for r in others { - src += r.seq_len() as usize; - if r.decode() { - dst += 1; - if dst < src { - buf.copy_within(src * len..(src + 1) * len, dst * len); - } - } - } - - x0.slice(&[slice![begin => dst + 1], slice![=>]]) - } - - fn logits(&self, mut x: Tensor) -> Tensor { - let dt = self.0.data_type(); - let voc = self.0.vocab_size() as udim; - let kernels = CpuKernels::new(&self.0); - - let mut logits = Tensor::alloc(dt, &[x.shape()[0], voc], Blob::new); - - // 复制一个 x 以实现原地归一化 - let x_ = unsafe { - x.as_ref() - .map_physical(|u| from_raw_parts(u.as_ptr(), u.len())) - }; - kernels.rms_norm(&mut x, &x_, &self.0.model_norm()); - - let lm_head = self.0.lm_head().transpose(&[1, 0]); - kernels.mat_mul(&mut logits, 0., &x, &lm_head, 1.); - - logits - } -} - -#[test] -fn test_build() { - use common::safe_tensors::SafeTensorsError; - use std::{io::ErrorKind::NotFound, time::Instant}; - use transformer::Memory; - - let Some(model_dir) = common::test_model::find() else { - return; - }; - println!("model_dir: {}", model_dir.display()); - - let t0 = Instant::now(); - let safetensors = Memory::load_safetensors(model_dir); - let t1 = Instant::now(); - println!("mmap {:?}", t1 - t0); - - let safetensors = match safetensors { - Ok(m) => m, - Err(SafeTensorsError::Io(e)) if e.kind() == NotFound => return, - Err(e) => panic!("{e:?}"), - }; - - let t0 = Instant::now(); - let _transformer = Transformer::new(safetensors); - let t1 = Instant::now(); - println!("build transformer {:?}", t1 - t0); -} - #[test] fn test_infer() { use std::time::Instant; diff --git a/web-api/Cargo.toml b/web-api/Cargo.toml index 254ab6cd..6dbbc04c 100644 --- a/web-api/Cargo.toml +++ b/web-api/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" authors = ["Zezhong Pan "] [dependencies] -transformer = { path = "../transformer" } +causal-lm = { path = "../causal-lm" } service = { path = "../service" } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true diff --git a/web-api/src/lib.rs b/web-api/src/lib.rs index 8d4e5744..2531933a 100644 --- a/web-api/src/lib.rs +++ b/web-api/src/lib.rs @@ -4,7 +4,8 @@ mod manager; mod response; mod schemas; -use actix_web::{post, web, App, HttpResponse, HttpServer}; +use actix_web::{web, App, HttpResponse, HttpServer}; +use causal_lm::CausalLM; use futures::StreamExt; use manager::ServiceManager; use schemas::{Drop, Fork, Infer}; @@ -14,15 +15,19 @@ use std::{fmt::Debug, net::ToSocketAddrs, sync::Arc}; extern crate log; /// All global variables and services shared among all endpoints in this App -struct AppState { +struct AppState { /// Manager of this App, which provides all kinds of services such as infer, session management, etc - service_manager: Arc, + service_manager: Arc>, } -pub async fn start_infer_service( - service: service::Service, +pub async fn start_infer_service( + service: service::Service, addrs: impl ToSocketAddrs + Debug, -) -> std::io::Result<()> { +) -> std::io::Result<()> +where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send + Sync + 'static, +{ info!("start service at {addrs:?}"); let app_state = web::Data::new(AppState { service_manager: Arc::new(service.into()), @@ -30,17 +35,20 @@ pub async fn start_infer_service( HttpServer::new(move || { App::new() .app_data(app_state.clone()) - .service(infer) - .service(fork) - .service(drop) + .route("/infer", web::post().to(infer::)) + .route("/fork", web::post().to(fork::)) + .route("/drop", web::post().to(drop::)) }) .bind(addrs)? .run() .await } -#[post("/infer")] -async fn infer(app_state: web::Data, request: web::Json) -> HttpResponse { +async fn infer(app_state: web::Data>, request: web::Json) -> HttpResponse +where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send + Sync + 'static, +{ info!("Request from {}: infer", request.session_id); match app_state.service_manager.infer(request.into_inner()) { Ok(stream) => response::text_stream(stream.map(|word| Ok(word.into()))), @@ -48,8 +56,11 @@ async fn infer(app_state: web::Data, request: web::Json) -> Htt } } -#[post("/fork")] -async fn fork(app_state: web::Data, request: web::Json) -> HttpResponse { +async fn fork(app_state: web::Data>, request: web::Json) -> HttpResponse +where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send + Sync + 'static, +{ info!("Request from {}: fork", request.session_id); match app_state.service_manager.fork(request.into_inner()) { Ok(s) => response::success(s), @@ -57,8 +68,11 @@ async fn fork(app_state: web::Data, request: web::Json) -> HttpR } } -#[post("/drop")] -async fn drop(app_state: web::Data, request: web::Json) -> HttpResponse { +async fn drop(app_state: web::Data>, request: web::Json) -> HttpResponse +where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send + Sync + 'static, +{ info!("Request from {}: drop", request.session_id); match app_state.service_manager.drop_(request.into_inner()) { Ok(s) => response::success(s), diff --git a/web-api/src/manager.rs b/web-api/src/manager.rs index 291b3423..34d9f727 100644 --- a/web-api/src/manager.rs +++ b/web-api/src/manager.rs @@ -1,4 +1,5 @@ -use crate::schemas::{Drop, DropSuccess, Fork, ForkSuccess, Infer, SessionError}; +use crate::schemas::{Drop, DropSuccess, Error, Fork, ForkSuccess, Infer}; +use causal_lm::CausalLM; use futures::{ channel::mpsc::{self, Receiver}, SinkExt, @@ -9,29 +10,33 @@ use std::{ sync::{Arc, Mutex}, }; -pub struct ServiceManager { +pub struct ServiceManager { /// Inference service provided by backend model - infer_service: Service, + infer_service: Service, /// All sessions, session id as key. /// New session will be created when a new id comes. /// The value will become empty when that session is being served, /// so that a new request with the same id will not be double-served. /// A session must be re-inserted after being served. - sessions: Mutex>>, + pending_sessions: Mutex>>>, } -impl From for ServiceManager { +impl From> for ServiceManager { #[inline] - fn from(infer_service: Service) -> Self { + fn from(infer_service: Service) -> Self { Self { infer_service, - sessions: Default::default(), + pending_sessions: Default::default(), } } } -impl ServiceManager { +impl ServiceManager +where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send + Sync + 'static, +{ /// Get existing or create new infer session for a infer request. /// Return session or error. pub fn infer( @@ -39,43 +44,52 @@ impl ServiceManager { Infer { session_id, inputs, - first_request, + dialog_pos, }: Infer, - ) -> Result, SessionError> { - let mut session = match self.sessions.lock().unwrap().entry(session_id.clone()) { - // Case session id exists - Entry::Occupied(mut e) => match e.get_mut().take() { - // Session id exists but user thinks otherwise - Some(_) if first_request => { - e.insert(Some(self.infer_service.launch())); // drop the old session - e.get_mut().take().unwrap() + ) -> Result, Error> { + if inputs.is_empty() { + return Err(Error::EmptyInput); + } + let mut session = match self + .pending_sessions + .lock() + .unwrap() + .entry(session_id.clone()) + { + Entry::Occupied(mut e) => match e.get() { + Some(session) + if dialog_pos == 0 + || (inputs.len() == 1 && dialog_pos <= session.dialog_pos()) => + { + Ok(e.get_mut().take().unwrap()) } - // take the existing session - Some(session) => session, - // If session is being served - None => return Err(SessionError::Busy), + Some(_) => Err(Error::InvalidDialogPos), + None => Err(Error::SessionBusy), }, - // First request, create new session - Entry::Vacant(e) if first_request => { - e.insert(Some(self.infer_service.launch())).take().unwrap() + Entry::Vacant(e) if dialog_pos == 0 => { + Ok(e.insert(Some(self.infer_service.launch())).take().unwrap()) } - // Session id does not exist but user thinks otherwise, histroy lost - _ => return Err(SessionError::NotFound), - }; - + Entry::Vacant(_) => Err(Error::SessionNotFound), + }?; let (mut sender, receiver) = mpsc::channel(4096); let self_ = self.clone(); tokio::spawn(async move { - let mut busy = session.chat(&inputs); - while let Some(s) = busy.decode().await { - if let Err(e) = sender.send(s.into_owned()).await { - warn!("Failed to send piece to {session_id} with error \"{e}\""); - break; + { + let mut busy = if dialog_pos == 0 { + session.reset(inputs.iter().map(|s| s.content.as_str())) + } else { + session.chat(dialog_pos, &inputs[0].content).unwrap() + }; + while let Some(s) = busy.decode().await { + if let Err(e) = sender.send(s.into_owned()).await { + warn!("Failed to send piece to {session_id} with error \"{e}\""); + break; + } } } - if let Some(opt) = self_.sessions.lock().unwrap().get_mut(&session_id) { - opt.get_or_insert(session); + if let Some(container) = self_.pending_sessions.lock().unwrap().get_mut(&session_id) { + container.get_or_insert(session); } }); @@ -88,33 +102,28 @@ impl ServiceManager { session_id, new_session_id, }: Fork, - ) -> Result { - let mut sessions = self.sessions.lock().unwrap(); - let session = sessions + ) -> Result { + let mut sessions = self.pending_sessions.lock().unwrap(); + if sessions.contains_key(&new_session_id) { + warn!("Failed to fork because \"{new_session_id}\" already exists"); + return Err(Error::SessionDuplicate); + } + let new = sessions .get_mut(&session_id) - .ok_or(SessionError::NotFound)? - .take() - .ok_or(SessionError::Busy)?; - let result = match sessions.entry(new_session_id) { - Entry::Occupied(e) => { - warn!("Failed to fork because \"{}\" already exists", e.key()); - Err(SessionError::Duplicate) - } - Entry::Vacant(e) => { - e.insert(Some(self.infer_service.fork(&session))); - Ok(ForkSuccess) - } - }; - sessions.get_mut(&session_id).unwrap().replace(session); - result + .ok_or(Error::SessionNotFound)? + .as_ref() + .ok_or(Error::SessionBusy)? + .fork(); + sessions.insert(new_session_id, Some(new)); + Ok(ForkSuccess) } - pub fn drop_(&self, Drop { session_id }: Drop) -> Result { - self.sessions + pub fn drop_(&self, Drop { session_id }: Drop) -> Result { + self.pending_sessions .lock() .unwrap() .remove(&session_id) .map(|_| DropSuccess) - .ok_or(SessionError::NotFound) + .ok_or(Error::SessionNotFound) } } diff --git a/web-api/src/response.rs b/web-api/src/response.rs index e07a7a52..e70abfab 100644 --- a/web-api/src/response.rs +++ b/web-api/src/response.rs @@ -31,7 +31,7 @@ pub fn success(s: impl schemas::Success) -> HttpResponse { } #[inline] -pub fn error(e: schemas::SessionError) -> HttpResponse { +pub fn error(e: schemas::Error) -> HttpResponse { #[derive(Serialize)] struct ErrorResponse { error: String, diff --git a/web-api/src/schemas.rs b/web-api/src/schemas.rs index 165a366e..788e9dd4 100644 --- a/web-api/src/schemas.rs +++ b/web-api/src/schemas.rs @@ -1,8 +1,15 @@ #[derive(serde::Deserialize)] pub(crate) struct Infer { pub session_id: String, - pub inputs: String, - pub first_request: bool, + pub inputs: Vec, + pub dialog_pos: usize, +} + +#[derive(serde::Deserialize)] +pub(crate) struct Sentence { + #[allow(unused)] + pub role: String, + pub content: String, } #[derive(serde::Deserialize)] @@ -37,19 +44,24 @@ impl Success for DropSuccess { } } -pub(crate) enum SessionError { - Busy, - Duplicate, - NotFound, +#[derive(Debug)] +pub(crate) enum Error { + SessionBusy, + SessionDuplicate, + SessionNotFound, + EmptyInput, + InvalidDialogPos, } -impl SessionError { +impl Error { #[inline] pub fn msg(&self) -> &'static str { match self { - Self::Busy => "Session is busy", - Self::Duplicate => "Session already exists", - Self::NotFound => "Session histroy is lost", + Self::SessionBusy => "Session is busy", + Self::SessionDuplicate => "Session already exists", + Self::SessionNotFound => "Session histroy is lost", + Self::EmptyInput => "Input is empty", + Self::InvalidDialogPos => "Invalid dialog position", } } } diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index d925d1f9..9c859a15 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -9,7 +9,9 @@ authors = ["YdrMaster "] [dependencies] common = { path = "../common" } tensor = { path = "../tensor" } +causal-lm = { path = "../causal-lm" } transformer = { path = "../transformer" } +transformer-cpu = { path = "../transformer-cpu" } service = { path = "../service" } web-api = { path = "../web-api" } log.workspace = true @@ -20,4 +22,4 @@ clap = { version = "4.5", features = ["derive"] } [features] default = ["nvidia"] -nvidia = ["service/nvidia"] +nvidia = [] diff --git a/xtask/src/chat.rs b/xtask/src/chat.rs index 5ed784e0..ee6b0db8 100644 --- a/xtask/src/chat.rs +++ b/xtask/src/chat.rs @@ -1,27 +1,74 @@ -use crate::InferenceArgs; +use crate::{init_log, InferenceArgs}; +use causal_lm::CausalLM; use colored::Colorize; use service::{Service, Session}; use std::{collections::HashMap, io::Write}; -use transformer::SampleArgs; impl InferenceArgs { pub async fn chat(self) { - let mut chatting = Chatting::from(self); + init_log(self.log.as_ref().map(String::as_str)); + let service = Service::::new(self.model); + let sessions = HashMap::from([(0, service.launch())]); + Chatting { + service, + current: 0, + next_id: 1, + sessions, + } + .chat() + .await; + } +} + +struct Chatting { + service: Service, + current: usize, + next_id: usize, + sessions: HashMap>, +} + +macro_rules! print_now { + ($($arg:tt)*) => {{ + print!($($arg)*); + std::io::stdout().flush().unwrap(); + }}; +} + +fn print_splitter() { + println!("====================================="); +} + +fn print_help() { + println!( + "\ +/create 新建会话session +/switch [0-9+] 切换至指定会话 +/drop [0-9+] 丢弃指定会话 +/args 打印当前参数 +/args key value 设置指定参数 +/help 打印帮助信息 + +使用 /exit 或 Ctrl + C 结束程序" + ); +} + +impl Chatting { + async fn chat(mut self) { println!( "\ ########################################### # 欢迎使用九源推理框架-大模型单机对话demo # ###########################################" ); - chatting.print_args(); + self.print_args(); println!(); print_help(); print_splitter(); let mut input = String::new(); loop { - chatting.print_session(); + self.print_session(); input.clear(); std::io::stdin() .read_line(&mut input) @@ -30,94 +77,52 @@ impl InferenceArgs { if !input.is_empty() { // 以 / 开头则为用户指令 if input.starts_with('/') { - if !chatting.execute_command(input) { + if !self.execute_command(input) { break; } } else { - chatting.infer(input).await; + self.infer(input).await; } } } } -} - -struct Chatting { - service: Service, - sample: SampleArgs, - session: Session, - sessions: HashMap, -} -impl From for Chatting { - fn from(args: InferenceArgs) -> Self { - let service: Service = args.into(); - let session = service.launch(); - let sample = service.sample_args(); - Self { - service, - sample, - session, - sessions: HashMap::new(), - } + #[inline] + fn session(&self) -> &Session { + self.sessions.get(&self.current).unwrap() } -} -macro_rules! print_now { - ($($arg:tt)*) => {{ - print!($($arg)*); - std::io::stdout().flush().unwrap(); - }}; -} -fn print_splitter() { - println!("====================================="); -} -fn print_help() { - println!( - "\ -/create 新建会话session -/switch [0-9+] 切换至指定会话 -/drop [0-9+] 丢弃指定会话 -/args 打印当前参数 -/args key value 设置指定参数 -/help 打印帮助信息 - -使用 /exit 或 Ctrl + C 结束程序" - ); -} + #[inline] + fn session_mut(&mut self) -> &mut Session { + self.sessions.get_mut(&self.current).unwrap() + } -impl Chatting { fn print_args(&self) { - println!( - "PID = {}, temperature = {}, top-k = {}, top-p = {}", - std::process::id(), - self.sample.temperature, - self.sample.top_k, - self.sample.top_p, - ); + println!("PID = {}", std::process::id()); + println!("Current session = {}", self.current); + let args = &self.session().sample; + println!("temperature = {}", args.temperature); + println!("top-k = {}", args.top_k); + println!("top-p = {}", args.top_p); } fn print_session(&mut self) { - print_now!( - "{}{}{}", - "User[".yellow(), - self.session.id(), - "]: ".yellow() - ); + print_now!("{}{}{}", "User[".yellow(), self.current, "]: ".yellow()); } fn execute_command(&mut self, command: &str) -> bool { match command.split_whitespace().collect::>().as_slice() { ["/create"] => { - let old = std::mem::replace(&mut self.session, self.service.launch()); - self.sessions.insert(old.id(), old); + self.current = self.next_id; + self.next_id += 1; + self.sessions.insert(self.current, self.service.launch()); } ["/switch", n] => match n.parse() { Ok(target_id) => { - if target_id == self.session.id() { + if target_id == self.current { println!("Already in session {}", target_id); - } else if let Some(target) = self.sessions.remove(&target_id) { - let old = std::mem::replace(&mut self.session, target); - self.sessions.insert(old.id(), old); + } else if self.sessions.contains_key(&target_id) { + self.current = target_id; } else { println!("Invalid session ID."); } @@ -126,20 +131,23 @@ impl Chatting { }, ["/drop", n] => match n.parse() { Ok(target_id) => { - if target_id == self.session.id() { - if let Some((&id, _)) = self.sessions.iter().next() { - let _ = std::mem::replace( - &mut self.session, - self.sessions.remove(&id).unwrap(), - ); - } else { - self.session = self.service.launch(); - } - println!("Session {target_id} is dropped.") - } else if self.sessions.remove(&target_id).is_some() { - println!("Session {target_id} is dropped."); - } else { + if self.sessions.remove(&target_id).is_none() { println!("Invalid session ID."); + } else { + println!("Session {target_id} is dropped."); + if target_id == self.current { + if let Some((&id, _)) = self.sessions.iter().next() { + self.current = id; + println!("Current session is dropped, switched to {id}."); + } else { + self.current = self.next_id; + self.next_id += 1; + println!( + "Current session is dropped, switched to new session {}.", + self.current + ); + } + } } } Err(_) => println!("Invalid drop command"), @@ -147,24 +155,21 @@ impl Chatting { ["/args"] => self.print_args(), ["/args", "temperature", t] => { if let Ok(t) = t.parse() { - self.sample.temperature = t; - self.service.set_sample_args(self.sample.clone()); + self.session_mut().sample.temperature = t; } else { println!("Invalid temperature"); } } ["/args", "top-k", k] => { if let Ok(k) = k.parse() { - self.sample.top_k = k; - self.service.set_sample_args(self.sample.clone()); + self.session_mut().sample.top_k = k; } else { println!("Invalid top-k"); } } ["/args", "top-p", p] => { if let Ok(p) = p.parse() { - self.sample.top_p = p; - self.service.set_sample_args(self.sample.clone()); + self.session_mut().sample.top_p = p; } else { println!("Invalid top-p"); } @@ -179,7 +184,7 @@ impl Chatting { async fn infer(&mut self, text: &str) { print_now!("{}", "AI: ".green()); - let mut busy = self.session.chat(text); + let mut busy = self.session_mut().chat(0, text).unwrap(); while let Some(s) = busy.decode().await { match &*s { "\\n" => println!(), diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 3f0d6da4..c72bc83a 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -1,13 +1,8 @@ mod cast; mod chat; -mod generate; -mod service; -use ::service::{Device, Service}; use clap::Parser; -use service::ServiceArgs; use std::future::Future; -use transformer::SampleArgs; #[macro_use] extern crate clap; @@ -16,9 +11,9 @@ fn main() { use Commands::*; match Cli::parse().command { Cast(cast) => cast.invode(), - Generate(args) => block_on(args.inference.generate(&args.prompt)), + // Generate(args) => block_on(args.inference.generate(&args.prompt)), Chat(chat) => block_on(chat.chat()), - Service(service) => block_on(service.serve()), + // Service(service) => block_on(service.serve()), } } @@ -27,10 +22,6 @@ fn block_on(f: impl Future) { let runtime = tokio::runtime::Runtime::new().unwrap(); runtime.block_on(f); runtime.shutdown_background(); - #[cfg(feature = "nvidia")] - { - ::service::synchronize(); - } } #[derive(Parser)] @@ -45,12 +36,12 @@ struct Cli { enum Commands { /// Cast model Cast(cast::CastArgs), - /// Generate following text - Generate(generate::GenerateArgs), + // /// Generate following text + // Generate(generate::GenerateArgs), /// Chat locally Chat(InferenceArgs), - /// Start the service - Service(ServiceArgs), + // /// Start the service + // Service(ServiceArgs), } #[derive(Args, Default)] @@ -58,15 +49,6 @@ struct InferenceArgs { /// Model directory. #[clap(short, long)] model: String, - /// Temperature for random sampling. - #[clap(long)] - temperature: Option, - /// Top-k for random sampling. - #[clap(long)] - top_k: Option, - /// Top-p for random sampling. - #[clap(long)] - top_p: Option, /// Log level, may be "off", "trace", "debug", "info" or "error". #[clap(long)] log: Option, @@ -76,58 +58,20 @@ struct InferenceArgs { nvidia: Option, } -impl From for Service { - fn from(args: InferenceArgs) -> Self { - use log::LevelFilter; - use simple_logger::SimpleLogger; - - let InferenceArgs { - model, - temperature, - top_k, - top_p, - #[cfg(feature = "nvidia")] - nvidia, - log, - } = args; +fn init_log(log: Option<&str>) { + use log::LevelFilter; + use simple_logger::SimpleLogger; - let log = log - .as_ref() - .and_then(|log| match log.to_lowercase().as_str() { - "off" | "none" => Some(LevelFilter::Off), - "trace" => Some(LevelFilter::Trace), - "debug" => Some(LevelFilter::Debug), - "info" => Some(LevelFilter::Info), - "error" => Some(LevelFilter::Error), - _ => None, - }) - .unwrap_or(LevelFilter::Warn); - SimpleLogger::new().with_level(log).init().unwrap(); - - Service::load_model( - model, - SampleArgs { - temperature: temperature.unwrap_or(0.), - top_k: top_k.unwrap_or(usize::MAX), - top_p: top_p.unwrap_or(1.), - }, - #[cfg(feature = "nvidia")] - { - if let Some(devices) = nvidia { - Device::NvidiaGpu( - devices - .split(',') - .map(|d| d.trim().parse::().unwrap()) - .collect(), - ) - } else { - Device::Cpu - } - }, - #[cfg(not(feature = "nvidia"))] - { - Device::Cpu - }, - ) - } + let log = log + .as_ref() + .and_then(|log| match log.to_lowercase().as_str() { + "off" | "none" => Some(LevelFilter::Off), + "trace" => Some(LevelFilter::Trace), + "debug" => Some(LevelFilter::Debug), + "info" => Some(LevelFilter::Info), + "error" => Some(LevelFilter::Error), + _ => None, + }) + .unwrap_or(LevelFilter::Warn); + SimpleLogger::new().with_level(log).init().unwrap(); }