Skip to content

Commit

Permalink
Merge pull request #42 from 0xPlaygrounds/refactor(vector-store)/upda…
Browse files Browse the repository at this point in the history
…te-vector-store-index-trait

refactor: Update VectorStoreIndex trait
  • Loading branch information
marieaurore123 authored Oct 1, 2024
2 parents 4042f5d + 8cd7565 commit 04469df
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 227 deletions.
15 changes: 12 additions & 3 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::env;

use rig::{
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::Client,
vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex},
};
Expand All @@ -24,13 +24,22 @@ async fn main() -> Result<(), anyhow::Error> {
let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?;

let results = index
.top_n_from_query("What is a linglingdong?", 1)
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, doc)| (score, doc.id, doc.document))
.map(|(score, id, doc)| (score, id, doc.document))
.collect::<Vec<_>>();

println!("Results: {:?}", results);

let id_results = index
.top_n_ids("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id)| (score, id))
.collect::<Vec<_>>();

println!("ID results: {:?}", id_results);

Ok(())
}
6 changes: 3 additions & 3 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::env;

use rig::{
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::cohere::Client,
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex},
};
Expand Down Expand Up @@ -29,10 +29,10 @@ async fn main() -> Result<(), anyhow::Error> {
let index = vector_store.index(search_model);

let results = index
.top_n_from_query("What is a linglingdong?", 1)
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, doc)| (score, doc.id, doc.document))
.map(|(score, id, doc)| (score, id, doc.document))
.collect::<Vec<_>>();

println!("Results: {:?}", results);
Expand Down
16 changes: 8 additions & 8 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_from_query(prompt, *num_sample)
.top_n(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| {
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let doc_text = serde_json::to_string_pretty(&doc.document)
.unwrap_or_else(|_| doc.document.to_string());
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());

Document {
id: doc.id,
text: doc_text,
id,
text,
additional_props: HashMap::new(),
}
})
Expand All @@ -200,10 +200,10 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids_from_query(prompt, *num_sample)
.top_n_ids(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| doc)
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
Expand Down
120 changes: 72 additions & 48 deletions rig-core/src/vector_store/in_memory_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,53 @@ pub struct InMemoryVectorStore {
embeddings: HashMap<String, DocumentEmbeddings>,
}

impl InMemoryVectorStore {
/// Implement vector search on InMemoryVectorStore.
/// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore.
fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking {
// Sort documents by best embedding distance
let mut docs: EmbeddingRanking = BinaryHeap::new();

for (id, doc_embeddings) in self.embeddings.iter() {
// Get the best context for the document given the prompt
if let Some((distance, embed_doc)) = doc_embeddings
.embeddings
.iter()
.map(|embedding| {
(
OrderedFloat(embedding.distance(prompt_embedding)),
&embedding.document,
)
})
.min_by(|a, b| a.0.cmp(&b.0))
{
docs.push(Reverse(RankingItem(
distance,
id,
doc_embeddings,
embed_doc,
)));
};

// If the heap size exceeds n, pop the least old element.
if docs.len() > n {
docs.pop();
}
}

// Log selected tools with their distances
tracing::info!(target: "rig",
"Selected documents: {}",
docs.iter()
.map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance))
.collect::<Vec<String>>()
.join(", ")
);

docs
}
}

/// RankingItem(distance, document_id, document, embed_doc)
#[derive(Eq, PartialEq)]
struct RankingItem<'a>(
Expand Down Expand Up @@ -198,63 +245,40 @@ impl<M: EmbeddingModel> InMemoryVectorIndex<M> {
}

impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorIndex<M> {
async fn top_n_from_query(
async fn top_n<T: for<'a> Deserialize<'a>>(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError> {
let prompt_embedding = self.model.embed_document(query).await?;
self.top_n_from_embedding(&prompt_embedding, n).await
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
let prompt_embedding = &self.model.embed_document(query).await?;

let docs = self.store.vector_search(prompt_embedding, n);

// Return n best
docs.into_iter()
.map(|Reverse(RankingItem(distance, _, doc, _))| {
let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?;
Ok((
distance.0,
doc.id.clone(),
serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?,
))
})
.collect::<Result<Vec<_>, _>>()
}

async fn top_n_from_embedding(
async fn top_n_ids(
&self,
query_embedding: &Embedding,
query: &str,
n: usize,
) -> Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError> {
// Sort documents by best embedding distance
let mut docs: EmbeddingRanking = BinaryHeap::new();

for (id, doc_embeddings) in self.store.embeddings.iter() {
// Get the best context for the document given the prompt
if let Some((distance, embed_doc)) = doc_embeddings
.embeddings
.iter()
.map(|embedding| {
(
OrderedFloat(embedding.distance(query_embedding)),
&embedding.document,
)
})
.min_by(|a, b| a.0.cmp(&b.0))
{
docs.push(Reverse(RankingItem(
distance,
id,
doc_embeddings,
embed_doc,
)));
};

// If the heap size exceeds n, pop the least old element.
if docs.len() > n {
docs.pop();
}
}
) -> Result<Vec<(f64, String)>, VectorStoreError> {
let prompt_embedding = &self.model.embed_document(query).await?;

// Log selected tools with their distances
tracing::info!(target: "rig",
"Selected documents: {}",
docs.iter()
.map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance))
.collect::<Vec<String>>()
.join(", ")
);
let docs = self.store.vector_search(prompt_embedding, n);

// Return n best
Ok(docs
.into_iter()
.map(|Reverse(RankingItem(distance, _, doc, _))| (distance.0, doc.clone()))
.collect())
docs.into_iter()
.map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone())))
.collect::<Result<Vec<_>, _>>()
}
}
Loading

0 comments on commit 04469df

Please sign in to comment.