Skip to content

Commit

Permalink
fix(service): 缓冲不完整的 utf8 字符
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed May 28, 2024
1 parent 31af6b3 commit e5f57a0
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 29 deletions.
28 changes: 14 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 43 additions & 11 deletions service/src/session/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ use crate::ServiceComponent;
use causal_lm::{CausalLM, DecodingMeta, SampleArgs, SampleMeta};
use common::utok;
use std::{
borrow::Cow,
iter::zip,
mem::{replace, size_of},
str,
sync::{Arc, Mutex},
};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};

pub(super) struct TaskHandle<M: CausalLM> {
receiver: Option<UnboundedReceiver<utok>>,
cache: Arc<Mutex<Option<Cache<M::Storage>>>>,
buffer: Utf8Buffer,
}

impl<M: CausalLM> TaskHandle<M> {
Expand All @@ -37,19 +39,26 @@ impl<M: CausalLM> ServiceComponent<M> {
TaskHandle {
receiver: Some(receiver),
cache,
buffer: Default::default(),
}
}

pub(super) async fn decode(&self, x: &mut TaskHandle<M>) -> Option<Cow<str>> {
x.receiver.as_mut().unwrap().recv().await.map(|token| {
// detokenize and denormalize the token
let ServiceComponent {
normalizer,
tokenizer,
..
} = self;
normalizer.decode(tokenizer.decode(token))
})
pub(super) async fn decode(&self, x: &mut TaskHandle<M>) -> Option<String> {
loop {
let s = x.receiver.as_mut().unwrap().recv().await.map(|token| {
// detokenize and denormalize the token
let ServiceComponent {
normalizer,
tokenizer,
..
} = self;
normalizer.decode(tokenizer.decode(token))
})?;
let s = x.buffer.push(s.as_bytes());
if !s.is_empty() {
return Some(s);
}
}
}
}

Expand Down Expand Up @@ -143,3 +152,26 @@ where
}
}
}

#[derive(Clone, Default, Debug)]
struct Utf8Buffer(Vec<u8>);

impl Utf8Buffer {
fn push(&mut self, bytes: impl AsRef<[u8]>) -> String {
self.0.extend_from_slice(bytes.as_ref());
let mut len = match str::from_utf8(&self.0) {
Ok(_) => self.0.len(),
Err(e) => e.valid_up_to(),
};
while len + size_of::<char>() <= self.0.len() {
len += 1;
match str::from_utf8(&self.0[len..]) {
Ok(s) => len += s.as_bytes().len(),
Err(e) => len += e.valid_up_to(),
}
}
let s = self.0.split_off(len);
let s = replace(&mut self.0, s);
unsafe { String::from_utf8_unchecked(s) }
}
}
5 changes: 2 additions & 3 deletions service/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use dialog::Dialog;
use dispatch::TaskHandle;
use log::info;
use std::{
borrow::Cow,
cmp::Ordering::{Equal, Greater, Less},
error, fmt,
sync::Arc,
Expand Down Expand Up @@ -157,7 +156,7 @@ pub struct BusySession<'a, M: CausalLM> {
impl<M: CausalLM> BusySession<'_, M> {
/// 接收模型解码产生的文本。
#[inline]
pub async fn decode(&mut self) -> Option<Cow<str>> {
pub async fn decode(&mut self) -> Option<String> {
self.session.component.decode(&mut self.handle).await
}
}
Expand Down Expand Up @@ -190,7 +189,7 @@ impl<M: CausalLM> Generator<M> {

/// 接收模型解码产生的文本。
#[inline]
pub async fn decode(&mut self) -> Option<Cow<str>> {
pub async fn decode(&mut self) -> Option<String> {
self.component.decode(&mut self.handle).await
}
}
Expand Down
2 changes: 1 addition & 1 deletion web-api/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ where
info!("{session_id} inference started");
let mut busy = session.chat();
while let Some(s) = busy.decode().await {
if let Err(e) = sender.send(s.into_owned()) {
if let Err(e) = sender.send(s) {
warn!("Failed to send piece to {session_id} with error \"{e}\"");
break;
}
Expand Down

0 comments on commit e5f57a0

Please sign in to comment.