Skip to content

Commit

Permalink
Refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
dialogflowchatbot committed Aug 7, 2024
1 parent ee63d19 commit 04052d9
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ futures-util = "0.3"
itoa = "1.0"
# jieba-rs = "0.6.7"
# lancedb = "0.4"
oasysdb = "0.7"
oasysdb = "0.7.3"
# once_cell = "1.19"
#ort = { version = "=2.0.0-rc.0", default-features = false }
redb = "2.1"
Expand Down
88 changes: 75 additions & 13 deletions src/db/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ use std::sync::OnceLock;
use std::vec::Vec;

use futures_util::StreamExt;
use sqlx::{
pool::PoolOptions,
sqlite::{SqliteArguments, SqliteRow},
Sqlite, SqlitePool,
};
use oasysdb::prelude::*;
use sqlx::{pool::PoolOptions, Row, Sqlite};

use crate::result::{Error, Result};

Expand Down Expand Up @@ -39,8 +36,35 @@ macro_rules! sql_query_one (
static DATA_SOURCE: OnceLock<SqliteConnPool> = OnceLock::new();
// static DATA_SOURCES: OnceLock<Mutex<HashMap<String, SqliteConnPool>>> = OnceLock::new();

fn get_idx_db() -> Result<Database> {
let mut p = crate::db::embedding::get_sqlite_path();
p.pop();
// let dir = std::env::temp_dir();
let db = Database::open(p, Some(get_sqlite_url()?))?;
Ok(db)
}

async fn create_idx_db(robot_id: &str) -> Result<()> {
let config = SourceConfig::new(robot_id, "id", "vectors").with_metadata(vec!["intent_id"]);
let params = ParamsIVFPQ::default();
let algorithm = IndexAlgorithm::IVFPQ(params);
get_idx_db()?
.async_create_index(robot_id, algorithm, config)
.await?;
Ok(())
}

pub(crate) fn search_idx_db(robot_id: &str, search_vector: Vector) -> Result<Vec<SearchResult>> {
let r = get_idx_db()?.search_index(robot_id, search_vector, 1, "")?;
Ok(r)
}

pub(crate) fn get_sqlite_path() -> std::path::PathBuf {
Path::new(".").join("data").join("intentev").join("e.dat")
let p = Path::new(".").join("data").join("intentev");
if !p.exists() {
std::fs::create_dir(&p).expect("Create data directory failed.");
}
p.join("e.dat")
}

pub(crate) fn get_sqlite_url() -> Result<String> {
Expand Down Expand Up @@ -126,38 +150,64 @@ pub(crate) async fn create_table(robot_id: &str) -> Result<()> {
let sql = format!(
"CREATE TABLE {} (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
intent_id TEXT NOT NULL,
vectors JSON NOT NULL
);", robot_id
);",
robot_id
);
// println!("ddl = {}", ddl);
// log::info!("sql = {}", &sql);
let mut stream = sqlx::raw_sql(&sql).execute_many(DATA_SOURCE.get().unwrap());
while let Some(res) = stream.next().await {
match res {
Ok(_r) => println!("Initialized table"),
Err(e) => log::error!("{:?}", e),
Ok(_r) => log::info!("Initialized intent table"),
Err(e) => log::error!("Create table failed, err: {:?}", e),
}
}
// let dml = include_str!("../resource/sql/dml.sql");
// if let Err(e) = sqlx::query(dml).execute(&pool).await {
// panic!("{:?}", e);
// }
Ok(())
create_idx_db(robot_id).await
}

