From 759b8835fd00020c67137689814dba8eaf64b6fb Mon Sep 17 00:00:00 2001 From: panzezhong Date: Mon, 8 Apr 2024 16:46:02 +0800 Subject: [PATCH] =?UTF-8?q?feat(web-api):=20=E5=A2=9E=E5=8A=A0=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8E=A8=E7=90=86=E7=AB=AF=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/src/lib.rs | 1 + web-api/Cargo.toml | 1 + web-api/src/handlers/cancel_handler.rs | 16 ++++++ .../infer_handler.rs} | 4 +- web-api/src/handlers/mod.rs | 4 ++ web-api/src/lib.rs | 4 ++ web-api/src/manager.rs | 49 +++++++++++++++++-- web-api/src/response.rs | 7 +++ web-api/src/schemas.rs | 32 ++++++++++++ 9 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 web-api/src/handlers/cancel_handler.rs rename web-api/src/{handlers.rs => handlers/infer_handler.rs} (95%) create mode 100644 web-api/src/handlers/mod.rs diff --git a/service/src/lib.rs b/service/src/lib.rs index 237c993d..b750f2fd 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -17,6 +17,7 @@ use tokio::{sync::mpsc::unbounded_channel, task::JoinSet}; use transformer::SampleArgs; pub use session::Session; +pub use session::SessionHandle; #[macro_use] extern crate log; diff --git a/web-api/Cargo.toml b/web-api/Cargo.toml index 6339b2aa..1b76db81 100644 --- a/web-api/Cargo.toml +++ b/web-api/Cargo.toml @@ -11,3 +11,4 @@ serde = { workspace = true, features = ["derive"] } tokio.workspace = true futures = "0.3" actix-web = "4.5" +log.workspace = true diff --git a/web-api/src/handlers/cancel_handler.rs b/web-api/src/handlers/cancel_handler.rs new file mode 100644 index 00000000..ca7ed4e8 --- /dev/null +++ b/web-api/src/handlers/cancel_handler.rs @@ -0,0 +1,16 @@ +use crate::response::Response; +use crate::schemas; +use crate::AppState; +use actix_web::{post, web, HttpResponse}; + +#[post("/cancel")] +pub async fn cancel( + app_state: web::Data, + request: web::Json, +) -> HttpResponse { + info!("Request from {}: cancel infer", request.session_id); + match app_state.service_manager.cancel_session(&request) { + Ok(s) => Response::success(s), + Err(e) => Response::error(e), + } +} diff --git a/web-api/src/handlers.rs b/web-api/src/handlers/infer_handler.rs similarity index 95% rename from web-api/src/handlers.rs rename to web-api/src/handlers/infer_handler.rs index 100fe2f5..aa702870 100644 --- a/web-api/src/handlers.rs +++ b/web-api/src/handlers/infer_handler.rs @@ -7,11 +7,11 @@ use futures::stream::{Stream, StreamExt}; use service::Session; #[post("/infer")] -async fn infer( +pub async fn infer( app_state: web::Data, request: web::Json, ) -> HttpResponse { - println!("Request from {}: infer", request.session_id); + info!("Request from {}: infer", request.session_id); match app_state.service_manager.get_session(&request) { Ok(session) => { diff --git a/web-api/src/handlers/mod.rs b/web-api/src/handlers/mod.rs new file mode 100644 index 00000000..26a4df8c --- /dev/null +++ b/web-api/src/handlers/mod.rs @@ -0,0 +1,4 @@ +mod cancel_handler; +mod infer_handler; +pub use cancel_handler::cancel; +pub use infer_handler::infer; diff --git a/web-api/src/lib.rs b/web-api/src/lib.rs index af4914a8..82a60c52 100644 --- a/web-api/src/lib.rs +++ b/web-api/src/lib.rs @@ -8,6 +8,9 @@ mod response; mod schemas; use manager::ServiceManager; +#[macro_use] +extern crate log; + /// All global variables and services shared among all endpoints in this App pub struct AppState { /// Manager of this App, which provides all kinds of services such as infer, session management, etc @@ -26,6 +29,7 @@ pub async fn start_infer_service( App::new() .app_data(app_state.clone()) .service(handlers::infer) + .service(handlers::cancel) }) .bind(addrs)? .run() diff --git a/web-api/src/manager.rs b/web-api/src/manager.rs index e01466bd..7be7d533 100644 --- a/web-api/src/manager.rs +++ b/web-api/src/manager.rs @@ -1,6 +1,6 @@ use crate::schemas; use actix_web::web; -use service::{Service, Session}; +use service::{Service, Session, SessionHandle}; use std::collections::{hash_map::Entry, HashMap}; use std::sync::{Arc, Mutex}; pub struct ServiceManager { @@ -11,6 +11,9 @@ pub struct ServiceManager { /// A session must be re-inserted after being served. sessions: Arc>>>, + /// The abort handle for all existing sessions, session id as key. + session_handles: Arc>>, + /// Inference service provided by backend model infer_service: Arc, } @@ -18,6 +21,7 @@ impl ServiceManager { pub fn from(infer_service: Arc) -> Self { ServiceManager { sessions: Default::default(), + session_handles: Default::default(), infer_service: infer_service.clone(), } } @@ -44,7 +48,11 @@ impl ServiceManager { Some(s) => { if request.first_request { // Session id exists but user thinks otherwise, overwrite current session - Ok(e.insert(Some(self.infer_service.launch())).take().unwrap()) + Ok( + e.insert(Some(self.create_session(request.session_id.to_string()))) + .take() + .unwrap(), + ) } else { // take the existing session Ok(s) @@ -55,7 +63,11 @@ impl ServiceManager { Entry::Vacant(e) => { if request.first_request { // First request, create new session - Ok(e.insert(Some(self.infer_service.launch())).take().unwrap()) + Ok( + e.insert(Some(self.create_session(request.session_id.to_string()))) + .take() + .unwrap(), + ) } else { // Session id does not exist but user thinks otherwise, histroy lost Err(schemas::Error::SessionNotFound) @@ -64,6 +76,18 @@ impl ServiceManager { } } + /// Create a new infer session + fn create_session(&self, session_id: String) -> Session { + // launch new infer session + let session = self.infer_service.launch(); + // register session abort handle + self.session_handles + .lock() + .unwrap() + .insert(session_id, session.handle()); + return session; + } + /// Return the taken-away session, should be called every time a request is done pub fn reset_session(&self, id: &str, session: Session) { self.sessions @@ -71,4 +95,23 @@ impl ServiceManager { .unwrap() .insert(id.to_string(), Some(session)); } + + /// Signal the backend model to stop the current inferrence task given session id + pub fn cancel_session( + &self, + request: &web::Json, + ) -> Result { + match self + .session_handles + .lock() + .unwrap() + .entry(request.session_id.to_string()) + { + Entry::Occupied(handle) => { + handle.get().abort(); + Ok(schemas::Success::SessionCanceled) + } + Entry::Vacant(_) => Err(schemas::Error::CancelFailed), + } + } } diff --git a/web-api/src/response.rs b/web-api/src/response.rs index f6ce62b6..3be5cf8b 100644 --- a/web-api/src/response.rs +++ b/web-api/src/response.rs @@ -20,4 +20,11 @@ impl Response { .content_type("application/json") .json(err) } + + pub fn success(s: schemas::Success) -> HttpResponse { + let success = schemas::SuccessResponse { result: s.msg() }; + HttpResponse::Ok() + .content_type("application/json") + .json(success) + } } diff --git a/web-api/src/schemas.rs b/web-api/src/schemas.rs index 55eac885..6efbeb7e 100644 --- a/web-api/src/schemas.rs +++ b/web-api/src/schemas.rs @@ -18,9 +18,40 @@ impl From> for InferRequest { } } +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct CancelRequest { + pub session_id: String, +} + +impl From> for CancelRequest { + fn from(request: web::Json) -> Self { + CancelRequest { + session_id: request.session_id.clone(), + } + } +} + +pub enum Success { + SessionCanceled, +} + +impl Success { + pub fn msg(&self) -> String { + match self { + Success::SessionCanceled => "Inferrence canceled".to_string(), + } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct SuccessResponse { + pub result: String, +} + pub enum Error { SessionBusy, SessionNotFound, + CancelFailed, } impl Error { @@ -28,6 +59,7 @@ impl Error { match self { Error::SessionBusy => "Session is busy".to_string(), Error::SessionNotFound => "Session histroy is lost".to_string(), + Error::CancelFailed => "Failed to cancel inferrence".to_string(), } } }