Skip to content

Commit

Permalink
Added more QnA restful APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
dialogflowchatbot committed Nov 30, 2024
1 parent c9577ba commit efd4bf5
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/ai/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ pub(crate) fn check_model_files(info: &HuggingFaceModelInfo) -> Result<()> {
.open(&f)?;
let mut bytes = Vec::with_capacity(4096);
file.read_to_end(&mut bytes)?;
let _ = serde_json::from_slice(&bytes)?;
let _: serde::de::IgnoredAny = serde_json::from_slice(&bytes)?;
} else if ext.eq("safetensors") {
let metadata = std::fs::metadata(&p)?;
if metadata.len() < 62914560u64 {
Expand Down
2 changes: 1 addition & 1 deletion src/flow/rt/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ fn convert_node(main_flow_id: &str, node: &mut Node) -> Result<()> {
context_len: n.context_length,
cur_run_times: 0,
exit_condition: n.exit_condition.clone(),
when_timeout_then: n.when_timeout_then.clone(),
answer_timeout_then: n.when_timeout_then.clone(),
streaming: n.response_streaming,
connect_timeout: n.connect_timeout,
read_timeout: n.read_timeout,
Expand Down
68 changes: 62 additions & 6 deletions src/flow/rt/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ pub(crate) enum LlmChatNodeExitCondition {

#[derive(Archive, Clone, Deserialize, Serialize, serde::Deserialize)]
#[rkyv(compare(PartialEq))]
pub(crate) enum LlmChatNodeWhenTimeoutThen {
pub(crate) enum LlmChatAnswerTimeoutThen {
GotoAnotherNode,
ResponseAlternateText(String),
DoNothing,
Expand All @@ -386,7 +386,7 @@ pub(crate) struct LlmChatNode {
pub(super) context_len: u8,
pub(super) cur_run_times: u8,
pub(super) exit_condition: LlmChatNodeExitCondition,
pub(super) when_timeout_then: LlmChatNodeWhenTimeoutThen,
pub(super) answer_timeout_then: LlmChatAnswerTimeoutThen,
pub(super) streaming: bool,
pub(crate) connect_timeout: Option<u32>,
pub(crate) read_timeout: Option<u32>,
Expand Down Expand Up @@ -474,14 +474,14 @@ impl RuntimeNode for LlmChatNode {
))
}) {
log::error!("LlmChatNode response failed, err: {:?}", &e);
match &self.when_timeout_then {
LlmChatNodeWhenTimeoutThen::GotoAnotherNode => {
match &self.answer_timeout_then {
LlmChatAnswerTimeoutThen::GotoAnotherNode => {
ctx.node = None;
add_next_node(ctx, &self.next_node_id);
return false;
}
LlmChatNodeWhenTimeoutThen::ResponseAlternateText(t) => s.push_str(t),
LlmChatNodeWhenTimeoutThen::DoNothing => return false,
LlmChatAnswerTimeoutThen::ResponseAlternateText(t) => s.push_str(t),
LlmChatAnswerTimeoutThen::DoNothing => return false,
}
}
log::info!("LLM response |{}|", &s);
Expand Down Expand Up @@ -537,6 +537,62 @@ impl RuntimeNode for LlmChatNode {
}
}

#[derive(Archive, Clone, Deserialize, Serialize, serde::Deserialize)]
#[rkyv(compare(PartialEq))]
pub(crate) enum KnowledgeBaseAnswerNoRecallThen {
GotoAnotherNode,
ReturnAlternativeAnswerInstead(String),
}

#[derive(Archive, Clone, Deserialize, Serialize)]
#[rkyv(compare(PartialEq))]
pub(crate) struct KnowledgeBaseAnswerNode {
pub(super) recall_thresholds: f64,
pub(super) no_recall_then: KnowledgeBaseAnswerNoRecallThen,
pub(super) alternative_answer: String,
pub(super) next_node_id: String,
}

impl RuntimeNode for KnowledgeBaseAnswerNode {
fn exec(&mut self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool {
// log::info!("Into LlmChaKnowledgeBaseAnswerNodetNode");
let result = tokio::runtime::Handle::current().block_on(crate::kb::qa::retrieve_answer(
&req.robot_id,
&req.user_input,
));
match result {
Ok((answer, thresholds)) => {
if answer.is_some() && thresholds >= self.recall_thresholds {
response.answers.push(AnswerData {
text: answer.unwrap().answer,
answer_type: AnswerType::TextPlain,
});
true
} else {
match &self.no_recall_then {
KnowledgeBaseAnswerNoRecallThen::GotoAnotherNode => {
add_next_node(ctx, &self.next_node_id);
false
}
KnowledgeBaseAnswerNoRecallThen::ReturnAlternativeAnswerInstead(s) => {
response.answers.push(AnswerData {
text: s.clone(),
answer_type: AnswerType::TextPlain,
});
true
}
}
}
}
Err(e) => {
log::error!("KnowledgeBaseAnswerNode answer failed: {:?}", &e);
add_next_node(ctx, &self.next_node_id);
false
}
}
}
}

pub(crate) fn deser_node(bytes: &[u8]) -> Result<RuntimeNnodeEnum> {
let now = std::time::Instant::now();
let mut v = AlignedVec::<256>::with_capacity(bytes.len());
Expand Down
2 changes: 1 addition & 1 deletion src/flow/subflow/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ pub(crate) struct LlmChatNode {
#[serde(rename = "exitCondition")]
pub(crate) exit_condition: crate::flow::rt::node::LlmChatNodeExitCondition,
#[serde(rename = "whenTimeoutThen")]
pub(crate) when_timeout_then: crate::flow::rt::node::LlmChatNodeWhenTimeoutThen,
pub(crate) when_timeout_then: crate::flow::rt::node::LlmChatAnswerTimeoutThen,
#[serde(rename = "responseStreaming")]
pub(crate) response_streaming: bool,
pub(crate) branches: Vec<Branch>,
Expand Down
27 changes: 8 additions & 19 deletions src/intent/dto.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use std::sync::Mutex;

use serde::{Deserialize, Serialize};

use crate::result::Result;

#[derive(Deserialize, Debug)]
pub(crate) struct IntentFormData {
#[serde(rename = "robotId")]
Expand All @@ -18,15 +14,15 @@ pub(crate) struct IntentsData {
pub(crate) intents: Vec<Intent>,
}

static VEC_ROW_ID_LOCKER: Mutex<()> = Mutex::new(());
// static VEC_ROW_ID_LOCKER: Mutex<()> = Mutex::new(());

impl IntentsData {
pub(crate) fn inc_get_phrase_vec_id(&mut self) -> Result<i64> {
let _l = VEC_ROW_ID_LOCKER.lock()?;
self.phrase_vec_row_id = self.phrase_vec_row_id + 1;
Ok(self.phrase_vec_row_id)
}
}
// impl IntentsData {
// pub(crate) fn inc_get_phrase_vec_id(&mut self) -> Result<i64> {
// let _l = VEC_ROW_ID_LOCKER.lock()?;
// self.phrase_vec_row_id = self.phrase_vec_row_id + 1;
// Ok(self.phrase_vec_row_id)
// }
// }

#[derive(Clone, Serialize, Deserialize, Debug)]
pub(crate) struct Intent {
Expand Down Expand Up @@ -67,8 +63,6 @@ pub(crate) struct IntentDetail {
pub(crate) phrases: Vec<IntentPhraseData>,
}

static PHRASE_ID_LOCKER: Mutex<()> = Mutex::new(());

impl IntentDetail {
pub(crate) fn new(intent_idx: usize, intent_id: String, intent_name: String) -> Self {
IntentDetail {
Expand All @@ -81,9 +75,4 @@ impl IntentDetail {
phrases: vec![],
}
}
pub(crate) fn inc_get_phrase_vec_id(&mut self) -> Result<i64> {
let _l = PHRASE_ID_LOCKER.lock()?;
self.phrase_vec_row_id = self.phrase_vec_row_id + 1;
Ok(self.phrase_vec_row_id)
}
}
2 changes: 1 addition & 1 deletion src/intent/phrase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub(crate) async fn init_tables(robot_id: &str) -> Result<()> {

pub(crate) async fn search(robot_id: &str, vectors: &Vec<f32>) -> Result<Vec<(String, f64)>> {
let sql = format!(
"SELECT intent_id, intent_name, distance FROM {} WHERE vectors MATCH ? ORDER BY distance LIMIT 1",
"SELECT intent_id, intent_name, distance FROM {} WHERE vectors MATCH ? ORDER BY distance ASC LIMIT 1",
robot_id
);
let results = sqlx::query::<Sqlite>(&sql)
Expand Down
35 changes: 30 additions & 5 deletions src/kb/crud.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::path::Path;

use axum::{
Expand All @@ -12,14 +13,17 @@ use crate::result::{Error, Result};
use crate::robot::dto::RobotQuery;
use crate::web::server::to_res;

pub(crate) async fn upload(Query(q): Query<RobotQuery>, multipart: Multipart) -> impl IntoResponse {
if let Err(e) = do_uploading(&q.robot_id, multipart).await {
pub(crate) async fn upload_doc(
Query(q): Query<RobotQuery>,
multipart: Multipart,
) -> impl IntoResponse {
if let Err(e) = upload_doc_inner(&q.robot_id, multipart).await {
return to_res(Err(e));
}
to_res(Ok(()))
}

async fn do_uploading(robot_id: &str, mut multipart: Multipart) -> Result<()> {
async fn upload_doc_inner(robot_id: &str, mut multipart: Multipart) -> Result<()> {
let p = Path::new(".")
.join("data")
.join(robot_id)
Expand Down Expand Up @@ -68,11 +72,32 @@ pub(crate) async fn list_qa(Query(q): Query<RobotQuery>) -> impl IntoResponse {
to_res(r)
}

pub(crate) async fn add_qa(
pub(crate) async fn save_qa(
Query(q): Query<RobotQuery>,
Json(d): Json<QuestionAnswerPair>,
) -> impl IntoResponse {
let r = super::qa::add(&q.robot_id, d).await;
let r = super::qa::save(&q.robot_id, d).await;
// let r = sqlite_trans!(super::qa::add, &q.robot_id, d).await;
to_res(r)
}

pub(crate) async fn delete_qa(
Query(q): Query<RobotQuery>,
Json(d): Json<QuestionAnswerPair>,
) -> impl IntoResponse {
let r = super::qa::delete(&q.robot_id, d).await;
to_res(r)
}

pub(crate) async fn qa_dryrun(Query(q): Query<HashMap<String, String>>) -> impl IntoResponse {
let r = q.get("robotId");
let t = q.get("text");
if r.is_none() || t.is_none() {
let res = Err(Error::ErrorWithMessage(String::from(
"robotId or text was missing.",
)));
return to_res(res);
}
let r = super::qa::retrieve_answer(r.unwrap(), t.unwrap()).await;
to_res(r)
}
17 changes: 9 additions & 8 deletions src/kb/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ use std::vec::Vec;

use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize)]
pub(crate) struct QuestionAnswerData {
pub(super) id: String,
#[serde(rename = "qaData")]
pub(super) qa_data: QuestionAnswerPair,
}
// #[derive(Deserialize, Serialize)]
// pub(crate) struct QuestionAnswerData {
// pub(super) id: Option<String>,
// #[serde(rename = "qaData")]
// pub(super) qa_data: QuestionAnswerPair,
// }

#[derive(Deserialize, Serialize)]
pub(crate) struct QuestionAnswerPair {
pub(super) id: Option<String>,
pub(super) question: QuestionData,
#[serde(rename = "similarQuestions")]
pub(super) similar_questions: Option<Vec<QuestionData>>,
pub(super) answer: String,
pub(super) similar_questions: Vec<QuestionData>,
pub(crate) answer: String,
}

#[derive(Deserialize, Serialize)]
Expand Down
Loading

0 comments on commit efd4bf5

Please sign in to comment.