-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor: Merge and generalize Agent types #21
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,16 @@ | ||
//! This module contains the implementation of the [Agent] struct and its builder. | ||
//! | ||
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt), | ||
//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model | ||
//! by providing prompts and chat history without having to provide the preamble and other parameters everytime. | ||
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either | ||
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time). | ||
//! | ||
//! The [Agent] struct is highly configurable, allowing the user to define anything from | ||
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic | ||
//! context documents and tools. | ||
//! | ||
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating | ||
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to | ||
//! be used for generating chat completions. | ||
//! | ||
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent]. | ||
//! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters | ||
|
@@ -52,16 +60,63 @@ | |
//! .await | ||
//! .expect("Failed to send completion request"); | ||
//! ``` | ||
//! | ||
//! RAG Agent example | ||
//! ```rust | ||
//! use rig::{ | ||
//! completion::Prompt, | ||
//! embeddings::EmbeddingsBuilder, | ||
//! providers::openai, | ||
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, | ||
//! }; | ||
//! | ||
//! // Initialize OpenAI client | ||
//! let openai = openai::Client::from_env(); | ||
//! | ||
//! // Initialize OpenAI embedding model | ||
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); | ||
//! | ||
//! // Create vector store, compute embeddings and load them in the store | ||
//! let mut vector_store = InMemoryVectorStore::default(); | ||
//! | ||
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) | ||
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") | ||
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") | ||
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") | ||
//! .build() | ||
//! .await | ||
//! .expect("Failed to build embeddings"); | ||
//! | ||
//! vector_store.add_documents(embeddings) | ||
//! .await | ||
//! .expect("Failed to add documents"); | ||
//! | ||
//! // Create vector store index | ||
//! let index = vector_store.index(embedding_model); | ||
//! | ||
//! let rag_agent = openai.context_rag_agent(openai::GPT_4O) | ||
//! .preamble(" | ||
//! You are a dictionary assistant here to assist the user in understanding the meaning of words. | ||
//! You will find additional non-standard word definitions that could be useful below. | ||
//! ") | ||
//! .dynamic_context(1, index) | ||
//! .build(); | ||
//! | ||
//! // Prompt the agent and print the response | ||
//! let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await | ||
//! .expect("Failed to prompt the agent"); | ||
//! ``` | ||
use std::collections::HashMap; | ||
|
||
use futures::{stream, StreamExt}; | ||
use futures::{stream, StreamExt, TryStreamExt}; | ||
|
||
use crate::{ | ||
completion::{ | ||
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, | ||
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError, | ||
}, | ||
tool::{Tool, ToolSet}, | ||
vector_store::{VectorStoreError, VectorStoreIndexDyn}, | ||
}; | ||
|
||
/// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble | ||
|
@@ -85,52 +140,24 @@ use crate::{ | |
/// .expect("Failed to prompt the agent"); | ||
/// ``` | ||
pub struct Agent<M: CompletionModel> { | ||
/// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`) | ||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) | ||
model: M, | ||
/// System prompt | ||
preamble: String, | ||
/// Context documents always available to the agent | ||
context: Vec<Document>, | ||
static_context: Vec<Document>, | ||
/// Tools that are always available to the agent (identified by their name) | ||
static_tools: Vec<String>, | ||
/// Temperature of the model | ||
temperature: Option<f64>, | ||
/// Additional parameters to be passed to the model | ||
additional_params: Option<serde_json::Value>, | ||
/// List of vector store, with the sample number | ||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>, | ||
/// Dynamic tools | ||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>, | ||
/// Actual tool implementations | ||
tools: ToolSet, | ||
} | ||
|
||
impl<M: CompletionModel> Agent<M> { | ||
/// Create a new Agent | ||
pub fn new( | ||
model: M, | ||
preamble: String, | ||
static_context: Vec<String>, | ||
static_tools: Vec<impl Tool + 'static>, | ||
temperature: Option<f64>, | ||
additional_params: Option<serde_json::Value>, | ||
) -> Self { | ||
let static_tools_ids = static_tools.iter().map(|tool| tool.name()).collect(); | ||
|
||
Self { | ||
model, | ||
preamble, | ||
context: static_context | ||
.into_iter() | ||
.enumerate() | ||
.map(|(i, doc)| Document { | ||
id: format!("static_doc_{}", i), | ||
text: doc, | ||
additional_props: HashMap::new(), | ||
}) | ||
.collect(), | ||
tools: ToolSet::from_tools(static_tools), | ||
static_tools: static_tools_ids, | ||
temperature, | ||
additional_params, | ||
} | ||
} | ||
pub tools: ToolSet, | ||
} | ||
|
||
impl<M: CompletionModel> Completion<M> for Agent<M> { | ||
|
@@ -139,12 +166,63 @@ impl<M: CompletionModel> Completion<M> for Agent<M> { | |
prompt: &str, | ||
chat_history: Vec<Message>, | ||
) -> Result<CompletionRequestBuilder<M>, CompletionError> { | ||
let tool_definitions = stream::iter(self.static_tools.iter()) | ||
let dynamic_context = stream::iter(self.dynamic_context.iter()) | ||
.then(|(num_sample, index)| async { | ||
Ok::<_, VectorStoreError>( | ||
index | ||
.top_n_from_query(prompt, *num_sample) | ||
.await? | ||
.into_iter() | ||
.map(|(_, doc)| { | ||
let doc_text = serde_json::to_string_pretty(&doc.document) | ||
.unwrap_or_else(|_| doc.document.to_string()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why this is prettified or is this standard. |
||
|
||
Document { | ||
id: doc.id, | ||
text: doc_text, | ||
additional_props: HashMap::new(), | ||
} | ||
}) | ||
.collect::<Vec<_>>(), | ||
) | ||
}) | ||
.try_fold(vec![], |mut acc, docs| async { | ||
acc.extend(docs); | ||
Ok(acc) | ||
}) | ||
.await | ||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?; | ||
|
||
let dynamic_tools = stream::iter(self.dynamic_tools.iter()) | ||
.then(|(num_sample, index)| async { | ||
Ok::<_, VectorStoreError>( | ||
index | ||
.top_n_ids_from_query(prompt, *num_sample) | ||
.await? | ||
.into_iter() | ||
.map(|(_, doc)| doc) | ||
.collect::<Vec<_>>(), | ||
) | ||
}) | ||
.try_fold(vec![], |mut acc, docs| async { | ||
for doc in docs { | ||
if let Some(tool) = self.tools.get(&doc) { | ||
acc.push(tool.definition(prompt.into()).await) | ||
} else { | ||
tracing::warn!("Tool implementation not found in toolset: {}", doc); | ||
} | ||
} | ||
Ok(acc) | ||
}) | ||
.await | ||
.map_err(|e| CompletionError::RequestError(Box::new(e)))?; | ||
|
||
let static_tools = stream::iter(self.static_tools.iter()) | ||
.filter_map(|toolname| async move { | ||
if let Some(tool) = self.tools.get(toolname) { | ||
Some(tool.definition(prompt.into()).await) | ||
} else { | ||
tracing::error!(target: "rig", "Agent static tool {} not found", toolname); | ||
tracing::warn!("Tool implementation not found in toolset: {}", toolname); | ||
None | ||
} | ||
}) | ||
|
@@ -156,8 +234,8 @@ impl<M: CompletionModel> Completion<M> for Agent<M> { | |
.completion_request(prompt) | ||
.preamble(self.preamble.clone()) | ||
.messages(chat_history) | ||
.documents(self.context.clone()) | ||
.tools(tool_definitions.clone()) | ||
.documents([self.static_context.clone(), dynamic_context].concat()) | ||
.tools([static_tools.clone(), dynamic_tools].concat()) | ||
.temperature_opt(self.temperature) | ||
.additional_params_opt(self.additional_params.clone())) | ||
} | ||
|
@@ -206,12 +284,23 @@ impl<M: CompletionModel> Chat for Agent<M> { | |
/// .build(); | ||
/// ``` | ||
pub struct AgentBuilder<M: CompletionModel> { | ||
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) | ||
model: M, | ||
/// System prompt | ||
preamble: Option<String>, | ||
/// Context documents always available to the agent | ||
static_context: Vec<Document>, | ||
/// Tools that are always available to the agent (by name) | ||
static_tools: Vec<String>, | ||
temperature: Option<f64>, | ||
/// Additional parameters to be passed to the model | ||
additional_params: Option<serde_json::Value>, | ||
/// List of vector store, with the sample number | ||
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>, | ||
/// Dynamic tools | ||
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>, | ||
/// Temperature of the model | ||
temperature: Option<f64>, | ||
/// Actual tool implementations | ||
tools: ToolSet, | ||
} | ||
|
||
|
@@ -224,13 +313,15 @@ impl<M: CompletionModel> AgentBuilder<M> { | |
static_tools: vec![], | ||
temperature: None, | ||
additional_params: None, | ||
dynamic_context: vec![], | ||
dynamic_tools: vec![], | ||
tools: ToolSet::default(), | ||
} | ||
} | ||
|
||
/// Set the preamble of the agent | ||
pub fn preamble(mut self, doc: &str) -> Self { | ||
self.preamble = Some(doc.into()); | ||
/// Set the system prompt | ||
pub fn preamble(mut self, preamble: &str) -> Self { | ||
self.preamble = Some(preamble.into()); | ||
self | ||
} | ||
|
||
|
@@ -262,6 +353,31 @@ impl<M: CompletionModel> AgentBuilder<M> { | |
self | ||
} | ||
|
||
/// Add some dynamic context to the agent. On each prompt, `sample` documents from the | ||
/// dynamic context will be inserted in the request. | ||
pub fn dynamic_context( | ||
mut self, | ||
sample: usize, | ||
dynamic_context: impl VectorStoreIndexDyn + 'static, | ||
) -> Self { | ||
self.dynamic_context | ||
.push((sample, Box::new(dynamic_context))); | ||
self | ||
} | ||
|
||
/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the | ||
/// dynamic toolset will be inserted in the request. | ||
pub fn dynamic_tools( | ||
mut self, | ||
sample: usize, | ||
dynamic_tools: impl VectorStoreIndexDyn + 'static, | ||
toolset: ToolSet, | ||
) -> Self { | ||
self.dynamic_tools.push((sample, Box::new(dynamic_tools))); | ||
self.tools.add_tools(toolset); | ||
self | ||
} | ||
|
||
/// Set the temperature of the model | ||
pub fn temperature(mut self, temperature: f64) -> Self { | ||
self.temperature = Some(temperature); | ||
|
@@ -278,12 +394,14 @@ impl<M: CompletionModel> AgentBuilder<M> { | |
pub fn build(self) -> Agent<M> { | ||
Agent { | ||
model: self.model, | ||
preamble: self.preamble.unwrap_or_else(|| "".into()), | ||
context: self.static_context, | ||
tools: self.tools, | ||
preamble: self.preamble.unwrap_or_default(), | ||
static_context: self.static_context, | ||
static_tools: self.static_tools, | ||
temperature: self.temperature, | ||
additional_params: self.additional_params, | ||
dynamic_context: self.dynamic_context, | ||
dynamic_tools: self.dynamic_tools, | ||
tools: self.tools, | ||
} | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this not outdated, should it not be
openai.agent
?