diff --git a/Cargo.toml b/Cargo.toml index d3435564..0e5b9965 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ rayon = "1.9" serde_json = "1.0" serde = "1.0" log = "0.4" -tokio = { version = "1.37", features = ["rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.37", features = ["rt-multi-thread", "sync"] } cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "764733" } cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "764733" } diff --git a/service/Cargo.toml b/service/Cargo.toml index 6674ee88..0bc3d21b 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -15,7 +15,7 @@ transformer-cpu = { path = "../transformer-cpu" } transformer-nv = { path = "../nvidia/transformer", optional = true } half.workspace = true log.workspace = true -tokio.workspace = true +tokio = { workspace = true, features = ["time"] } [build-dependencies] search-cuda-tools.workspace = true diff --git a/service/src/batcher.rs b/service/src/batcher.rs index 914cc802..4cfc6bde 100644 --- a/service/src/batcher.rs +++ b/service/src/batcher.rs @@ -1,11 +1,10 @@ -use crate::session::SessionContext; -use common::utok; +use crate::session::{Respond, SessionContext}; use std::sync::{Condvar, Mutex}; use tokio::sync::mpsc::UnboundedSender; pub struct Task { pub ctx: SessionContext, - pub responsing: UnboundedSender, + pub responsing: UnboundedSender, } pub struct Batcher { diff --git a/service/src/dispatch.rs b/service/src/dispatch.rs index b46df1bf..f5ff2d8c 100644 --- a/service/src/dispatch.rs +++ b/service/src/dispatch.rs @@ -1,6 +1,6 @@ use crate::{ batcher::{Batcher, Task}, - session::{Command, SessionContext}, + session::{Command, Respond::Token, SessionContext}, }; use std::{ collections::{hash_map::Entry, HashMap, HashSet}, @@ -134,8 +134,7 @@ where for mut task in tasks { match tokens.get(&task.ctx.id) { Some(&token) => { - if token != eos { - task.responsing.send(token).unwrap(); + if token != eos && task.responsing.send(Token(token)).is_ok() { task.ctx.push(&[token], max_seq_len); batcher.enq(task); } else { diff --git a/service/src/lib.rs b/service/src/lib.rs index e9ba0492..03dad571 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -118,15 +118,16 @@ fn tokenizer(model_dir: impl AsRef) -> Box { #[test] fn test() { use colored::{Color, Colorize}; - use std::{io::Write, path::Path}; - use tokio::{runtime::Builder, task::JoinSet}; + use std::{io::Write, iter::zip, path::Path, time::Duration}; + use tokio::{runtime::Builder, task::JoinSet, time::sleep}; let model_dir = "../../TinyLlama-1.1B-Chat-v1.0_F16/"; if !Path::new(model_dir).exists() { + println!("model not exist"); return; } - let runtime = Builder::new_current_thread().build().unwrap(); + let runtime = Builder::new_current_thread().enable_time().build().unwrap(); let _rt = runtime.enter(); let service = Service::load_model( @@ -149,8 +150,15 @@ fn test() { ("Where is the capital of France?", Color::Green), ]; - for (prompt, color) in tasks { - let mut session = service.launch(); + let sessions = tasks.iter().map(|_| service.launch()).collect::>(); + + let handle = sessions.last().unwrap().handle(); + set.spawn(async move { + sleep(Duration::from_secs(3)).await; + handle.abort(); + }); + + for ((prompt, color), mut session) in zip(tasks, sessions) { set.spawn(async move { session .chat(prompt, |s| { diff --git a/service/src/session.rs b/service/src/session.rs index db75ac36..f50c02c0 100644 --- a/service/src/session.rs +++ b/service/src/session.rs @@ -3,17 +3,18 @@ use common::utok; use std::{ sync::{ atomic::{AtomicUsize, Ordering::Relaxed}, - Arc, + Arc, Mutex, }, time::Instant, }; use tokenizer::{Normalizer, Tokenizer}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender, WeakUnboundedSender}; use transformer::{LayerCache, Request}; pub struct Session { id: usize, component: Arc, + abort_handle: Arc>>>, } impl Session { @@ -23,6 +24,7 @@ impl Session { Self { id: ID_ROOT.fetch_add(1, Relaxed), component, + abort_handle: Default::default(), } } @@ -31,6 +33,14 @@ impl Session { self.id } + #[inline] + pub fn handle(&self) -> SessionHandle { + SessionHandle { + id: self.id, + abort_handle: self.abort_handle.clone(), + } + } + #[inline] pub async fn chat(&mut self, prompt: &str, f: impl FnMut(&str)) { self.send(&self.component.template.apply_chat(prompt), f) @@ -50,6 +60,11 @@ impl Session { let prompt = self.component.tokenizer.encode(&prompt); let (responsing, mut receiver) = unbounded_channel(); + self.abort_handle + .lock() + .unwrap() + .replace(responsing.downgrade()); + let chat = Command::Infer( self.id, Box::new(Infer { @@ -60,7 +75,7 @@ impl Session { ); self.component.sender.send(chat).unwrap(); - while let Some(token) = receiver.recv().await { + while let Some(Respond::Token(token)) = receiver.recv().await { let piece = self.component.tokenizer.decode(token); let piece = self.component.normalizer.decode(piece); f(&piece); @@ -75,15 +90,45 @@ impl Drop for Session { } } +pub struct SessionHandle { + id: usize, + abort_handle: Arc>>>, +} + +impl SessionHandle { + #[inline] + pub const fn id(&self) -> usize { + self.id + } + + #[inline] + pub fn abort(&self) { + if let Some(sender) = self + .abort_handle + .lock() + .unwrap() + .as_ref() + .and_then(WeakUnboundedSender::upgrade) + { + let _ = sender.send(Respond::Abort); + } + } +} + pub(crate) enum Command { Infer(usize, Box), Drop(usize), } +pub(crate) enum Respond { + Token(utok), + Abort, +} + pub(crate) struct Infer { pub _stamp: Instant, pub prompt: Vec, - pub responsing: UnboundedSender, + pub responsing: UnboundedSender, } pub(crate) struct SessionComponent {