From efd4bf56b38d30e6d662de0485be795e7a5a3b6c Mon Sep 17 00:00:00 2001 From: dialogflowchatbot Date: Sat, 30 Nov 2024 21:11:07 +0800 Subject: [PATCH] Added more QnA restful APIs --- src/ai/huggingface.rs | 2 +- src/flow/rt/convertor.rs | 2 +- src/flow/rt/node.rs | 68 ++++++++++++++++-- src/flow/subflow/dto.rs | 2 +- src/intent/dto.rs | 27 +++----- src/intent/phrase.rs | 2 +- src/kb/crud.rs | 35 ++++++++-- src/kb/dto.rs | 17 ++--- src/kb/qa.rs | 146 ++++++++++++++++++++++++++++++--------- src/web/server.rs | 8 ++- 10 files changed, 231 insertions(+), 78 deletions(-) diff --git a/src/ai/huggingface.rs b/src/ai/huggingface.rs index 73002f5..a55128d 100644 --- a/src/ai/huggingface.rs +++ b/src/ai/huggingface.rs @@ -694,7 +694,7 @@ pub(crate) fn check_model_files(info: &HuggingFaceModelInfo) -> Result<()> { .open(&f)?; let mut bytes = Vec::with_capacity(4096); file.read_to_end(&mut bytes)?; - let _ = serde_json::from_slice(&bytes)?; + let _: serde::de::IgnoredAny = serde_json::from_slice(&bytes)?; } else if ext.eq("safetensors") { let metadata = std::fs::metadata(&p)?; if metadata.len() < 62914560u64 { diff --git a/src/flow/rt/convertor.rs b/src/flow/rt/convertor.rs index a03fe64..e5d87c9 100644 --- a/src/flow/rt/convertor.rs +++ b/src/flow/rt/convertor.rs @@ -172,7 +172,7 @@ fn convert_node(main_flow_id: &str, node: &mut Node) -> Result<()> { context_len: n.context_length, cur_run_times: 0, exit_condition: n.exit_condition.clone(), - when_timeout_then: n.when_timeout_then.clone(), + answer_timeout_then: n.when_timeout_then.clone(), streaming: n.response_streaming, connect_timeout: n.connect_timeout, read_timeout: n.read_timeout, diff --git a/src/flow/rt/node.rs b/src/flow/rt/node.rs index 6649147..fb097e1 100644 --- a/src/flow/rt/node.rs +++ b/src/flow/rt/node.rs @@ -373,7 +373,7 @@ pub(crate) enum LlmChatNodeExitCondition { #[derive(Archive, Clone, Deserialize, Serialize, serde::Deserialize)] #[rkyv(compare(PartialEq))] -pub(crate) enum LlmChatNodeWhenTimeoutThen { +pub(crate) enum LlmChatAnswerTimeoutThen { GotoAnotherNode, ResponseAlternateText(String), DoNothing, @@ -386,7 +386,7 @@ pub(crate) struct LlmChatNode { pub(super) context_len: u8, pub(super) cur_run_times: u8, pub(super) exit_condition: LlmChatNodeExitCondition, - pub(super) when_timeout_then: LlmChatNodeWhenTimeoutThen, + pub(super) answer_timeout_then: LlmChatAnswerTimeoutThen, pub(super) streaming: bool, pub(crate) connect_timeout: Option, pub(crate) read_timeout: Option, @@ -474,14 +474,14 @@ impl RuntimeNode for LlmChatNode { )) }) { log::error!("LlmChatNode response failed, err: {:?}", &e); - match &self.when_timeout_then { - LlmChatNodeWhenTimeoutThen::GotoAnotherNode => { + match &self.answer_timeout_then { + LlmChatAnswerTimeoutThen::GotoAnotherNode => { ctx.node = None; add_next_node(ctx, &self.next_node_id); return false; } - LlmChatNodeWhenTimeoutThen::ResponseAlternateText(t) => s.push_str(t), - LlmChatNodeWhenTimeoutThen::DoNothing => return false, + LlmChatAnswerTimeoutThen::ResponseAlternateText(t) => s.push_str(t), + LlmChatAnswerTimeoutThen::DoNothing => return false, } } log::info!("LLM response |{}|", &s); @@ -537,6 +537,62 @@ impl RuntimeNode for LlmChatNode { } } +#[derive(Archive, Clone, Deserialize, Serialize, serde::Deserialize)] +#[rkyv(compare(PartialEq))] +pub(crate) enum KnowledgeBaseAnswerNoRecallThen { + GotoAnotherNode, + ReturnAlternativeAnswerInstead(String), +} + +#[derive(Archive, Clone, Deserialize, Serialize)] +#[rkyv(compare(PartialEq))] +pub(crate) struct KnowledgeBaseAnswerNode { + pub(super) recall_thresholds: f64, + pub(super) no_recall_then: KnowledgeBaseAnswerNoRecallThen, + pub(super) alternative_answer: String, + pub(super) next_node_id: String, +} + +impl RuntimeNode for KnowledgeBaseAnswerNode { + fn exec(&mut self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool { + // log::info!("Into LlmChaKnowledgeBaseAnswerNodetNode"); + let result = tokio::runtime::Handle::current().block_on(crate::kb::qa::retrieve_answer( + &req.robot_id, + &req.user_input, + )); + match result { + Ok((answer, thresholds)) => { + if answer.is_some() && thresholds >= self.recall_thresholds { + response.answers.push(AnswerData { + text: answer.unwrap().answer, + answer_type: AnswerType::TextPlain, + }); + true + } else { + match &self.no_recall_then { + KnowledgeBaseAnswerNoRecallThen::GotoAnotherNode => { + add_next_node(ctx, &self.next_node_id); + false + } + KnowledgeBaseAnswerNoRecallThen::ReturnAlternativeAnswerInstead(s) => { + response.answers.push(AnswerData { + text: s.clone(), + answer_type: AnswerType::TextPlain, + }); + true + } + } + } + } + Err(e) => { + log::error!("KnowledgeBaseAnswerNode answer failed: {:?}", &e); + add_next_node(ctx, &self.next_node_id); + false + } + } + } +} + pub(crate) fn deser_node(bytes: &[u8]) -> Result { let now = std::time::Instant::now(); let mut v = AlignedVec::<256>::with_capacity(bytes.len()); diff --git a/src/flow/subflow/dto.rs b/src/flow/subflow/dto.rs index 053ae5d..4f83c0c 100644 --- a/src/flow/subflow/dto.rs +++ b/src/flow/subflow/dto.rs @@ -349,7 +349,7 @@ pub(crate) struct LlmChatNode { #[serde(rename = "exitCondition")] pub(crate) exit_condition: crate::flow::rt::node::LlmChatNodeExitCondition, #[serde(rename = "whenTimeoutThen")] - pub(crate) when_timeout_then: crate::flow::rt::node::LlmChatNodeWhenTimeoutThen, + pub(crate) when_timeout_then: crate::flow::rt::node::LlmChatAnswerTimeoutThen, #[serde(rename = "responseStreaming")] pub(crate) response_streaming: bool, pub(crate) branches: Vec, diff --git a/src/intent/dto.rs b/src/intent/dto.rs index cf6970c..31178cf 100644 --- a/src/intent/dto.rs +++ b/src/intent/dto.rs @@ -1,9 +1,5 @@ -use std::sync::Mutex; - use serde::{Deserialize, Serialize}; -use crate::result::Result; - #[derive(Deserialize, Debug)] pub(crate) struct IntentFormData { #[serde(rename = "robotId")] @@ -18,15 +14,15 @@ pub(crate) struct IntentsData { pub(crate) intents: Vec, } -static VEC_ROW_ID_LOCKER: Mutex<()> = Mutex::new(()); +// static VEC_ROW_ID_LOCKER: Mutex<()> = Mutex::new(()); -impl IntentsData { - pub(crate) fn inc_get_phrase_vec_id(&mut self) -> Result { - let _l = VEC_ROW_ID_LOCKER.lock()?; - self.phrase_vec_row_id = self.phrase_vec_row_id + 1; - Ok(self.phrase_vec_row_id) - } -} +// impl IntentsData { +// pub(crate) fn inc_get_phrase_vec_id(&mut self) -> Result { +// let _l = VEC_ROW_ID_LOCKER.lock()?; +// self.phrase_vec_row_id = self.phrase_vec_row_id + 1; +// Ok(self.phrase_vec_row_id) +// } +// } #[derive(Clone, Serialize, Deserialize, Debug)] pub(crate) struct Intent { @@ -67,8 +63,6 @@ pub(crate) struct IntentDetail { pub(crate) phrases: Vec, } -static PHRASE_ID_LOCKER: Mutex<()> = Mutex::new(()); - impl IntentDetail { pub(crate) fn new(intent_idx: usize, intent_id: String, intent_name: String) -> Self { IntentDetail { @@ -81,9 +75,4 @@ impl IntentDetail { phrases: vec![], } } - pub(crate) fn inc_get_phrase_vec_id(&mut self) -> Result { - let _l = PHRASE_ID_LOCKER.lock()?; - self.phrase_vec_row_id = self.phrase_vec_row_id + 1; - Ok(self.phrase_vec_row_id) - } } diff --git a/src/intent/phrase.rs b/src/intent/phrase.rs index d43839a..b138380 100644 --- a/src/intent/phrase.rs +++ b/src/intent/phrase.rs @@ -79,7 +79,7 @@ pub(crate) async fn init_tables(robot_id: &str) -> Result<()> { pub(crate) async fn search(robot_id: &str, vectors: &Vec) -> Result> { let sql = format!( - "SELECT intent_id, intent_name, distance FROM {} WHERE vectors MATCH ? ORDER BY distance LIMIT 1", + "SELECT intent_id, intent_name, distance FROM {} WHERE vectors MATCH ? ORDER BY distance ASC LIMIT 1", robot_id ); let results = sqlx::query::(&sql) diff --git a/src/kb/crud.rs b/src/kb/crud.rs index 19a9382..12282d6 100644 --- a/src/kb/crud.rs +++ b/src/kb/crud.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::Path; use axum::{ @@ -12,14 +13,17 @@ use crate::result::{Error, Result}; use crate::robot::dto::RobotQuery; use crate::web::server::to_res; -pub(crate) async fn upload(Query(q): Query, multipart: Multipart) -> impl IntoResponse { - if let Err(e) = do_uploading(&q.robot_id, multipart).await { +pub(crate) async fn upload_doc( + Query(q): Query, + multipart: Multipart, +) -> impl IntoResponse { + if let Err(e) = upload_doc_inner(&q.robot_id, multipart).await { return to_res(Err(e)); } to_res(Ok(())) } -async fn do_uploading(robot_id: &str, mut multipart: Multipart) -> Result<()> { +async fn upload_doc_inner(robot_id: &str, mut multipart: Multipart) -> Result<()> { let p = Path::new(".") .join("data") .join(robot_id) @@ -68,11 +72,32 @@ pub(crate) async fn list_qa(Query(q): Query) -> impl IntoResponse { to_res(r) } -pub(crate) async fn add_qa( +pub(crate) async fn save_qa( Query(q): Query, Json(d): Json, ) -> impl IntoResponse { - let r = super::qa::add(&q.robot_id, d).await; + let r = super::qa::save(&q.robot_id, d).await; // let r = sqlite_trans!(super::qa::add, &q.robot_id, d).await; to_res(r) } + +pub(crate) async fn delete_qa( + Query(q): Query, + Json(d): Json, +) -> impl IntoResponse { + let r = super::qa::delete(&q.robot_id, d).await; + to_res(r) +} + +pub(crate) async fn qa_dryrun(Query(q): Query>) -> impl IntoResponse { + let r = q.get("robotId"); + let t = q.get("text"); + if r.is_none() || t.is_none() { + let res = Err(Error::ErrorWithMessage(String::from( + "robotId or text was missing.", + ))); + return to_res(res); + } + let r = super::qa::retrieve_answer(r.unwrap(), t.unwrap()).await; + to_res(r) +} diff --git a/src/kb/dto.rs b/src/kb/dto.rs index 66e98eb..b4a8ee4 100644 --- a/src/kb/dto.rs +++ b/src/kb/dto.rs @@ -2,19 +2,20 @@ use std::vec::Vec; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Serialize)] -pub(crate) struct QuestionAnswerData { - pub(super) id: String, - #[serde(rename = "qaData")] - pub(super) qa_data: QuestionAnswerPair, -} +// #[derive(Deserialize, Serialize)] +// pub(crate) struct QuestionAnswerData { +// pub(super) id: Option, +// #[serde(rename = "qaData")] +// pub(super) qa_data: QuestionAnswerPair, +// } #[derive(Deserialize, Serialize)] pub(crate) struct QuestionAnswerPair { + pub(super) id: Option, pub(super) question: QuestionData, #[serde(rename = "similarQuestions")] - pub(super) similar_questions: Option>, - pub(super) answer: String, + pub(super) similar_questions: Vec, + pub(crate) answer: String, } #[derive(Deserialize, Serialize)] diff --git a/src/kb/qa.rs b/src/kb/qa.rs index e757698..6021b6d 100644 --- a/src/kb/qa.rs +++ b/src/kb/qa.rs @@ -1,19 +1,10 @@ -use core::time::Duration; - -use std::fs::OpenOptions; -use std::path::Path; use std::sync::OnceLock; use std::vec::Vec; -use axum::{ - extract::{Multipart, Query}, - response::IntoResponse, - Json, -}; use futures_util::StreamExt; -use sqlx::{pool::PoolOptions, Row, Sqlite}; +use sqlx::{Row, Sqlite}; -use super::dto::{QuestionAnswerData, QuestionAnswerPair, QuestionData}; +use super::dto::{QuestionAnswerPair, QuestionData}; use crate::ai::embedding::embedding; use crate::result::{Error, Result}; @@ -96,28 +87,25 @@ pub(crate) async fn init_tables(robot_id: &str) -> Result<()> { // } // ); -pub(crate) async fn list(robot_id: &str) -> Result> { +pub(crate) async fn list(robot_id: &str) -> Result> { let sql = format!( - "SELECT id, qa_data FROM {}_qa ORDER BY created_at DESC", + "SELECT qa_data FROM {}_qa ORDER BY created_at DESC", robot_id ); let results = sqlx::query::(&sql) .fetch_all(DATA_SOURCE.get().unwrap()) .await?; - let mut d: Vec = Vec::with_capacity(results.len()); + let mut d: Vec = Vec::with_capacity(results.len()); for r in results.iter() { - d.push(QuestionAnswerData { - id: r.try_get(0)?, - qa_data: serde_json::from_str(dbg!(r.try_get(1)?))?, - }); + d.push(serde_json::from_str(dbg!(r.try_get(0)?))?); } Ok(d) } -pub(crate) async fn add(robot_id: &str, d: QuestionAnswerPair) -> Result { +pub(crate) async fn save(robot_id: &str, d: QuestionAnswerPair) -> Result { let ds = DATA_SOURCE.get().unwrap(); let mut transaction = ds.begin().await?; - let r = add_inner(robot_id, d, &mut transaction).await; + let r = save_inner(robot_id, d, &mut transaction).await; if r.is_ok() { transaction.commit().await?; } else { @@ -126,19 +114,23 @@ pub(crate) async fn add(robot_id: &str, d: QuestionAnswerPair) -> Result r } -async fn add_inner( +async fn save_inner( robot_id: &str, mut d: QuestionAnswerPair, transaction: &mut sqlx::Transaction<'_, sqlx::Sqlite>, ) -> Result { let mut questions: Vec<&mut QuestionData> = Vec::with_capacity(5); questions.push(&mut d.question); - if d.similar_questions.is_some() { - let similar_questions: Vec<&mut QuestionData> = - d.similar_questions.as_mut().unwrap().iter_mut().collect(); + if !d.similar_questions.is_empty() { + let similar_questions: Vec<&mut QuestionData> = d.similar_questions.iter_mut().collect(); questions.extend(similar_questions); } - let qa_id = scru128::new_string(); + let mut new_record = false; + if d.id.is_none() { + d.id = Some(scru128::new_string()); + new_record = true; + } + let mut vec_row_ids: Vec = Vec::with_capacity(questions.len()); for q in questions.iter_mut() { let vectors = embedding(robot_id, &q.question).await?; if vectors.0.is_empty() { @@ -159,7 +151,7 @@ async fn add_inner( .last_insert_rowid(); let sql = format!( "CREATE VIRTUAL TABLE IF NOT EXISTS {} USING vec0 ( - +qa_id TEXT NOT NULL, + qa_id TEXT NOT NULL, vectors float[{}] ); INSERT INTO {} (rowid, qa_id, vectors)VALUES(?, ?, ?)", @@ -170,29 +162,115 @@ async fn add_inner( ); sqlx::query::(&sql) .bind(last_insert_rowid) - .bind(&qa_id) + .bind(d.id.as_ref().unwrap()) .bind(serde_json::to_string(&vectors.0)?) .execute(&mut **transaction) .await?; q.vec_row_id = Some(last_insert_rowid); } else { - let sql = format!("UPDATE {} SET vectors = ? WHERE = ?", robot_id); + let sql = format!("UPDATE {} SET vectors = ? WHERE rowid = ?", robot_id); let vec_row_id = q.vec_row_id.unwrap(); sqlx::query::(&sql) .bind(serde_json::to_string(&vectors.0)?) .bind(vec_row_id) .execute(&mut **transaction) .await?; + vec_row_ids.push(vec_row_id); }; } + if !vec_row_ids.is_empty() { + let params = format!("?{}", ", ?".repeat(vec_row_ids.len() - 1)); + let sql = format!("DELETE FROM {} WHERE rowid NOT IN ({})", robot_id, params); + let mut query = sqlx::query(&sql); + for i in vec_row_ids { + query = query.bind(i); + } + query.fetch_all(&mut **transaction).await?; + } + if new_record { + let sql = format!( + "INSERT INTO {}_qa(id, qa_data, created_at)VALUES(?, ?, unixepoch())", + robot_id + ); + sqlx::query::(&sql) + .bind(d.id.as_ref().unwrap()) + .bind(dbg!(serde_json::to_string(&d)?)) + .execute(&mut **transaction) + .await?; + } else { + let sql = format!("UPDATE {}_qa SET qa_data = ? WHERE id = ?", robot_id); + sqlx::query::(&sql) + .bind(dbg!(serde_json::to_string(&d)?)) + .bind(d.id.as_ref().unwrap()) + .execute(&mut **transaction) + .await?; + } + Ok(d.id.unwrap()) +} + +pub(crate) async fn delete(robot_id: &str, d: QuestionAnswerPair) -> Result<()> { + let ds = DATA_SOURCE.get().unwrap(); + let mut transaction = ds.begin().await?; + let r = delete_inner(robot_id, d, &mut transaction).await; + if r.is_ok() { + transaction.commit().await?; + } else { + transaction.rollback().await?; + } + r +} + +async fn delete_inner( + robot_id: &str, + d: QuestionAnswerPair, + transaction: &mut sqlx::Transaction<'_, sqlx::Sqlite>, +) -> Result<()> { + //todo sqlx prepare statement let sql = format!( - "INSERT INTO {}_qa(id, qa_data, created_at)VALUES(?, ?, unixepoch())", - robot_id + " + DELETE FROM {} WHERE qa_id = ?; + DELETE FROM {}_qa WHERE id = ?; + ", + robot_id, robot_id ); - sqlx::query::(&sql) - .bind(&qa_id) - .bind(dbg!(serde_json::to_string(&d)?)) + let qa_id = d.id.as_ref().unwrap(); + let r = sqlx::query(&sql) + .bind(qa_id) + .bind(qa_id) .execute(&mut **transaction) .await?; - Ok(qa_id) + log::info!("{}", r.rows_affected()); + Ok(()) +} + +pub(crate) async fn retrieve_answer( + robot_id: &str, + question: &str, +) -> Result<(Option, f64)> { + let vectors = embedding(robot_id, question).await?; + if vectors.0.is_empty() { + let err = format!("{} embedding data is empty", question); + log::warn!("{}", &err); + return Err(Error::ErrorWithMessage(err)); + } + + let sql = format!( + " + SELECT qa_data, v.distance FROM {}_qa q INNER JOIN + (SELECT qa_id, distance FROM {} WHERE vectors MATCH ? ORDER BY distance ASC LIMIT 1) v + ON q.id = v.qa_id + ", + robot_id, robot_id + ); + let results = sqlx::query::(&sql) + .bind(serde_json::to_string(&vectors.0)?) + .fetch_all(DATA_SOURCE.get().unwrap()) + .await?; + if results.len() > 0 { + return Ok(( + Some(serde_json::from_str(results[0].try_get(0)?)?), + results[0].try_get(1)?, + )); + } + Ok((None, 1.0)) } diff --git a/src/web/server.rs b/src/web/server.rs index 735988f..da6e2a4 100644 --- a/src/web/server.rs +++ b/src/web/server.rs @@ -247,8 +247,12 @@ fn gen_router() -> Router { "/management/settings/model/check/embedding", get(settings::check_embedding_model), ) - .route("/kb/qa", get(kb::list_qa).post(kb::add_qa)) - .route("/kb/doc/upload", post(kb::upload)) + .route( + "/kb/qa", + get(kb::list_qa).post(kb::save_qa).delete(kb::delete_qa), + ) + .route("/kb/qa/dryrun", get(kb::qa_dryrun)) + .route("/kb/doc/upload", post(kb::upload_doc)) .route("/management/settings/smtp/test", post(settings::smtp_test)) .route("/flow/answer", post(rt::answer)) .route("/flow/answer/sse", post(rt::answer_sse))