Skip to content

Commit

Permalink
feat(web-api): 增加取消推理端口
Browse files Browse the repository at this point in the history
  • Loading branch information
PanZezhong1725 authored and YdrMaster committed Apr 16, 2024
1 parent 81b0333 commit 759b883
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 5 deletions.
1 change: 1 addition & 0 deletions service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tokio::{sync::mpsc::unbounded_channel, task::JoinSet};
use transformer::SampleArgs;

pub use session::Session;
pub use session::SessionHandle;

#[macro_use]
extern crate log;
Expand Down
1 change: 1 addition & 0 deletions web-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ serde = { workspace = true, features = ["derive"] }
tokio.workspace = true
futures = "0.3"
actix-web = "4.5"
log.workspace = true
16 changes: 16 additions & 0 deletions web-api/src/handlers/cancel_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::response::Response;
use crate::schemas;
use crate::AppState;
use actix_web::{post, web, HttpResponse};

#[post("/cancel")]
pub async fn cancel(
app_state: web::Data<AppState>,
request: web::Json<schemas::CancelRequest>,
) -> HttpResponse {
info!("Request from {}: cancel infer", request.session_id);
match app_state.service_manager.cancel_session(&request) {
Ok(s) => Response::success(s),
Err(e) => Response::error(e),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use futures::stream::{Stream, StreamExt};
use service::Session;

#[post("/infer")]
async fn infer(
pub async fn infer(
app_state: web::Data<AppState>,
request: web::Json<schemas::InferRequest>,
) -> HttpResponse {
println!("Request from {}: infer", request.session_id);
info!("Request from {}: infer", request.session_id);

match app_state.service_manager.get_session(&request) {
Ok(session) => {
Expand Down
4 changes: 4 additions & 0 deletions web-api/src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod cancel_handler;
mod infer_handler;
pub use cancel_handler::cancel;
pub use infer_handler::infer;
4 changes: 4 additions & 0 deletions web-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ mod response;
mod schemas;
use manager::ServiceManager;

#[macro_use]
extern crate log;

/// All global variables and services shared among all endpoints in this App
pub struct AppState {
/// Manager of this App, which provides all kinds of services such as infer, session management, etc
Expand All @@ -26,6 +29,7 @@ pub async fn start_infer_service(
App::new()
.app_data(app_state.clone())
.service(handlers::infer)
.service(handlers::cancel)
})
.bind(addrs)?
.run()
Expand Down
49 changes: 46 additions & 3 deletions web-api/src/manager.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::schemas;
use actix_web::web;
use service::{Service, Session};
use service::{Service, Session, SessionHandle};
use std::collections::{hash_map::Entry, HashMap};
use std::sync::{Arc, Mutex};
pub struct ServiceManager {
Expand All @@ -11,13 +11,17 @@ pub struct ServiceManager {
/// A session must be re-inserted after being served.
sessions: Arc<Mutex<HashMap<String, Option<Session>>>>,

/// The abort handle for all existing sessions, session id as key.
session_handles: Arc<Mutex<HashMap<String, SessionHandle>>>,

/// Inference service provided by backend model
infer_service: Arc<Service>,
}
impl ServiceManager {
pub fn from(infer_service: Arc<Service>) -> Self {
ServiceManager {
sessions: Default::default(),
session_handles: Default::default(),
infer_service: infer_service.clone(),
}
}
Expand All @@ -44,7 +48,11 @@ impl ServiceManager {
Some(s) => {
if request.first_request {
// Session id exists but user thinks otherwise, overwrite current session
Ok(e.insert(Some(self.infer_service.launch())).take().unwrap())
Ok(
e.insert(Some(self.create_session(request.session_id.to_string())))
.take()
.unwrap(),
)
} else {
// take the existing session
Ok(s)
Expand All @@ -55,7 +63,11 @@ impl ServiceManager {
Entry::Vacant(e) => {
if request.first_request {
// First request, create new session
Ok(e.insert(Some(self.infer_service.launch())).take().unwrap())
Ok(
e.insert(Some(self.create_session(request.session_id.to_string())))
.take()
.unwrap(),
)
} else {
// Session id does not exist but user thinks otherwise, histroy lost
Err(schemas::Error::SessionNotFound)
Expand All @@ -64,11 +76,42 @@ impl ServiceManager {
}
}

/// Create a new infer session
fn create_session(&self, session_id: String) -> Session {
// launch new infer session
let session = self.infer_service.launch();
// register session abort handle
self.session_handles
.lock()
.unwrap()
.insert(session_id, session.handle());
return session;
}

/// Return the taken-away session, should be called every time a request is done
pub fn reset_session(&self, id: &str, session: Session) {
self.sessions
.lock()
.unwrap()
.insert(id.to_string(), Some(session));
}

/// Signal the backend model to stop the current inferrence task given session id
pub fn cancel_session(
&self,
request: &web::Json<schemas::CancelRequest>,
) -> Result<schemas::Success, schemas::Error> {
match self
.session_handles
.lock()
.unwrap()
.entry(request.session_id.to_string())
{
Entry::Occupied(handle) => {
handle.get().abort();
Ok(schemas::Success::SessionCanceled)
}
Entry::Vacant(_) => Err(schemas::Error::CancelFailed),
}
}
}
7 changes: 7 additions & 0 deletions web-api/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ impl Response {
.content_type("application/json")
.json(err)
}

pub fn success(s: schemas::Success) -> HttpResponse {
let success = schemas::SuccessResponse { result: s.msg() };
HttpResponse::Ok()
.content_type("application/json")
.json(success)
}
}
32 changes: 32 additions & 0 deletions web-api/src/schemas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,48 @@ impl From<web::Json<InferRequest>> for InferRequest {
}
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct CancelRequest {
pub session_id: String,
}

impl From<web::Json<CancelRequest>> for CancelRequest {
fn from(request: web::Json<CancelRequest>) -> Self {
CancelRequest {
session_id: request.session_id.clone(),
}
}
}

pub enum Success {
SessionCanceled,
}

impl Success {
pub fn msg(&self) -> String {
match self {
Success::SessionCanceled => "Inferrence canceled".to_string(),
}
}
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct SuccessResponse {
pub result: String,
}

pub enum Error {
SessionBusy,
SessionNotFound,
CancelFailed,
}

impl Error {
pub fn msg(&self) -> String {
match self {
Error::SessionBusy => "Session is busy".to_string(),
Error::SessionNotFound => "Session histroy is lost".to_string(),
Error::CancelFailed => "Failed to cancel inferrence".to_string(),
}
}
}
Expand Down

0 comments on commit 759b883

Please sign in to comment.