Skip to content

Commit

Permalink
LLM chat node: 20%
Browse files Browse the repository at this point in the history
  • Loading branch information
dialogflowchatbot committed Aug 12, 2024
1 parent 72e24e0 commit d2c3f47
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dialogflow"
version = "1.15.0"
version = "1.16.0"
edition = "2021"
homepage = "https://dialogflowchatbot.github.io/"
authors = ["dialogflowchatbot <[email protected]>"]
Expand Down
6 changes: 3 additions & 3 deletions src/ai/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ pub(crate) enum TextGenerationProvider {
}

#[derive(Deserialize, Serialize)]
pub(in crate::ai) struct Prompt {
pub(in crate::ai) role: String,
pub(in crate::ai) content: String,
pub(crate) struct Prompt {
pub(crate) role: String,
pub(crate) content: String,
}

static LOADED_MODELS: LazyLock<Mutex<HashMap<String, LoadedHuggingFaceModel>>> =
Expand Down
13 changes: 13 additions & 0 deletions src/flow/rt/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub(crate) struct Context {
robot_id: String,
pub(in crate::flow::rt) main_flow_id: String,
session_id: String,
pub(in crate::flow::rt) node: Option<Vec<u8>>,
pub(in crate::flow::rt) nodes: LinkedList<String>,
pub(crate) vars: HashMap<String, VariableValue>,
#[serde(skip)]
Expand Down Expand Up @@ -62,6 +63,7 @@ impl Context {
robot_id: String::from(robot_id),
main_flow_id: String::with_capacity(64),
session_id: String::from(session_id),
node: None,
nodes: LinkedList::new(),
vars: HashMap::with_capacity(16),
none_persistent_vars: HashMap::with_capacity(16),
Expand Down Expand Up @@ -103,6 +105,17 @@ impl Context {

pub(in crate::flow::rt) fn pop_node(&mut self) -> Option<RuntimeNnodeEnum> {
// println!("nodes len {}", self.nodes.len());

if self.node.is_some() {
let node = std::mem::replace(&mut self.node, None);
let v = node.unwrap();
match crate::flow::rt::node::deser_node(v.as_ref()) {
Ok(n) => return Some(n),
Err(e) => {
log::error!("pop_node failed err: {:?}", &e);
},
}
}
if let Some(node_id) = self.nodes.pop_front() {
// println!("main_flow_id {} node_id {}", &self.main_flow_id, &node_id);
if let Ok(r) = super::crud::get_runtime_node(&self.main_flow_id, &node_id) {
Expand Down
18 changes: 17 additions & 1 deletion src/flow/rt/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::vec::Vec;
use super::condition::ConditionData;
use super::node::{
CollectNode, ConditionNode, ExternalHttpCallNode, GotoAnotherNode, GotoMainFlowNode,
RuntimeNnodeEnum, SendEmailNode, TerminateNode, TextNode,
LlmChatNode, RuntimeNnodeEnum, SendEmailNode, TerminateNode, TextNode,
};
use crate::db;
use crate::db_executor;
Expand Down Expand Up @@ -70,6 +70,7 @@ fn check_first_node(
if node.get_node_id().eq(&id) {
match node {
Node::DialogNode(ref mut n) => n.node_id = String::from(first_node_id),
Node::LlmChatNode(n) => n.node_id = String::from(first_node_id),
Node::ConditionNode(n) => n.node_id = String::from(first_node_id),
Node::CollectNode(n) => n.node_id = String::from(first_node_id),
Node::GotoNode(n) => n.node_id = String::from(first_node_id),
Expand Down Expand Up @@ -154,6 +155,7 @@ fn convert_node(main_flow_id: &str, node: &mut Node) -> Result<()> {
Node::DialogNode(n) => {
let node = TextNode {
text: n.dialog_text.clone(),
text_type: n.dialog_text_type.clone(),
ret: NextActionType::WaitUserResponse == n.next_step,
next_node_id: n.branches[0].target_node_id.clone(),
};
Expand All @@ -164,6 +166,19 @@ fn convert_node(main_flow_id: &str, node: &mut Node) -> Result<()> {
// bytes.push(RuntimeNodeTypeId::TextNode as u8);
nodes.push((n.node_id.clone(), bytes));
}
Node::LlmChatNode(n) => {
let node = LlmChatNode {
prompt: n.prompt.clone(),
context_len: n.context_length,
cur_run_times: 0,
exit_condition: n.exit_condition.clone(),
streaming: n.response_streaming,
next_node_id: n.branches[0].target_node_id.clone(),
};
let r = RuntimeNnodeEnum::LlmChatNode(node);
let bytes = rkyv::to_bytes::<_, 256>(&r).unwrap();
nodes.push((n.node_id.clone(), bytes));
}
Node::ConditionNode(n) => {
// println!("Condition {}", &n.node_id);
let mut cnt = 1u8;
Expand Down Expand Up @@ -316,6 +331,7 @@ fn convert_node(main_flow_id: &str, node: &mut Node) -> Result<()> {
if !n.ending_text.is_empty() {
let node = TextNode {
text: n.ending_text.clone(),
text_type: super::dto::AnswerType::TextPlain,
ret: false,
next_node_id: end_node_id,
};
Expand Down
3 changes: 2 additions & 1 deletion src/flow/rt/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ pub(crate) struct CollectData {
pub(crate) value: String,
}

#[derive(Serialize)]
#[derive(Clone, Deserialize, Serialize, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[archive(compare(PartialEq), check_bytes)]
pub(crate) enum AnswerType {
TextPlain,
TextHtml,
Expand Down
2 changes: 1 addition & 1 deletion src/flow/rt/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub(in crate::flow::rt) fn exec(req: &Request, ctx: &mut Context) -> Result<Resp
let mut response = Response::new(req);
for _i in 0..100 {
// let now = std::time::Instant::now();
if let Some(n) = ctx.pop_node() {
if let Some(mut n) = ctx.pop_node() {
// println!("pop node {:?}", now.elapsed());
let ret = n.exec(&req, ctx, &mut response);
// println!("node exec {:?}", now.elapsed());
Expand Down
45 changes: 29 additions & 16 deletions src/flow/rt/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::ai::chat::ResultReceiver;
use crate::external::http::client as http;
use crate::flow::rt::collector;
use crate::flow::subflow::dto::NextActionType;
use crate::man::settings::get_settings;
use crate::result::Result;
use crate::variable::crud as variable;
use crate::variable::dto::{VariableType, VariableValue};
Expand Down Expand Up @@ -39,11 +40,12 @@ pub(crate) enum RuntimeNnodeEnum {
ExternalHttpCallNode,
TerminateNode,
SendEmailNode,
LlmChatNode,
}

#[enum_dispatch(RuntimeNnodeEnum)]
pub(crate) trait RuntimeNode {
fn exec(&self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool;
fn exec(&mut self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool;
}

fn replace_vars(text: &str, req: &Request, ctx: &mut Context) -> Result<String> {
Expand Down Expand Up @@ -88,18 +90,19 @@ fn add_next_node(ctx: &mut Context, next_node_id: &str) {
#[archive(compare(PartialEq), check_bytes)]
pub(crate) struct TextNode {
pub(super) text: String,
pub(crate) text_type: AnswerType,
pub(super) ret: bool,
pub(super) next_node_id: String,
}

impl RuntimeNode for TextNode {
fn exec(&self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool {
fn exec(&mut self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool {
// println!("Into TextNode");
// let now = std::time::Instant::now();
match replace_vars(&self.text, &req, ctx) {
Ok(answer) => response.answers.push(AnswerData {
text: answer,
answer_type: AnswerType::TextPlain,
answer_type: self.text_type.clone(),
}),
Err(e) => log::error!("{:?}", e),
};
Expand All @@ -117,7 +120,7 @@ pub(crate) struct GotoMainFlowNode {
}

impl RuntimeNode for GotoMainFlowNode {
fn exec(&self, _req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
fn exec(&mut self, _req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
// println!("Into GotoMainFlowNode");
ctx.main_flow_id.clear();
ctx.main_flow_id.push_str(&self.main_flow_id);
Expand All @@ -133,7 +136,7 @@ pub(crate) struct GotoAnotherNode {
}

impl RuntimeNode for GotoAnotherNode {
fn exec(&self, _req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
fn exec(&mut self, _req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
// println!("Into GotoAnotherNode");
add_next_node(ctx, &self.next_node_id);
false
Expand All @@ -150,7 +153,7 @@ pub(crate) struct CollectNode {
}

impl RuntimeNode for CollectNode {
fn exec(&self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool {
fn exec(&mut self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool {
// println!("Into CollectNode");
if let Some(r) = collector::collect(&req.user_input, &self.collect_type) {
// println!("{} {}", &self.var_name, r);
Expand Down Expand Up @@ -179,7 +182,7 @@ pub(crate) struct ConditionNode {
}

impl RuntimeNode for ConditionNode {
fn exec(&self, req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
fn exec(&mut self, req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
// println!("Into ConditionNode");
let mut r = false;
for and_conditions in self.conditions.iter() {
Expand All @@ -204,7 +207,7 @@ impl RuntimeNode for ConditionNode {
pub(crate) struct TerminateNode {}

impl RuntimeNode for TerminateNode {
fn exec(&self, _req: &Request, _ctx: &mut Context, response: &mut Response) -> bool {
fn exec(&mut self, _req: &Request, _ctx: &mut Context, response: &mut Response) -> bool {
// println!("Into TerminateNode");
response.next_action = NextActionType::Terminate;
true
Expand All @@ -219,7 +222,7 @@ pub(crate) struct ExternalHttpCallNode {
}

impl RuntimeNode for ExternalHttpCallNode {
fn exec(&self, req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
fn exec(&mut self, req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
// println!("Into ExternalHttpCallNode");
if let Ok(op) =
crate::external::http::crud::get_detail(&req.robot_id, self.http_api_id.as_str())
Expand Down Expand Up @@ -342,9 +345,8 @@ impl SendEmailNode {
}

impl RuntimeNode for SendEmailNode {
fn exec(&self, req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
fn exec(&mut self, req: &Request, ctx: &mut Context, _response: &mut Response) -> bool {
// println!("Into SendEmailNode");
use crate::man::settings::get_settings;
if let Ok(op) = get_settings(&req.robot_id) {
if let Some(settings) = op {
if !settings.smtp_host.is_empty() {
Expand All @@ -359,27 +361,29 @@ impl RuntimeNode for SendEmailNode {
}
}

#[derive(Archive, Deserialize, Serialize)]
#[derive(Archive, Clone, Deserialize, Serialize, serde::Deserialize)]
#[archive(compare(PartialEq), check_bytes)]
pub(crate) enum LlmChatNodeExitCondition {
Intent(String),
SpecialInputs(String),
MaxChatTimes(u32),
MaxChatTimes(u8),
}

#[derive(Archive, Deserialize, Serialize)]
#[derive(Archive, Clone, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
pub(crate) struct LlmChatNode {
pub(super) prompt: String,
pub(super) context_len: u8,
pub(super) cur_run_times: u8,
pub(super) exit_condition: LlmChatNodeExitCondition,
pub(super) streaming: bool,
pub(super) next_node_id: String,
}

impl RuntimeNode for LlmChatNode {
fn exec(&self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool {
fn exec(&mut self, req: &Request, ctx: &mut Context, response: &mut Response) -> bool {
// println!("Into LlmChatNode");
self.cur_run_times = self.cur_run_times + 1;
match &self.exit_condition {
LlmChatNodeExitCondition::Intent(i) => {
if req.user_input_intent.is_some() && req.user_input_intent.as_ref().unwrap().eq(i)
Expand All @@ -394,8 +398,16 @@ impl RuntimeNode for LlmChatNode {
return false;
}
}
LlmChatNodeExitCondition::MaxChatTimes(t) => todo!(),
LlmChatNodeExitCondition::MaxChatTimes(t) => {
if self.cur_run_times > *t {
add_next_node(ctx, &self.next_node_id);
return false;
}
},
}
let r = RuntimeNnodeEnum::LlmChatNode(self.clone());
let bytes = rkyv::to_bytes::<_, 256>(&r).unwrap();
ctx.node = Some(bytes.into_vec());
if self.streaming {
let r = super::facade::get_sender(&req.session_id);
if r.is_err() {
Expand Down Expand Up @@ -425,6 +437,7 @@ impl RuntimeNode for LlmChatNode {
}) {
log::info!("LlmChatNode response failed, err: {:?}", &e);
} else {
log::info!("LLM response {}", &s);
response.answers.push(AnswerData {
text: s,
answer_type: AnswerType::TextPlain,
Expand Down
42 changes: 42 additions & 0 deletions src/flow/subflow/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub(crate) struct CanvasCell {
#[serde(tag = "nodeType")]
pub(crate) enum Node {
DialogNode(DialogNode),
LlmChatNode(LlmChatNode),
ConditionNode(ConditionNode),
CollectNode(CollectNode),
GotoNode(GotoNode),
Expand Down Expand Up @@ -94,6 +95,20 @@ impl Node {
Ok(())
}
}
Node::LlmChatNode(n) => {
let t = "Dialog";
if !n.valid {
Self::err(f, t, &n.node_name, "verification failed")
} else if n.node_name.is_empty() {
Self::err(f, t, &n.node_name, "node name not filled in")
} else if n.prompt.is_empty() {
Self::err(f, t, &n.node_name, "No prompt filled in")
} else if n.branches.len() != 1 {
Self::err(f, t, &n.node_name, "Branch information is incorrect")
} else {
Ok(())
}
}
Node::ConditionNode(n) => {
let t = "Condition";
if !n.valid {
Expand Down Expand Up @@ -205,6 +220,7 @@ impl Node {
pub(crate) fn get_node_id(&self) -> String {
match self {
Self::DialogNode(n) => n.node_id.clone(),
Self::LlmChatNode(n) => n.node_id.clone(),
Self::ConditionNode(n) => n.node_id.clone(),
Self::CollectNode(n) => n.node_id.clone(),
Self::GotoNode(n) => n.node_id.clone(),
Expand All @@ -222,6 +238,11 @@ impl Node {
.iter()
.for_each(|b| ids.push(b.target_node_id.clone()));
}
Self::LlmChatNode(n) => {
n.branches
.iter()
.for_each(|b| ids.push(b.target_node_id.clone()));
}
Self::EndNode(_) | Self::GotoNode(_) => {}
Self::ConditionNode(n) => {
n.branches
Expand Down Expand Up @@ -250,6 +271,7 @@ impl Node {
pub(crate) fn get_branches(&mut self) -> Option<&mut Vec<Branch>> {
match self {
Self::DialogNode(n) => Some(&mut n.branches),
Self::LlmChatNode(n) => Some(&mut n.branches),
Self::EndNode(_) | Self::GotoNode(_) => None,
Self::ConditionNode(n) => Some(&mut n.branches),
Self::CollectNode(n) => Some(&mut n.branches),
Expand Down Expand Up @@ -306,11 +328,31 @@ pub(crate) struct DialogNode {
pub(crate) node_name: String,
#[serde(rename = "dialogText")]
pub(crate) dialog_text: String,
#[serde(rename = "dialogTextType")]
pub(crate) dialog_text_type: crate::flow::rt::dto::AnswerType,
#[serde(rename = "nextStep")]
pub(crate) next_step: NextActionType,
pub(crate) branches: Vec<Branch>,
}

#[derive(Deserialize)]
pub(crate) struct LlmChatNode {
pub(crate) valid: bool,
#[serde(rename = "nodeId")]
pub(crate) node_id: String,
#[serde(rename = "nodeName")]
pub(crate) node_name: String,
#[serde(rename = "prompt")]
pub(crate) prompt: String,
#[serde(rename = "contextLength")]
pub(crate) context_length: u8,
#[serde(rename = "exitCondition")]
pub(crate) exit_condition: crate::flow::rt::node::LlmChatNodeExitCondition,
#[serde(rename = "responseStreaming")]
pub(crate) response_streaming: bool,
pub(crate) branches: Vec<Branch>,
}

#[derive(Deserialize)]
pub(crate) struct ConditionNode {
pub(crate) valid: bool,
Expand Down
4 changes: 2 additions & 2 deletions src/web/asset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::sync::LazyLock;
pub(crate) static ASSETS_MAP: LazyLock<HashMap<&str, usize>> = LazyLock::new(|| {
HashMap::from([
(r"/assets/inbound-bot-PJJg_rST.png", 0),
(r"/assets/index-BjYMHWgz.css", 1),
(r"/assets/index-CQUEiyGj.js", 2),
(r"/assets/index-BisrZMCw.js", 1),
(r"/assets/index-DpRbUEHB.css", 2),
(r"/assets/outbound-bot-EmsLuWRN.png", 3),
(r"/assets/text-bot-CWb_Poym.png", 4),
(r"/assets/usedByDialogNodeTextGeneration-DrFqkTqi.png", 5),
Expand Down
Loading

0 comments on commit d2c3f47

Please sign in to comment.