diff --git a/Cargo.lock b/Cargo.lock index e994818b..7b4374a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,7 +46,7 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy 0.7.35", @@ -773,7 +773,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "tracing", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -1086,7 +1086,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" dependencies = [ "futures-core", - "getrandom", + "getrandom 0.2.15", "instant", "pin-project-lite", "rand", @@ -1328,7 +1328,7 @@ dependencies = [ "serde_bytes", "serde_json", "time", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -1701,7 +1701,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom", + "getrandom 0.2.15", "once_cell", "tiny-keccak", ] @@ -1976,7 +1976,7 @@ dependencies = [ "tempfile", "tokio", "url", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -2068,7 +2068,7 @@ dependencies = [ "rand", "regex", "unicode-segmentation", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -3379,10 +3379,22 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + [[package]] name = "gif" version = "0.13.1" @@ -4538,7 +4550,7 @@ dependencies = [ "tokio", "tracing", "url", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -4553,7 +4565,7 @@ dependencies = [ "arrow-data", "arrow-schema", "arrow-select", - "getrandom", + "getrandom 0.2.15", "half", "num-traits", "rand", @@ -4744,7 +4756,7 @@ dependencies = [ "tempfile", "tokio", "tracing", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -4851,7 +4863,7 @@ dependencies = [ "tokio", "tracing", "url", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -5272,7 +5284,7 @@ checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ "hermit-abi 0.3.9", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -5307,7 +5319,7 @@ dependencies = [ "tagptr", "thiserror 1.0.69", "triomphe", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -5354,7 +5366,7 @@ dependencies = [ "tokio-rustls 0.24.1", "tokio-util", "typed-builder", - "uuid 1.11.0", + "uuid 1.13.1", "webpki-roots 0.25.4", ] @@ -6470,19 +6482,22 @@ dependencies = [ [[package]] name = "qdrant-client" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbff72d38eac3860f5888f02d4688690de6cdc77c901112d3b0788694afc1d37" +checksum = "9b585625d1ef06478e97fe8d7170a3f32a1cba5dbf986ff136095a85a0ec3d91" dependencies = [ "anyhow", "derive_builder", + "futures", "futures-util", "prost 0.13.3", "prost-types 0.13.3", "reqwest 0.12.9", + "semver", "serde", "serde_json", "thiserror 1.0.69", + "tokio", "tonic", ] @@ -6506,7 +6521,7 @@ dependencies = [ "mach2", "once_cell", "raw-cpuid", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -6558,7 +6573,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", - "getrandom", + "getrandom 0.2.15", "rand", "ring 0.17.8", "rustc-hash 2.0.0", @@ -6627,7 +6642,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -6775,7 +6790,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror 1.0.69", ] @@ -7091,7 +7106,7 @@ dependencies = [ "tokio-test", "tracing", "tracing-subscriber", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -7106,6 +7121,7 @@ dependencies = [ "serde_json", "testcontainers", "tokio", + "uuid 1.13.1", ] [[package]] @@ -7150,7 +7166,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin 0.9.8", "untrusted 0.9.0", @@ -7600,9 +7616,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.23" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" dependencies = [ "serde", ] @@ -8127,7 +8143,7 @@ dependencies = [ "tokio-stream", "tracing", "url", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -8208,7 +8224,7 @@ dependencies = [ "stringprep", "thiserror 2.0.3", "tracing", - "uuid 1.11.0", + "uuid 1.13.1", "whoami", ] @@ -8246,7 +8262,7 @@ dependencies = [ "stringprep", "thiserror 2.0.3", "tracing", - "uuid 1.11.0", + "uuid 1.13.1", "whoami", ] @@ -8271,7 +8287,7 @@ dependencies = [ "sqlx-core", "tracing", "url", - "uuid 1.11.0", + "uuid 1.13.1", ] [[package]] @@ -8540,7 +8556,7 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "time", - "uuid 1.11.0", + "uuid 1.13.1", "winapi", ] @@ -8872,7 +8888,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", - "getrandom", + "getrandom 0.2.15", "itertools 0.12.1", "lazy_static", "log", @@ -9454,17 +9470,17 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" dependencies = [ - "getrandom", + "getrandom 0.2.15", "serde", ] [[package]] name = "uuid" -version = "1.11.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0" dependencies = [ - "getrandom", + "getrandom 0.3.1", "serde", ] @@ -9546,6 +9562,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasite" version = "0.1.0" @@ -9926,6 +9951,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "worker" version = "0.5.0" diff --git a/rig-qdrant/Cargo.toml b/rig-qdrant/Cargo.toml index 65fec42c..1a90e604 100644 --- a/rig-qdrant/Cargo.toml +++ b/rig-qdrant/Cargo.toml @@ -11,7 +11,8 @@ repository = "https://github.com/0xPlaygrounds/rig" rig-core = { path = "../rig-core", version = "0.9.0" } serde_json = "1.0.128" serde = "1.0.210" -qdrant-client = "1.12.1" +qdrant-client = "1.13.0" +uuid = { version = "1.13.1", features = ["v4"] } [dev-dependencies] tokio = { version = "1.40.0", features = ["rt-multi-thread"] } diff --git a/rig-qdrant/examples/qdrant_vector_search.rs b/rig-qdrant/examples/qdrant_vector_search.rs index 7ce9679e..d68bd328 100644 --- a/rig-qdrant/examples/qdrant_vector_search.rs +++ b/rig-qdrant/examples/qdrant_vector_search.rs @@ -8,12 +8,10 @@ use std::env; +use anyhow::anyhow; use qdrant_client::{ - qdrant::{ - CreateCollectionBuilder, Distance, PointStruct, QueryPointsBuilder, UpsertPointsBuilder, - VectorParamsBuilder, - }, - Payload, Qdrant, + qdrant::{CreateCollectionBuilder, Distance, QueryPointsBuilder, VectorParamsBuilder}, + Qdrant, }; use rig::{ embeddings::EmbeddingsBuilder, @@ -71,25 +69,14 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let points: Vec = documents - .into_iter() - .map(|(d, embeddings)| { - let vec: Vec = embeddings.first().vec.iter().map(|&x| x as f32).collect(); - PointStruct::new( - d.id.clone(), - vec, - Payload::try_from(serde_json::to_value(&d).unwrap()).unwrap(), - ) - }) - .collect(); - - client - .upsert_points(UpsertPointsBuilder::new(COLLECTION_NAME, points)) - .await?; - let query_params = QueryPointsBuilder::new(COLLECTION_NAME).with_payload(true); let vector_store = QdrantVectorStore::new(client, model, query_params.build()); + vector_store + .insert_documents(documents) + .await + .map_err(|err| anyhow!("Couldn't insert documents: {err}"))?; + let results = vector_store .top_n::("What is a linglingdong?", 1) .await?; diff --git a/rig-qdrant/src/lib.rs b/rig-qdrant/src/lib.rs index 666d0a4f..f97bad10 100644 --- a/rig-qdrant/src/lib.rs +++ b/rig-qdrant/src/lib.rs @@ -1,12 +1,16 @@ use qdrant_client::{ - qdrant::{point_id::PointIdOptions, PointId, Query, QueryPoints}, - Qdrant, + qdrant::{ + point_id::PointIdOptions, PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder, + }, + Payload, Qdrant, }; use rig::{ - embeddings::EmbeddingModel, + embeddings::{Embedding, EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, + Embed, OneOrMany, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; /// Represents a vector store implementation using Qdrant - as the backend. pub struct QdrantVectorStore { @@ -34,6 +38,10 @@ impl QdrantVectorStore { } } + pub fn client(&self) -> &Qdrant { + &self.client + } + /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format. async fn generate_query_vector(&self, query: &str) -> Result, VectorStoreError> { let embedding = self.model.embed_text(query).await?; @@ -47,6 +55,38 @@ impl QdrantVectorStore { params.limit = Some(limit as u64); params } + + pub async fn insert_documents( + &self, + documents: Vec<(Doc, OneOrMany)>, + ) -> Result<(), VectorStoreError> { + let collection_name = self.query_params.collection_name.clone(); + + for (document, embeddings) in documents { + let json_document = serde_json::to_value(&document).unwrap(); + let doc_as_payload = Payload::try_from(json_document).unwrap(); + + let embeddings_as_point_structs = embeddings + .into_iter() + .map(|embedding| { + let embedding_as_f32: Vec = + embedding.vec.into_iter().map(|x| x as f32).collect(); + PointStruct::new( + Uuid::new_v4().to_string(), + embedding_as_f32, + doc_as_payload.clone(), + ) + }) + .collect::>(); + + let request = UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs); + self.client.upsert_points(request).await.map_err(|err| { + VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into()) + })?; + } + + Ok(()) + } } /// Converts a `PointId` to its string representation.