pub(crate) async fn add(robot_id: &str, intent_id: &str, vector: &Vec<f32>) -> Result<i64> {
// check_datasource(robot_id, intent_id).await?;
let sql = format!("INSERT INTO {} (intent_id,vectors)VALUES(?)", robot_id);
let sql = format!("INSERT INTO {} (intent_id,vectors)VALUES(?,?)", robot_id);
let last_insert_rowid = sqlx::query::<Sqlite>(&sql)
.bind(intent_id)
.bind(intent_id)
.bind(serde_json::to_string(vector)?)
.execute(DATA_SOURCE.get().unwrap())
.await?
.last_insert_rowid();
get_idx_db()?.async_refresh_index(robot_id).await?;
Ok(last_insert_rowid)
// Ok(0i64)
}

pub(crate) async fn batch_add(
robot_id: &str,
intent_id: &str,
vectors: &Vec<Vec<f32>>,
) -> Result<Vec<i64>> {
// check_datasource(robot_id, intent_id).await?;
let sql = format!("INSERT INTO {} (intent_id,vectors)VALUES(?,?)", robot_id);
let mut ids: Vec<i64> = Vec::with_capacity(vectors.len());
for v in vectors.iter() {
let last_insert_rowid = sqlx::query::<Sqlite>(&sql)
.bind(intent_id)
.bind(serde_json::to_string(v)?)
.execute(DATA_SOURCE.get().unwrap())
.await?
.last_insert_rowid();
ids.push(last_insert_rowid);
}
get_idx_db()?.async_refresh_index(robot_id).await?;
Ok(ids)
// Ok(0i64)
}

