Skip to content

Commit

Permalink
Now can regenerate all embeddings for a specific intention
Browse files Browse the repository at this point in the history
  • Loading branch information
dialogflowchatbot committed Jul 22, 2024
1 parent 208505d commit a0e80f8
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 11 deletions.
3 changes: 1 addition & 2 deletions src/ai/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use std::vec::Vec;

use candle::{IndexOp, Tensor};
use candle_transformers::models::bert::BertModel;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -88,7 +87,7 @@ fn hugging_face(robot_id: &str, info: &HuggingFaceModelInfo, s: &str) -> Result<
let outputs = m.forward(&token_ids, &token_type_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = outputs.dims3()?;
let embeddings = (outputs.sum(1)? / (n_tokens as f64))?;
// let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let r = embeddings.i(0)?.to_vec1::<f32>()?;
Ok(r)
}
Expand Down
11 changes: 10 additions & 1 deletion src/ai/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ pub(crate) fn load_llama_model_files(
info: &HuggingFaceModelInfo,
device: &Device,
) -> Result<(Llama, LlamaCache, Tokenizer, Option<u32>)> {
log::info!("load_llama_model_files start");
let tokenizer = init_tokenizer(&info.repository)?;

let config_filename = construct_model_file_path(&info.repository, "config.json");
Expand All @@ -823,7 +824,9 @@ pub(crate) fn load_llama_model_files(
let dtype = DType::F16;
let cache = LlamaCache::new(true, dtype, &config, device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, device)? };
Ok((Llama::load(vb, &config)?, cache, tokenizer, eos_token_id))
let m = Llama::load(vb, &config)?;
log::info!("load_llama_model_files end");
Ok((m, cache, tokenizer, eos_token_id))
}

