diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index b664c6cd..c2c97407 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::Client, vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, }; @@ -24,13 +24,22 @@ async fn main() -> Result<(), anyhow::Error> { let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); + let id_results = index + .top_n_ids("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id)| (score, id)) + .collect::>(); + + println!("ID results: {:?}", id_results); + Ok(()) } diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 7e9226af..144e716f 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::cohere::Client, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; @@ -29,10 +29,10 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index(search_model); let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 1607d339..949b1114 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -172,17 +172,17 @@ impl Completion for Agent { .then(|(num_sample, index)| async { Ok::<_, VectorStoreError>( index - .top_n_from_query(prompt, *num_sample) + .top_n(prompt, *num_sample) .await? .into_iter() - .map(|(_, doc)| { + .map(|(_, id, doc)| { // Pretty print the document if possible for better readability - let doc_text = serde_json::to_string_pretty(&doc.document) - .unwrap_or_else(|_| doc.document.to_string()); + let text = serde_json::to_string_pretty(&doc) + .unwrap_or_else(|_| doc.to_string()); Document { - id: doc.id, - text: doc_text, + id, + text, additional_props: HashMap::new(), } }) @@ -200,10 +200,10 @@ impl Completion for Agent { .then(|(num_sample, index)| async { Ok::<_, VectorStoreError>( index - .top_n_ids_from_query(prompt, *num_sample) + .top_n_ids(prompt, *num_sample) .await? .into_iter() - .map(|(_, doc)| doc) + .map(|(_, id)| id) .collect::>(), ) }) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 02c19cf8..a5db505f 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -18,6 +18,53 @@ pub struct InMemoryVectorStore { embeddings: HashMap, } +impl InMemoryVectorStore { + /// Implement vector search on InMemoryVectorStore. + /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + // Sort documents by best embedding distance + let mut docs: EmbeddingRanking = BinaryHeap::new(); + + for (id, doc_embeddings) in self.embeddings.iter() { + // Get the best context for the document given the prompt + if let Some((distance, embed_doc)) = doc_embeddings + .embeddings + .iter() + .map(|embedding| { + ( + OrderedFloat(embedding.distance(prompt_embedding)), + &embedding.document, + ) + }) + .min_by(|a, b| a.0.cmp(&b.0)) + { + docs.push(Reverse(RankingItem( + distance, + id, + doc_embeddings, + embed_doc, + ))); + }; + + // If the heap size exceeds n, pop the least old element. + if docs.len() > n { + docs.pop(); + } + } + + // Log selected tools with their distances + tracing::info!(target: "rig", + "Selected documents: {}", + docs.iter() + .map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance)) + .collect::>() + .join(", ") + ); + + docs + } +} + /// RankingItem(distance, document_id, document, embed_doc) #[derive(Eq, PartialEq)] struct RankingItem<'a>( @@ -198,63 +245,40 @@ impl InMemoryVectorIndex { } impl VectorStoreIndex for InMemoryVectorIndex { - async fn top_n_from_query( + async fn top_n Deserialize<'a>>( &self, query: &str, n: usize, - ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n).await + ) -> Result, VectorStoreError> { + let prompt_embedding = &self.model.embed_document(query).await?; + + let docs = self.store.vector_search(prompt_embedding, n); + + // Return n best + docs.into_iter() + .map(|Reverse(RankingItem(distance, _, doc, _))| { + let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; + Ok(( + distance.0, + doc.id.clone(), + serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, + )) + }) + .collect::, _>>() } - async fn top_n_from_embedding( + async fn top_n_ids( &self, - query_embedding: &Embedding, + query: &str, n: usize, - ) -> Result, VectorStoreError> { - // Sort documents by best embedding distance - let mut docs: EmbeddingRanking = BinaryHeap::new(); - - for (id, doc_embeddings) in self.store.embeddings.iter() { - // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = doc_embeddings - .embeddings - .iter() - .map(|embedding| { - ( - OrderedFloat(embedding.distance(query_embedding)), - &embedding.document, - ) - }) - .min_by(|a, b| a.0.cmp(&b.0)) - { - docs.push(Reverse(RankingItem( - distance, - id, - doc_embeddings, - embed_doc, - ))); - }; - - // If the heap size exceeds n, pop the least old element. - if docs.len() > n { - docs.pop(); - } - } + ) -> Result, VectorStoreError> { + let prompt_embedding = &self.model.embed_document(query).await?; - // Log selected tools with their distances - tracing::info!(target: "rig", - "Selected documents: {}", - docs.iter() - .map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance)) - .collect::>() - .join(", ") - ); + let docs = self.store.vector_search(prompt_embedding, n); // Return n best - Ok(docs - .into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| (distance.0, doc.clone())) - .collect()) + docs.into_iter() + .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) + .collect::, _>>() } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 2e6a652f..b07d348a 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,7 +1,8 @@ use futures::future::BoxFuture; use serde::Deserialize; +use serde_json::Value; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingError}; +use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; pub mod in_memory_store; @@ -50,167 +51,48 @@ pub trait VectorStore: Send + Sync { /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { - /// Get the top n documents based on the distance to the given embedding. - /// The distance is calculated as the cosine distance between the prompt and - /// the document embedding. - /// The result is a list of tuples with the distance and the document. - fn top_n_from_query( + /// Get the top n documents based on the distance to the given query. + /// The result is a list of tuples of the form (score, id, document) + fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send; + ) -> impl std::future::Future, VectorStoreError>> + Send; - /// Same as `top_n_from_query` but returns the documents without its embeddings. - /// The documents are deserialized into the given type. - fn top_n_documents_from_query Deserialize<'a>>( + /// Same as `top_n` but returns the document ids only. + fn top_n_ids( &self, query: &str, n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send { - async move { - let documents = self.top_n_from_query(query, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) - .collect()) - } - } - - /// Same as `top_n_from_query` but returns the document ids only. - fn top_n_ids_from_query( - &self, - query: &str, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send - { - async move { - let documents = self.top_n_from_query(query, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, doc.id)) - .collect()) - } - } - - /// Get the top n documents based on the distance to the given embedding. - /// The distance is calculated as the cosine distance between the prompt and - /// the document embedding. - /// The result is a list of tuples with the distance and the document. - fn top_n_from_embedding( - &self, - prompt_embedding: &Embedding, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send; - - /// Same as `top_n_from_embedding` but returns the documents without its embeddings. - /// The documents are deserialized into the given type. - fn top_n_documents_from_embedding Deserialize<'a>>( - &self, - prompt_embedding: &Embedding, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send { - async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) - .collect()) - } - } - - /// Same as `top_n_from_embedding` but returns the document ids only. - fn top_n_ids_from_embedding( - &self, - prompt_embedding: &Embedding, - n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send - { - async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, doc.id)) - .collect()) - } - } + ) -> impl std::future::Future, VectorStoreError>> + Send; } +pub type TopNResults = Result, VectorStoreError>; + pub trait VectorStoreIndexDyn: Send + Sync { - fn top_n_from_query<'a>( - &'a self, - query: &'a str, - n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>>; + fn top_n<'a>(&'a self, query: &'a str, n: usize) -> BoxFuture<'a, TopNResults>; - fn top_n_ids_from_query<'a>( + fn top_n_ids<'a>( &'a self, query: &'a str, n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - let documents = self.top_n_from_query(query, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, doc.id)) - .collect()) - }) - } - - fn top_n_from_embedding<'a>( - &'a self, - prompt_embedding: &'a Embedding, - n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>>; - - fn top_n_ids_from_embedding<'a>( - &'a self, - prompt_embedding: &'a Embedding, - n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; - Ok(documents - .into_iter() - .map(|(distance, doc)| (distance, doc.id)) - .collect()) - }) - } + ) -> BoxFuture<'a, Result, VectorStoreError>>; } impl VectorStoreIndexDyn for I { - fn top_n_from_query<'a>( + fn top_n<'a>( &'a self, query: &'a str, n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(self.top_n_from_query(query, n)) + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(self.top_n(query, n)) } - fn top_n_from_embedding<'a>( + fn top_n_ids<'a>( &'a self, - prompt_embedding: &'a Embedding, + query: &'a str, n: usize, - ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(self.top_n_from_embedding(prompt_embedding, n)) - } -} - -pub struct NoIndex; - -impl VectorStoreIndex for NoIndex { - async fn top_n_from_query( - &self, - _query: &str, - _n: usize, - ) -> Result, VectorStoreError> { - Ok(vec![]) - } - - async fn top_n_from_embedding( - &self, - _prompt_embedding: &Embedding, - _n: usize, - ) -> Result, VectorStoreError> { - Ok(vec![]) + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(self.top_n_ids(query, n)) } } diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index a39e7c93..5bdc26e6 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -49,17 +49,26 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "context_vector_index", doc! {}); + let index = vector_store.index(model, "vector_index", doc! {}); // Query the index let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); + let id_results = index + .top_n_ids("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id)| (score, id)) + .collect::>(); + + println!("ID results: {:?}", id_results); + Ok(()) } diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 7d85201f..9939d668 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -1,10 +1,11 @@ use futures::StreamExt; -use mongodb::bson::doc; +use mongodb::bson::{self, doc}; use rig::{ embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; +use serde::Deserialize; /// A MongoDB vector store. pub struct MongoDbVectorStore { @@ -107,6 +108,33 @@ pub struct MongoDbVectorIndex { filter: mongodb::bson::Document, } +impl MongoDbVectorIndex { + /// Vector search stage of aggregation pipeline of mongoDB collection. + /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. + fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { + doc! { + "$vectorSearch": { + "index": &self.index_name, + "path": "embeddings.vec", + "queryVector": &prompt_embedding.vec, + "numCandidates": (n * 10) as u32, + "limit": n as u32, + "filter": &self.filter, + } + } + } + + /// Score declaration stage of aggregation pipeline of mongoDB collection. + /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. + fn pipeline_score_stage(&self) -> bson::Document { + doc! { + "$addFields": { + "score": { "$meta": "vectorSearchScore" } + } + } + } +} + impl MongoDbVectorIndex { pub fn new( collection: mongodb::Collection, @@ -124,38 +152,64 @@ impl MongoDbVectorIndex { } impl VectorStoreIndex for MongoDbVectorIndex { - async fn top_n_from_query( + async fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, n: usize, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n).await + + let mut cursor = self + .collection + .aggregate( + [ + self.pipeline_search_stage(&prompt_embedding, n), + self.pipeline_score_stage(), + ], + None, + ) + .await + .map_err(mongodb_to_rig_error)? + .with_type::(); + + let mut results = Vec::new(); + while let Some(doc) = cursor.next().await { + let doc = doc.map_err(mongodb_to_rig_error)?; + let score = doc.get("score").expect("score").as_f64().expect("f64"); + let id = doc.get("_id").expect("_id").to_string(); + let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?; + results.push((score, id, doc_t)); + } + + tracing::info!(target: "rig", + "Selected documents: {}", + results.iter() + .map(|(distance, id, _)| format!("{} ({})", id, distance)) + .collect::>() + .join(", ") + ); + + Ok(results) } - async fn top_n_from_embedding( + async fn top_n_ids( &self, - prompt_embedding: &Embedding, + query: &str, n: usize, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { + let prompt_embedding = self.model.embed_document(query).await?; + let mut cursor = self .collection .aggregate( [ + self.pipeline_search_stage(&prompt_embedding, n), + self.pipeline_score_stage(), doc! { - "$vectorSearch": { - "index": &self.index_name, - "path": "embeddings.vec", - "queryVector": &prompt_embedding.vec, - "numCandidates": (n * 10) as u32, - "limit": n as u32, - "filter": &self.filter, - } - }, - doc! { - "$addFields": { - "score": { "$meta": "vectorSearchScore" } - } + "$project": { + "_id": 1, + "score": 1 + }, }, ], None, @@ -168,14 +222,14 @@ impl VectorStoreIndex for MongoDbV while let Some(doc) = cursor.next().await { let doc = doc.map_err(mongodb_to_rig_error)?; let score = doc.get("score").expect("score").as_f64().expect("f64"); - let document: DocumentEmbeddings = serde_json::from_value(doc).expect("document"); - results.push((score, document)); + let id = doc.get("_id").expect("_id").to_string(); + results.push((score, id)); } tracing::info!(target: "rig", "Selected documents: {}", results.iter() - .map(|(distance, doc)| format!("{} ({})", doc.id, distance)) + .map(|(distance, id)| format!("{} ({})", id, distance)) .collect::>() .join(", ") );