Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Merge and generalize Agent types #21

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
.tool_rag_agent("gpt-4")
.agent("gpt-4")
.preamble(
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
Follow these instructions closely.
Expand Down
5 changes: 2 additions & 3 deletions rig-core/examples/multi_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use rig::{
agent::{Agent, AgentBuilder},
cli_chatbot::cli_chatbot,
completion::{Chat, CompletionModel, Message, PromptError},
model::{Model, ModelBuilder},
providers::openai::Client as OpenAIClient,
};

Expand All @@ -14,7 +13,7 @@ use rig::{
/// prompt in english, before answering it with GPT-4. The answer in english is returned.
struct EnglishTranslator<M: CompletionModel> {
translator_agent: Agent<M>,
gpt4: Model<M>,
gpt4: Agent<M>,
}

impl<M: CompletionModel> EnglishTranslator<M> {
Expand All @@ -29,7 +28,7 @@ impl<M: CompletionModel> EnglishTranslator<M> {
.build(),

// Create the GPT4 model
gpt4: ModelBuilder::new(model).build()
gpt4: AgentBuilder::new(model).build()
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async fn main() -> Result<(), anyhow::Error> {
// Create vector store index
let index = vector_store.index(embedding_model);

let rag_agent = openai_client.context_rag_agent("gpt-4")
let rag_agent = openai_client.agent("gpt-4")
.preamble("
You are a dictionary assistant here to assist the user in understanding the meaning of words.
You will find additional non-standard word definitions that could be useful below.
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
.tool_rag_agent("gpt-4")
.agent("gpt-4")
.preamble("You are a calculator here to help the user perform arithmetic operations.")
// Add a dynamic tool source with a sample rate of 1 (i.e.: only
// 1 additional tool will be added to prompts)
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/simple_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ async fn main() {
// Create OpenAI client and model
let openai_client = openai::Client::from_env();

let gpt4 = openai_client.model("gpt-4").build();
let gpt4 = openai_client.agent("gpt-4").build();

// Prompt the model and print its response
let response = gpt4
Expand Down
216 changes: 167 additions & 49 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
//! This module contains the implementation of the [Agent] struct and its builder.
//!
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model
//! by providing prompts and chat history without having to provide the preamble and other parameters everytime.
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
//!
//! The [Agent] struct is highly configurable, allowing the user to define anything from
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
//! context documents and tools.
//!
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
//! be used for generating chat completions.
//!
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
//! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters
Expand Down Expand Up @@ -52,16 +60,63 @@
//! .await
//! .expect("Failed to send completion request");
//! ```
//!
//! RAG Agent example
//! ```rust
//! use rig::{
//! completion::Prompt,
//! embeddings::EmbeddingsBuilder,
//! providers::openai,
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
//! };
//!
//! // Initialize OpenAI client
//! let openai = openai::Client::from_env();
//!
//! // Initialize OpenAI embedding model
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
//!
//! // Create vector store, compute embeddings and load them in the store
//! let mut vector_store = InMemoryVectorStore::default();
//!
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
//! .build()
//! .await
//! .expect("Failed to build embeddings");
//!
//! vector_store.add_documents(embeddings)
//! .await
//! .expect("Failed to add documents");
//!
//! // Create vector store index
//! let index = vector_store.index(embedding_model);
//!
//! let rag_agent = openai.context_rag_agent(openai::GPT_4O)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not outdated, should it not be openai.agent?

//! .preamble("
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
//! You will find additional non-standard word definitions that could be useful below.
//! ")
//! .dynamic_context(1, index)
//! .build();
//!
//! // Prompt the agent and print the response
//! let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
use std::collections::HashMap;

use futures::{stream, StreamExt};
use futures::{stream, StreamExt, TryStreamExt};

use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
},
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};

/// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble
Expand All @@ -85,52 +140,24 @@ use crate::{
/// .expect("Failed to prompt the agent");
/// ```
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`)
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: String,
/// Context documents always available to the agent
context: Vec<Document>,
static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
static_tools: Vec<String>,
/// Temperature of the model
temperature: Option<f64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Actual tool implementations
tools: ToolSet,
}

impl<M: CompletionModel> Agent<M> {
/// Create a new Agent
pub fn new(
model: M,
preamble: String,
static_context: Vec<String>,
static_tools: Vec<impl Tool + 'static>,
temperature: Option<f64>,
additional_params: Option<serde_json::Value>,
) -> Self {
let static_tools_ids = static_tools.iter().map(|tool| tool.name()).collect();

Self {
model,
preamble,
context: static_context
.into_iter()
.enumerate()
.map(|(i, doc)| Document {
id: format!("static_doc_{}", i),
text: doc,
additional_props: HashMap::new(),
})
.collect(),
tools: ToolSet::from_tools(static_tools),
static_tools: static_tools_ids,
temperature,
additional_params,
}
}
pub tools: ToolSet,
}

impl<M: CompletionModel> Completion<M> for Agent<M> {
Expand All @@ -139,12 +166,63 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
prompt: &str,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let tool_definitions = stream::iter(self.static_tools.iter())
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| {
let doc_text = serde_json::to_string_pretty(&doc.document)
.unwrap_or_else(|_| doc.document.to_string());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why this is prettified or is this standard.


Document {
id: doc.id,
text: doc_text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| doc)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(prompt.into()).await)
} else {
tracing::error!(target: "rig", "Agent static tool {} not found", toolname);
tracing::warn!("Tool implementation not found in toolset: {}", toolname);
None
}
})
Expand All @@ -156,8 +234,8 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.documents(self.context.clone())
.tools(tool_definitions.clone())
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.additional_params_opt(self.additional_params.clone()))
}
Expand Down Expand Up @@ -206,12 +284,23 @@ impl<M: CompletionModel> Chat for Agent<M> {
/// .build();
/// ```
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: Option<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
static_tools: Vec<String>,
temperature: Option<f64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
tools: ToolSet,
}

Expand All @@ -224,13 +313,15 @@ impl<M: CompletionModel> AgentBuilder<M> {
static_tools: vec![],
temperature: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}

/// Set the preamble of the agent
pub fn preamble(mut self, doc: &str) -> Self {
self.preamble = Some(doc.into());
/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}

Expand Down Expand Up @@ -262,6 +353,31 @@ impl<M: CompletionModel> AgentBuilder<M> {
self
}

/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
/// dynamic context will be inserted in the request.
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}

/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
/// dynamic toolset will be inserted in the request.
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}

/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
Expand All @@ -278,12 +394,14 @@ impl<M: CompletionModel> AgentBuilder<M> {
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_else(|| "".into()),
context: self.static_context,
tools: self.tools,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}
Loading