Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(loaders): loaders for files and pdfs #55

Merged
merged 15 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
907 changes: 540 additions & 367 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ futures = "0.3.29"
ordered-float = "4.2.0"
schemars = "0.8.16"
thiserror = "1.0.61"
glob = "0.3.1"
lopdf = { version = "0.34.0", optional = true }

[dev-dependencies]
anyhow = "1.0.75"
assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"

[features]
pdf = ["dep:lopdf"]
38 changes: 38 additions & 0 deletions rig-core/examples/agent_with_loaders.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::env;

use rig::{
agent::AgentBuilder,
completion::Prompt,
loaders::FileLoader,
providers::openai::{self, GPT_4O},
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let openai_client =
openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"));

let model = openai_client.completion_model(GPT_4O);

// Load in all the rust examples
let examples = FileLoader::with_glob("rig-core/examples/*.rs")?
.read_with_path()
.ignore_errors()
.into_iter();

// Create an agent with multiple context documents
let agent = examples
.fold(AgentBuilder::new(model), |builder, (path, content)| {
builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str())
})
.build();

// Prompt the agent and print the response
let response = agent
.prompt("Which rust example is best suited for the operation 1 + 2")
.await?;

println!("{}", response);

Ok(())
}
14 changes: 14 additions & 0 deletions rig-core/examples/loaders.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use rig::loaders::FileLoader;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
FileLoader::with_glob("cargo.toml")?
.read()
.into_iter()
.for_each(|result| match result {
Ok(content) => println!("{}", content),
Err(e) => eprintln!("Error reading file: {}", e),
});

Ok(())
}
1 change: 1 addition & 0 deletions rig-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub mod completion;
pub mod embeddings;
pub mod extractor;
pub(crate) mod json_utils;
pub mod loaders;
pub mod providers;
pub mod tool;
pub mod vector_store;
260 changes: 260 additions & 0 deletions rig-core/src/loaders/file.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
use std::{fs, path::PathBuf};

use glob::glob;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum FileLoaderError {
#[error("Invalid glob pattern: {0}")]
InvalidGlobPattern(String),

#[error("IO error: {0}")]
IoError(#[from] std::io::Error),

#[error("Pattern error: {0}")]
PatternError(#[from] glob::PatternError),

#[error("Glob error: {0}")]
GlobError(#[from] glob::GlobError),
}

// Implementing Readable trait for reading file contents
pub(crate) trait Readable {
fn read(self) -> Result<String, FileLoaderError>;
fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError>;
}

impl<'a> FileLoader<'a, PathBuf> {
pub fn read(self) -> FileLoader<'a, Result<String, FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read())),
}
}
pub fn read_with_path(self) -> FileLoader<'a, Result<(PathBuf, String), FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read_with_path())),
}
}
}

impl Readable for PathBuf {
fn read(self) -> Result<String, FileLoaderError> {
fs::read_to_string(self).map_err(FileLoaderError::IoError)
}
fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError> {
let contents = fs::read_to_string(&self);
Ok((self, contents?))
}
}
impl<T: Readable> Readable for Result<T, FileLoaderError> {
fn read(self) -> Result<String, FileLoaderError> {
self.map(|t| t.read())?
}
fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError> {
self.map(|t| t.read_with_path())?
}
}

// ## FileLoader definitions and implementations ##

/// `FileLoader` is a utility for loading files from the filesystem using glob patterns or directory paths.
/// It provides methods to read file contents and handle errors gracefully.
///
/// # Errors
///
/// This module defines a custom error type `FileLoaderError` which can represent various errors that might occur
/// during file loading operations, such as invalid glob patterns, IO errors, and glob errors.
///
/// # Example Usage
///
/// ```rust
/// use rig:loaders::FileLoader;
///
/// fn main() -> Result<(), Box<dyn std::error::Error>> {
/// // Create a FileLoader using a glob pattern
/// let loader = FileLoader::with_glob("path/to/files/*.txt")?;
///
/// // Read file contents, ignoring any errors
/// let contents: Vec<String> = loader
/// .read()
/// .ignore_errors()
/// .into_iter()
/// .collect();
///
/// for content in contents {
/// println!("{}", content);
/// }
///
/// Ok(())
/// }
/// ```
///
/// `FileLoader` uses strict typing between the iterator methods to ensure that transitions between
/// different implementations of the loaders and it's methods are handled properly by the compiler.
pub struct FileLoader<'a, T> {
iterator: Box<dyn Iterator<Item = T> + 'a>,
}

