From e5f57a049909e99a951d8646306dff494cd526e4 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 29 May 2024 06:37:15 +0800 Subject: [PATCH] =?UTF-8?q?fix(service):=20=E7=BC=93=E5=86=B2=E4=B8=8D?= =?UTF-8?q?=E5=AE=8C=E6=95=B4=E7=9A=84=20utf8=20=E5=AD=97=E7=AC=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- Cargo.lock | 28 ++++++++--------- service/src/session/dispatch.rs | 54 ++++++++++++++++++++++++++------- service/src/session/mod.rs | 5 ++- web-api/src/manager.rs | 2 +- 4 files changed, 60 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 83b8fea2..10eb8e21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.21.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" dependencies = [ "gimli", ] @@ -131,9 +131,9 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "17c6a35df3749d2e8bb1b7b21a976d82b15548788d2735b9d82f329268f71a11" dependencies = [ "addr2line", "cc", @@ -217,9 +217,9 @@ dependencies = [ [[package]] name = "bytemuck_derive" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "369cfaf2a5bed5d8f8202073b2e093c9f508251de1551a0deb4253e4c7d80909" +checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" dependencies = [ "proc-macro2", "quote", @@ -857,9 +857,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "glob" @@ -983,9 +983,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d8d52be92d09acc2e01dddb7fde3ad983fc6489c7db4837e605bc3fca4cb63e" +checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" dependencies = [ "bytes", "futures-util", @@ -1408,9 +1408,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.2" +version = "0.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "b8ec7ab813848ba4522158d5517a6093db1ded27575b070f4177b8d12b41db5e" dependencies = [ "memchr", ] @@ -2515,9 +2515,9 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "winnow" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3c52e9c97a68071b23e836c9380edae937f17b9c4667bd021973efc689f618d" +checksum = "86c949fede1d13936a99f14fafd3e76fd642b556dd2ce96287fbe2e0151bfac6" dependencies = [ "memchr", ] diff --git a/service/src/session/dispatch.rs b/service/src/session/dispatch.rs index 95182e19..04d3426e 100644 --- a/service/src/session/dispatch.rs +++ b/service/src/session/dispatch.rs @@ -3,8 +3,9 @@ use crate::ServiceComponent; use causal_lm::{CausalLM, DecodingMeta, SampleArgs, SampleMeta}; use common::utok; use std::{ - borrow::Cow, iter::zip, + mem::{replace, size_of}, + str, sync::{Arc, Mutex}, }; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; @@ -12,6 +13,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; pub(super) struct TaskHandle { receiver: Option>, cache: Arc>>>, + buffer: Utf8Buffer, } impl TaskHandle { @@ -37,19 +39,26 @@ impl ServiceComponent { TaskHandle { receiver: Some(receiver), cache, + buffer: Default::default(), } } - pub(super) async fn decode(&self, x: &mut TaskHandle) -> Option> { - x.receiver.as_mut().unwrap().recv().await.map(|token| { - // detokenize and denormalize the token - let ServiceComponent { - normalizer, - tokenizer, - .. - } = self; - normalizer.decode(tokenizer.decode(token)) - }) + pub(super) async fn decode(&self, x: &mut TaskHandle) -> Option { + loop { + let s = x.receiver.as_mut().unwrap().recv().await.map(|token| { + // detokenize and denormalize the token + let ServiceComponent { + normalizer, + tokenizer, + .. + } = self; + normalizer.decode(tokenizer.decode(token)) + })?; + let s = x.buffer.push(s.as_bytes()); + if !s.is_empty() { + return Some(s); + } + } } } @@ -143,3 +152,26 @@ where } } } + +#[derive(Clone, Default, Debug)] +struct Utf8Buffer(Vec); + +impl Utf8Buffer { + fn push(&mut self, bytes: impl AsRef<[u8]>) -> String { + self.0.extend_from_slice(bytes.as_ref()); + let mut len = match str::from_utf8(&self.0) { + Ok(_) => self.0.len(), + Err(e) => e.valid_up_to(), + }; + while len + size_of::() <= self.0.len() { + len += 1; + match str::from_utf8(&self.0[len..]) { + Ok(s) => len += s.as_bytes().len(), + Err(e) => len += e.valid_up_to(), + } + } + let s = self.0.split_off(len); + let s = replace(&mut self.0, s); + unsafe { String::from_utf8_unchecked(s) } + } +} diff --git a/service/src/session/mod.rs b/service/src/session/mod.rs index 873eeaff..2873dbed 100644 --- a/service/src/session/mod.rs +++ b/service/src/session/mod.rs @@ -11,7 +11,6 @@ use dialog::Dialog; use dispatch::TaskHandle; use log::info; use std::{ - borrow::Cow, cmp::Ordering::{Equal, Greater, Less}, error, fmt, sync::Arc, @@ -157,7 +156,7 @@ pub struct BusySession<'a, M: CausalLM> { impl BusySession<'_, M> { /// 接收模型解码产生的文本。 #[inline] - pub async fn decode(&mut self) -> Option> { + pub async fn decode(&mut self) -> Option { self.session.component.decode(&mut self.handle).await } } @@ -190,7 +189,7 @@ impl Generator { /// 接收模型解码产生的文本。 #[inline] - pub async fn decode(&mut self) -> Option> { + pub async fn decode(&mut self) -> Option { self.component.decode(&mut self.handle).await } } diff --git a/web-api/src/manager.rs b/web-api/src/manager.rs index fc2a5cce..3cf4dd22 100644 --- a/web-api/src/manager.rs +++ b/web-api/src/manager.rs @@ -65,7 +65,7 @@ where info!("{session_id} inference started"); let mut busy = session.chat(); while let Some(s) = busy.decode().await { - if let Err(e) = sender.send(s.into_owned()) { + if let Err(e) = sender.send(s) { warn!("Failed to send piece to {session_id} with error \"{e}\""); break; }