pub(crate) async fn remove(robot_id: &str, id: i64) -> Result<()> {
get_idx_db()?.delete_from_index(robot_id, vec![RecordID(id as u32)])?;
let sql = format!("DELETE FROM {} WHERE id=?", robot_id);
sqlx::query::<Sqlite>(&sql)
.bind(id)
Expand All @@ -167,6 +217,17 @@ pub(crate) async fn remove(robot_id: &str, id: i64) -> Result<()> {
}

pub(crate) async fn remove_by_intent_id(robot_id: &str, intent_id: &str) -> Result<()> {
let sql = format!("SELECT id FROM {} WHERE intent_id=?", robot_id);
let results = sqlx::query::<Sqlite>(&sql)
.bind(intent_id)
.fetch_all(DATA_SOURCE.get().unwrap())
.await?;
let mut ids: Vec<RecordID> = Vec::with_capacity(results.len());
for r in results.iter() {
ids.push(RecordID(r.try_get(0)?));
}
get_idx_db()?.delete_from_index(robot_id, ids)?;

let sql = format!("DELETE FROM {} WHERE intent_id=?", robot_id);
sqlx::query::<Sqlite>(&sql)
.bind(intent_id)
Expand All @@ -176,6 +237,7 @@ pub(crate) async fn remove_by_intent_id(robot_id: &str, intent_id: &str) -> Resu
}

pub(crate) async fn remove_table(robot_id: &str) -> Result<()> {
get_idx_db()?.delete_index(robot_id)?;
let sql = format!("DROP TABLE {}", robot_id);
sqlx::query::<Sqlite>(&sql)
.execute(DATA_SOURCE.get().unwrap())
Expand Down
6 changes: 3 additions & 3 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ macro_rules! db_executor (
);

pub(crate) static DB: LazyLock<Database> = LazyLock::new(|| {
let data_folder = std::path::Path::new("./data");
let data_folder = std::path::Path::new(".").join("data");
if !data_folder.exists() {
std::fs::create_dir(data_folder).expect("Create data directory failed.");
}
Expand All @@ -45,7 +45,7 @@ pub(crate) static DB: LazyLock<Database> = LazyLock::new(|| {
}
});

pub(crate) fn init() -> Result<GlobalSettings> {
pub(crate) async fn init() -> Result<GlobalSettings> {
let is_en = *server::IS_EN;

// Settings
Expand All @@ -55,7 +55,7 @@ pub(crate) fn init() -> Result<GlobalSettings> {
return Ok(settings::get_global_settings()?.unwrap());
}
let settings = settings::init_global()?;
robot::init(is_en)?;
robot::init(is_en).await?;
// 流程上下文
context::init()?;
Ok(settings)
Expand Down
27 changes: 17 additions & 10 deletions src/intent/crud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ pub(crate) fn init(robot_id: &str, is_en: bool) -> Result<()> {
]
};
let regexes: Vec<&str> = vec![];
let mut intent = Intent::new(if is_en { "Positive" } else { "肯定" });
let intent_detail = IntentDetail {
intent_idx: 0,
intent_id: intent.id.clone(),
intent_name: intent.name.clone(),
keywords: keywords.into_iter().map(|s| String::from(s)).collect(),
regexes: regexes.into_iter().map(|s| String::from(s)).collect(),
phrases: vec![],
};
let mut intent = Intent::new(if is_en { "Positive" } else { "肯定" });
intent.keyword_num = intent_detail.keywords.len();
intent.regex_num = intent_detail.regexes.len();

Expand Down Expand Up @@ -116,13 +118,15 @@ pub(crate) fn init(robot_id: &str, is_en: bool) -> Result<()> {
]
};
let regexes: Vec<&str> = vec![];
let mut intent = Intent::new(if is_en { "Negative" } else { "否定" });
let intent_detail = IntentDetail {
intent_idx: 1,
intent_id: intent.id.clone(),
intent_name: intent.name.clone(),
keywords: keywords.into_iter().map(|s| String::from(s)).collect(),
regexes: regexes.into_iter().map(|s| String::from(s)).collect(),
phrases: vec![],
};
let mut intent = Intent::new(if is_en { "Negative" } else { "否定" });
intent.keyword_num = intent_detail.keywords.len();
intent.regex_num = intent_detail.regexes.len();

Expand Down Expand Up @@ -171,6 +175,8 @@ fn add_intent(robot_id: &str, intent_name: &str) -> Result<()> {

let intent_detail = IntentDetail {
intent_idx,
intent_id: intent.id.clone(),
intent_name: intent.name.clone(),
keywords: vec![],
regexes: vec![],
phrases: vec![],
Expand Down Expand Up @@ -236,6 +242,13 @@ pub(crate) async fn remove(Json(params): Json<IntentFormData>) -> impl IntoRespo
to_res(r)
}

pub(in crate::intent) fn get_detail_by_id(
robot_id: &str,
intent_id: &str,
) -> Result<Option<IntentDetail>> {
db_executor!(db::query, robot_id, TABLE_SUFFIX, intent_id)
}

pub(crate) async fn detail(Query(params): Query<IntentFormData>) -> impl IntoResponse {
// let mut od: Option<IntentDetail> = None;
// let r = db::process_data(dbg!(params.id.as_str()), |d: &mut IntentDetail| {
Expand All @@ -244,12 +257,7 @@ pub(crate) async fn detail(Query(params): Query<IntentFormData>) -> impl IntoRes
// }).map(|_| od);
// to_res(r)
// let r: Result<Option<IntentDetail>> = db::query(TABLE, params.id.as_str());
let r: Result<Option<IntentDetail>> = db_executor!(
db::query,
&params.robot_id,
TABLE_SUFFIX,
params.id.as_str()
);
let r = get_detail_by_id(&params.robot_id, params.id.as_str());
to_res(r)
}

Expand Down Expand Up @@ -387,8 +395,7 @@ pub(crate) async fn add_phrase(
let mut d = r.unwrap();
let r = detector::save_intent_embedding(&params.robot_id, key, &params.data).await;
if r.is_err() {
// return to_res(r.map(|_| ()));
return to_res(Ok(()));
return to_res(r.map(|_| ()));
}
// let r:Result<i64> = Ok(0i64);
d.phrases.push(IntentPhraseData {
Expand Down
35 changes: 9 additions & 26 deletions src/intent/detector.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
use oasysdb::prelude::*;

use crate::ai::embedding::embedding;
use crate::db::embedding as embedding_db;
use crate::result::{Error, Result};

fn get_embedding_db() -> Result<Database> {
log::info!("1111111");
let mut p = crate::db::embedding::get_sqlite_path();
p.pop();
let dir = std::env::temp_dir();
println!("Temporary directory: {}", dir.display());
let db = Database::open(p, Some(embedding_db::get_sqlite_url()?))?;
log::info!("2222222");
Ok(db)
}

pub(crate) async fn detect(robot_id: &str, s: &str) -> Result<Option<String>> {
// let now = std::time::Instant::now();
let embedding = embedding(robot_id, s).await?;
Expand All @@ -23,7 +10,7 @@ pub(crate) async fn detect(robot_id: &str, s: &str) -> Result<Option<String>> {
// log::info!("detect embedding {}", regex.replace_all(&s, ""));
let search_vector: Vec<f32> = embedding.0.into();
let similarity_threshold = embedding.1;
let result = get_embedding_db()?.search_index(robot_id, search_vector, 1, "")?;
let result = embedding_db::search_idx_db(robot_id, search_vector.into())?;
// println!("inner intent detect {:?}", now.elapsed());
if result.len() == 0 {
if let Some(record) = result.get(0) {
Expand All @@ -32,7 +19,8 @@ pub(crate) async fn detect(robot_id: &str, s: &str) -> Result<Option<String>> {
if let Some(data) = record.data.get("intent_id") {
if let Some(metadata) = data {
if let oasysdb::types::record::DataValue::String(s) = metadata {
return Ok(Some(String::from(s)));
let intent = super::crud::get_detail_by_id(robot_id, s)?;
return Ok(intent.map(|i| i.intent_name));
}
}
}
Expand All @@ -50,7 +38,6 @@ pub(crate) async fn save_intent_embedding(robot_id: &str, intent_id: &str, s: &s
return Err(Error::ErrorWithMessage(err));
}
let id = embedding_db::add(robot_id, intent_id, &embedding.0).await?;
get_embedding_db()?.refresh_index(robot_id)?;
Ok(id)
}

Expand All @@ -59,29 +46,25 @@ pub(crate) async fn save_intent_embeddings(
intent_id: &str,
array: Vec<&str>,
) -> Result<()> {
// let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(array.len());
embedding_db::remove_by_intent_id(robot_id, intent_id).await?;
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(array.len());
for &s in array.iter() {
let embedding = embedding(robot_id, s).await?;
if embedding.0.is_empty() {
let err = format!("{s} embedding data is empty");
log::warn!("{}", &err);
} else {
// embeddings.push(embedding.0);
embedding_db::add(robot_id, intent_id, &embedding.0).await?;
embeddings.push(embedding.0);
}
}
// if embeddings.is_empty() {
// return Err(Error::ErrorWithMessage(String::from(
// "No embeddings were generated.",
// )));
// }
embedding_db::remove_by_intent_id(robot_id, intent_id).await?;
let db = get_embedding_db()?;
db.delete_index(robot_id)?;
let config = SourceConfig::new(robot_id, "id", "vectors").with_metadata(vec!["intent_id"]);
let params = ParamsIVFPQ::default();
let algorithm = IndexAlgorithm::IVFPQ(params);
db.create_index(robot_id, algorithm, config)?;
if !embeddings.is_empty() {
embedding_db::batch_add(robot_id, intent_id, &embeddings).await?;
}
Ok(())
}

Expand Down
2 changes: 2 additions & 0 deletions src/intent/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub(crate) struct IntentPhraseData {
#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct IntentDetail {
pub(crate) intent_idx: usize,
pub(crate) intent_id: String,
pub(crate) intent_name: String,
pub(crate) keywords: Vec<String>,
pub(crate) regexes: Vec<String>,
pub(crate) phrases: Vec<IntentPhraseData>,
Expand Down
Loading

0 comments on commit 04052d9

Please sign in to comment.