Skip to content

Commit

Permalink
Append chat history to LLM prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
dialogflowchatbot committed Aug 31, 2024
1 parent fb03763 commit a4ebea7
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 106 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dialogflow"
version = "1.16.0"
version = "1.16.1"
edition = "2021"
homepage = "https://dialogflowchatbot.github.io/"
authors = ["dialogflowchatbot <[email protected]>"]
Expand Down
123 changes: 80 additions & 43 deletions src/ai/chat.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use core::time::Duration;
use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
use std::vec::Vec;

use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tokio::sync::mpsc::Sender;

use super::completion::Prompt;
use crate::ai::huggingface::{HuggingFaceModel, LoadedHuggingFaceModel};
use crate::man::settings;
use crate::result::{Error, Result};

static LOADED_MODELS: LazyLock<Mutex<HashMap<String, LoadedHuggingFaceModel>>> =
LazyLock::new(|| Mutex::new(HashMap::with_capacity(32)));
// static HTTP_CLIENTS: LazyLock<Mutex<HashMap<String, LoadedHuggingFaceModel>>> =
// LazyLock::new(|| Mutex::new(HashMap::with_capacity(32)));

pub(crate) enum ResultReceiver<'r> {
SseSender(&'r Sender<String>),
Expand All @@ -37,6 +40,7 @@ pub(crate) fn replace_model_cache(robot_id: &str, m: &HuggingFaceModel) -> Resul
pub(crate) async fn chat(
robot_id: &str,
prompt: &str,
chat_history: Option<Vec<Prompt>>,
connect_timeout: Option<u32>,
read_timeout: Option<u32>,
result_receiver: ResultReceiver<'_>,
Expand All @@ -49,16 +53,17 @@ pub(crate) async fn chat(
robot_id,
&m,
prompt,
chat_history,
settings.chat_provider.max_response_token_length as usize,
result_receiver,
)
.await?;
)?;
Ok(())
}
ChatProvider::OpenAI(m) => {
open_ai(
&m,
prompt,
chat_history,
connect_timeout
.unwrap_or(settings.text_generation_provider.connect_timeout_millis),
read_timeout.unwrap_or(settings.text_generation_provider.read_timeout_millis),
Expand All @@ -73,6 +78,7 @@ pub(crate) async fn chat(
&settings.text_generation_provider.api_url,
&m,
prompt,
chat_history,
connect_timeout
.unwrap_or(settings.text_generation_provider.connect_timeout_millis),
read_timeout.unwrap_or(settings.text_generation_provider.read_timeout_millis),
Expand All @@ -91,16 +97,17 @@ pub(crate) async fn chat(
}
}

async fn huggingface(
fn huggingface(
robot_id: &str,
m: &HuggingFaceModel,
prompt: &str,
chat_history: Option<Vec<Prompt>>,
sample_len: usize,
mut result_receiver: ResultReceiver<'_>,
) -> Result<()> {
let info = m.get_info();
// log::info!("model_type={:?}", &info.model_type);
let new_prompt = info.convert_prompt(prompt)?;
let new_prompt = info.convert_prompt(prompt, chat_history)?;
let mut model = LOADED_MODELS.lock().unwrap_or_else(|e| {
log::warn!("{:#?}", &e);
e.into_inner()
Expand Down Expand Up @@ -152,37 +159,54 @@ async fn huggingface(
async fn open_ai(
m: &str,
s: &str,
chat_history: Option<Vec<Prompt>>,
connect_timeout_millis: u32,
read_timeout_millis: u32,
proxy_url: &str,
result_receiver: ResultReceiver<'_>,
) -> Result<()> {
let mut client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(connect_timeout_millis.into()))
.read_timeout(Duration::from_millis(read_timeout_millis.into()));
if proxy_url.is_empty() {
client = client.no_proxy();
} else {
let proxy = reqwest::Proxy::http(proxy_url)?;
client = client.proxy(proxy);
}
let client = client.build()?;
let mut message0 = Map::new();
message0.insert(String::from("role"), Value::from("system"));
message0.insert(String::from("content"), Value::from("system_hint"));
let mut message1 = Map::new();
message1.insert(String::from("role"), Value::from("user"));
message1.insert(String::from("content"), Value::from(s));
let messages = Value::Array(vec![message0.into(), message1.into()]);
let mut map = Map::new();
map.insert(String::from("model"), Value::from(m));
map.insert(String::from("messages"), messages);
// let client = HTTP_CLIENTS.lock()?.entry(String::from("value")).or_insert(crate::external::http::get_client(connect_timeout_millis.into(), read_timeout_millis.into(), proxy_url)?);
let client = crate::external::http::get_client(
connect_timeout_millis.into(),
read_timeout_millis.into(),
proxy_url,
)?;
let mut req_body = Map::new();
req_body.insert(String::from("model"), Value::from(m));

let mut sys_message = Map::new();
sys_message.insert(String::from("role"), Value::from("system"));
sys_message.insert(
String::from("content"),
Value::from("You are a helpful assistant."),
);
let mut messages: Vec<Value> = match chat_history {
Some(h) if !h.is_empty() => {
let mut d = Vec::with_capacity(h.len() + 1);
d.push(Value::Null);
for p in h.into_iter() {
let mut map = Map::new();
map.insert(p.role, Value::String(p.content));
d.push(Value::from(map));
}
d
}
_ => Vec::with_capacity(1),
};
messages[0] = Value::Object(sys_message);
let mut user_message = Map::new();
user_message.insert(String::from("role"), Value::from("user"));
user_message.insert(String::from("content"), Value::from(s));
messages.push(Value::Object(user_message));
let messages = Value::Array(messages);
req_body.insert(String::from("messages"), messages);

let stream = match result_receiver {
ResultReceiver::SseSender(_) => true,
ResultReceiver::StrBuf(_) => false,
};
map.insert(String::from("stream"), Value::Bool(stream));
let obj = Value::Object(map);
req_body.insert(String::from("stream"), Value::Bool(stream));
let obj = Value::Object(req_body);
let req = client
.post("https://api.openai.com/v1/chat/completions")
.header("Content-Type", "application/json")
Expand Down Expand Up @@ -249,6 +273,7 @@ async fn ollama(
u: &str,
m: &str,
s: &str,
chat_history: Option<Vec<Prompt>>,
connect_timeout_millis: u32,
read_timeout_millis: u32,
proxy_url: &str,
Expand All @@ -268,30 +293,42 @@ async fn ollama(
if prompt.is_empty() {
return Ok(());
}
let mut client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(connect_timeout_millis.into()))
.read_timeout(Duration::from_millis(read_timeout_millis.into()));
if proxy_url.is_empty() {
client = client.no_proxy();
} else {
let proxy = reqwest::Proxy::http(proxy_url)?;
client = client.proxy(proxy);
}
let client = client.build()?;
let mut map = Map::new();
map.insert(String::from("prompt"), Value::String(prompt));
map.insert(String::from("model"), Value::String(String::from(m)));
let client = crate::external::http::get_client(
connect_timeout_millis.into(),
read_timeout_millis.into(),
proxy_url,
)?;
let mut req_body = Map::new();
req_body.insert(String::from("prompt"), Value::String(prompt));
req_body.insert(String::from("model"), Value::String(String::from(m)));
let stream = match result_receiver {
ResultReceiver::SseSender(_) => true,
ResultReceiver::StrBuf(_) => false,
};
map.insert(String::from("stream"), Value::Bool(stream));
req_body.insert(String::from("stream"), Value::Bool(stream));

let mut messages: Vec<Value> = match chat_history {
Some(h) if !h.is_empty() => {
let mut d = Vec::with_capacity(h.len() + 1);
for p in h.into_iter() {
let mut map = Map::new();
map.insert(p.role, Value::String(p.content));
d.push(Value::from(map));
}
d
}
_ => Vec::with_capacity(1),
};
let mut m = Map::new();
m.insert(String::from("user"), Value::String(String::from(s)));
messages.push(Value::from(m));
req_body.insert(String::from("messages"), Value::Array(messages));

let mut num_predict = Map::new();
num_predict.insert(String::from("num_predict"), Value::from(sample_len));
req_body.insert(String::from("options"), Value::from(num_predict));

map.insert(String::from("options"), Value::from(num_predict));
let obj = Value::Object(map);
let obj = Value::Object(req_body);
let body = serde_json::to_string(&obj)?;
// log::info!("Request Ollama body {}", &body);
let req = client.post(u).body(body);
Expand Down
35 changes: 12 additions & 23 deletions src/ai/completion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use core::time::Duration;
use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};

Expand All @@ -25,7 +24,7 @@ pub(crate) enum TextGenerationProvider {
Ollama(String),
}

#[derive(Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Prompt {
pub(crate) role: String,
pub(crate) content: String,
Expand Down Expand Up @@ -184,7 +183,7 @@ async fn huggingface(
) -> Result<()> {
let info = m.get_info();
// log::info!("model_type={:?}", &info.model_type);
let new_prompt = info.convert_prompt(prompt)?;
let new_prompt = info.convert_prompt(prompt, None)?;
let mut model = LOADED_MODELS.lock().unwrap_or_else(|e| {
log::warn!("{:#?}", &e);
e.into_inner()
Expand Down Expand Up @@ -242,16 +241,11 @@ async fn open_ai(
proxy_url: &str,
sender: &Sender<String>,
) -> Result<()> {
let mut client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(connect_timeout_millis.into()))
.read_timeout(Duration::from_millis(read_timeout_millis.into()));
if proxy_url.is_empty() {
client = client.no_proxy();
} else {
let proxy = reqwest::Proxy::http(proxy_url)?;
client = client.proxy(proxy);
}
let client = client.build()?;
let client = crate::external::http::get_client(
connect_timeout_millis.into(),
read_timeout_millis.into(),
proxy_url,
)?;
let mut message0 = Map::new();
message0.insert(String::from("role"), Value::from("system"));
message0.insert(String::from("content"), Value::from("system_hint"));
Expand Down Expand Up @@ -319,16 +313,11 @@ async fn ollama(
if prompt.is_empty() {
return Ok(());
}
let mut client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(connect_timeout_millis.into()))
.read_timeout(Duration::from_millis(read_timeout_millis.into()));
if proxy_url.is_empty() {
client = client.no_proxy();
} else {
let proxy = reqwest::Proxy::http(proxy_url)?;
client = client.proxy(proxy);
}
let client = client.build()?;
let client = crate::external::http::get_client(
connect_timeout_millis.into(),
read_timeout_millis.into(),
proxy_url,
)?;
let mut map = Map::new();
map.insert(String::from("prompt"), Value::String(prompt));
map.insert(String::from("model"), Value::String(String::from(m)));
Expand Down
32 changes: 10 additions & 22 deletions src/ai/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use core::time::Duration;

// use std::collections::VecDeque;
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
Expand Down Expand Up @@ -121,16 +119,11 @@ async fn open_ai(
read_timeout_millis: u32,
proxy_url: &str,
) -> Result<Vec<f32>> {
let mut client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(connect_timeout_millis.into()))
.read_timeout(Duration::from_millis(read_timeout_millis.into()));
if proxy_url.is_empty() {
client = client.no_proxy();
} else {
let proxy = reqwest::Proxy::http(proxy_url)?;
client = client.proxy(proxy);
}
let client = client.build()?;
let client = crate::external::http::get_client(
connect_timeout_millis.into(),
read_timeout_millis.into(),
proxy_url,
)?;
let mut map = Map::new();
map.insert(String::from("input"), Value::String(String::from(s)));
map.insert(String::from("model"), Value::String(String::from(m)));
Expand Down Expand Up @@ -174,16 +167,11 @@ async fn ollama(
read_timeout_millis: u32,
proxy_url: &str,
) -> Result<Vec<f32>> {
let mut client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(connect_timeout_millis.into()))
.read_timeout(Duration::from_millis(read_timeout_millis.into()));
if proxy_url.is_empty() {
client = client.no_proxy();
} else {
let proxy = reqwest::Proxy::http(proxy_url)?;
client = client.proxy(proxy);
}
let client = client.build()?;
let client = crate::external::http::get_client(
connect_timeout_millis.into(),
read_timeout_millis.into(),
proxy_url,
)?;
let mut map = Map::new();
map.insert(String::from("prompt"), Value::String(String::from(s)));
map.insert(String::from("model"), Value::String(String::from(m)));
Expand Down
Loading

0 comments on commit a4ebea7

Please sign in to comment.