Skip to content

Commit

Permalink
feat(service): 支持中断正在解码的会话
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 3, 2024
1 parent e7deb99 commit 56fe5b8
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rayon = "1.9"
serde_json = "1.0"
serde = "1.0"
log = "0.4"
tokio = { version = "1.37", features = ["rt", "rt-multi-thread", "sync"] }
tokio = { version = "1.37", features = ["rt-multi-thread", "sync"] }

cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "764733" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "764733" }
Expand Down
2 changes: 1 addition & 1 deletion service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ transformer-cpu = { path = "../transformer-cpu" }
transformer-nv = { path = "../nvidia/transformer", optional = true }
half.workspace = true
log.workspace = true
tokio.workspace = true
tokio = { workspace = true, features = ["time"] }

[build-dependencies]
search-cuda-tools.workspace = true
Expand Down
5 changes: 2 additions & 3 deletions service/src/batcher.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use crate::session::SessionContext;
use common::utok;
use crate::session::{Respond, SessionContext};
use std::sync::{Condvar, Mutex};
use tokio::sync::mpsc::UnboundedSender;

pub struct Task<Cache> {
pub ctx: SessionContext<Cache>,
pub responsing: UnboundedSender<utok>,
pub responsing: UnboundedSender<Respond>,
}

pub struct Batcher<Cache> {
Expand Down
5 changes: 2 additions & 3 deletions service/src/dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
batcher::{Batcher, Task},
session::{Command, SessionContext},
session::{Command, Respond::Token, SessionContext},
};
use std::{
collections::{hash_map::Entry, HashMap, HashSet},
Expand Down Expand Up @@ -134,8 +134,7 @@ where
for mut task in tasks {
match tokens.get(&task.ctx.id) {
Some(&token) => {
if token != eos {
task.responsing.send(token).unwrap();
if token != eos && task.responsing.send(Token(token)).is_ok() {
task.ctx.push(&[token], max_seq_len);
batcher.enq(task);
} else {
Expand Down
18 changes: 13 additions & 5 deletions service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ fn tokenizer(model_dir: impl AsRef<Path>) -> Box<dyn Tokenizer + Send + Sync> {
#[test]
fn test() {
use colored::{Color, Colorize};
use std::{io::Write, path::Path};
use tokio::{runtime::Builder, task::JoinSet};
use std::{io::Write, iter::zip, path::Path, time::Duration};
use tokio::{runtime::Builder, task::JoinSet, time::sleep};

let model_dir = "../../TinyLlama-1.1B-Chat-v1.0_F16/";
if !Path::new(model_dir).exists() {
println!("model not exist");
return;
}

let runtime = Builder::new_current_thread().build().unwrap();
let runtime = Builder::new_current_thread().enable_time().build().unwrap();
let _rt = runtime.enter();

let service = Service::load_model(
Expand All @@ -149,8 +150,15 @@ fn test() {
("Where is the capital of France?", Color::Green),
];

for (prompt, color) in tasks {
let mut session = service.launch();
let sessions = tasks.iter().map(|_| service.launch()).collect::<Vec<_>>();

let handle = sessions.last().unwrap().handle();
set.spawn(async move {
sleep(Duration::from_secs(3)).await;
handle.abort();
});

for ((prompt, color), mut session) in zip(tasks, sessions) {
set.spawn(async move {
session
.chat(prompt, |s| {
Expand Down
53 changes: 49 additions & 4 deletions service/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ use common::utok;
use std::{
sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
Arc, Mutex,
},
time::Instant,
};
use tokenizer::{Normalizer, Tokenizer};
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender, WeakUnboundedSender};
use transformer::{LayerCache, Request};

pub struct Session {
id: usize,
component: Arc<SessionComponent>,
abort_handle: Arc<Mutex<Option<WeakUnboundedSender<Respond>>>>,
}

impl Session {
Expand All @@ -23,6 +24,7 @@ impl Session {
Self {
id: ID_ROOT.fetch_add(1, Relaxed),
component,
abort_handle: Default::default(),
}
}

Expand All @@ -31,6 +33,14 @@ impl Session {
self.id
}

#[inline]
pub fn handle(&self) -> SessionHandle {
SessionHandle {
id: self.id,
abort_handle: self.abort_handle.clone(),
}
}

#[inline]
pub async fn chat(&mut self, prompt: &str, f: impl FnMut(&str)) {
self.send(&self.component.template.apply_chat(prompt), f)
Expand All @@ -50,6 +60,11 @@ impl Session {
let prompt = self.component.tokenizer.encode(&prompt);

let (responsing, mut receiver) = unbounded_channel();
self.abort_handle
.lock()
.unwrap()
.replace(responsing.downgrade());

let chat = Command::Infer(
self.id,
Box::new(Infer {
Expand All @@ -60,7 +75,7 @@ impl Session {
);

self.component.sender.send(chat).unwrap();
while let Some(token) = receiver.recv().await {
while let Some(Respond::Token(token)) = receiver.recv().await {
let piece = self.component.tokenizer.decode(token);
let piece = self.component.normalizer.decode(piece);
f(&piece);
Expand All @@ -75,15 +90,45 @@ impl Drop for Session {
}
}

pub struct SessionHandle {
id: usize,
abort_handle: Arc<Mutex<Option<WeakUnboundedSender<Respond>>>>,
}

impl SessionHandle {
#[inline]
pub const fn id(&self) -> usize {
self.id
}

#[inline]
pub fn abort(&self) {
if let Some(sender) = self
.abort_handle
.lock()
.unwrap()
.as_ref()
.and_then(WeakUnboundedSender::upgrade)
{
let _ = sender.send(Respond::Abort);
}
}
}

pub(crate) enum Command {
Infer(usize, Box<Infer>),
Drop(usize),
}

pub(crate) enum Respond {
Token(utok),
Abort,
}

pub(crate) struct Infer {
pub _stamp: Instant,
pub prompt: Vec<utok>,
pub responsing: UnboundedSender<utok>,
pub responsing: UnboundedSender<Respond>,
}

pub(crate) struct SessionComponent {
Expand Down

0 comments on commit 56fe5b8

Please sign in to comment.