Skip to content

Commit

Permalink
refactor&fix: make PR review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
marieaurore123 committed Oct 17, 2024
1 parent 564bef4 commit 04f1f3e
Show file tree
Hide file tree
Showing 21 changed files with 244 additions and 256 deletions.
4 changes: 2 additions & 2 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"

[features]
rig_derive = ["dep:rig-derive"]
derive = ["dep:rig-derive"]

[[test]]
name = "embeddable_macro"
required-features = ["rig_derive"]
required-features = ["derive"]
2 changes: 1 addition & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use rig::{
cli_chatbot::cli_chatbot,
completion::ToolDefinition,
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::in_memory_store::InMemoryVectorStore,
EmbeddingsBuilder,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::env;
use rig::{
completion::Prompt,
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
EmbeddingsBuilder,
};

#[tokio::main]
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::in_memory_store::InMemoryVectorStore,
EmbeddingsBuilder,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::env;

use rig::{
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
EmbeddingsBuilder,
};

#[tokio::main]
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::env;

use rig::{
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
EmbeddingsBuilder,
};

#[tokio::main]
Expand Down
7 changes: 4 additions & 3 deletions rig-core/src/embeddings/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ use std::{cmp::max, collections::HashMap};
use futures::{stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};

use crate::tool::{ToolEmbedding, ToolSet, ToolType};

use super::embedding::{Embedding, EmbeddingError, EmbeddingModel};
use crate::{
embeddings::{Embedding, EmbeddingError, EmbeddingModel},
tool::{ToolEmbedding, ToolSet, ToolType},
};

/// Struct that holds a document and its embeddings.
///
Expand Down
238 changes: 16 additions & 222 deletions rig-core/src/embeddings/embeddable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded.
//! //! # Example
//! # Example
//! ```rust
//! use std::env;
//!
Expand All @@ -8,7 +8,7 @@
//! struct FakeDefinition {
//! id: String,
//! word: String,
//! definitions: Vec<String>,
//! definition: String,
//! }
//!
//! let fake_definition = FakeDefinition {
Expand All @@ -28,151 +28,26 @@
//! }
//! ```
use crate::vec_utils::OneOrMany;

/// Error type used for when the `embeddable` method fails.
/// Used by default implementations of `Embeddable` for common types.
#[derive(Debug, thiserror::Error)]
pub enum EmbeddableError {
#[error("SerdeError: {0}")]
SerdeError(#[from] serde_json::Error),
}
#[error("{0}")]
pub struct EmbeddableError(#[from] Box<dyn std::error::Error + Send + Sync>);

/// Trait for types that can be embedded.
/// The `embeddable` method returns a OneOrMany<String> which contains strings for which embeddings will be generated by the embeddings builder.
/// The `embeddable` method returns a `OneOrMany<String>` which contains strings for which embeddings will be generated by the embeddings builder.
/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`.
pub trait Embeddable {
type Error: std::error::Error;
type Error: std::error::Error + Sync + Send + 'static;

fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error>;
}

/// Struct containing either a single item or a list of items of type T.
/// If a single item is present, `first` will contain it and `rest` will be empty.
/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest.
/// IMPORTANT: this struct cannot be created with an empty vector.
/// OneOrMany objects can only be created using OneOrMany::one() or OneOrMany::many().
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct OneOrMany<T> {
/// First item in the list.
first: T,
/// Rest of the items in the list.
rest: Vec<T>,
}

impl<T: Clone> OneOrMany<T> {
/// Get the first item in the list.
pub fn first(&self) -> T {
self.first.clone()
}

/// Get the rest of the items in the list (excluding the first one).
pub fn rest(&self) -> Vec<T> {
self.rest.clone()
}

/// Create a OneOrMany object with a single item of any type.
pub fn one(item: T) -> Self {
OneOrMany {
first: item,
rest: vec![],
}
}

/// Create a OneOrMany object with a vector of items of any type.
pub fn many(items: Vec<T>) -> Self {
let mut iter = items.into_iter();
OneOrMany {
first: match iter.next() {
Some(item) => item,
None => panic!("Cannot create OneOrMany with an empty vector."),
},
rest: iter.collect(),
}
}

/// Use the Iterator trait on OneOrMany
pub fn iter(&self) -> OneOrManyIterator<T> {
OneOrManyIterator {
one_or_many: self,
index: 0,
}
}
}

/// Implement Iterator for OneOrMany.
/// Iterates over all items in both `first` and `rest`.
/// Borrows the OneOrMany object that is being iterator over.
pub struct OneOrManyIterator<'a, T> {
one_or_many: &'a OneOrMany<T>,
index: usize,
}

impl<'a, T> Iterator for OneOrManyIterator<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<Self::Item> {
let mut item = None;
if self.index == 0 {
item = Some(&self.one_or_many.first)
} else if self.index - 1 < self.one_or_many.rest.len() {
item = Some(&self.one_or_many.rest[self.index - 1]);
};

self.index += 1;
item
}
}

/// Implement IntoIterator for OneOrMany.
/// Iterates over all items in both `first` and `rest`.
/// Takes ownership the OneOrMany object that is being iterator over.
pub struct OneOrManyIntoIterator<T> {
one_or_many: OneOrMany<T>,
index: usize,
}

impl<T: Clone> IntoIterator for OneOrMany<T> {
type Item = T;
type IntoIter = OneOrManyIntoIterator<T>;

fn into_iter(self) -> OneOrManyIntoIterator<T> {
OneOrManyIntoIterator {
one_or_many: self,
index: 0,
}
}
}

impl<T: Clone> Iterator for OneOrManyIntoIterator<T> {
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
let mut item = None;
if self.index == 0 {
item = Some(self.one_or_many.first())
} else if self.index - 1 < self.one_or_many.rest.len() {
item = Some(self.one_or_many.rest[self.index - 1].clone());
};

self.index += 1;
item
}
}

/// Merge a list of OneOrMany items into a single OneOrMany item.
impl<T: Clone> From<Vec<OneOrMany<T>>> for OneOrMany<T> {
fn from(value: Vec<OneOrMany<T>>) -> Self {
let items = value
.into_iter()
.flat_map(|one_or_many| one_or_many.into_iter())
.collect::<Vec<_>>();

OneOrMany::many(items)
}
}

//////////////////////////////////////////////////////
/// Implementations of Embeddable for common types ///
//////////////////////////////////////////////////////
// ================================================================
// Implementations of Embeddable for common types
// ================================================================
impl Embeddable for String {
type Error = EmbeddableError;

Expand Down Expand Up @@ -258,102 +133,21 @@ impl Embeddable for serde_json::Value {

fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> {
Ok(OneOrMany::one(
serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?,
serde_json::to_string(self).map_err(|e| EmbeddableError(Box::new(e)))?,
))
}
}

impl<T: Embeddable> Embeddable for Vec<T> {
type Error = T::Error;
type Error = EmbeddableError;

fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> {
let items = self
.iter()
.map(|item| item.embeddable())
.collect::<Result<Vec<_>, _>>()?;

Ok(OneOrMany::from(items))
}
}

#[cfg(test)]
mod test {
use super::OneOrMany;

#[test]
fn test_one_or_many_iter_single() {
let one_or_many = OneOrMany::one("hello".to_string());

assert_eq!(one_or_many.iter().count(), 1);

one_or_many.iter().for_each(|i| {
assert_eq!(i, "hello");
});
}

#[test]
fn test_one_or_many_iter() {
let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]);

assert_eq!(one_or_many.iter().count(), 2);

one_or_many.iter().enumerate().for_each(|(i, item)| {
if i == 0 {
assert_eq!(item, "hello");
}
if i == 1 {
assert_eq!(item, "word");
}
});
}

#[test]
fn test_one_or_many_into_iter_single() {
let one_or_many = OneOrMany::one("hello".to_string());

assert_eq!(one_or_many.clone().into_iter().count(), 1);

one_or_many.into_iter().for_each(|i| {
assert_eq!(i, "hello".to_string());
});
}

#[test]
fn test_one_or_many_into_iter() {
let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]);

assert_eq!(one_or_many.clone().into_iter().count(), 2);

one_or_many.into_iter().enumerate().for_each(|(i, item)| {
if i == 0 {
assert_eq!(item, "hello".to_string());
}
if i == 1 {
assert_eq!(item, "word".to_string());
}
});
}

#[test]
fn test_one_or_many_merge() {
let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]);

let one_or_many_2 = OneOrMany::one("sup".to_string());

let merged = OneOrMany::from(vec![one_or_many_1, one_or_many_2]);

assert_eq!(merged.iter().count(), 3);
.collect::<Result<Vec<_>, _>>()
.map_err(|e| EmbeddableError(Box::new(e)))?;

merged.iter().enumerate().for_each(|(i, item)| {
if i == 0 {
assert_eq!(item, "hello");
}
if i == 1 {
assert_eq!(item, "word");
}
if i == 2 {
assert_eq!(item, "sup");
}
});
OneOrMany::merge(items).map_err(|e| EmbeddableError(Box::new(e)))
}
}
2 changes: 1 addition & 1 deletion rig-core/src/embeddings/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder]
//! generate embeddings for documents. It also provides an implementation of the [embeddings::EmbeddingsBuilder]
//! struct, which allows users to build collections of document embeddings using different embedding
//! models and document sources.
//!
Expand Down
4 changes: 4 additions & 0 deletions rig-core/src/embeddings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
pub mod builder;
pub mod embeddable;
pub mod embedding;

pub use builder::EmbeddingsBuilder;
pub use embeddable::Embeddable;
pub use embedding::{Embedding, EmbeddingError, EmbeddingModel};
4 changes: 2 additions & 2 deletions rig-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ pub mod extractor;
pub mod json_utils;
pub mod providers;
pub mod tool;
mod vec_utils;
pub mod vector_store;

// Re-export commonly used types and traits
pub use embeddings::builder::EmbeddingsBuilder;
pub use embeddings::embeddable::Embeddable;

#[cfg(feature = "rig_derive")]
#[cfg(feature = "derive")]
pub use rig_derive::Embeddable;
Loading

0 comments on commit 04f1f3e

Please sign in to comment.