diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 0fc540f2..dc4d9412 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -257,7 +257,7 @@ async fn main() -> Result<(), anyhow::Error> { // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client - .tool_rag_agent("gpt-4") + .agent("gpt-4") .preamble( "You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations. Follow these instructions closely. diff --git a/rig-core/examples/multi_agent.rs b/rig-core/examples/multi_agent.rs index 9c26787d..e6e9ef75 100644 --- a/rig-core/examples/multi_agent.rs +++ b/rig-core/examples/multi_agent.rs @@ -4,7 +4,6 @@ use rig::{ agent::{Agent, AgentBuilder}, cli_chatbot::cli_chatbot, completion::{Chat, CompletionModel, Message, PromptError}, - model::{Model, ModelBuilder}, providers::openai::Client as OpenAIClient, }; @@ -14,7 +13,7 @@ use rig::{ /// prompt in english, before answering it with GPT-4. The answer in english is returned. struct EnglishTranslator { translator_agent: Agent, - gpt4: Model, + gpt4: Agent, } impl EnglishTranslator { @@ -29,7 +28,7 @@ impl EnglishTranslator { .build(), // Create the GPT4 model - gpt4: ModelBuilder::new(model).build() + gpt4: AgentBuilder::new(model).build() } } } diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index d6929b52..3abd8ee9 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -30,7 +30,7 @@ async fn main() -> Result<(), anyhow::Error> { // Create vector store index let index = vector_store.index(embedding_model); - let rag_agent = openai_client.context_rag_agent("gpt-4") + let rag_agent = openai_client.agent("gpt-4") .preamble(" You are a dictionary assistant here to assist the user in understanding the meaning of words. You will find additional non-standard word definitions that could be useful below. diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 348c0fcf..a0b6c997 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -170,7 +170,7 @@ async fn main() -> Result<(), anyhow::Error> { // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client - .tool_rag_agent("gpt-4") + .agent("gpt-4") .preamble("You are a calculator here to help the user perform arithmetic operations.") // Add a dynamic tool source with a sample rate of 1 (i.e.: only // 1 additional tool will be added to prompts) diff --git a/rig-core/examples/simple_model.rs b/rig-core/examples/simple_model.rs index a8141ef2..164dee09 100644 --- a/rig-core/examples/simple_model.rs +++ b/rig-core/examples/simple_model.rs @@ -5,7 +5,7 @@ async fn main() { // Create OpenAI client and model let openai_client = openai::Client::from_env(); - let gpt4 = openai_client.model("gpt-4").build(); + let gpt4 = openai_client.agent("gpt-4").build(); // Prompt the model and print its response let response = gpt4 diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 965ec474..8648a9eb 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -1,8 +1,16 @@ //! This module contains the implementation of the [Agent] struct and its builder. //! //! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt), -//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model -//! by providing prompts and chat history without having to provide the preamble and other parameters everytime. +//! a set of context documents, and a set of tools. Note: both context documents and tools can be either +//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time). +//! +//! The [Agent] struct is highly configurable, allowing the user to define anything from +//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic +//! context documents and tools. +//! +//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating +//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to +//! be used for generating chat completions. //! //! The [AgentBuilder] implements the builder pattern for creating instances of [Agent]. //! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters @@ -52,9 +60,55 @@ //! .await //! .expect("Failed to send completion request"); //! ``` +//! +//! RAG Agent example +//! ```rust +//! use rig::{ +//! completion::Prompt, +//! embeddings::EmbeddingsBuilder, +//! providers::openai, +//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, +//! }; +//! +//! // Initialize OpenAI client +//! let openai = openai::Client::from_env(); +//! +//! // Initialize OpenAI embedding model +//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); +//! +//! // Create vector store, compute embeddings and load them in the store +//! let mut vector_store = InMemoryVectorStore::default(); +//! +//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) +//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") +//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") +//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") +//! .build() +//! .await +//! .expect("Failed to build embeddings"); +//! +//! vector_store.add_documents(embeddings) +//! .await +//! .expect("Failed to add documents"); +//! +//! // Create vector store index +//! let index = vector_store.index(embedding_model); +//! +//! let agent = openai.agent(openai::GPT_4O) +//! .preamble(" +//! You are a dictionary assistant here to assist the user in understanding the meaning of words. +//! You will find additional non-standard word definitions that could be useful below. +//! ") +//! .dynamic_context(1, index) +//! .build(); +//! +//! // Prompt the agent and print the response +//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await +//! .expect("Failed to prompt the agent"); +//! ``` use std::collections::HashMap; -use futures::{stream, StreamExt}; +use futures::{stream, StreamExt, TryStreamExt}; use crate::{ completion::{ @@ -62,6 +116,7 @@ use crate::{ CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError, }, tool::{Tool, ToolSet}, + vector_store::{VectorStoreError, VectorStoreIndexDyn}, }; /// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble @@ -85,52 +140,24 @@ use crate::{ /// .expect("Failed to prompt the agent"); /// ``` pub struct Agent { - /// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`) + /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) model: M, /// System prompt preamble: String, /// Context documents always available to the agent - context: Vec, + static_context: Vec, /// Tools that are always available to the agent (identified by their name) static_tools: Vec, /// Temperature of the model temperature: Option, /// Additional parameters to be passed to the model additional_params: Option, + /// List of vector store, with the sample number + dynamic_context: Vec<(usize, Box)>, + /// Dynamic tools + dynamic_tools: Vec<(usize, Box)>, /// Actual tool implementations - tools: ToolSet, -} - -impl Agent { - /// Create a new Agent - pub fn new( - model: M, - preamble: String, - static_context: Vec, - static_tools: Vec, - temperature: Option, - additional_params: Option, - ) -> Self { - let static_tools_ids = static_tools.iter().map(|tool| tool.name()).collect(); - - Self { - model, - preamble, - context: static_context - .into_iter() - .enumerate() - .map(|(i, doc)| Document { - id: format!("static_doc_{}", i), - text: doc, - additional_props: HashMap::new(), - }) - .collect(), - tools: ToolSet::from_tools(static_tools), - static_tools: static_tools_ids, - temperature, - additional_params, - } - } + pub tools: ToolSet, } impl Completion for Agent { @@ -139,12 +166,64 @@ impl Completion for Agent { prompt: &str, chat_history: Vec, ) -> Result, CompletionError> { - let tool_definitions = stream::iter(self.static_tools.iter()) + let dynamic_context = stream::iter(self.dynamic_context.iter()) + .then(|(num_sample, index)| async { + Ok::<_, VectorStoreError>( + index + .top_n_from_query(prompt, *num_sample) + .await? + .into_iter() + .map(|(_, doc)| { + // Pretty print the document if possible for better readability + let doc_text = serde_json::to_string_pretty(&doc.document) + .unwrap_or_else(|_| doc.document.to_string()); + + Document { + id: doc.id, + text: doc_text, + additional_props: HashMap::new(), + } + }) + .collect::>(), + ) + }) + .try_fold(vec![], |mut acc, docs| async { + acc.extend(docs); + Ok(acc) + }) + .await + .map_err(|e| CompletionError::RequestError(Box::new(e)))?; + + let dynamic_tools = stream::iter(self.dynamic_tools.iter()) + .then(|(num_sample, index)| async { + Ok::<_, VectorStoreError>( + index + .top_n_ids_from_query(prompt, *num_sample) + .await? + .into_iter() + .map(|(_, doc)| doc) + .collect::>(), + ) + }) + .try_fold(vec![], |mut acc, docs| async { + for doc in docs { + if let Some(tool) = self.tools.get(&doc) { + acc.push(tool.definition(prompt.into()).await) + } else { + tracing::warn!("Tool implementation not found in toolset: {}", doc); + } + } + Ok(acc) + }) + .await + .map_err(|e| CompletionError::RequestError(Box::new(e)))?; + + let static_tools = stream::iter(self.static_tools.iter()) .filter_map(|toolname| async move { if let Some(tool) = self.tools.get(toolname) { Some(tool.definition(prompt.into()).await) } else { - tracing::error!(target: "rig", "Agent static tool {} not found", toolname); + tracing::warn!("Tool implementation not found in toolset: {}", toolname); None } }) @@ -156,8 +235,8 @@ impl Completion for Agent { .completion_request(prompt) .preamble(self.preamble.clone()) .messages(chat_history) - .documents(self.context.clone()) - .tools(tool_definitions.clone()) + .documents([self.static_context.clone(), dynamic_context].concat()) + .tools([static_tools.clone(), dynamic_tools].concat()) .temperature_opt(self.temperature) .additional_params_opt(self.additional_params.clone())) } @@ -206,12 +285,23 @@ impl Chat for Agent { /// .build(); /// ``` pub struct AgentBuilder { + /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) model: M, + /// System prompt preamble: Option, + /// Context documents always available to the agent static_context: Vec, + /// Tools that are always available to the agent (by name) static_tools: Vec, - temperature: Option, + /// Additional parameters to be passed to the model additional_params: Option, + /// List of vector store, with the sample number + dynamic_context: Vec<(usize, Box)>, + /// Dynamic tools + dynamic_tools: Vec<(usize, Box)>, + /// Temperature of the model + temperature: Option, + /// Actual tool implementations tools: ToolSet, } @@ -224,13 +314,15 @@ impl AgentBuilder { static_tools: vec![], temperature: None, additional_params: None, + dynamic_context: vec![], + dynamic_tools: vec![], tools: ToolSet::default(), } } - /// Set the preamble of the agent - pub fn preamble(mut self, doc: &str) -> Self { - self.preamble = Some(doc.into()); + /// Set the system prompt + pub fn preamble(mut self, preamble: &str) -> Self { + self.preamble = Some(preamble.into()); self } @@ -262,6 +354,31 @@ impl AgentBuilder { self } + /// Add some dynamic context to the agent. On each prompt, `sample` documents from the + /// dynamic context will be inserted in the request. + pub fn dynamic_context( + mut self, + sample: usize, + dynamic_context: impl VectorStoreIndexDyn + 'static, + ) -> Self { + self.dynamic_context + .push((sample, Box::new(dynamic_context))); + self + } + + /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the + /// dynamic toolset will be inserted in the request. + pub fn dynamic_tools( + mut self, + sample: usize, + dynamic_tools: impl VectorStoreIndexDyn + 'static, + toolset: ToolSet, + ) -> Self { + self.dynamic_tools.push((sample, Box::new(dynamic_tools))); + self.tools.add_tools(toolset); + self + } + /// Set the temperature of the model pub fn temperature(mut self, temperature: f64) -> Self { self.temperature = Some(temperature); @@ -278,12 +395,14 @@ impl AgentBuilder { pub fn build(self) -> Agent { Agent { model: self.model, - preamble: self.preamble.unwrap_or_else(|| "".into()), - context: self.static_context, - tools: self.tools, + preamble: self.preamble.unwrap_or_default(), + static_context: self.static_context, static_tools: self.static_tools, temperature: self.temperature, additional_params: self.additional_params, + dynamic_context: self.dynamic_context, + dynamic_tools: self.dynamic_tools, + tools: self.tools, } } } diff --git a/rig-core/src/model.rs b/rig-core/src/model.rs index 8534dca5..b4e873ed 100644 --- a/rig-core/src/model.rs +++ b/rig-core/src/model.rs @@ -47,56 +47,11 @@ //! .await //! .expect("Failed to send completion request"); //! ``` -use crate::completion::{ - Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, - CompletionResponse, Message, ModelChoice, Prompt, PromptError, -}; - -/// A model that can be used to prompt completions from a completion model. -/// This is the simplest building block for creating an LLM powered application. -pub struct Model { - /// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`) - model: M, - /// Temperature of the model - temperature: Option, -} - -impl Completion for Model { - async fn completion( - &self, - prompt: &str, - chat_history: Vec, - ) -> Result, CompletionError> { - Ok(self - .model - .completion_request(prompt) - .messages(chat_history) - .temperature_opt(self.temperature)) - } -} - -impl Prompt for Model { - async fn prompt(&self, prompt: &str) -> Result { - self.chat(prompt, vec![]).await - } -} - -impl Chat for Model { - async fn chat(&self, prompt: &str, chat_history: Vec) -> Result { - match self.completion(prompt, chat_history).await?.send().await? { - CompletionResponse { - choice: ModelChoice::Message(message), - .. - } => Ok(message), - CompletionResponse { - choice: ModelChoice::ToolCall(toolname, _), - .. - } => Err(PromptError::ToolError( - crate::tool::ToolSetError::ToolNotFoundError(toolname), - )), - } - } -} +#[deprecated( + since = "0.2.0", + note = "Please use the `Agent` type directly instead of the `Model` type." +)] +pub type Model = crate::agent::Agent; /// A builder for creating a model /// @@ -113,37 +68,9 @@ impl Chat for Model { /// .temperature(0.8) /// .build(); /// ``` -pub struct ModelBuilder { - model: M, - temperature: Option, -} - -impl ModelBuilder { - /// Create a new model builder - pub fn new(model: M) -> Self { - Self { - model, - temperature: None, - } - } - - /// Set the temperature of the model - pub fn temperature(mut self, temperature: f64) -> Self { - self.temperature = Some(temperature); - self - } - - /// Set the temperature of the model (set to `None` to use the default temperature of the model) - pub fn temperature_opt(mut self, temperature: Option) -> Self { - self.temperature = temperature; - self - } - /// Build the model - pub fn build(self) -> Model { - Model { - model: self.model, - temperature: self.temperature, - } - } -} +#[deprecated( + since = "0.2.0", + note = "Please use the `AgentBuilder` type directly instead of the `ModelBuilder` type." +)] +pub type ModelBuilder = crate::agent::AgentBuilder; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 9efe4a58..5f5fc397 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -16,9 +16,6 @@ use crate::{ embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, json_utils, - model::ModelBuilder, - rag::RagAgentBuilder, - vector_store::{NoIndex, VectorStoreIndex}, }; use schemars::JsonSchema; @@ -77,8 +74,12 @@ impl Client { CompletionModel::new(self.clone(), model) } - pub fn model(&self, model: &str) -> ModelBuilder { - ModelBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `model` method." + )] + pub fn model(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } pub fn agent(&self, model: &str) -> AgentBuilder { @@ -92,25 +93,28 @@ impl Client { ExtractorBuilder::new(self.completion_model(model)) } - pub fn rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `rag_agent` method." + )] + pub fn rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } - pub fn tool_rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `tool_rag_agent` method." + )] + pub fn tool_rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } - pub fn context_rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `context_rag_agent` method." + )] + pub fn context_rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 8c12a745..87c3f557 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -14,9 +14,6 @@ use crate::{ embeddings::{self, EmbeddingError}, extractor::ExtractorBuilder, json_utils, - model::ModelBuilder, - rag::RagAgentBuilder, - vector_store::{NoIndex, VectorStoreIndex}, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -134,8 +131,12 @@ impl Client { /// .temperature(0.0) /// .build(); /// ``` - pub fn model(&self, model: &str) -> ModelBuilder { - ModelBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `model` method." + )] + pub fn model(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } /// Create an agent builder with the given completion model. @@ -164,25 +165,28 @@ impl Client { ExtractorBuilder::new(self.completion_model(model)) } - pub fn rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `rag_agent` method." + )] + pub fn rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } - pub fn tool_rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `tool_rag_agent` method." + )] + pub fn tool_rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } - pub fn context_rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `context_rag_agent` method." + )] + pub fn context_rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 01703143..e88ffb7c 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -14,9 +14,6 @@ use crate::{ completion::{self, CompletionError}, extractor::ExtractorBuilder, json_utils, - model::ModelBuilder, - rag::RagAgentBuilder, - vector_store::{NoIndex, VectorStoreIndex}, }; use schemars::JsonSchema; @@ -74,8 +71,12 @@ impl Client { CompletionModel::new(self.clone(), model) } - pub fn model(&self, model: &str) -> ModelBuilder { - ModelBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `model` method." + )] + pub fn model(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } pub fn agent(&self, model: &str) -> AgentBuilder { @@ -89,18 +90,20 @@ impl Client { ExtractorBuilder::new(self.completion_model(model)) } - pub fn rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `rag_agent` method." + )] + pub fn rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } - pub fn context_rag_agent( - &self, - model: &str, - ) -> RagAgentBuilder { - RagAgentBuilder::new(self.completion_model(model)) + #[deprecated( + since = "0.2.0", + note = "Please use the `agent` method instead of the `context_rag_agent` method." + )] + pub fn context_rag_agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) } } diff --git a/rig-core/src/rag.rs b/rig-core/src/rag.rs index 89bb21d1..8353fc92 100644 --- a/rig-core/src/rag.rs +++ b/rig-core/src/rag.rs @@ -53,276 +53,17 @@ //! let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await //! .expect("Failed to prompt the agent"); //! ``` -use std::collections::HashMap; -use futures::{stream, StreamExt, TryStreamExt}; +use crate::agent::{Agent, AgentBuilder}; -use crate::{ - completion::{ - Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, - CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError, - }, - tool::{Tool, ToolSet, ToolSetError}, - vector_store::{NoIndex, VectorStoreError, VectorStoreIndex}, -}; +#[deprecated( + since = "0.2.0", + note = "Please use the `Agent` type directly instead of the `RagAgent` type." +)] +pub type RagAgent = Agent; -/// Struct representing a RAG agent, i.e.: an agent enhanced with two collections of -/// vector store indices, one for context documents and one for tools. -/// The ragged context and tools are used to enhance the completion model at prompt-time. -/// Note: The type of the [VectorStoreIndex] must be the same for all the dynamic context -/// and tools indices (but can be different for context and tools). -/// If you need to use a more complex combination of vector store indices, -/// you should implement a custom agent. -pub struct RagAgent { - /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) - model: M, - /// System prompt - preamble: String, - /// Context documents always available to the agent - static_context: Vec, - /// Tools that are always available to the agent (identified by their name) - static_tools: Vec, - /// Temperature of the model - temperature: Option, - /// Additional parameters to be passed to the model - additional_params: Option, - /// List of vector store, with the sample number - dynamic_context: Vec<(usize, C)>, - /// Dynamic tools - dynamic_tools: Vec<(usize, T)>, - /// Actual tool implementations - pub tools: ToolSet, -} - -pub type ToolRagAgent = RagAgent; -pub type ContextRagAgent = RagAgent; - -impl Completion - for RagAgent -{ - async fn completion( - &self, - prompt: &str, - chat_history: Vec, - ) -> Result, CompletionError> { - let dynamic_context = stream::iter(self.dynamic_context.iter()) - .then(|(num_sample, index)| async { - Ok::<_, VectorStoreError>( - index - .top_n_from_query(prompt, *num_sample) - .await? - .into_iter() - .map(|(_, doc)| { - let doc_text = serde_json::to_string_pretty(&doc.document) - .unwrap_or_else(|_| doc.document.to_string()); - - Document { - id: doc.id, - text: doc_text, - additional_props: HashMap::new(), - } - }) - .collect::>(), - ) - }) - .try_fold(vec![], |mut acc, docs| async { - acc.extend(docs); - Ok(acc) - }) - .await - .map_err(|e| CompletionError::RequestError(Box::new(e)))?; - - let dynamic_tools = stream::iter(self.dynamic_tools.iter()) - .then(|(num_sample, index)| async { - Ok::<_, VectorStoreError>( - index - .top_n_ids_from_query(prompt, *num_sample) - .await? - .into_iter() - .map(|(_, doc)| doc) - .collect::>(), - ) - }) - .try_fold(vec![], |mut acc, docs| async { - for doc in docs { - if let Some(tool) = self.tools.get(&doc) { - acc.push(tool.definition(prompt.into()).await) - } else { - tracing::warn!("Tool implementation not found in toolset: {}", doc); - } - } - Ok(acc) - }) - .await - .map_err(|e| CompletionError::RequestError(Box::new(e)))?; - - let static_tools = stream::iter(self.static_tools.iter()) - .filter_map(|toolname| async move { - if let Some(tool) = self.tools.get(toolname) { - Some(tool.definition(prompt.into()).await) - } else { - tracing::warn!("Tool implementation not found in toolset: {}", toolname); - None - } - }) - .collect::>() - .await; - - Ok(self - .model - .completion_request(prompt) - .preamble(self.preamble.clone()) - .messages(chat_history) - .documents([self.static_context.clone(), dynamic_context].concat()) - .tools([static_tools.clone(), dynamic_tools].concat()) - .temperature_opt(self.temperature) - .additional_params_opt(self.additional_params.clone())) - } -} - -impl Prompt for RagAgent { - async fn prompt(&self, prompt: &str) -> Result { - self.chat(prompt, vec![]).await - } -} - -impl Chat for RagAgent { - async fn chat(&self, prompt: &str, chat_history: Vec) -> Result { - match self.completion(prompt, chat_history).await?.send().await? { - CompletionResponse { - choice: ModelChoice::Message(msg), - .. - } => Ok(msg), - CompletionResponse { - choice: ModelChoice::ToolCall(toolname, args), - .. - } => Ok(self.tools.call(&toolname, args.to_string()).await?), - } - } -} - -impl RagAgent { - pub async fn call_tool(&self, toolname: &str, args: &str) -> Result { - self.tools.call(toolname, args.to_string()).await - } -} - -/// Builder for creating a RAG agent -/// -/// # Example -/// ``` -/// use rig::{providers::openai, rag_agent::RagAgentBuilder}; -/// use serde_json::json; -/// -/// let openai_client = openai::Client::from_env(); -/// -/// let model = openai_client.completion_model("gpt-4"); -/// -/// // Configure the agent -/// let agent = RagAgentBuilder::new(model) -/// .preamble("System prompt") -/// .static_context("Context document 1") -/// .static_context("Context document 2") -/// .dynamic_context(2, vector_index) -/// .tool(tool1) -/// .tool(tool2) -/// .temperature(0.8) -/// .additional_params(json!({"foo": "bar"})) -/// .build(); -/// ``` -pub struct RagAgentBuilder { - /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) - model: M, - /// System prompt - preamble: Option, - /// Context documents always available to the agent - static_context: Vec, - /// Tools that are always available to the agent (by name) - static_tools: Vec, - /// Additional parameters to be passed to the model - additional_params: Option, - /// List of vector store, with the sample number - dynamic_context: Vec<(usize, C)>, - /// Dynamic tools - dynamic_tools: Vec<(usize, T)>, - /// Temperature of the model - temperature: Option, - /// Actual tool implementations - tools: ToolSet, -} - -impl RagAgentBuilder { - pub fn new(model: M) -> Self { - Self { - model, - preamble: None, - static_context: vec![], - static_tools: vec![], - temperature: None, - additional_params: None, - dynamic_context: vec![], - dynamic_tools: vec![], - tools: ToolSet::default(), - } - } - - /// Set the system prompt - pub fn preamble(mut self, preamble: &str) -> Self { - self.preamble = Some(preamble.into()); - self - } - - /// Add a static context document to the RAG agent - pub fn static_context(mut self, doc: &str) -> Self { - self.static_context.push(Document { - id: format!("static_doc_{}", self.static_context.len()), - text: doc.into(), - additional_props: HashMap::new(), - }); - self - } - - /// Add a static tool to the RAG agent - pub fn static_tool(mut self, tool: impl Tool + 'static) -> Self { - let toolname = tool.name(); - self.tools.add_tool(tool); - self.static_tools.push(toolname); - self - } - - /// Add some dynamic context to the RAG agent. On each prompt, `sample` documents from the - /// dynamic context will be inserted in the request. - pub fn dynamic_context(mut self, sample: usize, dynamic_context: C) -> Self { - self.dynamic_context.push((sample, dynamic_context)); - self - } - - /// Add some dynamic tools to the RAG agent. On each prompt, `sample` tools from the - /// dynamic toolset will be inserted in the request. - pub fn dynamic_tools(mut self, sample: usize, dynamic_tools: T, toolset: ToolSet) -> Self { - self.dynamic_tools.push((sample, dynamic_tools)); - self.tools.add_tools(toolset); - self - } - - /// Set the temperature of the model - pub fn temperature(mut self, temperature: f64) -> Self { - self.temperature = Some(temperature); - self - } - - /// Build the RAG agent - pub fn build(self) -> RagAgent { - RagAgent { - model: self.model, - preamble: self.preamble.unwrap_or_default(), - static_context: self.static_context, - static_tools: self.static_tools, - temperature: self.temperature, - additional_params: self.additional_params, - dynamic_context: self.dynamic_context, - dynamic_tools: self.dynamic_tools, - tools: self.tools, - } - } -} +#[deprecated( + since = "0.2.0", + note = "Please use the `AgentBuilder` type directly instead of the `RagAgentBuilder` type." +)] +pub type RagAgentBuilder = AgentBuilder; diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 0caea0f0..2e6a652f 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,3 +1,4 @@ +use futures::future::BoxFuture; use serde::Deserialize; use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingError}; @@ -134,6 +135,66 @@ pub trait VectorStoreIndex: Send + Sync { } } +pub trait VectorStoreIndexDyn: Send + Sync { + fn top_n_from_query<'a>( + &'a self, + query: &'a str, + n: usize, + ) -> BoxFuture<'a, Result, VectorStoreError>>; + + fn top_n_ids_from_query<'a>( + &'a self, + query: &'a str, + n: usize, + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(async move { + let documents = self.top_n_from_query(query, n).await?; + Ok(documents + .into_iter() + .map(|(distance, doc)| (distance, doc.id)) + .collect()) + }) + } + + fn top_n_from_embedding<'a>( + &'a self, + prompt_embedding: &'a Embedding, + n: usize, + ) -> BoxFuture<'a, Result, VectorStoreError>>; + + fn top_n_ids_from_embedding<'a>( + &'a self, + prompt_embedding: &'a Embedding, + n: usize, + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(async move { + let documents = self.top_n_from_embedding(prompt_embedding, n).await?; + Ok(documents + .into_iter() + .map(|(distance, doc)| (distance, doc.id)) + .collect()) + }) + } +} + +impl VectorStoreIndexDyn for I { + fn top_n_from_query<'a>( + &'a self, + query: &'a str, + n: usize, + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(self.top_n_from_query(query, n)) + } + + fn top_n_from_embedding<'a>( + &'a self, + prompt_embedding: &'a Embedding, + n: usize, + ) -> BoxFuture<'a, Result, VectorStoreError>> { + Box::pin(self.top_n_from_embedding(prompt_embedding, n)) + } +} + pub struct NoIndex; impl VectorStoreIndex for NoIndex {