pub(crate) fn load_gemma_model_files(
Expand All @@ -844,3 +847,9 @@ pub(crate) fn load_gemma_model_files(
let model = GemmaModel::new(device.is_cuda(), &config, vb)?;
Ok((model, tokenizer))
}

pub(crate) fn load_pytorch_mode_files(info: &HuggingFaceModelInfo, device: &Device) -> Result<()> {
let weights_filename = construct_model_file_path(&info.repository, "pytorch_model.bin");
let vb = VarBuilder::from_pth(&weights_filename, DType::BF16, device)?;
Ok(())
}
23 changes: 21 additions & 2 deletions src/intent/crud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ pub(crate) async fn add_phrase(
if r.is_err() {
return to_res(r.map(|_| ()));
}
println!("1");
let r = r.unwrap();
if r.is_none() {
return to_res(Err(Error::ErrorWithMessage(String::from(
Expand All @@ -382,7 +381,6 @@ pub(crate) async fn add_phrase(
if r.is_err() {
return to_res(r.map(|_| ()));
}
println!("2");
d.phrases.push(IntentPhraseData {
id: r.unwrap(),
phrase: String::from(params.data.as_str()),
Expand Down Expand Up @@ -437,3 +435,24 @@ pub(crate) async fn remove_phrase(Json(params): Json<IntentFormData>) -> impl In
pub(crate) async fn detect(Json(params): Json<IntentFormData>) -> impl IntoResponse {
to_res(detector::detect(&params.robot_id, &params.data).await)
}

pub(crate) async fn regenerate_embeddings(
Query(params): Query<IntentFormData>,
) -> impl IntoResponse {
let key = params.id.as_str();
let r: Result<Option<IntentDetail>> =
db_executor!(db::query, &params.robot_id, TABLE_SUFFIX, key);
if r.is_err() {
return to_res(r.map(|_| ()));
}
let r = r.unwrap();
if r.is_none() {
return to_res(Err(Error::ErrorWithMessage(String::from(
"Can NOT find intention detail",
))));
}
let d = r.unwrap();
let array: Vec<&str> = d.phrases.iter().map(|v| v.phrase.as_ref()).collect();
let r = detector::save_intent_embeddings(&params.robot_id, &params.data, array).await;
to_res(r)
}
42 changes: 39 additions & 3 deletions src/intent/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ pub(crate) async fn detect(robot_id: &str, s: &str) -> Result<Option<String>> {
search_vector = Some(embedding(robot_id, s).await?.into());
}
if search_vector.is_some() {
println!("1");
let results = collection.search(search_vector.as_ref().unwrap(), 5)?;
println!("{}", results.len());
for r in results.iter() {
log::info!("r.distance={}", r.distance);
if r.distance >= 0.9 {
if r.distance <= 0.15 {
return Ok(Some(i.name.clone()));
}
}
Expand Down Expand Up @@ -100,7 +99,9 @@ pub(crate) async fn save_intent_embedding(
if is_col_not_found_err(&e) {
let mut config = Config::default();
config.distance = Distance::Cosine;
Collection::new(&config)
let mut collection = Collection::new(&config);
collection.set_dimension(embedding.len())?;
collection
} else {
return Err(e.into());
}
Expand All @@ -118,6 +119,41 @@ pub(crate) async fn save_intent_embedding(
Ok(r.to_usize())
}

pub(crate) async fn save_intent_embeddings(
robot_id: &str,
intent_id: &str,
array: Vec<&str>,
) -> Result<()> {
delete_all_embeddings(robot_id, intent_id)?;
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(array.len());
for &s in array.iter() {
let embedding = embedding(robot_id, s).await?;
if embedding.is_empty() {
let err = format!("{s} embedding data is empty");
log::warn!("{}", &err);
} else {
embeddings.push(embedding);
}
}
if embeddings.is_empty() {
return Err(Error::ErrorWithMessage(String::from(
"No embeddings were generated.",
)));
}
let mut db = Database::open(&format!("{}{}", SAVING_PATH_ROOT, robot_id))?;
// log::info!("{:#?}", &embedding);
// let records = Record::many_random(128, 5);
// log::info!("Gened {}", records.get(0).unwrap().vector.0.get(0).unwrap());
let vectors: Vec<Vector> = embeddings.iter().map(|d| d.into()).collect();
let records: Vec<Record> = vectors.iter().map(|v| Record::new(v, &"".into())).collect();
let mut config = Config::default();
config.distance = Distance::Cosine;
let collection = Collection::build(&config, &records).unwrap();
db.save_collection(intent_id, &collection)?;
db.flush()?;
Ok(())
}

pub(crate) fn delete_intent_embedding(robot_id: &str, intent_id: &str, id: usize) -> Result<()> {
let mut db = Database::open(&format!("{}{}", SAVING_PATH_ROOT, robot_id))?;
let mut collection = db.get_collection(intent_id)?;
Expand Down
6 changes: 3 additions & 3 deletions src/man/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl Default for GlobalSettings {
port: 12715,
select_random_port_when_conflict: false,
hf_model_download: HfModelDownload {
connect_timeout_millis: 1000,
connect_timeout_millis: 2000,
read_timeout_millis: 10000,
access_token: String::new(),
},
Expand All @@ -148,7 +148,7 @@ impl Default for Settings {
api_url: String::new(),
api_key: String::new(),
model: String::new(),
connect_timeout_millis: 200,
connect_timeout_millis: 2000,
read_timeout_millis: 10000,
max_response_token_length: 10,
},
Expand All @@ -159,7 +159,7 @@ impl Default for Settings {
api_url: String::new(),
api_key: String::new(),
model: String::new(),
connect_timeout_millis: 200,
connect_timeout_millis: 2000,
read_timeout_millis: 10000,
},
smtp_host: String::new(),
Expand Down
4 changes: 4 additions & 0 deletions src/web/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ fn gen_router() -> Router {
"/intent/phrase",
post(intent::add_phrase).delete(intent::remove_phrase),
)
.route(
"/intent/phrase/regenerate-all",
get(intent::regenerate_embeddings),
)
.route(
"/variable",
get(variable::list)
Expand Down

0 comments on commit a0e80f8

Please sign in to comment.