impl<'a> FileLoader<'a, Result<PathBuf, FileLoaderError>> {
/// Reads the contents of the files within the iterator returned by `with_glob` or `with_dir`.
///
/// # Example
/// Read files in directory "files/*.txt" and print the content for each file
///
/// ```rust
/// let content = FileLoader::with_glob(...)?.read().into_iter();
/// for result in content {
/// match result {
/// Ok(content) => println!("{}", content),
/// Err(e) => eprintln!("Error reading file: {}", e),
/// }
/// }
/// ```
pub fn read(self) -> FileLoader<'a, Result<String, FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read())),
}
}
/// Reads the contents of the files within the iterator returned by `with_glob` or `with_dir`
/// and returns the path along with the content.
///
/// # Example
/// Read files in directory "files/*.txt" and print the content for cooresponding path for each
/// file.
///
/// ```rust
/// let content = FileLoader::with_glob("files/*.txt")?.read().into_iter();
/// for (path, result) in content {
/// match result {
/// Ok((path, content)) => println!("{:?} {}", path, content),
/// Err(e) => eprintln!("Error reading file: {}", e),
/// }
/// }
/// ```
pub fn read_with_path(self) -> FileLoader<'a, Result<(PathBuf, String), FileLoaderError>> {
FileLoader {
iterator: Box::new(self.iterator.map(|res| res.read_with_path())),
}
}
}

impl<'a, T: 'a> FileLoader<'a, Result<T, FileLoaderError>> {
/// Ignores errors in the iterator, returning only successful results. This can be used on any
/// `FileLoader` state of iterator whose items are results.
///
/// # Example
/// Read files in directory "files/*.txt" and ignore errors from unreadable files.
///
/// ```rust
/// let content = FileLoader::with_glob("files/*.txt")?.read().ignore_errors().into_iter();
/// for result in content {
/// println!("{}", content)
/// }
/// ```
pub fn ignore_errors(self) -> FileLoader<'a, T> {
FileLoader {
iterator: Box::new(self.iterator.filter_map(|res| res.ok())),
}
}
}

impl<'a> FileLoader<'a, Result<PathBuf, FileLoaderError>> {
/// Creates a new `FileLoader` using a glob pattern to match files.
///
/// # Example
/// Create a `FileLoader` for all `.txt` files that match the glob "files/*.txt".
///
/// ```rust
/// let loader = FileLoader::with_glob("files/*.txt")?;
/// ```
pub fn with_glob(
pattern: &str,
) -> Result<FileLoader<Result<PathBuf, FileLoaderError>>, FileLoaderError> {
let paths = glob(pattern)?;
Ok(FileLoader {
iterator: Box::new(
paths
.into_iter()
.map(|path| path.map_err(FileLoaderError::GlobError)),
),
})
}

/// Creates a new `FileLoader` on all files within a directory.
///
/// # Example
/// Create a `FileLoader` for all files that are in the directory "files".
///
/// ```rust
/// let loader = FileLoader::with_dir("files")?;
/// ```
pub fn with_dir(
directory: &str,
) -> Result<FileLoader<Result<PathBuf, FileLoaderError>>, FileLoaderError> {
Ok(FileLoader {
iterator: Box::new(fs::read_dir(directory)?.map(|entry| Ok(entry?.path()))),
})
}
}

// Iterators for FileLoader

pub struct IntoIter<'a, T> {
iterator: Box<dyn Iterator<Item = T> + 'a>,
}

impl<'a, T> IntoIterator for FileLoader<'a, T> {
type Item = T;
type IntoIter = IntoIter<'a, T>;

fn into_iter(self) -> Self::IntoIter {
IntoIter {
iterator: self.iterator,
}
}
}

impl<'a, T> Iterator for IntoIter<'a, T> {
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
self.iterator.next()
}
}

#[cfg(test)]
mod tests {
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};

use super::FileLoader;

#[test]
fn test_file_loader() {
let temp = assert_fs::TempDir::new().expect("Failed to create temp dir");
let foo_file = temp.child("foo.txt");
let bar_file = temp.child("bar.txt");

foo_file.touch().expect("Failed to create foo.txt");
bar_file.touch().expect("Failed to create bar.txt");

foo_file.write_str("foo").expect("Failed to write to foo");
bar_file.write_str("bar").expect("Failed to write to bar");

let glob = temp.path().to_string_lossy().to_string() + "/*.txt";

let loader = FileLoader::with_glob(&glob).unwrap();
let mut actual = loader
.ignore_errors()
.read()
.ignore_errors()
.into_iter()
.collect::<Vec<_>>();
let mut expected = vec!["foo".to_string(), "bar".to_string()];

actual.sort();
expected.sort();

assert!(!actual.is_empty());
assert!(expected == actual)
}
}
9 changes: 9 additions & 0 deletions rig-core/src/loaders/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pub mod file;

pub use file::FileLoader;

#[cfg(feature = "pdf")]
pub mod pdf;

#[cfg(feature = "pdf")]
pub use pdf::PdfFileLoader;
Loading
Loading