diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs new file mode 100644 index 00000000..1e0b9c48 --- /dev/null +++ b/rig-core/src/providers/azure.rs @@ -0,0 +1,592 @@ +//! Azure OpenAI API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::azure; +//! +//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); +//! +//! let gpt4o = client.completion_model(azure::GPT_4O); +//! ``` +use crate::{ + agent::AgentBuilder, + completion::{self, CompletionError, CompletionRequest}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, + extractor::ExtractorBuilder, + json_utils, Embed, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +// ================================================================ +// Main Azure OpenAI Client +// ================================================================ + +#[derive(Clone)] +pub struct Client { + api_version: String, + azure_endpoint: String, + http_client: reqwest::Client, +} + +impl Client { + /// Creates a new Azure OpenAI client. + /// + /// # Arguments + /// + /// * `api_key` - Azure OpenAI API key required for authentication + /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview) + /// * `azure_endpoint` - Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com + pub fn new(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self { + Self { + api_version: api_version.to_string(), + azure_endpoint: azure_endpoint.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("api-key", api_key.parse().expect("API key should parse")); + headers + }) + .build() + .expect("Azure OpenAI reqwest client should build"), + } + } + + /// Create a new Azure OpenAI client from the `AZURE_API_KEY`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables. + /// Panics if these environment variables are not set. + pub fn from_env() -> Self { + let api_key = std::env::var("AZURE_API_KEY").expect("AZURE_API_KEY not set"); + let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set"); + let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set"); + Self::new(&api_key, &api_version, &azure_endpoint) + } + + fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder { + let url = format!( + "{}/openai/deployments/{}/embeddings?api-version={}", + self.azure_endpoint, deployment_id, self.api_version + ) + .replace("//", "/"); + self.http_client.post(url) + } + + fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder { + let url = format!( + "{}/openai/deployments/{}/chat/completions?api-version={}", + self.azure_endpoint, deployment_id, self.api_version + ) + .replace("//", "/"); + self.http_client.post(url) + } + + /// Create an embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model is not known. + /// If this is the case, it's better to use function `embedding_model_with_ndims` + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE); + /// ``` + pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + let ndims = match model { + TEXT_EMBEDDING_3_LARGE => 3072, + TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536, + _ => 0, + }; + EmbeddingModel::new(self.clone(), model, ndims) + } + + /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model. + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072); + /// ``` + pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, ndims) + } + + /// Create an embedding builder with the given embedding model. + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let embeddings = azure.embeddings(azure::TEXT_EMBEDDING_3_LARGE) + /// .simple_document("doc0", "Hello, world!") + /// .simple_document("doc1", "Goodbye, world!") + /// .build() + /// .await + /// .expect("Failed to embed documents"); + /// ``` + pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + EmbeddingsBuilder::new(self.embedding_model(model)) + } + + /// Create a completion model with the given name. + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let gpt4 = azure.completion_model(azure::GPT_4); + /// ``` + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let agent = azure.agent(azure::GPT_4) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +// ================================================================ +// Azure OpenAI Embedding API +// ================================================================ +/// `text-embedding-3-large` embedding model +pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; +/// `text-embedding-3-small` embedding model +pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; +/// `text-embedding-ada-002` embedding model +pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002"; + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + pub object: String, + pub data: Vec, + pub model: String, + pub usage: Usage, +} + +impl From for EmbeddingError { + fn from(err: ApiErrorResponse) -> Self { + EmbeddingError::ProviderError(err.message) + } +} + +impl From> for Result { + fn from(value: ApiResponse) -> Self { + match value { + ApiResponse::Ok(response) => Ok(response), + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingData { + pub object: String, + pub embedding: Vec, + pub index: usize, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub total_tokens: usize, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {} Total tokens: {}", + self.prompt_tokens, self.total_tokens + ) + } +} + +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + pub model: String, + ndims: usize, +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + self.ndims + } + + #[cfg_attr(feature = "worker", worker::send)] + async fn embed_texts( + &self, + documents: impl IntoIterator, + ) -> Result, EmbeddingError> { + let documents = documents.into_iter().collect::>(); + + let response = self + .client + .post_embedding(&self.model) + .json(&json!({ + "input": documents, + })) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Azure embedding token usage: {}", + response.usage + ); + + if response.data.len() != documents.len() { + return Err(EmbeddingError::ResponseError( + "Response data length does not match input length".into(), + )); + } + + Ok(response + .data + .into_iter() + .zip(documents.into_iter()) + .map(|(embedding, document)| embeddings::Embedding { + document, + vec: embedding.embedding, + }) + .collect()) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } else { + Err(EmbeddingError::ProviderError(response.text().await?)) + } + } +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { + Self { + client, + model: model.to_string(), + ndims, + } + } +} + +// ================================================================ +// Azure OpenAI Completion API +// ================================================================ +/// `o1` completion model +pub const O1: &str = "o1"; +/// `o1-preview` completion model +pub const O1_PREVIEW: &str = "o1-preview"; +/// `o1-mini` completion model +pub const O1_MINI: &str = "o1-mini"; +/// `gpt-4o` completion model +pub const GPT_4O: &str = "gpt-4o"; +/// `gpt-4o-mini` completion model +pub const GPT_4O_MINI: &str = "gpt-4o-mini"; +/// `gpt-4o-realtime-preview` completion model +pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview"; +/// `gpt-4-turbo` completion model +pub const GPT_4_TURBO: &str = "gpt-4"; +/// `gpt-4` completion model +pub const GPT_4: &str = "gpt-4"; +/// `gpt-4-32k` completion model +pub const GPT_4_32K: &str = "gpt-4-32k"; +/// `gpt-4-32k` completion model +pub const GPT_4_32K_0613: &str = "gpt-4-32k"; +/// `gpt-3.5-turbo` completion model +pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; +/// `gpt-3.5-turbo-instruct` completion model +pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; +/// `gpt-3.5-turbo-16k` completion model +pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k"; + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: Option, + pub choices: Vec, + pub usage: Option, +} + +impl From for CompletionError { + fn from(err: ApiErrorResponse) -> Self { + CompletionError::ProviderError(err.message) + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(value: CompletionResponse) -> std::prelude::v1::Result { + match value.choices.as_slice() { + [Choice { + message: + Message { + tool_calls: Some(calls), + .. + }, + .. + }, ..] + if !calls.is_empty() => + { + let call = calls.first().unwrap(); + + Ok(completion::CompletionResponse { + choice: completion::ModelChoice::ToolCall( + call.function.name.clone(), + "".to_owned(), + serde_json::from_str(&call.function.arguments)?, + ), + raw_response: value, + }) + } + [Choice { + message: + Message { + content: Some(content), + .. + }, + .. + }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(content.to_string()), + raw_response: value, + }), + _ => Err(CompletionError::ResponseError( + "Response did not contain a message or tool call".into(), + )), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct Choice { + pub index: usize, + pub message: Message, + pub logprobs: Option, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize)] +pub struct Message { + pub role: String, + pub content: Option, + pub tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct ToolCall { + pub id: String, + pub r#type: String, + pub function: Function, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolDefinition { + pub r#type: String, + pub function: completion::ToolDefinition, +} + +impl From for ToolDefinition { + fn from(tool: completion::ToolDefinition) -> Self { + Self { + r#type: "function".into(), + function: tool, + } + } +} + +#[derive(Debug, Deserialize)] +pub struct Function { + pub name: String, + pub arguments: String, +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: gpt-4o-mini) + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + mut completion_request: CompletionRequest, + ) -> Result, CompletionError> { + // Add preamble to chat history (if available) + // NOTE: Azure o1-preview models does not support system messages + let mut full_history = if let Some(preamble) = &completion_request.preamble { + vec![completion::Message { + role: "system".into(), + content: preamble.clone(), + }] + } else { + vec![] + }; + + // Extend existing chat history + full_history.append(&mut completion_request.chat_history); + + // Add context documents to chat history + let prompt_with_context = completion_request.prompt_with_context(); + + // Add context documents to chat history + full_history.push(completion::Message { + role: "user".into(), + content: prompt_with_context, + }); + + let request = if completion_request.tools.is_empty() { + json!({ + "messages": full_history, + "temperature": completion_request.temperature, + }) + } else { + json!({ + "messages": full_history, + "temperature": completion_request.temperature, + "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), + "tool_choice": "auto", + }) + }; + + let response = self + .client + .post_chat_completion(&self.model) + .json( + &if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }, + ) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Azure completion token usage: {:?}", + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) + ); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) + } + } +} + +#[cfg(test)] +mod azure_tests { + use super::*; + + use crate::completion::CompletionModel; + use crate::embeddings::EmbeddingModel; + + #[tokio::test] + #[ignore] + async fn test_azure_embedding() { + let _ = tracing_subscriber::fmt::try_init(); + + let client = Client::from_env(); + let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL); + let embeddings = model + .embed_texts(vec!["Hello, world!".to_string()]) + .await + .unwrap(); + + tracing::info!("Azure embedding: {:?}", embeddings); + } + + #[tokio::test] + #[ignore] + async fn test_azure_completion() { + let _ = tracing_subscriber::fmt::try_init(); + + let client = Client::from_env(); + let model = client.completion_model(GPT_4O_MINI); + let completion = model + .completion(CompletionRequest { + preamble: Some("You are a helpful assistant.".to_string()), + chat_history: vec![], + prompt: "Hello, world!".to_string(), + documents: vec![], + max_tokens: Some(100), + temperature: Some(0.0), + tools: vec![], + additional_params: None, + }) + .await + .unwrap(); + + tracing::info!("Azure completion: {:?}", completion); + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 2ac67f9a..d0260000 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -9,6 +9,7 @@ //! - xAI //! - EternalAI //! - DeepSeek +//! - Azure OpenAI //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -44,6 +45,7 @@ //! Note: The example above uses the OpenAI provider client, but the same pattern can //! be used with the Cohere provider client. pub mod anthropic; +pub mod azure; pub mod cohere; pub mod deepseek; pub mod eternalai;