Skip to content

Commit

Permalink
refactor(web-api): 重构 web-api 实现,使逻辑更清晰
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed May 17, 2024
1 parent b6b431b commit 388d156
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 87 deletions.
2 changes: 1 addition & 1 deletion causal-lm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub fn pos<'a, S: 'a>(
) -> Tensor<Vec<upos>> {
let mut ans = Vec::with_capacity(nt_hint as usize);
for query in queries {
ans.extend(query.range.clone().into_iter());
ans.extend(query.range.clone());
}
Tensor::new(tensor::DataType::U32, &[ans.len() as _], ans)
}
Expand Down
220 changes: 134 additions & 86 deletions web-api/src/manager.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::schemas::{Drop, DropSuccess, Error, Fork, ForkSuccess, Infer};
use crate::schemas::{Drop, DropSuccess, Error, Fork, ForkSuccess, Infer, Sentence};
use causal_lm::CausalLM;
use lru::LruCache;
use service::{Service, Session};
Expand Down Expand Up @@ -41,96 +41,144 @@ where
top_p,
}: Infer,
) -> Result<UnboundedReceiver<String>, Error> {
let dialog_pos = dialog_pos.unwrap_or(0);
let (sender, receiver) = mpsc::unbounded_channel();

macro_rules! set_sample {
($session:expr) => {
if let Some(temperature) = temperature {
$session.sample.temperature = temperature;
}
if let Some(top_k) = top_k {
$session.sample.top_k = top_k;
}
if let Some(top_p) = top_p {
$session.sample.top_p = top_p;
}
};
}

if let Some(session_id) = session_id {
let mut lru = self.pending.lock().unwrap();
let mut session = match lru.get_mut(&session_id) {
Some(option) => {
if let Some(session) = option.as_mut() {
if session.revert(dialog_pos).is_ok() {
info!("Session {session_id} reverted to {dialog_pos}, inference ready");
Ok(option.take().unwrap())
} else {
let current = session.dialog_pos();
warn!(
"Session {session_id} failed to revert from {current} to {dialog_pos}"
);
Err(Error::InvalidDialogPos(current))
}
} else {
warn!("Session {session_id} busy");
Err(Error::SessionBusy)
}
}
None if dialog_pos == 0 => {
info!("Session {session_id} created, inference ready");
if let Some((out, _)) = lru.push(session_id.clone(), None) {
warn!("Session {out} dropped because LRU cache is full");
}
Ok(self.service.launch())
}
None => {
warn!("Session {session_id} not found");
Err(Error::SessionNotFound)
}
}?;
async fn infer<M: CausalLM>(
session_id: &str,
session: &mut Session<M>,
messages: Vec<Sentence>,
temperature: Option<f32>,
top_k: Option<usize>,
top_p: Option<f32>,
sender: mpsc::UnboundedSender<String>,
) {
if let Some(temperature) = temperature {
session.sample.temperature = temperature;
}
if let Some(top_k) = top_k {
session.sample.top_k = top_k;
}
if let Some(top_p) = top_p {
session.sample.top_p = top_p;
}

session.extend(messages.iter().map(|s| s.content.as_str()));
set_sample!(session);

let self_ = self.clone();
tokio::spawn(async move {
if session.dialog_pos() % 2 == 1 {
let mut busy = session.chat();
while let Some(s) = busy.decode().await {
if let Err(e) = sender.send(s.into_owned()) {
warn!("Failed to send piece to {session_id} with error \"{e}\"");
break;
}
}
} else {
warn!("Only revert, no inference for session {session_id}");
}
if let Some(container) = self_.pending.lock().unwrap().get_mut(&session_id) {
container.get_or_insert(session);
}
});
} else if dialog_pos != 0 {
warn!("Temporary session must be created with zero dialog position");
return Err(Error::InvalidDialogPos(0));
} else if messages.len() % 2 == 1 {
info!("Temporary session created, inference ready");
let mut session = self.service.launch();
session.extend(messages.iter().map(|s| s.content.as_str()));
set_sample!(session);

tokio::spawn(async move {
if session.dialog_pos() % 2 == 1 {
info!("{session_id} inference started");
let mut busy = session.chat();
while let Some(s) = busy.decode().await {
if let Err(e) = sender.send(s.into_owned()) {
warn!("Failed to send piece to temporary session with error \"{e}\"");
warn!("Failed to send piece to {session_id} with error \"{e}\"");
break;
}
}
});
info!("{session_id} inference stopped");
} else {
info!("{session_id} inference skipped");
}
}

match (session_id, dialog_pos.unwrap_or(0)) {
(Some(session_id), 0) => {
let mut session = self
.pending
.lock()
.unwrap()
.get_or_insert_mut(session_id.clone(), || {
info!("{session_id} created");
Some(self.service.launch())
})
.take()
.ok_or(Error::SessionBusy)?;

let (sender, receiver) = mpsc::unbounded_channel();
let self_ = self.clone();
tokio::spawn(async move {
session.revert(0).unwrap();

infer(
&session_id,
&mut session,
messages,
temperature,
top_k,
top_p,
sender,
)
.await;

self_.restore(session_id, session);
});

Ok(receiver)
}
(Some(session_id), p) => {
let mut session = self
.pending
.lock()
.unwrap()
.get_mut(&session_id)
.ok_or(Error::SessionNotFound)?
.take()
.ok_or(Error::SessionBusy)?;

if session.revert(p).is_err() {
let current = session.dialog_pos();
warn!("Failed to revert {session_id} from {current} to {p}, session restored");
self.restore(session_id, session);
return Err(Error::InvalidDialogPos(current));
}

let (sender, receiver) = mpsc::unbounded_channel();
let self_ = self.clone();
tokio::spawn(async move {
info!("{session_id} reverted to {p}");

infer(
&session_id,
&mut session,
messages,
temperature,
top_k,
top_p,
sender,
)
.await;

self_.restore(session_id, session);
});

Ok(receiver)
}
(None, 0) => {
let (sender, receiver) = mpsc::unbounded_channel();
if messages.len() % 2 == 1 {
let self_ = self.clone();
tokio::spawn(async move {
infer(
"Temporary session",
&mut self_.service.launch(),
messages,
temperature,
top_k,
top_p,
sender,
)
.await;
});
}
Ok(receiver)
}
(None, _) => {
warn!("Temporary session must be created with zero dialog position");
Err(Error::InvalidDialogPos(0))
}
}
}

#[inline]
fn restore(&self, session_id: String, session: Session<M>) {
if let Some(option) = self.pending.lock().unwrap().get_mut(&session_id) {
assert!(option.replace(session).is_none());
}
Ok(receiver)
}

pub fn fork(
Expand All @@ -149,20 +197,20 @@ where
.ok_or(Error::SessionBusy)?
.fork();

info!("Session \"{new_session_id}\" is forked from \"{session_id}\"");
info!("{new_session_id} is forked from {session_id}");
if let Some((out, _)) = sessions.push(new_session_id, Some(new)) {
warn!("Session {out} dropped because LRU cache is full");
warn!("{out} dropped because LRU cache is full");
}
Ok(ForkSuccess)
} else {
warn!("Session fork failed because \"{new_session_id}\" already exists");
warn!("Fork failed because {new_session_id} already exists");
Err(Error::SessionDuplicate)
}
}

pub fn drop_(&self, Drop { session_id }: Drop) -> Result<DropSuccess, Error> {
if self.pending.lock().unwrap().pop(&session_id).is_some() {
info!("Session \"{session_id}\" dropped");
info!("{session_id} dropped");
Ok(DropSuccess)
} else {
Err(Error::SessionNotFound)
Expand Down

0 comments on commit 388d156

Please sign in to comment.