-
Notifications
You must be signed in to change notification settings - Fork 324
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: support together ai * feat: support mistral ai * feat: support together ai * chore: format codes * chore: refactor together ai support * chore: format codes * feat: fine tune together ai's embedding model * feat: add more chat models and embedding models of together AI * chore: format codes * chore: add tool example --------- Co-authored-by: Joshua Mo <[email protected]>
- Loading branch information
1 parent
9128188
commit f3529ac
Showing
8 changed files
with
759 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
use rig::{ | ||
agent::AgentBuilder, | ||
completion::{Prompt, ToolDefinition}, | ||
providers::together, | ||
tool::Tool, | ||
}; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::json; | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
println!("Running basic agent with together"); | ||
basic().await?; | ||
|
||
println!("\nRunning tools agent with tools"); | ||
tools().await?; | ||
|
||
println!("\nRunning together agent with context"); | ||
context().await?; | ||
|
||
println!("\n\nAll agents ran successfully"); | ||
Ok(()) | ||
} | ||
|
||
async fn basic() -> Result<(), anyhow::Error> { | ||
let together_ai_client = together::Client::new( | ||
&std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"), | ||
); | ||
|
||
// Choose a model, replace "together-model-v1" with an actual Together AI model name | ||
let model = | ||
together_ai_client.completion_model(rig::providers::together::MIXTRAL_8X7B_INSTRUCT_V0_1); | ||
|
||
let agent = AgentBuilder::new(model) | ||
.preamble("You are a comedian here to entertain the user using humour and jokes.") | ||
.build(); | ||
|
||
// Prompt the agent and print the response | ||
let response = agent.prompt("Entertain me!").await?; | ||
println!("{}", response); | ||
|
||
Ok(()) | ||
} | ||
|
||
async fn tools() -> Result<(), anyhow::Error> { | ||
// Create Together AI client | ||
let together_ai_client = together::Client::new( | ||
&std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"), | ||
); | ||
|
||
// Choose a model, replace "together-model-v1" with an actual Together AI model name | ||
let model = | ||
together_ai_client.completion_model(rig::providers::together::MIXTRAL_8X7B_INSTRUCT_V0_1); | ||
|
||
// Create an agent with multiple context documents | ||
let calculator_agent = AgentBuilder::new(model) | ||
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") | ||
.tool(Adder) | ||
.build(); | ||
|
||
// Prompt the agent and print the response | ||
println!("Calculate 5 + 3"); | ||
println!( | ||
"Calculator Agent: {}", | ||
calculator_agent.prompt("Calculate 5 + 3").await? | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
async fn context() -> Result<(), anyhow::Error> { | ||
// Create Together AI client | ||
let together_ai_client = together::Client::new( | ||
&std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"), | ||
); | ||
|
||
// Choose a model, replace "together-model-v1" with an actual Together AI model name | ||
let model = | ||
together_ai_client.completion_model(rig::providers::together::MIXTRAL_8X7B_INSTRUCT_V0_1); | ||
|
||
// Create an agent with multiple context documents | ||
let agent = AgentBuilder::new(model) | ||
.context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") | ||
.context("Definition of a *glarb-glarb*: A glarb-glarb is an ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") | ||
.context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") | ||
.build(); | ||
|
||
// Prompt the agent and print the response | ||
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?; | ||
|
||
println!("{}", response); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct OperationArgs { | ||
x: i32, | ||
y: i32, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[error("Math error")] | ||
struct MathError; | ||
|
||
#[derive(Deserialize, Serialize)] | ||
struct Adder; | ||
impl Tool for Adder { | ||
const NAME: &'static str = "add"; | ||
|
||
type Error = MathError; | ||
type Args = OperationArgs; | ||
type Output = i32; | ||
|
||
async fn definition(&self, _prompt: String) -> ToolDefinition { | ||
ToolDefinition { | ||
name: "add".to_string(), | ||
description: "Add x and y together".to_string(), | ||
parameters: json!({ | ||
"type": "object", | ||
"properties": { | ||
"x": { | ||
"type": "number", | ||
"description": "The first number to add" | ||
}, | ||
"y": { | ||
"type": "number", | ||
"description": "The second number to add" | ||
} | ||
} | ||
}), | ||
} | ||
} | ||
|
||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { | ||
println!("The args: {:?}", args); | ||
let result = args.x + args.y; | ||
Ok(result) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
use rig::providers::together; | ||
use rig::Embed; | ||
|
||
#[derive(Embed, Debug)] | ||
struct Greetings { | ||
#[embed] | ||
message: String, | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
// Initialize the together client | ||
let client = together::Client::from_env(); | ||
|
||
let embeddings = client | ||
.embeddings(together::embedding::M2_BERT_80M_8K_RETRIEVAL) | ||
.document(Greetings { | ||
message: "Hello, world!".to_string(), | ||
})? | ||
.document(Greetings { | ||
message: "Goodbye, world!".to_string(), | ||
})? | ||
.build() | ||
.await | ||
.expect("Failed to embed documents"); | ||
|
||
println!("{:?}", embeddings); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,4 +55,5 @@ pub mod hyperbolic; | |
pub mod moonshot; | ||
pub mod openai; | ||
pub mod perplexity; | ||
pub mod together; | ||
pub mod xai; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
use crate::{ | ||
agent::AgentBuilder, | ||
embeddings::{self}, | ||
extractor::ExtractorBuilder, | ||
Embed, | ||
}; | ||
use schemars::JsonSchema; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use super::{completion::CompletionModel, embedding::EmbeddingModel, M2_BERT_80M_8K_RETRIEVAL}; | ||
|
||
// ================================================================ | ||
// Together AI Client | ||
// ================================================================ | ||
const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz"; | ||
|
||
#[derive(Clone)] | ||
pub struct Client { | ||
base_url: String, | ||
http_client: reqwest::Client, | ||
} | ||
|
||
impl Client { | ||
/// Create a new Together AI client with the given API key. | ||
pub fn new(api_key: &str) -> Self { | ||
Self::from_url(api_key, TOGETHER_AI_BASE_URL) | ||
} | ||
|
||
fn from_url(api_key: &str, base_url: &str) -> Self { | ||
Self { | ||
base_url: base_url.to_string(), | ||
http_client: reqwest::Client::builder() | ||
.default_headers({ | ||
let mut headers = reqwest::header::HeaderMap::new(); | ||
headers.insert( | ||
reqwest::header::CONTENT_TYPE, | ||
"application/json".parse().unwrap(), | ||
); | ||
headers.insert( | ||
"Authorization", | ||
format!("Bearer {}", api_key) | ||
.parse() | ||
.expect("Bearer token should parse"), | ||
); | ||
headers | ||
}) | ||
.build() | ||
.expect("Together AI reqwest client should build"), | ||
} | ||
} | ||
|
||
/// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable. | ||
/// Panics if the environment variable is not set. | ||
pub fn from_env() -> Self { | ||
let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"); | ||
Self::new(&api_key) | ||
} | ||
|
||
pub fn post(&self, path: &str) -> reqwest::RequestBuilder { | ||
let url = format!("{}/{}", self.base_url, path).replace("//", "/"); | ||
|
||
tracing::debug!("POST {}", url); | ||
self.http_client.post(url) | ||
} | ||
|
||
/// Create an embedding model with the given name. | ||
/// Note: default embedding dimension of 0 will be used if model is not known. | ||
/// If this is the case, it's better to use function `embedding_model_with_ndims` | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// use rig::providers::together_ai::{Client, self}; | ||
/// | ||
/// // Initialize the Together AI client | ||
/// let together_ai = Client::new("your-together-ai-api-key"); | ||
/// | ||
/// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1); | ||
/// ``` | ||
pub fn embedding_model(&self, model: &str) -> EmbeddingModel { | ||
let ndims = match model { | ||
M2_BERT_80M_8K_RETRIEVAL => 8192, | ||
_ => 0, | ||
}; | ||
EmbeddingModel::new(self.clone(), model, ndims) | ||
} | ||
|
||
/// Create an embedding model with the given name and the number of dimensions in the embedding | ||
/// generated by the model. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// use rig::providers::together_ai::{Client, self}; | ||
/// | ||
/// // Initialize the Together AI client | ||
/// let together_ai = Client::new("your-together-ai-api-key"); | ||
/// | ||
/// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024); | ||
/// ``` | ||
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { | ||
EmbeddingModel::new(self.clone(), model, ndims) | ||
} | ||
|
||
/// Create an embedding builder with the given embedding model. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// use rig::providers::together_ai::{Client, self}; | ||
/// | ||
/// // Initialize the Together AI client | ||
/// let together_ai = Client::new("your-together-ai-api-key"); | ||
/// | ||
/// let embeddings = together_ai.embeddings(together_ai::embedding::EMBEDDING_V1) | ||
/// .simple_document("doc0", "Hello, world!") | ||
/// .simple_document("doc1", "Goodbye, world!") | ||
/// .build() | ||
/// .await | ||
/// .expect("Failed to embed documents"); | ||
/// ``` | ||
pub fn embeddings<D: Embed>( | ||
&self, | ||
model: &str, | ||
) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> { | ||
embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) | ||
} | ||
|
||
/// Create a completion model with the given name. | ||
pub fn completion_model(&self, model: &str) -> CompletionModel { | ||
CompletionModel::new(self.clone(), model) | ||
} | ||
|
||
/// Create an agent builder with the given completion model. | ||
/// # Example | ||
/// ``` | ||
/// use rig::providers::together_ai::{Client, self}; | ||
/// | ||
/// // Initialize the Together AI client | ||
/// let together_ai = Client::new("your-together-ai-api-key"); | ||
/// | ||
/// let agent = together_ai.agent(together_ai::completion::MODEL_NAME) | ||
/// .preamble("You are comedian AI with a mission to make people laugh.") | ||
/// .temperature(0.0) | ||
/// .build(); | ||
/// ``` | ||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> { | ||
AgentBuilder::new(self.completion_model(model)) | ||
} | ||
|
||
/// Create an extractor builder with the given completion model. | ||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>( | ||
&self, | ||
model: &str, | ||
) -> ExtractorBuilder<T, CompletionModel> { | ||
ExtractorBuilder::new(self.completion_model(model)) | ||
} | ||
} | ||
|
||
pub mod together_ai_api_types { | ||
use serde::Deserialize; | ||
|
||
impl ApiErrorResponse { | ||
pub fn message(&self) -> String { | ||
format!("Code `{}`: {}", self.code, self.error) | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
pub struct ApiErrorResponse { | ||
pub error: String, | ||
pub code: String, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
#[serde(untagged)] | ||
pub enum ApiResponse<T> { | ||
Ok(T), | ||
Error(ApiErrorResponse), | ||
} | ||
} |
Oops, something went wrong.