From a4ebea7a88fd24a32dceda716817e99ea01ae8d4 Mon Sep 17 00:00:00 2001 From: dialogflowchatbot Date: Sat, 31 Aug 2024 19:16:16 +0800 Subject: [PATCH] Append chat history to LLM prompt --- Cargo.toml | 2 +- src/ai/chat.rs | 123 +++++++++++++++++++++++------------- src/ai/completion.rs | 35 ++++------ src/ai/embedding.rs | 32 +++------- src/ai/huggingface.rs | 40 +++++++++++- src/external/http/client.rs | 23 +++++++ src/external/http/mod.rs | 2 + src/flow/rt/context.rs | 3 + src/flow/rt/executor.rs | 16 +++++ src/flow/rt/node.rs | 8 +++ src/web/asset.rs | 34 +++++----- 11 files changed, 212 insertions(+), 106 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b80623b..a3eb683 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dialogflow" -version = "1.16.0" +version = "1.16.1" edition = "2021" homepage = "https://dialogflowchatbot.github.io/" authors = ["dialogflowchatbot "] diff --git a/src/ai/chat.rs b/src/ai/chat.rs index ce5fbd8..2a78221 100644 --- a/src/ai/chat.rs +++ b/src/ai/chat.rs @@ -1,18 +1,21 @@ -use core::time::Duration; use std::collections::HashMap; use std::sync::{LazyLock, Mutex}; +use std::vec::Vec; use futures_util::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use tokio::sync::mpsc::Sender; +use super::completion::Prompt; use crate::ai::huggingface::{HuggingFaceModel, LoadedHuggingFaceModel}; use crate::man::settings; use crate::result::{Error, Result}; static LOADED_MODELS: LazyLock>> = LazyLock::new(|| Mutex::new(HashMap::with_capacity(32))); +// static HTTP_CLIENTS: LazyLock>> = +// LazyLock::new(|| Mutex::new(HashMap::with_capacity(32))); pub(crate) enum ResultReceiver<'r> { SseSender(&'r Sender), @@ -37,6 +40,7 @@ pub(crate) fn replace_model_cache(robot_id: &str, m: &HuggingFaceModel) -> Resul pub(crate) async fn chat( robot_id: &str, prompt: &str, + chat_history: Option>, connect_timeout: Option, read_timeout: Option, result_receiver: ResultReceiver<'_>, @@ -49,16 +53,17 @@ pub(crate) async fn chat( robot_id, &m, prompt, + chat_history, settings.chat_provider.max_response_token_length as usize, result_receiver, - ) - .await?; + )?; Ok(()) } ChatProvider::OpenAI(m) => { open_ai( &m, prompt, + chat_history, connect_timeout .unwrap_or(settings.text_generation_provider.connect_timeout_millis), read_timeout.unwrap_or(settings.text_generation_provider.read_timeout_millis), @@ -73,6 +78,7 @@ pub(crate) async fn chat( &settings.text_generation_provider.api_url, &m, prompt, + chat_history, connect_timeout .unwrap_or(settings.text_generation_provider.connect_timeout_millis), read_timeout.unwrap_or(settings.text_generation_provider.read_timeout_millis), @@ -91,16 +97,17 @@ pub(crate) async fn chat( } } -async fn huggingface( +fn huggingface( robot_id: &str, m: &HuggingFaceModel, prompt: &str, + chat_history: Option>, sample_len: usize, mut result_receiver: ResultReceiver<'_>, ) -> Result<()> { let info = m.get_info(); // log::info!("model_type={:?}", &info.model_type); - let new_prompt = info.convert_prompt(prompt)?; + let new_prompt = info.convert_prompt(prompt, chat_history)?; let mut model = LOADED_MODELS.lock().unwrap_or_else(|e| { log::warn!("{:#?}", &e); e.into_inner() @@ -152,37 +159,54 @@ async fn huggingface( async fn open_ai( m: &str, s: &str, + chat_history: Option>, connect_timeout_millis: u32, read_timeout_millis: u32, proxy_url: &str, result_receiver: ResultReceiver<'_>, ) -> Result<()> { - let mut client = reqwest::Client::builder() - .connect_timeout(Duration::from_millis(connect_timeout_millis.into())) - .read_timeout(Duration::from_millis(read_timeout_millis.into())); - if proxy_url.is_empty() { - client = client.no_proxy(); - } else { - let proxy = reqwest::Proxy::http(proxy_url)?; - client = client.proxy(proxy); - } - let client = client.build()?; - let mut message0 = Map::new(); - message0.insert(String::from("role"), Value::from("system")); - message0.insert(String::from("content"), Value::from("system_hint")); - let mut message1 = Map::new(); - message1.insert(String::from("role"), Value::from("user")); - message1.insert(String::from("content"), Value::from(s)); - let messages = Value::Array(vec![message0.into(), message1.into()]); - let mut map = Map::new(); - map.insert(String::from("model"), Value::from(m)); - map.insert(String::from("messages"), messages); + // let client = HTTP_CLIENTS.lock()?.entry(String::from("value")).or_insert(crate::external::http::get_client(connect_timeout_millis.into(), read_timeout_millis.into(), proxy_url)?); + let client = crate::external::http::get_client( + connect_timeout_millis.into(), + read_timeout_millis.into(), + proxy_url, + )?; + let mut req_body = Map::new(); + req_body.insert(String::from("model"), Value::from(m)); + + let mut sys_message = Map::new(); + sys_message.insert(String::from("role"), Value::from("system")); + sys_message.insert( + String::from("content"), + Value::from("You are a helpful assistant."), + ); + let mut messages: Vec = match chat_history { + Some(h) if !h.is_empty() => { + let mut d = Vec::with_capacity(h.len() + 1); + d.push(Value::Null); + for p in h.into_iter() { + let mut map = Map::new(); + map.insert(p.role, Value::String(p.content)); + d.push(Value::from(map)); + } + d + } + _ => Vec::with_capacity(1), + }; + messages[0] = Value::Object(sys_message); + let mut user_message = Map::new(); + user_message.insert(String::from("role"), Value::from("user")); + user_message.insert(String::from("content"), Value::from(s)); + messages.push(Value::Object(user_message)); + let messages = Value::Array(messages); + req_body.insert(String::from("messages"), messages); + let stream = match result_receiver { ResultReceiver::SseSender(_) => true, ResultReceiver::StrBuf(_) => false, }; - map.insert(String::from("stream"), Value::Bool(stream)); - let obj = Value::Object(map); + req_body.insert(String::from("stream"), Value::Bool(stream)); + let obj = Value::Object(req_body); let req = client .post("https://api.openai.com/v1/chat/completions") .header("Content-Type", "application/json") @@ -249,6 +273,7 @@ async fn ollama( u: &str, m: &str, s: &str, + chat_history: Option>, connect_timeout_millis: u32, read_timeout_millis: u32, proxy_url: &str, @@ -268,30 +293,42 @@ async fn ollama( if prompt.is_empty() { return Ok(()); } - let mut client = reqwest::Client::builder() - .connect_timeout(Duration::from_millis(connect_timeout_millis.into())) - .read_timeout(Duration::from_millis(read_timeout_millis.into())); - if proxy_url.is_empty() { - client = client.no_proxy(); - } else { - let proxy = reqwest::Proxy::http(proxy_url)?; - client = client.proxy(proxy); - } - let client = client.build()?; - let mut map = Map::new(); - map.insert(String::from("prompt"), Value::String(prompt)); - map.insert(String::from("model"), Value::String(String::from(m))); + let client = crate::external::http::get_client( + connect_timeout_millis.into(), + read_timeout_millis.into(), + proxy_url, + )?; + let mut req_body = Map::new(); + req_body.insert(String::from("prompt"), Value::String(prompt)); + req_body.insert(String::from("model"), Value::String(String::from(m))); let stream = match result_receiver { ResultReceiver::SseSender(_) => true, ResultReceiver::StrBuf(_) => false, }; - map.insert(String::from("stream"), Value::Bool(stream)); + req_body.insert(String::from("stream"), Value::Bool(stream)); + + let mut messages: Vec = match chat_history { + Some(h) if !h.is_empty() => { + let mut d = Vec::with_capacity(h.len() + 1); + for p in h.into_iter() { + let mut map = Map::new(); + map.insert(p.role, Value::String(p.content)); + d.push(Value::from(map)); + } + d + } + _ => Vec::with_capacity(1), + }; + let mut m = Map::new(); + m.insert(String::from("user"), Value::String(String::from(s))); + messages.push(Value::from(m)); + req_body.insert(String::from("messages"), Value::Array(messages)); let mut num_predict = Map::new(); num_predict.insert(String::from("num_predict"), Value::from(sample_len)); + req_body.insert(String::from("options"), Value::from(num_predict)); - map.insert(String::from("options"), Value::from(num_predict)); - let obj = Value::Object(map); + let obj = Value::Object(req_body); let body = serde_json::to_string(&obj)?; // log::info!("Request Ollama body {}", &body); let req = client.post(u).body(body); diff --git a/src/ai/completion.rs b/src/ai/completion.rs index 5199a24..7ec81d8 100644 --- a/src/ai/completion.rs +++ b/src/ai/completion.rs @@ -1,4 +1,3 @@ -use core::time::Duration; use std::collections::HashMap; use std::sync::{LazyLock, Mutex}; @@ -25,7 +24,7 @@ pub(crate) enum TextGenerationProvider { Ollama(String), } -#[derive(Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize)] pub(crate) struct Prompt { pub(crate) role: String, pub(crate) content: String, @@ -184,7 +183,7 @@ async fn huggingface( ) -> Result<()> { let info = m.get_info(); // log::info!("model_type={:?}", &info.model_type); - let new_prompt = info.convert_prompt(prompt)?; + let new_prompt = info.convert_prompt(prompt, None)?; let mut model = LOADED_MODELS.lock().unwrap_or_else(|e| { log::warn!("{:#?}", &e); e.into_inner() @@ -242,16 +241,11 @@ async fn open_ai( proxy_url: &str, sender: &Sender, ) -> Result<()> { - let mut client = reqwest::Client::builder() - .connect_timeout(Duration::from_millis(connect_timeout_millis.into())) - .read_timeout(Duration::from_millis(read_timeout_millis.into())); - if proxy_url.is_empty() { - client = client.no_proxy(); - } else { - let proxy = reqwest::Proxy::http(proxy_url)?; - client = client.proxy(proxy); - } - let client = client.build()?; + let client = crate::external::http::get_client( + connect_timeout_millis.into(), + read_timeout_millis.into(), + proxy_url, + )?; let mut message0 = Map::new(); message0.insert(String::from("role"), Value::from("system")); message0.insert(String::from("content"), Value::from("system_hint")); @@ -319,16 +313,11 @@ async fn ollama( if prompt.is_empty() { return Ok(()); } - let mut client = reqwest::Client::builder() - .connect_timeout(Duration::from_millis(connect_timeout_millis.into())) - .read_timeout(Duration::from_millis(read_timeout_millis.into())); - if proxy_url.is_empty() { - client = client.no_proxy(); - } else { - let proxy = reqwest::Proxy::http(proxy_url)?; - client = client.proxy(proxy); - } - let client = client.build()?; + let client = crate::external::http::get_client( + connect_timeout_millis.into(), + read_timeout_millis.into(), + proxy_url, + )?; let mut map = Map::new(); map.insert(String::from("prompt"), Value::String(prompt)); map.insert(String::from("model"), Value::String(String::from(m))); diff --git a/src/ai/embedding.rs b/src/ai/embedding.rs index 6f263dc..e02c3b1 100644 --- a/src/ai/embedding.rs +++ b/src/ai/embedding.rs @@ -1,5 +1,3 @@ -use core::time::Duration; - // use std::collections::VecDeque; use std::collections::HashMap; use std::sync::{Mutex, OnceLock}; @@ -121,16 +119,11 @@ async fn open_ai( read_timeout_millis: u32, proxy_url: &str, ) -> Result> { - let mut client = reqwest::Client::builder() - .connect_timeout(Duration::from_millis(connect_timeout_millis.into())) - .read_timeout(Duration::from_millis(read_timeout_millis.into())); - if proxy_url.is_empty() { - client = client.no_proxy(); - } else { - let proxy = reqwest::Proxy::http(proxy_url)?; - client = client.proxy(proxy); - } - let client = client.build()?; + let client = crate::external::http::get_client( + connect_timeout_millis.into(), + read_timeout_millis.into(), + proxy_url, + )?; let mut map = Map::new(); map.insert(String::from("input"), Value::String(String::from(s))); map.insert(String::from("model"), Value::String(String::from(m))); @@ -174,16 +167,11 @@ async fn ollama( read_timeout_millis: u32, proxy_url: &str, ) -> Result> { - let mut client = reqwest::Client::builder() - .connect_timeout(Duration::from_millis(connect_timeout_millis.into())) - .read_timeout(Duration::from_millis(read_timeout_millis.into())); - if proxy_url.is_empty() { - client = client.no_proxy(); - } else { - let proxy = reqwest::Proxy::http(proxy_url)?; - client = client.proxy(proxy); - } - let client = client.build()?; + let client = crate::external::http::get_client( + connect_timeout_millis.into(), + read_timeout_millis.into(), + proxy_url, + )?; let mut map = Map::new(); map.insert(String::from("prompt"), Value::String(String::from(s))); map.insert(String::from("model"), Value::String(String::from(m))); diff --git a/src/ai/huggingface.rs b/src/ai/huggingface.rs index c1dc00b..04898a3 100644 --- a/src/ai/huggingface.rs +++ b/src/ai/huggingface.rs @@ -92,7 +92,11 @@ pub(crate) struct HuggingFaceModelInfo { } impl HuggingFaceModelInfo { - pub(super) fn convert_prompt(&self, s: &str) -> Result { + pub(super) fn convert_prompt( + &self, + s: &str, + history: Option>, + ) -> Result { let mut prompts: Vec = serde_json::from_str(s)?; let mut system = String::new(); let mut user = String::new(); @@ -112,6 +116,15 @@ impl HuggingFaceModelInfo { p.push_str(&system); p.push_str("\n"); } + if let Some(h) = history { + for i in h.iter() { + p.push_str("<|"); + p.push_str(&i.role); + p.push_str("|>\n"); + p.push_str(&i.content); + p.push_str("\n"); + } + } p.push_str("<|user|>\n"); p.push_str(&user); p.push_str("\n<|assistant|>"); @@ -128,7 +141,21 @@ impl HuggingFaceModelInfo { } HuggingFaceModelType::Gemma => { let mut p = String::with_capacity(s.len()); - p.push_str("user\n"); + // p.push_str(""); + if let Some(h) = history { + for i in h.iter() { + p.push_str(""); + if i.role.eq("assistant") { + p.push_str("model"); + } else { + p.push_str(&i.role); + } + p.push_str("\n"); + p.push_str(&i.content); + p.push_str("\n"); + } + } + p.push_str("user\n"); p.push_str(&user); p.push_str("\nmodel"); Ok(p) @@ -141,6 +168,15 @@ impl HuggingFaceModelInfo { p.push_str(&system); p.push_str("<|end|>\n"); } + if let Some(h) = history { + for i in h.iter() { + p.push_str("<|"); + p.push_str(&i.role); + p.push_str("|>\n"); + p.push_str(&i.content); + p.push_str("<|end|>\n"); + } + } p.push_str("<|user|>\n"); p.push_str(&user); p.push_str("<|end|>\n<|assistant|>"); diff --git a/src/external/http/client.rs b/src/external/http/client.rs index 5ab0e0f..2395403 100644 --- a/src/external/http/client.rs +++ b/src/external/http/client.rs @@ -3,11 +3,34 @@ use std::time::Duration; use std::vec::Vec; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use reqwest::Client; use reqwest::RequestBuilder; use super::dto::{HttpReqInfo, Method, PostContentType, Protocol, ResponseData, ValueSource}; +use crate::result::Result; use crate::variable::dto::VariableValue; +pub(crate) fn get_client( + connect_timeout_millis: u64, + read_timeout_millis: u64, + proxy_url: &str, +) -> Result { + let mut client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(connect_timeout_millis)) + .read_timeout(Duration::from_millis(read_timeout_millis)) + // Since can not reuse Client currently, so set pool size to 0 + .pool_max_idle_per_host(0) + .pool_idle_timeout(Duration::from_secs(1)); + if proxy_url.is_empty() { + client = client.no_proxy(); + } else { + let proxy = reqwest::Proxy::http(proxy_url)?; + client = client.proxy(proxy); + } + let client = client.build()?; + Ok(client) +} + pub(crate) async fn req_async( info: HttpReqInfo, vars: HashMap, diff --git a/src/external/http/mod.rs b/src/external/http/mod.rs index b9e82c4..817f065 100644 --- a/src/external/http/mod.rs +++ b/src/external/http/mod.rs @@ -1,3 +1,5 @@ pub(crate) mod client; pub(crate) mod crud; pub(crate) mod dto; + +pub(crate) use client::get_client; diff --git a/src/flow/rt/context.rs b/src/flow/rt/context.rs index b35c195..17f99ff 100644 --- a/src/flow/rt/context.rs +++ b/src/flow/rt/context.rs @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; use tokio::time::{interval, Duration}; use super::node::RuntimeNnodeEnum; +use crate::ai::completion::Prompt; use crate::db; use crate::man::settings; use crate::result::Result; @@ -36,6 +37,7 @@ pub(crate) struct Context { #[serde(skip)] pub(crate) none_persistent_data: HashMap, last_active_time: u64, + pub(crate) chat_history: Vec, } impl Context { @@ -72,6 +74,7 @@ impl Context { .duration_since(UNIX_EPOCH) .unwrap() .as_secs(), + chat_history: Vec::with_capacity(16), }; ctx } diff --git a/src/flow/rt/executor.rs b/src/flow/rt/executor.rs index d01f5b3..4503508 100644 --- a/src/flow/rt/executor.rs +++ b/src/flow/rt/executor.rs @@ -1,5 +1,6 @@ use super::context::Context; use super::dto::{Request, Response}; +use crate::ai::completion::Prompt; use crate::flow::rt::node::RuntimeNode; use crate::intent::detector; use crate::result::{Error, Result}; @@ -34,7 +35,22 @@ pub(in crate::flow::rt) async fn process(req: &mut Request) -> Result } // println!("intent detect {:?}", now.elapsed()); // let now = std::time::Instant::now(); + ctx.chat_history.push(Prompt { + role: String::from("user"), + content: req.user_input.clone(), + }); let r = exec(req, &mut ctx); + if r.is_ok() { + let res = r.as_ref().unwrap(); + if !res.answers.is_empty() { + for a in res.answers.iter() { + ctx.chat_history.push(Prompt { + role: String::from("assistant"), + content: a.text.clone(), + }); + } + } + } // println!("exec {:?}", now.elapsed()); // let now = std::time::Instant::now(); ctx.save()?; diff --git a/src/flow/rt/node.rs b/src/flow/rt/node.rs index 0fb8ef9..1392ea2 100644 --- a/src/flow/rt/node.rs +++ b/src/flow/rt/node.rs @@ -1,4 +1,5 @@ use core::time::Duration; +use std::ops::DerefMut; use enum_dispatch::enum_dispatch; use lettre::transport::smtp::PoolConfig; @@ -442,6 +443,7 @@ impl RuntimeNode for LlmChatNode { if let Err(e) = crate::ai::chat::chat( &robot_id, &prompt, + None, connect_timeout, read_timeout, ResultReceiver::SseSender(&s), @@ -457,9 +459,15 @@ impl RuntimeNode for LlmChatNode { let mut s = String::with_capacity(1024); if let Err(e) = tokio::task::block_in_place(|| { // log::info!("prompt |{}|", &self.prompt); + let chat_history = if ctx.chat_history.is_empty() { + None + } else { + Some(ctx.chat_history.clone()) + }; tokio::runtime::Handle::current().block_on(crate::ai::chat::chat( &req.robot_id, &self.prompt, + chat_history, self.connect_timeout, self.read_timeout, ResultReceiver::StrBuf(&mut s), diff --git a/src/web/asset.rs b/src/web/asset.rs index 222d6d2..09e40fe 100644 --- a/src/web/asset.rs +++ b/src/web/asset.rs @@ -3,18 +3,22 @@ use std::collections::HashMap; use std::sync::LazyLock; pub(crate) static ASSETS_MAP: LazyLock> = LazyLock::new(|| { -HashMap::from([ -(r"/assets/inbound-bot-PJJg_rST.png", 0), -(r"/assets/index-B4PwGmOZ.css", 1), -(r"/assets/index-DaTxxXw1.js", 2), -(r"/assets/outbound-bot-EmsLuWRN.png", 3), -(r"/assets/text-bot-CWb_Poym.png", 4), -(r"/assets/usedByDialogNodeTextGeneration-DrFqkTqi.png", 5), -(r"/assets/usedByDialogNodeTextGeneration-thumbnail-C1iQCVQO.png", 6), -(r"/assets/usedByLlmChatNode-Bv2Fg5P7.png", 7), -(r"/assets/usedBySentenceEmbedding-Dmju1hVB.png", 8), -(r"/assets/usedBySentenceEmbedding-thumbnail-DVXz_sh0.png", 9), -(r"/favicon.ico", 10), -("/", 11), -(r"/index.html", 11), -])}); + HashMap::from([ + (r"/assets/inbound-bot-PJJg_rST.png", 0), + (r"/assets/index-B4PwGmOZ.css", 1), + (r"/assets/index-DaTxxXw1.js", 2), + (r"/assets/outbound-bot-EmsLuWRN.png", 3), + (r"/assets/text-bot-CWb_Poym.png", 4), + (r"/assets/usedByDialogNodeTextGeneration-DrFqkTqi.png", 5), + ( + r"/assets/usedByDialogNodeTextGeneration-thumbnail-C1iQCVQO.png", + 6, + ), + (r"/assets/usedByLlmChatNode-Bv2Fg5P7.png", 7), + (r"/assets/usedBySentenceEmbedding-Dmju1hVB.png", 8), + (r"/assets/usedBySentenceEmbedding-thumbnail-DVXz_sh0.png", 9), + (r"/favicon.ico", 10), + ("/", 11), + (r"/index.html", 11), + ]) +});