diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6c99efe6..3ca6fb77 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,14 +10,11 @@ name: CI on: + pull_request: push: paths-ignore: - '**.md' - 'LICENSE' - pull_request: - paths: - - '**.md' - - 'LICENSE' jobs: rust-clippy-analyze: @@ -32,6 +29,9 @@ jobs: - name: Check format run: cargo fmt --check + - name: Update to latest deps + run: cargo update + - name: Run test run: cargo test diff --git a/Cargo.lock b/Cargo.lock index ef60c46d..55f639e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -858,7 +858,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1484,6 +1484,7 @@ dependencies = [ "common 0.0.0", "llama-cpu", "log", + "lru", "tensor", "tokenizer", "tokio", @@ -1716,7 +1717,6 @@ dependencies = [ "hyper", "hyper-util", "log", - "lru", "serde", "serde_json", "service", @@ -1770,7 +1770,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1790,18 +1790,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -1812,9 +1812,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -1824,9 +1824,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -1836,15 +1836,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -1854,9 +1854,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -1866,9 +1866,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -1878,9 +1878,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -1890,9 +1890,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "xtask" diff --git a/service/Cargo.toml b/service/Cargo.toml index 8274aaba..523d3d61 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -13,6 +13,7 @@ tokenizer = { path = "../tokenizer" } causal-lm = { path = "../causal-lm" } log.workspace = true tokio.workspace = true +lru = "0.12" [dev-dependencies] colored = "2.1" diff --git a/service/src/lib.rs b/service/src/lib.rs index fedd4a82..753390c7 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -1,6 +1,7 @@ #![deny(warnings)] mod session; +mod session_manager; mod template; use causal_lm::{CausalLM, SampleArgs}; @@ -11,6 +12,7 @@ use tokenizer::{BPECommonNormalizer, Normalizer, Tokenizer, VocabTxt, BPE}; use tokio::task::JoinHandle; pub use session::{BusySession, ChatError, Session}; +pub use session_manager::{SessionError, SessionManager}; /// 对话服务。 pub struct Service { diff --git a/service/src/session_manager.rs b/service/src/session_manager.rs new file mode 100644 index 00000000..9ebb643f --- /dev/null +++ b/service/src/session_manager.rs @@ -0,0 +1,88 @@ +use crate::Session; +use causal_lm::CausalLM; +use log::warn; +use lru::LruCache; +use std::{fmt::Debug, hash::Hash, num::NonZeroUsize, sync::Mutex}; + +pub struct SessionManager { + pending: Mutex>>>, +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum SessionError { + Busy, + Duplicate, + NotFound, +} + +impl SessionManager { + pub fn new(capacity: Option) -> Self { + let cache = capacity + .map(|c| NonZeroUsize::new(c).expect("Session capacity must be non-zero")) + .map(LruCache::new) + .unwrap_or_else(LruCache::unbounded); + Self { + pending: Mutex::new(cache), + } + } + + pub fn take(&self, k: &SessionId) -> Result, SessionError> { + self.pending + .lock() + .unwrap() + .get_mut(k) + .ok_or(SessionError::NotFound)? + .take() + .ok_or(SessionError::Busy) + } + + pub fn get_or_insert( + &self, + session_id: SessionId, + f: impl FnOnce() -> Session, + ) -> Result, SessionError> { + self.pending + .lock() + .unwrap() + .get_or_insert_mut(session_id, || Some(f())) + .take() + .ok_or(SessionError::Busy) + } + + pub fn drop_(&self, session_id: &SessionId) -> Result<(), SessionError> { + if self.pending.lock().unwrap().pop(session_id).is_some() { + Ok(()) + } else { + Err(SessionError::NotFound) + } + } + + pub fn fork( + &self, + session_id: SessionId, + new_session_id: SessionId, + ) -> Result<(), SessionError> { + let mut sessions = self.pending.lock().unwrap(); + + if !sessions.contains(&new_session_id) { + let new = sessions + .get_mut(&session_id) + .ok_or(SessionError::NotFound)? + .as_ref() + .ok_or(SessionError::Busy)? + .fork(); + if let Some((out, _)) = sessions.push(new_session_id, Some(new)) { + warn!("{out:?} dropped because LRU cache is full"); + } + Ok(()) + } else { + Err(SessionError::Duplicate) + } + } + + pub fn restore(&self, session_id: &SessionId, session: Session) { + if let Some(option) = self.pending.lock().unwrap().get_mut(session_id) { + assert!(option.replace(session).is_none()); + } + } +} diff --git a/web-api/Cargo.toml b/web-api/Cargo.toml index 2885255b..1b651ac4 100644 --- a/web-api/Cargo.toml +++ b/web-api/Cargo.toml @@ -12,7 +12,6 @@ serde_json.workspace = true tokio = { workspace = true, features = ["net"] } log.workspace = true -lru = "0.12" hyper = { version = "1.3", features = ["http1", "server"] } hyper-util = { version = "0.1", features = ["http1", "tokio", "server"] } http-body-util = "0.1" diff --git a/web-api/src/manager.rs b/web-api/src/manager.rs index 98f14ed3..f4e3fd06 100644 --- a/web-api/src/manager.rs +++ b/web-api/src/manager.rs @@ -1,45 +1,22 @@ -use crate::schemas::{Drop, DropSuccess, Error, Fork, ForkSuccess, Infer, Sentence}; -use causal_lm::CausalLM; -use lru::LruCache; -use service::{Service, Session}; -use std::{ - num::NonZeroUsize, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex, - }, +use crate::schemas::{ + AnonymousSessionId, DropSuccess, Drop_, Error, Fork, ForkSuccess, Infer, Sentence, SessionId, }; +use causal_lm::CausalLM; +use service::{Service, Session, SessionError, SessionManager}; +use std::sync::Arc; use tokio::sync::mpsc::{self, UnboundedReceiver}; pub(crate) struct ServiceManager { service: Service, - pending: Mutex>>>, -} - -#[derive(Eq, PartialEq, Hash, Clone, Debug)] -struct AnonymousSessionId(usize); - -impl AnonymousSessionId { - fn new() -> Self { - static NEXT: AtomicUsize = AtomicUsize::new(0); - Self(NEXT.fetch_add(1, Ordering::Relaxed)) - } -} - -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -enum SessionId { - Permanent(String), - Temporary(AnonymousSessionId), + session_manager: SessionManager, } impl ServiceManager { #[inline] pub fn new(service: Service, capacity: Option) -> Self { - let cap = - capacity.map(|c| NonZeroUsize::new(c).expect("Session capacity must be non-zero")); Self { service, - pending: Mutex::new(cap.map(LruCache::new).unwrap_or_else(LruCache::unbounded)), + session_manager: SessionManager::new(capacity), } } } @@ -99,21 +76,13 @@ where (Some(session_id_str), 0) => { let session_id = SessionId::Permanent(session_id_str); let mut session = self - .pending - .lock() - .unwrap() - .get_or_insert_mut(session_id.clone(), || { - info!("{:?} created", &session_id); - Some(self.service.launch()) - }) - .take() - .ok_or(Error::SessionBusy)?; - + .session_manager + .get_or_insert(session_id.clone(), || self.service.launch()) + .map_err(|e| Error::Session(e))?; let (sender, receiver) = mpsc::unbounded_channel(); let self_ = self.clone(); tokio::spawn(async move { session.revert(0).unwrap(); - infer( &session_id, &mut session, @@ -127,20 +96,14 @@ where self_.restore(&session_id, session); }); - Ok(receiver) } (Some(session_id_str), p) => { let session_id = SessionId::Permanent(session_id_str); let mut session = self - .pending - .lock() - .unwrap() - .get_mut(&session_id) - .ok_or(Error::SessionNotFound)? - .take() - .ok_or(Error::SessionBusy)?; - + .session_manager + .take(&session_id) + .map_err(|e| Error::Session(e))?; if session.revert(p).is_err() { let current = session.dialog_pos(); warn!( @@ -149,12 +112,10 @@ where self.restore(&session_id, session); return Err(Error::InvalidDialogPos(current)); } - let (sender, receiver) = mpsc::unbounded_channel(); let self_ = self.clone(); tokio::spawn(async move { info!("{session_id:?} reverted to {p}"); - infer( &session_id, &mut session, @@ -168,21 +129,14 @@ where self_.restore(&session_id, session); }); - Ok(receiver) } (None, 0) => { let session_id = SessionId::Temporary(AnonymousSessionId::new()); let mut session = self - .pending - .lock() - .unwrap() - .get_or_insert_mut(session_id.clone(), || { - info!("{:?} created", &session_id); - Some(self.service.launch()) - }) - .take() - .ok_or(Error::SessionNotFound)?; + .session_manager + .get_or_insert(session_id.clone(), || self.service.launch()) + .map_err(|e| Error::Session(e))?; let (sender, receiver) = mpsc::unbounded_channel(); let self_ = self.clone(); if messages.len() % 2 == 1 { @@ -211,9 +165,7 @@ where #[inline] fn restore(&self, session_id: &SessionId, session: Session) { - if let Some(option) = self.pending.lock().unwrap().get_mut(session_id) { - assert!(option.replace(session).is_none()); - } + self.session_manager.restore(session_id, session); } pub fn fork( @@ -223,38 +175,26 @@ where new_session_id, }: Fork, ) -> Result { - let mut sessions = self.pending.lock().unwrap(); - let new_session_id_warped = SessionId::Permanent(new_session_id.clone()); - if !sessions.contains(&new_session_id_warped) { - let new = sessions - .get_mut(&SessionId::Permanent(session_id.clone())) - .ok_or(Error::SessionNotFound)? - .as_ref() - .ok_or(Error::SessionBusy)? - .fork(); - - info!("{new_session_id} is forked from {session_id:?}"); - if let Some((out, _)) = sessions.push(new_session_id_warped, Some(new)) { - warn!("{out:?} dropped because LRU cache is full"); - } - Ok(ForkSuccess) - } else { - warn!("Fork failed because {new_session_id} already exists"); - Err(Error::SessionDuplicate) + let new_session_id_ = SessionId::Permanent(new_session_id.clone()); + let session_id_ = SessionId::Permanent(session_id.clone()); + let result = self.session_manager.fork(session_id_, new_session_id_); + match result { + Ok(()) => Ok(ForkSuccess), + Err(e) => Err(Error::Session(e)), } } - pub fn drop_(&self, Drop { session_id }: Drop) -> Result { - let session_id = SessionId::Permanent(session_id); - self.drop_with_session_id(session_id) + pub fn drop_(&self, Drop_ { session_id }: Drop_) -> Result { + let result = self.drop_with_session_id(SessionId::Permanent(session_id)); + match result { + Ok(DropSuccess) => Ok(DropSuccess), + Err(e) => Err(Error::Session(e)), + } } - fn drop_with_session_id(&self, session_id: SessionId) -> Result { - if self.pending.lock().unwrap().pop(&session_id).is_some() { - info!("{session_id:?} dropped in drop function"); - Ok(DropSuccess) - } else { - Err(Error::SessionNotFound) - } + fn drop_with_session_id(&self, session_id: SessionId) -> Result { + self.session_manager + .drop_(&session_id) + .map(|()| DropSuccess) } } diff --git a/web-api/src/schemas.rs b/web-api/src/schemas.rs index 678ef6a7..e691af93 100644 --- a/web-api/src/schemas.rs +++ b/web-api/src/schemas.rs @@ -1,4 +1,6 @@ use hyper::StatusCode; +use service::SessionError; +use std::sync::atomic::{AtomicUsize, Ordering}; #[derive(serde::Deserialize)] pub(crate) struct Infer { @@ -17,6 +19,22 @@ pub(crate) struct Sentence { pub content: String, } +#[derive(Eq, PartialEq, Hash, Clone, Debug)] +pub struct AnonymousSessionId(usize); + +impl AnonymousSessionId { + pub(crate) fn new() -> Self { + static NEXT: AtomicUsize = AtomicUsize::new(0); + Self(NEXT.fetch_add(1, Ordering::Relaxed)) + } +} + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +pub enum SessionId { + Permanent(String), + Temporary(AnonymousSessionId), +} + #[derive(serde::Deserialize)] pub(crate) struct Fork { pub session_id: String, @@ -24,14 +42,14 @@ pub(crate) struct Fork { } #[derive(serde::Deserialize)] -pub(crate) struct Drop { +pub(crate) struct Drop_ { pub session_id: String, } pub(crate) struct ForkSuccess; pub(crate) struct DropSuccess; -pub(crate) trait Success { +pub trait Success { fn msg(&self) -> &str; } @@ -48,9 +66,7 @@ impl Success for DropSuccess { #[derive(Debug)] pub(crate) enum Error { - SessionBusy, - SessionDuplicate, - SessionNotFound, + Session(SessionError), WrongJson(serde_json::Error), InvalidDialogPos(usize), } @@ -65,10 +81,11 @@ struct ErrorBody { impl Error { #[inline] pub const fn status(&self) -> StatusCode { + use SessionError::*; match self { - Self::SessionNotFound => StatusCode::NOT_FOUND, - Self::SessionBusy => StatusCode::NOT_ACCEPTABLE, - Self::SessionDuplicate => StatusCode::CONFLICT, + Self::Session(NotFound) => StatusCode::NOT_FOUND, + Self::Session(Busy) => StatusCode::NOT_ACCEPTABLE, + Self::Session(Duplicate) => StatusCode::CONFLICT, Self::WrongJson(_) => StatusCode::BAD_REQUEST, Self::InvalidDialogPos(_) => StatusCode::RANGE_NOT_SATISFIABLE, } @@ -91,10 +108,11 @@ impl Error { serde_json::to_value(v).unwrap() } + use SessionError::*; match self { - Self::SessionNotFound => json(error!(0, "Session not found")), - Self::SessionBusy => json(error!(0, "Session is busy")), - Self::SessionDuplicate => json(error!(0, "Session ID already exists")), + Self::Session(NotFound) => json(error!(0, "Session not found")), + Self::Session(Busy) => json(error!(0, "Session is busy")), + Self::Session(Duplicate) => json(error!(0, "Session ID already exists")), Self::WrongJson(e) => json(error!(0, e.to_string())), &Self::InvalidDialogPos(current_dialog_pos) => { #[derive(serde::Serialize)]