Skip to content

Commit

Permalink
move session manager to service
Browse files Browse the repository at this point in the history
  • Loading branch information
qinyiqun committed Jul 5, 2024
1 parent 6321c96 commit 794f1e3
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 137 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
name: CI

on:
pull_request:
push:
paths-ignore:
- '**.md'
- 'LICENSE'
pull_request:
paths:
- '**.md'
- 'LICENSE'

jobs:
rust-clippy-analyze:
Expand All @@ -32,6 +29,9 @@ jobs:
- name: Check format
run: cargo fmt --check

- name: Update to latest deps
run: cargo update

- name: Run test
run: cargo test

Expand Down
56 changes: 28 additions & 28 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ tokenizer = { path = "../tokenizer" }
causal-lm = { path = "../causal-lm" }
log.workspace = true
tokio.workspace = true
lru = "0.12"

[dev-dependencies]
colored = "2.1"
Expand Down
2 changes: 2 additions & 0 deletions service/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![deny(warnings)]

mod session;
mod session_manager;
mod template;

use causal_lm::{CausalLM, SampleArgs};
Expand All @@ -11,6 +12,7 @@ use tokenizer::{BPECommonNormalizer, Normalizer, Tokenizer, VocabTxt, BPE};
use tokio::task::JoinHandle;

pub use session::{BusySession, ChatError, Session};
pub use session_manager::{SessionError, SessionManager};

/// 对话服务。
pub struct Service<M: CausalLM> {
Expand Down
88 changes: 88 additions & 0 deletions service/src/session_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use crate::Session;
use causal_lm::CausalLM;
use log::warn;
use lru::LruCache;
use std::{fmt::Debug, hash::Hash, num::NonZeroUsize, sync::Mutex};

pub struct SessionManager<SessionId, M: CausalLM> {
pending: Mutex<LruCache<SessionId, Option<Session<M>>>>,
}

#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum SessionError {
Busy,
Duplicate,
NotFound,
}

impl<SessionId: Eq + Hash + Debug, M: CausalLM> SessionManager<SessionId, M> {
pub fn new(capacity: Option<usize>) -> Self {
let cache = capacity
.map(|c| NonZeroUsize::new(c).expect("Session capacity must be non-zero"))
.map(LruCache::new)
.unwrap_or_else(LruCache::unbounded);
Self {
pending: Mutex::new(cache),
}
}

pub fn take(&self, k: &SessionId) -> Result<Session<M>, SessionError> {
self.pending
.lock()
.unwrap()
.get_mut(k)
.ok_or(SessionError::NotFound)?
.take()
.ok_or(SessionError::Busy)
}

pub fn get_or_insert(
&self,
session_id: SessionId,
f: impl FnOnce() -> Session<M>,
) -> Result<Session<M>, SessionError> {
self.pending
.lock()
.unwrap()
.get_or_insert_mut(session_id, || Some(f()))
.take()
.ok_or(SessionError::Busy)
}

pub fn drop_(&self, session_id: &SessionId) -> Result<(), SessionError> {
if self.pending.lock().unwrap().pop(session_id).is_some() {
Ok(())
} else {
Err(SessionError::NotFound)
}
}

pub fn fork(
&self,
session_id: SessionId,
new_session_id: SessionId,
) -> Result<(), SessionError> {
let mut sessions = self.pending.lock().unwrap();

if !sessions.contains(&new_session_id) {
let new = sessions
.get_mut(&session_id)
.ok_or(SessionError::NotFound)?
.as_ref()
.ok_or(SessionError::Busy)?
.fork();
if let Some((out, _)) = sessions.push(new_session_id, Some(new)) {
warn!("{out:?} dropped because LRU cache is full");
}
Ok(())
} else {
Err(SessionError::Duplicate)
}
}

pub fn restore(&self, session_id: &SessionId, session: Session<M>) {
if let Some(option) = self.pending.lock().unwrap().get_mut(session_id) {
assert!(option.replace(session).is_none());
}
}
}
1 change: 0 additions & 1 deletion web-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ serde_json.workspace = true
tokio = { workspace = true, features = ["net"] }
log.workspace = true

lru = "0.12"
hyper = { version = "1.3", features = ["http1", "server"] }
hyper-util = { version = "0.1", features = ["http1", "tokio", "server"] }
http-body-util = "0.1"
Expand Down
Loading

0 comments on commit 794f1e3

Please sign in to comment.