Skip to content

Commit

Permalink
feat: support together AI (#230)
Browse files Browse the repository at this point in the history
* feat: support together ai

* feat: support mistral ai

* feat: support together ai

* chore: format codes

* chore: refactor together ai support

* chore: format codes

* feat: fine tune together ai's embedding model

* feat: add more chat models and embedding models of together AI

* chore: format codes

* chore: add tool example

---------

Co-authored-by: Joshua Mo <[email protected]>
  • Loading branch information
threewebcode and joshua-mo-143 authored Feb 20, 2025
1 parent 9128188 commit f3529ac
Show file tree
Hide file tree
Showing 8 changed files with 759 additions and 0 deletions.
8 changes: 8 additions & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,11 @@ required-features = ["derive"]
[[example]]
name = "agent_with_moonshot"
required-features = ["derive"]

[[example]]
name = "agent_with_together"
required-features = ["derive"]

[[example]]
name = "together_embeddings"
required-features = ["derive"]
140 changes: 140 additions & 0 deletions rig-core/examples/agent_with_together.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
use rig::{
agent::AgentBuilder,
completion::{Prompt, ToolDefinition},
providers::together,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
println!("Running basic agent with together");
basic().await?;

println!("\nRunning tools agent with tools");
tools().await?;

println!("\nRunning together agent with context");
context().await?;

println!("\n\nAll agents ran successfully");
Ok(())
}

async fn basic() -> Result<(), anyhow::Error> {
let together_ai_client = together::Client::new(
&std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"),
);

// Choose a model, replace "together-model-v1" with an actual Together AI model name
let model =
together_ai_client.completion_model(rig::providers::together::MIXTRAL_8X7B_INSTRUCT_V0_1);

let agent = AgentBuilder::new(model)
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

// Prompt the agent and print the response
let response = agent.prompt("Entertain me!").await?;
println!("{}", response);

Ok(())
}

async fn tools() -> Result<(), anyhow::Error> {
// Create Together AI client
let together_ai_client = together::Client::new(
&std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"),
);

// Choose a model, replace "together-model-v1" with an actual Together AI model name
let model =
together_ai_client.completion_model(rig::providers::together::MIXTRAL_8X7B_INSTRUCT_V0_1);

// Create an agent with multiple context documents
let calculator_agent = AgentBuilder::new(model)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.tool(Adder)
.build();

// Prompt the agent and print the response
println!("Calculate 5 + 3");
println!(
"Calculator Agent: {}",
calculator_agent.prompt("Calculate 5 + 3").await?
);

Ok(())
}

async fn context() -> Result<(), anyhow::Error> {
// Create Together AI client
let together_ai_client = together::Client::new(
&std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"),
);

// Choose a model, replace "together-model-v1" with an actual Together AI model name
let model =
together_ai_client.completion_model(rig::providers::together::MIXTRAL_8X7B_INSTRUCT_V0_1);

// Create an agent with multiple context documents
let agent = AgentBuilder::new(model)
.context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.context("Definition of a *glarb-glarb*: A glarb-glarb is an ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build();

// Prompt the agent and print the response
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;

println!("{}", response);

Ok(())
}

#[derive(Debug, Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;

#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("The args: {:?}", args);
let result = args.x + args.y;
Ok(result)
}
}
30 changes: 30 additions & 0 deletions rig-core/examples/together_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use rig::providers::together;
use rig::Embed;

#[derive(Embed, Debug)]
struct Greetings {
#[embed]
message: String,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the together client
let client = together::Client::from_env();

let embeddings = client
.embeddings(together::embedding::M2_BERT_80M_8K_RETRIEVAL)
.document(Greetings {
message: "Hello, world!".to_string(),
})?
.document(Greetings {
message: "Goodbye, world!".to_string(),
})?
.build()
.await
.expect("Failed to embed documents");

println!("{:?}", embeddings);

Ok(())
}
1 change: 1 addition & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ pub mod hyperbolic;
pub mod moonshot;
pub mod openai;
pub mod perplexity;
pub mod together;
pub mod xai;
178 changes: 178 additions & 0 deletions rig-core/src/providers/together/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use crate::{
agent::AgentBuilder,
embeddings::{self},
extractor::ExtractorBuilder,
Embed,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::{completion::CompletionModel, embedding::EmbeddingModel, M2_BERT_80M_8K_RETRIEVAL};

// ================================================================
// Together AI Client
// ================================================================
const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
/// Create a new Together AI client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, TOGETHER_AI_BASE_URL)
}

fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Together AI reqwest client should build"),
}
}

/// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
Self::new(&api_key)
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");

tracing::debug!("POST {}", url);
self.http_client.post(url)
}

/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig::providers::together_ai::{Client, self};
///
/// // Initialize the Together AI client
/// let together_ai = Client::new("your-together-ai-api-key");
///
/// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1);
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
let ndims = match model {
M2_BERT_80M_8K_RETRIEVAL => 8192,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, ndims)
}

/// Create an embedding model with the given name and the number of dimensions in the embedding
/// generated by the model.
///
/// # Example
/// ```
/// use rig::providers::together_ai::{Client, self};
///
/// // Initialize the Together AI client
/// let together_ai = Client::new("your-together-ai-api-key");
///
/// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
/// ```
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, ndims)
}

/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::together_ai::{Client, self};
///
/// // Initialize the Together AI client
/// let together_ai = Client::new("your-together-ai-api-key");
///
/// let embeddings = together_ai.embeddings(together_ai::embedding::EMBEDDING_V1)
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings<D: Embed>(
&self,
model: &str,
) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}

/// Create a completion model with the given name.
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
/// # Example
/// ```
/// use rig::providers::together_ai::{Client, self};
///
/// // Initialize the Together AI client
/// let together_ai = Client::new("your-together-ai-api-key");
///
/// let agent = together_ai.agent(together_ai::completion::MODEL_NAME)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

pub mod together_ai_api_types {
use serde::Deserialize;

impl ApiErrorResponse {
pub fn message(&self) -> String {
format!("Code `{}`: {}", self.code, self.error)
}
}

#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub error: String,
pub code: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Error(ApiErrorResponse),
}
}
Loading

0 comments on commit f3529ac

Please sign in to comment.