diff --git a/Cargo.toml b/Cargo.toml index 4623413..2b50738 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ name = "dialogflow" [dependencies] # artful = "0.1.1" anyhow = "1.0" -axum = {version = "0.7", features = ["query", "tokio", "macros"]} +axum = {version = "0.7", features = ["query", "tokio", "macros", "multipart"]} bigdecimal = "0.4" # candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.1" } # candle = { version = "0.6", package = "candle-core", default-features = false } @@ -57,7 +57,7 @@ scraper = "0.20" # strsim = "0.10.0" # textdistance = "1.0.2" time = { version = "0.3", features = ["formatting"] } -tower-http = { version = "0.5", features = ["cors"] } +tower-http = { version = "0.5", features = ["cors", "limit"] } # typetag = "0.2" tokio = { version = "1", features = ["fs", "io-util", "macros", "net", "rt", "rt-multi-thread", "signal", "time"] } tokio-stream = "0.1" diff --git a/src/flow/rt/context.rs b/src/flow/rt/context.rs index 17f99ff..a35ba1e 100644 --- a/src/flow/rt/context.rs +++ b/src/flow/rt/context.rs @@ -108,6 +108,7 @@ impl Context { pub(in crate::flow::rt) fn pop_node(&mut self) -> Option { // log::info!("nodes len {}", self.nodes.len()); + let now = std::time::Instant::now(); if self.node.is_some() { let node = std::mem::replace(&mut self.node, None); let v = node.unwrap(); @@ -121,6 +122,7 @@ impl Context { if let Some(node_id) = self.nodes.pop_front() { // log::info!("main_flow_id {} node_id {}", &self.main_flow_id, &node_id); if let Ok(r) = super::crud::get_runtime_node(&self.main_flow_id, &node_id) { + log::info!("pop_node time {:?}", now.elapsed()); return r; } } diff --git a/src/flow/rt/executor.rs b/src/flow/rt/executor.rs index 4503508..6eb8a43 100644 --- a/src/flow/rt/executor.rs +++ b/src/flow/rt/executor.rs @@ -6,20 +6,20 @@ use crate::intent::detector; use crate::result::{Error, Result}; pub(in crate::flow::rt) async fn process(req: &mut Request) -> Result { - // let now = std::time::Instant::now(); + let now = std::time::Instant::now(); if req.session_id.is_empty() { req.session_id = scru128::new_string(); } let mut ctx = Context::get(&req.robot_id, &req.session_id); - // println!("get ctx {:?}", now.elapsed()); - // let now = std::time::Instant::now(); + log::info!("get ctx {:?}", now.elapsed()); + let now = std::time::Instant::now(); if ctx.no_node() { if ctx.main_flow_id.is_empty() { ctx.main_flow_id.push_str(&req.main_flow_id); } ctx.add_node(&req.main_flow_id); } - // println!("add_node {:?}", now.elapsed()); + log::info!("add_node {:?}", now.elapsed()); let now = std::time::Instant::now(); if req.user_input_intent.is_none() { req.user_input_intent = detector::detect(&req.robot_id, &req.user_input).await?; @@ -52,13 +52,14 @@ pub(in crate::flow::rt) async fn process(req: &mut Request) -> Result } } // println!("exec {:?}", now.elapsed()); - // let now = std::time::Instant::now(); + let now = std::time::Instant::now(); ctx.save()?; - // println!("ctx save {:?}", now.elapsed()); + log::info!("ctx save {:?}", now.elapsed()); r } pub(in crate::flow::rt) fn exec(req: &Request, ctx: &mut Context) -> Result { + let now = std::time::Instant::now(); let mut response = Response::new(req); for _i in 0..100 { // let now = std::time::Instant::now(); @@ -67,6 +68,7 @@ pub(in crate::flow::rt) fn exec(req: &Request, ctx: &mut Context) -> Result Result { + let now = std::time::Instant::now(); let mut v = AlignedVec::<256>::with_capacity(bytes.len()); v.extend_from_slice(bytes); let r = rkyv::from_bytes::(&v).unwrap(); // let archived = rkyv::access::(bytes).unwrap(); // let deserialized = rkyv::deserialize::(archived).unwrap(); + log::info!("deser_node time {:?}", now.elapsed()); return Ok(r); } diff --git a/src/intent/detector.rs b/src/intent/detector.rs index 6093a87..9bd73d6 100644 --- a/src/intent/detector.rs +++ b/src/intent/detector.rs @@ -41,6 +41,9 @@ pub(crate) async fn detect(robot_id: &str, s: &str) -> Result> { } } let embedding = embedding(robot_id, s).await?; + if embedding.0.is_empty() { + return Ok(None); + } // log::info!("Generate embedding cost {:?}", now.elapsed()); // let s = format!("{:?}", &embedding); // let regex = regex::Regex::new(r"\s").unwrap(); diff --git a/src/kb/crud.rs b/src/kb/crud.rs new file mode 100644 index 0000000..cfc5b04 --- /dev/null +++ b/src/kb/crud.rs @@ -0,0 +1,28 @@ +use axum::{extract::Multipart, response::IntoResponse}; + +use crate::result::Error; +use crate::web::server::to_res; + +pub(crate) async fn upload(mut multipart: Multipart) -> impl IntoResponse { + loop { + let r = multipart.next_field().await; + if r.is_err() { + let m = format!("Upload failed, err: {:?}.", r.unwrap_err()); + return to_res(Err(Error::ErrorWithMessage(m))); + } + let field = r.unwrap(); + if field.is_none() { + return to_res(Ok("Upload successfully.")); + } + let field = field.unwrap(); + let name = field.name().unwrap().to_string(); + let file_name = field.file_name().unwrap().to_string(); + let content_type = field.content_type().unwrap().to_string(); + let data = field.bytes().await.unwrap(); + + println!( + "Length of `{name}` (`{file_name}`: `{content_type}`) is {} bytes", + data.len() + ); + } +} \ No newline at end of file diff --git a/src/kb/mod.rs b/src/kb/mod.rs new file mode 100644 index 0000000..9706a3c --- /dev/null +++ b/src/kb/mod.rs @@ -0,0 +1 @@ +pub(crate) mod crud; \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index d2d6829..996ab38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub(crate) mod db; pub(crate) mod external; pub(crate) mod flow; pub(crate) mod intent; +pub(crate) mod kb; pub(crate) mod man; pub(crate) mod result; pub(crate) mod robot; diff --git a/src/man/settings.rs b/src/man/settings.rs index 4706941..ac0f040 100644 --- a/src/man/settings.rs +++ b/src/man/settings.rs @@ -45,6 +45,7 @@ pub(crate) struct GlobalSettings { #[derive(Clone, Deserialize, Serialize)] pub(crate) struct Settings { + settings_version: u8, #[serde(rename = "maxSessionIdleSec")] pub(crate) max_session_idle_sec: u32, #[serde(rename = "chatProvider")] @@ -206,6 +207,7 @@ impl Default for GlobalSettings { impl Default for Settings { fn default() -> Self { Settings { + settings_version: 1u8, max_session_idle_sec: 1800, chat_provider: ChatProvider { provider: chat::ChatProvider::HuggingFace( diff --git a/src/web/asset.rs b/src/web/asset.rs index 09e40fe..222d6d2 100644 --- a/src/web/asset.rs +++ b/src/web/asset.rs @@ -3,22 +3,18 @@ use std::collections::HashMap; use std::sync::LazyLock; pub(crate) static ASSETS_MAP: LazyLock> = LazyLock::new(|| { - HashMap::from([ - (r"/assets/inbound-bot-PJJg_rST.png", 0), - (r"/assets/index-B4PwGmOZ.css", 1), - (r"/assets/index-DaTxxXw1.js", 2), - (r"/assets/outbound-bot-EmsLuWRN.png", 3), - (r"/assets/text-bot-CWb_Poym.png", 4), - (r"/assets/usedByDialogNodeTextGeneration-DrFqkTqi.png", 5), - ( - r"/assets/usedByDialogNodeTextGeneration-thumbnail-C1iQCVQO.png", - 6, - ), - (r"/assets/usedByLlmChatNode-Bv2Fg5P7.png", 7), - (r"/assets/usedBySentenceEmbedding-Dmju1hVB.png", 8), - (r"/assets/usedBySentenceEmbedding-thumbnail-DVXz_sh0.png", 9), - (r"/favicon.ico", 10), - ("/", 11), - (r"/index.html", 11), - ]) -}); +HashMap::from([ +(r"/assets/inbound-bot-PJJg_rST.png", 0), +(r"/assets/index-B4PwGmOZ.css", 1), +(r"/assets/index-DaTxxXw1.js", 2), +(r"/assets/outbound-bot-EmsLuWRN.png", 3), +(r"/assets/text-bot-CWb_Poym.png", 4), +(r"/assets/usedByDialogNodeTextGeneration-DrFqkTqi.png", 5), +(r"/assets/usedByDialogNodeTextGeneration-thumbnail-C1iQCVQO.png", 6), +(r"/assets/usedByLlmChatNode-Bv2Fg5P7.png", 7), +(r"/assets/usedBySentenceEmbedding-Dmju1hVB.png", 8), +(r"/assets/usedBySentenceEmbedding-thumbnail-DVXz_sh0.png", 9), +(r"/favicon.ico", 10), +("/", 11), +(r"/index.html", 11), +])}); diff --git a/src/web/server.rs b/src/web/server.rs index 102fd97..0786c0b 100644 --- a/src/web/server.rs +++ b/src/web/server.rs @@ -2,6 +2,7 @@ use std::sync::LazyLock; use std::vec::Vec; +use axum::extract::DefaultBodyLimit; use axum::http::{header, HeaderMap, HeaderValue, Method, StatusCode, Uri}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -9,6 +10,7 @@ use axum::Router; use colored::Colorize; use serde::{Deserialize, Serialize}; use tower_http::cors::{AllowOrigin, CorsLayer}; +use tower_http::limit::RequestBodyLimitLayer; use super::asset::ASSETS_MAP; use crate::ai::crud as ai; @@ -17,6 +19,7 @@ use crate::flow::mainflow::crud as mainflow; use crate::flow::rt::facade as rt; use crate::flow::subflow::crud as subflow; use crate::intent::crud as intent; +use crate::kb::crud as kb; use crate::man::settings; use crate::result::Error; use crate::robot::crud as robot; @@ -230,6 +233,10 @@ fn gen_router() -> Router { "/management/settings/model/check/embedding", get(settings::check_embedding_model), ) + .route( + "/kb/doc/upload", + post(kb::upload), + ) .route("/management/settings/smtp/test", post(settings::smtp_test)) .route("/flow/answer", post(rt::answer)) .route("/flow/answer/sse", post(rt::answer_sse)) @@ -237,6 +244,10 @@ fn gen_router() -> Router { .route("/version.json", get(version)) .route("/check-new-version.json", get(check_new_version)) // .route("/o", get(subflow::output)) + .layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new( + 250 * 1024 * 1024, /* 250mb */ + )) .layer( CorsLayer::new() .allow_origin(AllowOrigin::predicate(