Skip to content

Commit

Permalink
style(web-api): cleanup and update
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 16, 2024
1 parent 759b883 commit 68b1ae5
Show file tree
Hide file tree
Showing 9 changed files with 737 additions and 79 deletions.
691 changes: 691 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion web-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ transformer = { path = "../transformer" }
service = { path = "../service" }
serde = { workspace = true, features = ["derive"] }
tokio.workspace = true
log.workspace = true
futures = "0.3"
actix-web = "4.5"
log.workspace = true
4 changes: 1 addition & 3 deletions web-api/src/handlers/cancel_handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use crate::response::Response;
use crate::schemas;
use crate::AppState;
use crate::{response::Response, schemas, AppState};
use actix_web::{post, web, HttpResponse};

#[post("/cancel")]
Expand Down
50 changes: 19 additions & 31 deletions web-api/src/handlers/infer_handler.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,41 @@
use crate::response::Response;
use crate::schemas;
use crate::AppState;
use crate::{response::Response, schemas, AppState};
use actix_web::{post, web, Error, HttpResponse};
use futures::channel::mpsc;
use futures::stream::{Stream, StreamExt};
use service::Session;
use futures::{
channel::mpsc,
stream::{Stream, StreamExt},
};

#[post("/infer")]
pub async fn infer(
app_state: web::Data<AppState>,
request: web::Json<schemas::InferRequest>,
) -> HttpResponse {
info!("Request from {}: infer", request.session_id);

match app_state.service_manager.get_session(&request) {
Ok(session) => {
let infer_stream = create_infer_stream(session, request, app_state);
Response::text_stream(infer_stream)
}
Ok(session) => Response::text_stream(create_infer_stream(session, request, app_state)),
Err(e) => Response::error(e),
}
}

fn create_infer_stream(
mut session: Session,
mut session: service::Session,
request: web::Json<schemas::InferRequest>,
app_state: web::Data<AppState>,
) -> impl Stream<Item = Result<web::Bytes, Error>> {
let (sender, receiver) = mpsc::channel(4096);
let (mut sender, receiver) = mpsc::channel(4096);

tokio::spawn(async move {
let id = request.session_id.clone();
session_async_infer(&mut session, request, sender).await;
app_state.service_manager.reset_session(&id, session);
session
.chat(&request.inputs, |s| {
sender
.try_send(s.to_string())
.expect("Failed to write data into output channel.")
})
.await;
app_state
.service_manager
.reset_session(&request.session_id, session);
});

receiver.map(|word| Ok(web::Bytes::from(word)))
}

async fn session_async_infer(
session: &mut Session,
request: web::Json<schemas::InferRequest>,
mut sender: mpsc::Sender<String>,
) {
session
.chat(&request.inputs, |s| {
sender
.try_send(s.to_string())
.expect("Failed to write data into output channel.")
})
.await
receiver.map(|word| Ok(word.into()))
}
1 change: 1 addition & 0 deletions web-api/src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod cancel_handler;
mod infer_handler;

pub use cancel_handler::cancel;
pub use infer_handler::infer;
13 changes: 6 additions & 7 deletions web-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
use actix_web::{web, App, HttpServer};
use service::Service;
use std::{fmt::Debug, net::ToSocketAddrs, sync::Arc};

mod handlers;
mod manager;
mod response;
mod schemas;

use actix_web::{web, App, HttpServer};
use manager::ServiceManager;
use std::{fmt::Debug, net::ToSocketAddrs, sync::Arc};

#[macro_use]
extern crate log;

/// All global variables and services shared among all endpoints in this App
pub struct AppState {
struct AppState {
/// Manager of this App, which provides all kinds of services such as infer, session management, etc
pub service_manager: Arc<ServiceManager>,
service_manager: Arc<ServiceManager>,
}

pub async fn start_infer_service(
service: Service,
service: service::Service,
addrs: impl ToSocketAddrs + Debug,
) -> std::io::Result<()> {
println!("Starting service at {addrs:?}");
Expand Down
20 changes: 0 additions & 20 deletions web-api/src/main.rs

This file was deleted.

33 changes: 17 additions & 16 deletions web-api/src/manager.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::schemas;
use actix_web::web;
use service::{Service, Session, SessionHandle};
use std::collections::{hash_map::Entry, HashMap};
use std::sync::{Arc, Mutex};
use std::{
collections::{hash_map::Entry, HashMap},
sync::{Arc, Mutex},
};

pub struct ServiceManager {
/// All sessions, session id as key.
/// New session will be created when a new id comes.
Expand All @@ -17,6 +20,7 @@ pub struct ServiceManager {
/// Inference service provided by backend model
infer_service: Arc<Service>,
}

impl ServiceManager {
pub fn from(infer_service: Arc<Service>) -> Self {
ServiceManager {
Expand All @@ -42,14 +46,12 @@ impl ServiceManager {
{
// Case session id exists
Entry::Occupied(mut e) => match e.get_mut().take() {
// If session is being served
None => Err(schemas::Error::SessionBusy),
// If session available, check request
Some(s) => {
if request.first_request {
// Session id exists but user thinks otherwise, overwrite current session
Ok(
e.insert(Some(self.create_session(request.session_id.to_string())))
e.insert(Some(self.create_session(request.session_id.clone())))
.take()
.unwrap(),
)
Expand All @@ -58,13 +60,15 @@ impl ServiceManager {
Ok(s)
}
}
// If session is being served
None => Err(schemas::Error::SessionBusy),
},
// Case new session id
Entry::Vacant(e) => {
if request.first_request {
// First request, create new session
Ok(
e.insert(Some(self.create_session(request.session_id.to_string())))
e.insert(Some(self.create_session(request.session_id.clone())))
.take()
.unwrap(),
)
Expand Down Expand Up @@ -101,17 +105,14 @@ impl ServiceManager {
&self,
request: &web::Json<schemas::CancelRequest>,
) -> Result<schemas::Success, schemas::Error> {
match self
.session_handles
self.session_handles
.lock()
.unwrap()
.entry(request.session_id.to_string())
{
Entry::Occupied(handle) => {
handle.get().abort();
Ok(schemas::Success::SessionCanceled)
}
Entry::Vacant(_) => Err(schemas::Error::CancelFailed),
}
.get(&request.session_id)
.ok_or(schemas::Error::CancelFailed)
.map(|handle| {
handle.abort();
schemas::Success::SessionCanceled
})
}
}
2 changes: 1 addition & 1 deletion xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn block_on(f: impl Future) {
runtime.shutdown_background();
#[cfg(feature = "nvidia")]
{
service::synchronize();
::service::synchronize();
}
}

Expand Down

0 comments on commit 68b1ae5

Please sign in to comment.