From 4ea8c97dfdbaf40b45cb61f82d61a2006fa653ab Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 30 Sep 2024 16:51:18 -0400 Subject: [PATCH 1/4] feat: update VectorStoreIndex trait --- rig-core/examples/vector_search.rs | 6 +- rig-core/examples/vector_search_cohere.rs | 6 +- rig-core/src/agent.rs | 26 ++-- rig-core/src/vector_store/in_memory_store.rs | 31 ++-- rig-core/src/vector_store/mod.rs | 142 ++++-------------- rig-mongodb/examples/vector_search_mongodb.rs | 4 +- rig-mongodb/src/lib.rs | 21 +-- 7 files changed, 71 insertions(+), 165 deletions(-) diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index b664c6cd..0c0f5b0e 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::Client, vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, }; @@ -24,10 +24,10 @@ async fn main() -> Result<(), anyhow::Error> { let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 7e9226af..144e716f 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::cohere::Client, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; @@ -29,10 +29,10 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index(search_model); let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 1607d339..376f18a1 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -172,16 +172,16 @@ impl Completion for Agent { .then(|(num_sample, index)| async { Ok::<_, VectorStoreError>( index - .top_n_from_query(prompt, *num_sample) + .top_n(prompt, *num_sample) .await? .into_iter() - .map(|(_, doc)| { + .map(|(_, id, doc)| { // Pretty print the document if possible for better readability - let doc_text = serde_json::to_string_pretty(&doc.document) - .unwrap_or_else(|_| doc.document.to_string()); + let doc_text = serde_json::to_string_pretty(&doc) + .unwrap_or_else(|_| doc.to_string()); Document { - id: doc.id, + id, text: doc_text, additional_props: HashMap::new(), } @@ -198,14 +198,14 @@ impl Completion for Agent { let dynamic_tools = stream::iter(self.dynamic_tools.iter()) .then(|(num_sample, index)| async { - Ok::<_, VectorStoreError>( - index - .top_n_ids_from_query(prompt, *num_sample) - .await? - .into_iter() - .map(|(_, doc)| doc) - .collect::>(), - ) + index + .top_n(prompt, *num_sample) + .await? + .into_iter() + .map(|(_, _, doc)| { + serde_json::to_string(&doc).map_err(VectorStoreError::JsonError) + }) + .collect::, _>>() }) .try_fold(vec![], |mut acc, docs| async { for doc in docs { diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 02c19cf8..859b4eb1 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,7 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; +use crate::embeddings::{DocumentEmbeddings, EmbeddingModel, EmbeddingsBuilder}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. @@ -198,20 +198,13 @@ impl InMemoryVectorIndex { } impl VectorStoreIndex for InMemoryVectorIndex { - async fn top_n_from_query( + async fn top_n Deserialize<'a>>( &self, query: &str, n: usize, - ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n).await - } + ) -> Result, VectorStoreError> { + let prompt_embedding = &self.model.embed_document(query).await?; - async fn top_n_from_embedding( - &self, - query_embedding: &Embedding, - n: usize, - ) -> Result, VectorStoreError> { // Sort documents by best embedding distance let mut docs: EmbeddingRanking = BinaryHeap::new(); @@ -222,7 +215,7 @@ impl VectorStoreIndex for InMemoryVectorI .iter() .map(|embedding| { ( - OrderedFloat(embedding.distance(query_embedding)), + OrderedFloat(embedding.distance(prompt_embedding)), &embedding.document, ) }) @@ -252,9 +245,15 @@ impl VectorStoreIndex for InMemoryVectorI ); // Return n best - Ok(docs - .into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| (distance.0, doc.clone())) - .collect()) + docs.into_iter() + .map(|Reverse(RankingItem(distance, _, doc, _))| { + let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; + Ok(( + distance.0, + doc.id.clone(), + serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, + )) + }) + .collect::, _>>() } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 2e6a652f..c743bd03 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,7 +1,8 @@ use futures::future::BoxFuture; use serde::Deserialize; +use serde_json::Value; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingError}; +use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; pub mod in_memory_store; @@ -50,167 +51,78 @@ pub trait VectorStore: Send + Sync { /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { - /// Get the top n documents based on the distance to the given embedding. - /// The distance is calculated as the cosine distance between the prompt and - /// the document embedding. + /// Get the top n documents based on the distance to the given query. /// The result is a list of tuples with the distance and the document. - fn top_n_from_query( + fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send; - - /// Same as `top_n_from_query` but returns the documents without its embeddings. - /// The documents are deserialized into the given type. - fn top_n_documents_from_query Deserialize<'a>>( - &self, - query: &str, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send { - async move { - let documents = self.top_n_from_query(query, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) - .collect()) - } - } + ) -> impl std::future::Future, VectorStoreError>> + Send; /// Same as `top_n_from_query` but returns the document ids only. - fn top_n_ids_from_query( + fn top_n_ids Deserialize<'a> + std::marker::Send>( &self, query: &str, n: usize, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_query(query, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, doc.id)) - .collect()) - } - } - - /// Get the top n documents based on the distance to the given embedding. - /// The distance is calculated as the cosine distance between the prompt and - /// the document embedding. - /// The result is a list of tuples with the distance and the document. - fn top_n_from_embedding( - &self, - prompt_embedding: &Embedding, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send; - - /// Same as `top_n_from_embedding` but returns the documents without its embeddings. - /// The documents are deserialized into the given type. - fn top_n_documents_from_embedding Deserialize<'a>>( - &self, - prompt_embedding: &Embedding, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send { - async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) - .collect()) - } - } - - /// Same as `top_n_from_embedding` but returns the document ids only. - fn top_n_ids_from_embedding( - &self, - prompt_embedding: &Embedding, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send - { - async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; - Ok(documents + Ok(self + .top_n::(query, n) + .await? .into_iter() - .map(|(distance, doc)| (distance, doc.id)) + .map(|(distance, id, _)| (distance, id)) .collect()) } } } +pub type TopNResults = Result, VectorStoreError>; + pub trait VectorStoreIndexDyn: Send + Sync { - fn top_n_from_query<'a>( - &'a self, - query: &'a str, - n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>>; + fn top_n<'a>(&'a self, query: &'a str, n: usize) -> BoxFuture<'a, TopNResults>; - fn top_n_ids_from_query<'a>( + fn top_n_ids<'a>( &'a self, query: &'a str, n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - let documents = self.top_n_from_query(query, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, doc.id)) - .collect()) - }) - } - - fn top_n_from_embedding<'a>( - &'a self, - prompt_embedding: &'a Embedding, - n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>>; - - fn top_n_ids_from_embedding<'a>( - &'a self, - prompt_embedding: &'a Embedding, - n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, doc.id)) - .collect()) - }) - } + ) -> BoxFuture<'a, Result, VectorStoreError>>; } impl VectorStoreIndexDyn for I { - fn top_n_from_query<'a>( + fn top_n<'a>( &'a self, query: &'a str, n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(self.top_n_from_query(query, n)) + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(self.top_n(query, n)) } - fn top_n_from_embedding<'a>( + fn top_n_ids<'a>( &'a self, - prompt_embedding: &'a Embedding, + query: &'a str, n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(self.top_n_from_embedding(prompt_embedding, n)) + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(self.top_n_ids::(query, n)) } } pub struct NoIndex; impl VectorStoreIndex for NoIndex { - async fn top_n_from_query( + async fn top_n Deserialize<'a>>( &self, _query: &str, _n: usize, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { Ok(vec![]) } - async fn top_n_from_embedding( + async fn top_n_ids Deserialize<'a>>( &self, - _prompt_embedding: &Embedding, + _query: &str, _n: usize, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { Ok(vec![]) } } diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index a39e7c93..0acb6d39 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -53,10 +53,10 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 7d85201f..596fea98 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,9 +2,10 @@ use futures::StreamExt; use mongodb::bson::doc; use rig::{ - embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, + embeddings::{DocumentEmbeddings, EmbeddingModel}, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; +use serde::Deserialize; /// A MongoDB vector store. pub struct MongoDbVectorStore { @@ -124,20 +125,13 @@ impl MongoDbVectorIndex { } impl VectorStoreIndex for MongoDbVectorIndex { - async fn top_n_from_query( + async fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, n: usize, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n).await - } - async fn top_n_from_embedding( - &self, - prompt_embedding: &Embedding, - n: usize, - ) -> Result, VectorStoreError> { let mut cursor = self .collection .aggregate( @@ -168,14 +162,15 @@ impl VectorStoreIndex for MongoDbV while let Some(doc) = cursor.next().await { let doc = doc.map_err(mongodb_to_rig_error)?; let score = doc.get("score").expect("score").as_f64().expect("f64"); - let document: DocumentEmbeddings = serde_json::from_value(doc).expect("document"); - results.push((score, document)); + let id = doc.get("_id").expect("_id").to_string(); + let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?; + results.push((score, id, doc_t)); } tracing::info!(target: "rig", "Selected documents: {}", results.iter() - .map(|(distance, doc)| format!("{} ({})", doc.id, distance)) + .map(|(distance, id, _)| format!("{} ({})", id, distance)) .collect::>() .join(", ") ); From abc565fad37676a54719388e8ec1c8743080cb59 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 1 Oct 2024 09:57:43 -0400 Subject: [PATCH 2/4] refactor: rename variable in completion trait --- rig-core/src/agent.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 376f18a1..8b9c1a23 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -177,12 +177,12 @@ impl Completion for Agent { .into_iter() .map(|(_, id, doc)| { // Pretty print the document if possible for better readability - let doc_text = serde_json::to_string_pretty(&doc) + let text = serde_json::to_string_pretty(&doc) .unwrap_or_else(|_| doc.to_string()); Document { id, - text: doc_text, + text, additional_props: HashMap::new(), } }) From 0671f56ff93b0bea291e3d285bd7769130d0786f Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 1 Oct 2024 14:19:44 -0400 Subject: [PATCH 3/4] refactor: make PR requested changes --- Cargo.lock | 19 ++++ rig-core/examples/vector_search.rs | 9 ++ rig-core/src/agent.rs | 16 +-- rig-core/src/vector_store/in_memory_store.rs | 103 +++++++++++------- rig-core/src/vector_store/mod.rs | 40 +------ rig-mongodb/Cargo.toml | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 11 +- rig-mongodb/src/lib.rs | 93 +++++++++++++--- 8 files changed, 192 insertions(+), 101 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 76e03fd9..67227bf3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -639,6 +639,20 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http", + "hyper", + "rustls", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -872,6 +886,7 @@ dependencies = [ "pbkdf2", "percent-encoding", "rand", + "reqwest", "rustc_version_runtime", "rustls", "rustls-pemfile", @@ -1174,6 +1189,7 @@ dependencies = [ "http", "http-body", "hyper", + "hyper-rustls", "hyper-tls", "ipnet", "js-sys", @@ -1183,6 +1199,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", + "rustls", "rustls-pemfile", "serde", "serde_json", @@ -1191,11 +1208,13 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", + "tokio-rustls", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", + "webpki-roots", "winreg", ] diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 0c0f5b0e..c2c97407 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -32,5 +32,14 @@ async fn main() -> Result<(), anyhow::Error> { println!("Results: {:?}", results); + let id_results = index + .top_n_ids("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id)| (score, id)) + .collect::>(); + + println!("ID results: {:?}", id_results); + Ok(()) } diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 8b9c1a23..949b1114 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -198,14 +198,14 @@ impl Completion for Agent { let dynamic_tools = stream::iter(self.dynamic_tools.iter()) .then(|(num_sample, index)| async { - index - .top_n(prompt, *num_sample) - .await? - .into_iter() - .map(|(_, _, doc)| { - serde_json::to_string(&doc).map_err(VectorStoreError::JsonError) - }) - .collect::, _>>() + Ok::<_, VectorStoreError>( + index + .top_n_ids(prompt, *num_sample) + .await? + .into_iter() + .map(|(_, id)| id) + .collect::>(), + ) }) .try_fold(vec![], |mut acc, docs| async { for doc in docs { diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 859b4eb1..a5db505f 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,7 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{DocumentEmbeddings, EmbeddingModel, EmbeddingsBuilder}; +use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. @@ -18,6 +18,53 @@ pub struct InMemoryVectorStore { embeddings: HashMap, } +impl InMemoryVectorStore { + /// Implement vector search on InMemoryVectorStore. + /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + // Sort documents by best embedding distance + let mut docs: EmbeddingRanking = BinaryHeap::new(); + + for (id, doc_embeddings) in self.embeddings.iter() { + // Get the best context for the document given the prompt + if let Some((distance, embed_doc)) = doc_embeddings + .embeddings + .iter() + .map(|embedding| { + ( + OrderedFloat(embedding.distance(prompt_embedding)), + &embedding.document, + ) + }) + .min_by(|a, b| a.0.cmp(&b.0)) + { + docs.push(Reverse(RankingItem( + distance, + id, + doc_embeddings, + embed_doc, + ))); + }; + + // If the heap size exceeds n, pop the least old element. + if docs.len() > n { + docs.pop(); + } + } + + // Log selected tools with their distances + tracing::info!(target: "rig", + "Selected documents: {}", + docs.iter() + .map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance)) + .collect::>() + .join(", ") + ); + + docs + } +} + /// RankingItem(distance, document_id, document, embed_doc) #[derive(Eq, PartialEq)] struct RankingItem<'a>( @@ -205,44 +252,7 @@ impl VectorStoreIndex for InMemoryVectorI ) -> Result, VectorStoreError> { let prompt_embedding = &self.model.embed_document(query).await?; - // Sort documents by best embedding distance - let mut docs: EmbeddingRanking = BinaryHeap::new(); - - for (id, doc_embeddings) in self.store.embeddings.iter() { - // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = doc_embeddings - .embeddings - .iter() - .map(|embedding| { - ( - OrderedFloat(embedding.distance(prompt_embedding)), - &embedding.document, - ) - }) - .min_by(|a, b| a.0.cmp(&b.0)) - { - docs.push(Reverse(RankingItem( - distance, - id, - doc_embeddings, - embed_doc, - ))); - }; - - // If the heap size exceeds n, pop the least old element. - if docs.len() > n { - docs.pop(); - } - } - - // Log selected tools with their distances - tracing::info!(target: "rig", - "Selected documents: {}", - docs.iter() - .map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance)) - .collect::>() - .join(", ") - ); + let docs = self.store.vector_search(prompt_embedding, n); // Return n best docs.into_iter() @@ -256,4 +266,19 @@ impl VectorStoreIndex for InMemoryVectorI }) .collect::, _>>() } + + async fn top_n_ids( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + let prompt_embedding = &self.model.embed_document(query).await?; + + let docs = self.store.vector_search(prompt_embedding, n); + + // Return n best + docs.into_iter() + .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) + .collect::, _>>() + } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index c743bd03..b07d348a 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -52,29 +52,19 @@ pub trait VectorStore: Send + Sync { /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { /// Get the top n documents based on the distance to the given query. - /// The result is a list of tuples with the distance and the document. + /// The result is a list of tuples of the form (score, id, document) fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, n: usize, ) -> impl std::future::Future, VectorStoreError>> + Send; - /// Same as `top_n_from_query` but returns the document ids only. - fn top_n_ids Deserialize<'a> + std::marker::Send>( + /// Same as `top_n` but returns the document ids only. + fn top_n_ids( &self, query: &str, n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send - { - async move { - Ok(self - .top_n::(query, n) - .await? - .into_iter() - .map(|(distance, id, _)| (distance, id)) - .collect()) - } - } + ) -> impl std::future::Future, VectorStoreError>> + Send; } pub type TopNResults = Result, VectorStoreError>; @@ -103,26 +93,6 @@ impl VectorStoreIndexDyn for I { query: &'a str, n: usize, ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(self.top_n_ids::(query, n)) - } -} - -pub struct NoIndex; - -impl VectorStoreIndex for NoIndex { - async fn top_n Deserialize<'a>>( - &self, - _query: &str, - _n: usize, - ) -> Result, VectorStoreError> { - Ok(vec![]) - } - - async fn top_n_ids Deserialize<'a>>( - &self, - _query: &str, - _n: usize, - ) -> Result, VectorStoreError> { - Ok(vec![]) + Box::pin(self.top_n_ids(query, n)) } } diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 05d259ac..e61e137b 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -11,7 +11,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] futures = "0.3.30" -mongodb = "2.8.2" +mongodb = { version = "2.8.2", features = ["aws-auth"] } rig-core = { path = "../rig-core", version = "0.1.0" } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 0acb6d39..5bdc26e6 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -49,7 +49,7 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "context_vector_index", doc! {}); + let index = vector_store.index(model, "vector_index", doc! {}); // Query the index let results = index @@ -61,5 +61,14 @@ async fn main() -> Result<(), anyhow::Error> { println!("Results: {:?}", results); + let id_results = index + .top_n_ids("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id)| (score, id)) + .collect::>(); + + println!("ID results: {:?}", id_results); + Ok(()) } diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 596fea98..9939d668 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -1,8 +1,8 @@ use futures::StreamExt; -use mongodb::bson::doc; +use mongodb::bson::{self, doc}; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingModel}, + embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; @@ -108,6 +108,33 @@ pub struct MongoDbVectorIndex { filter: mongodb::bson::Document, } +impl MongoDbVectorIndex { + /// Vector search stage of aggregation pipeline of mongoDB collection. + /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. + fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { + doc! { + "$vectorSearch": { + "index": &self.index_name, + "path": "embeddings.vec", + "queryVector": &prompt_embedding.vec, + "numCandidates": (n * 10) as u32, + "limit": n as u32, + "filter": &self.filter, + } + } + } + + /// Score declaration stage of aggregation pipeline of mongoDB collection. + /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. + fn pipeline_score_stage(&self) -> bson::Document { + doc! { + "$addFields": { + "score": { "$meta": "vectorSearchScore" } + } + } + } +} + impl MongoDbVectorIndex { pub fn new( collection: mongodb::Collection, @@ -136,21 +163,8 @@ impl VectorStoreIndex for MongoDbV .collection .aggregate( [ - doc! { - "$vectorSearch": { - "index": &self.index_name, - "path": "embeddings.vec", - "queryVector": &prompt_embedding.vec, - "numCandidates": (n * 10) as u32, - "limit": n as u32, - "filter": &self.filter, - } - }, - doc! { - "$addFields": { - "score": { "$meta": "vectorSearchScore" } - } - }, + self.pipeline_search_stage(&prompt_embedding, n), + self.pipeline_score_stage(), ], None, ) @@ -177,4 +191,49 @@ impl VectorStoreIndex for MongoDbV Ok(results) } + + async fn top_n_ids( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + let prompt_embedding = self.model.embed_document(query).await?; + + let mut cursor = self + .collection + .aggregate( + [ + self.pipeline_search_stage(&prompt_embedding, n), + self.pipeline_score_stage(), + doc! { + "$project": { + "_id": 1, + "score": 1 + }, + }, + ], + None, + ) + .await + .map_err(mongodb_to_rig_error)? + .with_type::(); + + let mut results = Vec::new(); + while let Some(doc) = cursor.next().await { + let doc = doc.map_err(mongodb_to_rig_error)?; + let score = doc.get("score").expect("score").as_f64().expect("f64"); + let id = doc.get("_id").expect("_id").to_string(); + results.push((score, id)); + } + + tracing::info!(target: "rig", + "Selected documents: {}", + results.iter() + .map(|(distance, id)| format!("{} ({})", id, distance)) + .collect::>() + .join(", ") + ); + + Ok(results) + } } From 8cd7565416ca58aada8b9ab0deb81fd1e52f1e5d Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 1 Oct 2024 15:47:48 -0400 Subject: [PATCH 4/4] fix: remove mongodb aws-auth feature --- Cargo.lock | 19 ------------------- rig-mongodb/Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 67227bf3..76e03fd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -639,20 +639,6 @@ dependencies = [ "want", ] -[[package]] -name = "hyper-rustls" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" -dependencies = [ - "futures-util", - "http", - "hyper", - "rustls", - "tokio", - "tokio-rustls", -] - [[package]] name = "hyper-tls" version = "0.5.0" @@ -886,7 +872,6 @@ dependencies = [ "pbkdf2", "percent-encoding", "rand", - "reqwest", "rustc_version_runtime", "rustls", "rustls-pemfile", @@ -1189,7 +1174,6 @@ dependencies = [ "http", "http-body", "hyper", - "hyper-rustls", "hyper-tls", "ipnet", "js-sys", @@ -1199,7 +1183,6 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", "rustls-pemfile", "serde", "serde_json", @@ -1208,13 +1191,11 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", - "tokio-rustls", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", "winreg", ] diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index e61e137b..05d259ac 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -11,7 +11,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] futures = "0.3.30" -mongodb = { version = "2.8.2", features = ["aws-auth"] } +mongodb = "2.8.2" rig-core = { path = "../rig-core", version = "0.1.0" } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117"