Skip to content

Commit

Permalink
style: cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jul 5, 2024
1 parent cf575e9 commit b2c33b1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion service/src/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<SessionId: Eq + Hash + Debug, M: CausalLM> SessionManager<SessionId, M> {
.ok_or(SessionError::Busy)
}

pub fn get_or_insert(
pub fn take_or_register(
&self,
session_id: SessionId,
f: impl FnOnce() -> Session<M>,
Expand Down
25 changes: 9 additions & 16 deletions web-api/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::schemas::{
AnonymousSessionId, DropSuccess, Drop_, Error, Fork, ForkSuccess, Infer, Sentence, SessionId,
};
use causal_lm::CausalLM;
use service::{Service, Session, SessionError, SessionManager};
use service::{Service, Session, SessionManager};
use std::sync::Arc;
use tokio::sync::mpsc::{self, UnboundedReceiver};

Expand Down Expand Up @@ -77,7 +77,7 @@ where
let session_id = SessionId::Permanent(session_id_str);
let mut session = self
.session_manager
.get_or_insert(session_id.clone(), || self.service.launch())
.take_or_register(session_id.clone(), || self.service.launch())
.map_err(Error::Session)?;
let (sender, receiver) = mpsc::unbounded_channel();
let self_ = self.clone();
Expand Down Expand Up @@ -135,7 +135,7 @@ where
let session_id = SessionId::Temporary(AnonymousSessionId::new());
let mut session = self
.session_manager
.get_or_insert(session_id.clone(), || self.service.launch())
.take_or_register(session_id.clone(), || self.service.launch())
.map_err(Error::Session)?;
let (sender, receiver) = mpsc::unbounded_channel();
let self_ = self.clone();
Expand All @@ -151,7 +151,7 @@ where
sender,
)
.await;
self_.drop_with_session_id(session_id).unwrap();
self_.session_manager.drop_(&session_id).unwrap();
});
}
Ok(receiver)
Expand All @@ -170,23 +170,16 @@ where
new_session_id,
}: Fork,
) -> Result<ForkSuccess, Error> {
let new_session_id = SessionId::Permanent(new_session_id);
let session_id = SessionId::Permanent(session_id);
self.session_manager
.fork(session_id, new_session_id)
.map_err(Error::Session)?;
Ok(ForkSuccess)
.fork(session_id.into(), new_session_id.into())
.map(|()| ForkSuccess)
.map_err(Error::Session)
}

pub fn drop_(&self, Drop_ { session_id }: Drop_) -> Result<DropSuccess, Error> {
self.drop_with_session_id(SessionId::Permanent(session_id))
.map_err(Error::Session)?;
Ok(DropSuccess)
}

fn drop_with_session_id(&self, session_id: SessionId) -> Result<DropSuccess, SessionError> {
self.session_manager
.drop_(&session_id)
.drop_(&session_id.into())
.map(|()| DropSuccess)
.map_err(Error::Session)
}
}
14 changes: 14 additions & 0 deletions web-api/src/schemas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ pub enum SessionId {
Temporary(AnonymousSessionId),
}

impl From<String> for SessionId {
#[inline]
fn from(value: String) -> Self {
Self::Permanent(value)
}
}

impl From<AnonymousSessionId> for SessionId {
#[inline]
fn from(value: AnonymousSessionId) -> Self {
Self::Temporary(value)
}
}

#[derive(serde::Deserialize)]
pub(crate) struct Fork {
pub session_id: String,
Expand Down

0 comments on commit b2c33b1

Please sign in to comment.