diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs new file mode 100644 index 00000000..9a91432e --- /dev/null +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; + +use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; +use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; +use rig::embeddings::DocumentEmbeddings; + +// Schema of table in LanceDB. +pub fn schema(dims: usize) -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("content", DataType::Utf8, false), + Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float64, true)), + dims as i32, + ), + false, + ), + ])) +} + +// Convert DocumentEmbeddings objects to a RecordBatch. +pub fn as_record_batch( + records: Vec, + dims: usize, +) -> Result { + let id = StringArray::from_iter_values( + records + .iter() + .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) + .collect::>(), + ); + + let content = StringArray::from_iter_values( + records + .iter() + .flat_map(|record| { + record + .embeddings + .iter() + .map(|embedding| embedding.document.clone()) + }) + .collect::>(), + ); + + let embedding = FixedSizeListArray::from_iter_primitive::( + records + .into_iter() + .flat_map(|record| { + record + .embeddings + .into_iter() + .map(|embedding| embedding.vec.into_iter().map(Some).collect::>()) + .map(Some) + .collect::>() + }) + .collect::>(), + dims as i32, + ); + + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(id) as ArrayRef), + ("content", Arc::new(content) as ArrayRef), + ("embedding", Arc::new(embedding) as ArrayRef), + ]) +} diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index ce1dd612..358ead03 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,13 +1,25 @@ -use std::env; +use std::{env, sync::Arc}; +use arrow_array::RecordBatchIterator; +use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndexDyn}, + vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; + +#[path = "./fixtures/lib.rs"] +mod fixture; + +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -15,15 +27,9 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - // Select the embedding model and generate our embeddings + // Select an embedding model. let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - - // Initialize LanceDB locally. - let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; - // Generate test data for RAG demo let agent = openai_client .agent("gpt-4o") @@ -39,6 +45,7 @@ async fn main() -> Result<(), anyhow::Error> { definitions.extend(definitions.clone()); definitions.extend(definitions.clone()); + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") @@ -47,17 +54,35 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - // vector_store.add_documents(embeddings).await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + + // Create table with embeddings. + let record_batch = as_record_batch(embeddings, model.ndims()); + let table = db + .create_table( + "definitions", + RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + ) + .execute() + .await?; + + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information vector_store - .create_index(lancedb::index::Index::IvfPq( - IvfPqIndexBuilder::default() - // This overrides the default distance type of L2. - // Needs to be the same distance type as the one used in search params. - .distance_type(DistanceType::Cosine), - )) + .create_index( + lancedb::index::Index::IvfPq( + IvfPqIndexBuilder::default() + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. + .distance_type(DistanceType::Cosine), + ), + &["embedding"], + ) .await?; // Query the index @@ -65,8 +90,14 @@ async fn main() -> Result<(), anyhow::Error> { .top_n("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) - .collect::>(); + .map(|(score, id, doc)| { + anyhow::Ok(( + score, + id, + serde_json::from_value::(doc)?, + )) + }) + .collect::, _>>()?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 7c099807..1ca2971d 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,11 +1,17 @@ -use std::env; +use std::{env, sync::Arc}; +use arrow_array::RecordBatchIterator; +use fixture::{as_record_batch, schema}; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndexDyn}, + vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; + +#[path = "./fixtures/lib.rs"] +mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -16,10 +22,6 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Initialize LanceDB locally. - let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model, &SearchParams::default()).await?; - let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") @@ -27,16 +29,28 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - // vector_store.add_documents(embeddings).await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default(); + + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + + // Create table with embeddings. + let record_batch = as_record_batch(embeddings, model.ndims()); + let table = db + .create_table( + "definitions", + RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + ) + .execute() + .await?; + + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // Query the index let results = vector_store - .top_n("My boss says I zindle too much, what does that mean?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| (score, id, doc)) - .collect::>(); + .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 00ba96ea..b56d9156 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,17 +1,28 @@ -use std::env; +use std::{env, sync::Arc}; +use arrow_array::RecordBatchIterator; +use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndexDyn}, + vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; + +#[path = "./fixtures/lib.rs"] +mod fixture; + +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -21,23 +32,13 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - - // Initialize LanceDB on S3. - // Note: see below docs for more options and IAM permission required to read/write to S3. - // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 - let db = lancedb::connect("s3://lancedb-test-829666124233") - .execute() - .await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; - // Generate test data for RAG demo let agent = openai_client .agent("gpt-4o") .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") .build(); let response = agent - .prompt("Invent at least 100 words and their definitions") + .prompt("Invent 100 words and their definitions") .await?; let mut definitions: Vec = serde_json::from_str(&response)?; @@ -46,7 +47,8 @@ async fn main() -> Result<(), anyhow::Error> { definitions.extend(definitions.clone()); definitions.extend(definitions.clone()); - let embeddings: Vec = EmbeddingsBuilder::new(model.clone()) + // Generate embeddings for the test data. + let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") @@ -54,26 +56,53 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - // vector_store.add_documents(embeddings).await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + + // Initialize LanceDB on S3. + // Note: see below docs for more options and IAM permission required to read/write to S3. + // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 + let db = lancedb::connect("s3://lancedb-test-829666124233") + .execute() + .await?; + // Create table with embeddings. + let record_batch = as_record_batch(embeddings, model.ndims()); + let table = db + .create_table( + "definitions", + RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + ) + .execute() + .await?; + + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information vector_store - .create_index(lancedb::index::Index::IvfPq( - IvfPqIndexBuilder::default() - // This overrides the default distance type of L2. - // Needs to be the same distance type as the one used in search params. - .distance_type(DistanceType::Cosine), - )) + .create_index( + lancedb::index::Index::IvfPq( + IvfPqIndexBuilder::default() + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. + .distance_type(DistanceType::Cosine), + ), + &["embedding"], + ) .await?; // Query the index let results = vector_store - .top_n("My boss says I zindle too much, what does that mean?", 1) + .top_n("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) - .collect::>(); + .map(|(score, id, doc)| { + anyhow::Ok(( + score, + id, + serde_json::from_value::(doc)?, + )) + }) + .collect::, _>>()?; println!("Results: {:?}", results); diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index a141d8dc..2b4e596d 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,9 +1,6 @@ -use std::sync::Arc; - use lancedb::{ - arrow::arrow_schema::{DataType, Field, Fields, Schema}, index::Index, - query::QueryBase, + query::{QueryBase, VectorQuery}, DistanceType, }; use rig::{ @@ -14,7 +11,6 @@ use serde::Deserialize; use serde_json::Value; use utils::Query; -mod table_schemas; mod utils; fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { @@ -34,6 +30,41 @@ pub struct LanceDbVectorStore { search_params: SearchParams, } +impl LanceDbVectorStore { + fn build_query(&self, mut query: VectorQuery) -> VectorQuery { + let SearchParams { + distance_type, + search_type, + nprobes, + refine_factor, + post_filter, + } = self.search_params.clone(); + + if let Some(distance_type) = distance_type { + query = query.distance_type(distance_type); + } + + if let Some(SearchType::Flat) = search_type { + query = query.bypass_vector_index(); + } + + if let Some(SearchType::Approximate) = search_type { + if let Some(nprobes) = nprobes { + query = query.nprobes(nprobes); + } + if let Some(refine_factor) = refine_factor { + query = query.refine_factor(refine_factor); + } + } + + if let Some(true) = post_filter { + query = query.postfilter(); + } + + query + } +} + /// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. #[derive(Debug, Clone)] pub enum SearchType { @@ -94,19 +125,19 @@ impl LanceDbVectorStore { pub async fn new( table: lancedb::Table, model: M, - id_field: String, + id_field: &str, search_params: SearchParams, ) -> Result { Ok(Self { table, model, - id_field, + id_field: id_field.to_string(), search_params, }) } /// Define index on document table `id` field for search optimization. - pub async fn create_document_index( + pub async fn create_index( &self, index: Index, field_names: &[impl AsRef], @@ -123,48 +154,19 @@ impl VectorStoreIndex for LanceDbV ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - let mut query = self + let query = self .table .vector_search(prompt_embedding.vec.clone()) .map_err(lancedb_to_rig_error)? .limit(n); - let SearchParams { - distance_type, - search_type, - nprobes, - refine_factor, - post_filter, - } = self.search_params.clone(); - - if let Some(distance_type) = distance_type { - query = query.distance_type(distance_type); - } - - if let Some(SearchType::Flat) = search_type { - query = query.bypass_vector_index(); - } - - if let Some(SearchType::Approximate) = search_type { - if let Some(nprobes) = nprobes { - query = query.nprobes(nprobes); - } - if let Some(refine_factor) = refine_factor { - query = query.refine_factor(refine_factor); - } - } - - if let Some(true) = post_filter { - query = query.postfilter(); - } - - query + self.build_query(query) .execute_query() .await? .into_iter() .map(|value| { Ok(( - match value.get("distance") { + match value.get("_distance") { Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(), _ => 0.0, }, @@ -183,6 +185,32 @@ impl VectorStoreIndex for LanceDbV query: &str, n: usize, ) -> Result, VectorStoreError> { - todo!() + let prompt_embedding = self.model.embed_document(query).await?; + + let query = self + .table + .query() + .select(lancedb::query::Select::Columns(vec![self.id_field.clone()])) + .nearest_to(prompt_embedding.vec.clone()) + .map_err(lancedb_to_rig_error)? + .limit(n); + + self.build_query(query) + .execute_query() + .await? + .into_iter() + .map(|value| { + Ok(( + match value.get("distance") { + Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(), + _ => 0.0, + }, + match value.get(self.id_field.clone()) { + Some(Value::String(id)) => id.to_string(), + _ => "".to_string(), + }, + )) + }) + .collect() } } diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs deleted file mode 100644 index 384eb4bf..00000000 --- a/rig-lancedb/src/table_schemas/document.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::sync::Arc; - -use arrow_array::{types::Utf8Type, ArrayRef, RecordBatch, StringArray}; -use lancedb::arrow::arrow_schema::ArrowError; -use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; - -use crate::utils::DeserializeByteArray; - -/// Schema of `documents` table in LanceDB defined as a struct. -#[derive(Clone, Debug)] -pub struct DocumentRecord { - pub id: String, - pub document: String, -} - -/// Wrapper around `Vec` -#[derive(Debug)] -pub struct DocumentRecords(Vec); - -impl DocumentRecords { - fn new() -> Self { - Self(Vec::new()) - } - - fn records(&self) -> Vec { - self.0.clone() - } - - fn add_records(&mut self, records: Vec) { - self.0.extend(records); - } - - fn documents(&self) -> impl Iterator + '_ { - self.as_iter().map(|doc| doc.document.clone()) - } - - pub fn ids(&self) -> impl Iterator + '_ { - self.as_iter().map(|doc| doc.id.clone()) - } - - pub fn as_iter(&self) -> impl Iterator { - self.0.iter() - } -} - -/// Converts a `DocumentEmbeddings` object to a `DocumentRecord` object. -/// The `DocumentRecord` contains the correct schema required by the `documents` table. -impl TryFrom for DocumentRecord { - type Error = serde_json::Error; - - fn try_from(document: DocumentEmbeddings) -> Result { - Ok(DocumentRecord { - id: document.id, - document: serde_json::to_string(&document.document)?, - }) - } -} - -/// Converts a list of `DocumentEmbeddings` objects to a list of `DocumentRecord` objects. -/// This is useful when we need to write many `DocumentEmbeddings` items to the `documents` table at once. -impl TryFrom> for DocumentRecords { - type Error = serde_json::Error; - - fn try_from(documents: Vec) -> Result { - Ok(Self( - documents - .into_iter() - .map(DocumentRecord::try_from) - .collect::, _>>()?, - )) - } -} - -/// Convert a list of documents (`DocumentRecords`) to a `RecordBatch`, the data structure that needs ot be written to LanceDB. -/// All documents will be written to the database as part of the same batch. -impl TryFrom for RecordBatch { - type Error = ArrowError; - - fn try_from(document_records: DocumentRecords) -> Result { - let id = Arc::new(StringArray::from_iter_values(document_records.ids())) as ArrayRef; - let document = - Arc::new(StringArray::from_iter_values(document_records.documents())) as ArrayRef; - - RecordBatch::try_from_iter(vec![("id", id), ("document", document)]) - } -} - -impl From for Vec> { - fn from(documents: DocumentRecords) -> Self { - vec![RecordBatch::try_from(documents)] - } -} - -/// Convert a `RecordBatch` object, read from a lanceDb table, to a list of `DocumentRecord` objects. -/// This allows us to convert the query result to our data format. -impl TryFrom for DocumentRecords { - type Error = ArrowError; - - fn try_from(record_batch: RecordBatch) -> Result { - let binding_0 = record_batch.column(0); - let ids = binding_0.to_str::()?; - - let binding_1 = record_batch.column(1); - let documents = binding_1.to_str::()?; - - Ok(DocumentRecords( - ids.into_iter() - .zip(documents) - .map(|(id, document)| DocumentRecord { - id: id.to_string(), - document: document.to_string(), - }) - .collect(), - )) - } -} - -/// Convert a list of `RecordBatch` objects, read from a lanceDb table, to a list of `DocumentRecord` objects. -impl TryFrom> for DocumentRecords { - type Error = VectorStoreError; - - fn try_from(record_batches: Vec) -> Result { - let documents = record_batches - .into_iter() - .map(DocumentRecords::try_from) - .collect::, _>>() - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - Ok(documents - .into_iter() - .fold(DocumentRecords::new(), |mut acc, document| { - acc.add_records(document.records()); - acc - })) - } -} - -#[cfg(test)] -mod tests { - use arrow_array::RecordBatch; - - use crate::table_schemas::document::{DocumentRecord, DocumentRecords}; - - #[tokio::test] - async fn test_record_batch_conversion() { - let document_records = DocumentRecords(vec![ - DocumentRecord { - id: "ABC".to_string(), - document: serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string(), - }, - DocumentRecord { - id: "DEF".to_string(), - document: serde_json::json!({ - "title": "Sup dog", - "body": "Greetings", - }) - .to_string(), - }, - ]); - - let record_batch = RecordBatch::try_from(document_records).unwrap(); - - let deserialized_record_batch = DocumentRecords::try_from(record_batch).unwrap(); - - assert_eq!(deserialized_record_batch.0.len(), 2); - - assert_eq!(deserialized_record_batch.0[0].id, "ABC"); - assert_eq!( - deserialized_record_batch.0[0].document, - serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string() - ); - } -} diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs deleted file mode 100644 index 7f74dd12..00000000 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ /dev/null @@ -1,299 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use arrow_array::{ - builder::{FixedSizeListBuilder, Float64Builder}, - types::{Float32Type, Float64Type, Utf8Type}, - ArrayRef, RecordBatch, StringArray, -}; -use lancedb::arrow::arrow_schema::ArrowError; -use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; - -use crate::utils::{DeserializeByteArray, DeserializeListArray, DeserializePrimitiveArray}; - -/// Data format in the LanceDB table `embeddings` -#[derive(Clone, Debug, PartialEq)] -pub struct EmbeddingRecord { - pub id: String, - pub document_id: String, - pub content: String, - pub embedding: Vec, - /// Distance from prompt. - /// This value is only present after vector search executes and determines the distance - pub distance: Option, -} - -/// Group of EmbeddingRecord objects. This represents the list of embedding objects in a `DocumentEmbeddings` object. -#[derive(Clone, Debug)] -pub struct EmbeddingRecords { - records: Vec, - dimension: i32, -} - -impl EmbeddingRecords { - fn new(records: Vec, dimension: i32) -> Self { - EmbeddingRecords { records, dimension } - } - - fn add_record(&mut self, record: EmbeddingRecord) { - self.records.push(record); - } - - pub fn as_iter(&self) -> impl Iterator { - self.records.iter() - } -} - -/// HashMap where the key is the `DocumentEmbeddings` id -/// and the value is the`EmbeddingRecords` object that corresponds to the document. -#[derive(Debug)] -pub struct EmbeddingRecordsBatch(HashMap); - -impl EmbeddingRecordsBatch { - fn as_iter(&self) -> impl Iterator { - self.0.clone().into_values().collect::>().into_iter() - } - - pub fn get_by_id(&self, id: &str) -> Option { - self.0.get(id).cloned() - } - - pub fn document_ids(&self) -> String { - self.0 - .clone() - .into_keys() - .map(|id| format!("'{id}'")) - .collect::>() - .join(",") - } -} - -/// Convert from a `DocumentEmbeddings` to an `EmbeddingRecords` object (a list of `EmbeddingRecord` objects) -impl From for EmbeddingRecords { - fn from(document: DocumentEmbeddings) -> Self { - EmbeddingRecords::new( - document - .embeddings - .clone() - .into_iter() - .enumerate() - .map(move |(i, embedding)| EmbeddingRecord { - id: format!("{}-{i}", document.id), - document_id: document.id.clone(), - content: embedding.document, - embedding: embedding.vec, - distance: None, - }) - .collect(), - document - .embeddings - .first() - .map(|embedding| embedding.vec.len() as i32) - .unwrap_or(0), - ) - } -} - -/// Convert from a list of `DocumentEmbeddings` to an `EmbeddingRecordsBatch` object -/// For each `DocumentEmbeddings`, we create an `EmbeddingRecords` and add it to the -/// hashmap with its corresponding `DocumentEmbeddings` id. -impl From> for EmbeddingRecordsBatch { - fn from(documents: Vec) -> Self { - EmbeddingRecordsBatch( - documents - .into_iter() - .fold(HashMap::new(), |mut acc, document| { - acc.insert(document.id.clone(), EmbeddingRecords::from(document)); - acc - }), - ) - } -} - -/// Convert a list of embeddings (`EmbeddingRecords`) to a `RecordBatch`, the data structure that needs ot be written to LanceDB. -/// All embeddings related to a document will be written to the database as part of the same batch. -impl TryFrom for RecordBatch { - fn try_from(embedding_records: EmbeddingRecords) -> Result { - let id = StringArray::from_iter_values( - embedding_records.as_iter().map(|record| record.id.clone()), - ); - let document_id = StringArray::from_iter_values( - embedding_records - .as_iter() - .map(|record| record.document_id.clone()), - ); - let content = StringArray::from_iter_values( - embedding_records - .as_iter() - .map(|record| record.content.clone()), - ); - - let mut builder = - FixedSizeListBuilder::new(Float64Builder::new(), embedding_records.dimension); - embedding_records.as_iter().for_each(|record| { - record - .embedding - .iter() - .for_each(|value| builder.values().append_value(*value)); - builder.append(true); - }); - - RecordBatch::try_from_iter(vec![ - ("id", Arc::new(id) as ArrayRef), - ("document_id", Arc::new(document_id) as ArrayRef), - ("content", Arc::new(content) as ArrayRef), - ("embedding", Arc::new(builder.finish()) as ArrayRef), - ]) - } - - type Error = ArrowError; -} - -impl From for Vec> { - fn from(embeddings: EmbeddingRecordsBatch) -> Self { - embeddings.as_iter().map(RecordBatch::try_from).collect() - } -} - -impl TryFrom for EmbeddingRecords { - type Error = ArrowError; - - fn try_from(record_batch: RecordBatch) -> Result { - let binding_0 = record_batch.column(0); - let ids = binding_0.to_str::()?; - - let binding_1 = record_batch.column(1); - let document_ids = binding_1.to_str::()?; - - let binding_2 = record_batch.column(2); - let contents = binding_2.to_str::()?; - - let embeddings = record_batch.column(3).to_float_list::()?; - - // There is a `_distance` field in the response if the executed query was a VectorQuery - // Otherwise, for normal queries, the `_distance` field is not present in the response. - let distances = if record_batch.num_columns() == 5 { - record_batch - .column(4) - .to_float::()? - .into_iter() - .map(Some) - .collect() - } else { - vec![None; record_batch.num_rows()] - }; - - Ok(EmbeddingRecords::new( - ids.into_iter() - .zip(document_ids) - .zip(contents) - .zip(embeddings.clone()) - .zip(distances) - .map( - |((((id, document_id), content), embedding), distance)| EmbeddingRecord { - id: id.to_string(), - document_id: document_id.to_string(), - content: content.to_string(), - embedding, - distance, - }, - ) - .collect(), - embeddings - .iter() - .map(|embedding| embedding.len() as i32) - .next() - .unwrap_or(0), - )) - } -} - -impl TryFrom> for EmbeddingRecordsBatch { - type Error = VectorStoreError; - - fn try_from(record_batches: Vec) -> Result { - let embedding_records = record_batches - .into_iter() - .map(EmbeddingRecords::try_from) - .collect::, _>>() - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - let grouped_records = - embedding_records - .into_iter() - .fold(HashMap::new(), |mut acc, records| { - records.as_iter().for_each(|record| { - acc.entry(record.document_id.clone()) - .and_modify(|item: &mut EmbeddingRecords| { - item.add_record(record.clone()) - }) - .or_insert(EmbeddingRecords::new( - vec![record.clone()], - record.embedding.len() as i32, - )); - }); - acc - }); - - Ok(EmbeddingRecordsBatch(grouped_records)) - } -} - -#[cfg(test)] -mod tests { - use arrow_array::RecordBatch; - - use crate::table_schemas::embedding::{EmbeddingRecord, EmbeddingRecords}; - - #[tokio::test] - async fn test_record_batch_conversion() { - let embedding_records = EmbeddingRecords::new( - vec![ - EmbeddingRecord { - id: "some_id".to_string(), - document_id: "ABC".to_string(), - content: serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string(), - embedding: vec![1.0, 2.0, 3.0], - distance: None, - }, - EmbeddingRecord { - id: "another_id".to_string(), - document_id: "DEF".to_string(), - content: serde_json::json!({ - "title": "Sup dog", - "body": "Greetings", - }) - .to_string(), - embedding: vec![4.0, 5.0, 6.0], - distance: None, - }, - ], - 3, - ); - - let record_batch = RecordBatch::try_from(embedding_records).unwrap(); - - let deserialized_record_batch = EmbeddingRecords::try_from(record_batch).unwrap(); - - assert_eq!(deserialized_record_batch.as_iter().count(), 2); - assert_eq!( - deserialized_record_batch.as_iter().nth(0).unwrap().clone(), - EmbeddingRecord { - id: "some_id".to_string(), - document_id: "ABC".to_string(), - content: serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string(), - embedding: vec![1.0, 2.0, 3.0], - distance: None - } - ); - - assert!(false) - } -} diff --git a/rig-lancedb/src/table_schemas/mod.rs b/rig-lancedb/src/table_schemas/mod.rs deleted file mode 100644 index bd24dd65..00000000 --- a/rig-lancedb/src/table_schemas/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod document; -pub mod embedding; diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index a35d7e1b..e8db4559 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,78 +1,12 @@ pub mod deserializer; -use std::sync::Arc; -use arrow_array::{ - types::ByteArrayType, Array, ArrowPrimitiveType, FixedSizeListArray, GenericByteArray, - PrimitiveArray, RecordBatch, RecordBatchIterator, -}; use deserializer::RecordBatchDeserializer; use futures::TryStreamExt; -use lancedb::{ - arrow::arrow_schema::{ArrowError, Schema}, - query::ExecutableQuery, -}; +use lancedb::query::ExecutableQuery; use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; -/// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. -pub trait DeserializePrimitiveArray { - fn to_float( - &self, - ) -> Result::Native>, ArrowError>; -} - -impl DeserializePrimitiveArray for &Arc { - fn to_float( - &self, - ) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } - } -} - -/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. -pub trait DeserializeByteArray { - fn to_str(&self) -> Result::Native>, ArrowError>; -} - -impl DeserializeByteArray for &Arc { - fn to_str(&self) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } - } -} - -/// Trait used to "deserialize" an arrow_array::Array as as list of lists of primitive objects. -pub trait DeserializeListArray { - fn to_float_list( - &self, - ) -> Result::Native>>, ArrowError>; -} - -impl DeserializeListArray for &Arc { - fn to_float_list( - &self, - ) -> Result::Native>>, ArrowError> { - match self.as_any().downcast_ref::() { - Some(list_array) => (0..list_array.len()) - .map(|j| (&list_array.value(j)).to_float::()) - .collect::, _>>(), - None => Err(ArrowError::CastError(format!( - "Can't cast column {self:?} to fixed size list array" - ))), - } - } -} - /// Trait that facilitates the conversion of columnar data returned by a lanceDb query to the desired struct. /// Used whenever a lanceDb table is queried. /// First, execute the query and get the result as a list of RecordBatches (columnar data). @@ -96,18 +30,3 @@ impl Query for lancedb::query::VectorQuery { record_batches.deserialize() } } - -/// Trait that facilitate inserting data defined as Rust structs into lanceDB table which contains columnar data. -pub trait Insert { - async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error>; -} - -impl>>> Insert for lancedb::Table { - async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error> { - self.add(RecordBatchIterator::new(data.into(), Arc::new(schema))) - .execute() - .await?; - - Ok(()) - } -}