From 3068982eed3e43237f893b936b77527d7bcbabb7 Mon Sep 17 00:00:00 2001 From: Vaibhav Satija Date: Fri, 7 Feb 2025 21:24:02 +0530 Subject: [PATCH 1/6] feat: Add Mira AI provider integration --- rig-core/src/providers/mira.rs | 220 +++++++++++++++++++++++++++++++++ rig-core/src/providers/mod.rs | 2 + 2 files changed, 222 insertions(+) create mode 100644 rig-core/src/providers/mira.rs diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs new file mode 100644 index 00000000..9c4e26db --- /dev/null +++ b/rig-core/src/providers/mira.rs @@ -0,0 +1,220 @@ +//! Mira API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::mira; +//! +//! let client = mira::Client::new("YOUR_API_KEY"); +//! +//! ``` +use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::string::FromUtf8Error; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum MiraError { + #[error("Invalid API key")] + InvalidApiKey, + #[error("API error: {0}")] + ApiError(u16), + #[error("Request error: {0}")] + RequestError(#[from] reqwest::Error), + #[error("UTF-8 error: {0}")] + Utf8Error(#[from] FromUtf8Error), +} + +#[derive(Debug, Serialize)] +pub struct AiRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +#[derive(Debug, Deserialize)] +pub struct ChatResponse { + pub choices: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct ChatChoice { + pub message: ChatMessage, +} + +#[derive(Debug, Deserialize)] +struct ModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct ModelInfo { + id: String, +} + +/// Client for interacting with the Mira API +pub struct Client { + base_url: String, + client: reqwest::Client, + headers: HeaderMap, +} + +impl Client { + /// Create a new Mira client with the given API key + pub fn new(api_key: impl AsRef) -> Result { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", api_key.as_ref())) + .map_err(|_| MiraError::InvalidApiKey)?, + ); + + Ok(Self { + base_url: "https://apis.mira.network".to_string(), + client: reqwest::Client::new(), + headers, + }) + } + + /// Create a new Mira client with a custom base URL and API key + pub fn new_with_base_url( + api_key: impl AsRef, + base_url: impl Into, + ) -> Result { + let mut client = Self::new(api_key)?; + client.base_url = base_url.into(); + Ok(client) + } + + /// Generate a chat completion + pub async fn generate(&self, request: AiRequest) -> Result { + let response = self + .client + .post(format!("{}/v1/chat/completions", self.base_url)) + .headers(self.headers.clone()) + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + return Err(MiraError::ApiError(response.status().as_u16())); + } + + Ok(response.json().await?) + } + + /// List available models + pub async fn list_models(&self) -> Result, MiraError> { + let response = self + .client + .get(format!("{}/v1/models", self.base_url)) + .headers(self.headers.clone()) + .send() + .await?; + + if !response.status().is_success() { + return Err(MiraError::ApiError(response.status().as_u16())); + } + + let models: ModelsResponse = response.json().await?; + Ok(models.data.into_iter().map(|model| model.id).collect()) + } + + /// Get user credits information + pub async fn get_user_credits(&self) -> Result { + let response = self + .client + .get(format!("{}/user-credits", self.base_url)) + .headers(self.headers.clone()) + .send() + .await?; + + if !response.status().is_success() { + return Err(MiraError::ApiError(response.status().as_u16())); + } + + Ok(response.json().await?) + } + + /// Get credits history + pub async fn get_credits_history(&self) -> Result, MiraError> { + let response = self + .client + .get(format!("{}/user-credits-history", self.base_url)) + .headers(self.headers.clone()) + .send() + .await?; + + if !response.status().is_success() { + return Err(MiraError::ApiError(response.status().as_u16())); + } + + Ok(response.json().await?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_generate() { + let client = Client::new("mira-api-key").unwrap(); + + // First get available models to ensure we use a valid one + let _models = client.list_models().await.unwrap(); + // println!("Available models: {:?}", models); + + let request = AiRequest { + model: "deepseek-r1".to_string(), + messages: vec![ChatMessage { + role: "user".to_string(), + content: "Hello, What can you do?".to_string(), + }], + temperature: Some(0.7), + max_tokens: Some(100), + stream: None, + }; + + let response = client.generate(request).await.unwrap(); + println!("Response: {:?}", response); + assert!(!response.choices.is_empty()); + } + + #[tokio::test] + async fn test_list_models() { + let client = Client::new("mira-api-key").unwrap(); + let models = client.list_models().await.unwrap(); + println!("Models: {:?}", models); + assert!(!models.is_empty()); + assert!(models.iter().any(|model| model == "gpt-4o" + || model == "deepseek-r1" + || model == "claude-3.5-sonnet")); + } + + #[tokio::test] + async fn test_get_user_credits() { + let client = Client::new("mira-api-key").unwrap(); + let credits = client.get_user_credits().await.unwrap(); + println!("Credits: {:?}", credits); + } + + #[tokio::test] + async fn test_get_credits_history() { + let client = Client::new("mira-api-key").unwrap(); + let history = client.get_credits_history().await.unwrap(); + println!("History: {:?}", history); + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 138f0a59..fbab51fd 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -10,6 +10,7 @@ //! - EternalAI //! - DeepSeek //! - Azure OpenAI +//! - Mira //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -51,6 +52,7 @@ pub mod deepseek; pub mod galadriel; pub mod gemini; pub mod hyperbolic; +pub mod mira; pub mod moonshot; pub mod openai; pub mod perplexity; From b51e5188574fe7c52096b3eb0fd8edfd2622282c Mon Sep 17 00:00:00 2001 From: Vaibhav Satija Date: Fri, 7 Feb 2025 21:25:16 +0530 Subject: [PATCH 2/6] example: Add Mira AI provider example - generation request - available model list - user credits balance --- rig-core/examples/agent_with_mira.rs | 41 ++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 rig-core/examples/agent_with_mira.rs diff --git a/rig-core/examples/agent_with_mira.rs b/rig-core/examples/agent_with_mira.rs new file mode 100644 index 00000000..f90846df --- /dev/null +++ b/rig-core/examples/agent_with_mira.rs @@ -0,0 +1,41 @@ +use rig::providers::mira::{self, AiRequest, ChatMessage}; +use std::error::Error; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize the Mira client with your API key + let client = mira::Client::new("mira-api-key")?; + + // Create a chat request + let request = AiRequest { + model: "claude-3.5-sonnet".to_string(), + messages: vec![ChatMessage { + role: "user".to_string(), + content: "What are the three laws of robotics?".to_string(), + }], + temperature: Some(0.7), + max_tokens: Some(500), + stream: None, + }; + + // Generate a response + let response = client.generate(request).await?; + + // Print the response + if let Some(choice) = response.choices.first() { + println!("Assistant: {}", choice.message.content); + } + + // List available models + println!("\nAvailable models:"); + let models = client.list_models().await?; + for model in models { + println!("- {}", model); + } + + // Get user credits + let credits = client.get_user_credits().await?; + println!("\nUser credits: {:?}", credits); + + Ok(()) +} From c07a17c3f5e5f33c1ee4a6b136c544f0633f3637 Mon Sep 17 00:00:00 2001 From: Vaibhav Satija Date: Tue, 18 Feb 2025 21:00:17 +0530 Subject: [PATCH 3/6] refactor: Improve Mira provider with agent and completion model support --- rig-core/examples/agent_with_mira.rs | 31 ++-- rig-core/src/providers/mira.rs | 259 ++++++++++++++++++++++++--- 2 files changed, 252 insertions(+), 38 deletions(-) diff --git a/rig-core/examples/agent_with_mira.rs b/rig-core/examples/agent_with_mira.rs index f90846df..6f19b5ff 100644 --- a/rig-core/examples/agent_with_mira.rs +++ b/rig-core/examples/agent_with_mira.rs @@ -1,30 +1,25 @@ -use rig::providers::mira::{self, AiRequest, ChatMessage}; +use rig::completion::Prompt; use std::error::Error; #[tokio::main] async fn main() -> Result<(), Box> { // Initialize the Mira client with your API key - let client = mira::Client::new("mira-api-key")?; + let client = rig::providers::mira::Client::new( + "mira-api-key", + )?; - // Create a chat request - let request = AiRequest { - model: "claude-3.5-sonnet".to_string(), - messages: vec![ChatMessage { - role: "user".to_string(), - content: "What are the three laws of robotics?".to_string(), - }], - temperature: Some(0.7), - max_tokens: Some(500), - stream: None, - }; + // Create an agent with the Mira model + let agent = client + .agent("claude-3.5-sonnet") + .preamble("You are a helpful AI assistant.") + .temperature(0.7) + .build(); - // Generate a response - let response = client.generate(request).await?; + // Send a message to the agent + let response = agent.prompt("What are the 7 wonders of the world?").await?; // Print the response - if let Some(choice) = response.choices.first() { - println!("Assistant: {}", choice.message.content); - } + println!("Assistant: {}", response); // List available models println!("\nAvailable models:"); diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 9c4e26db..410e6f95 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -7,7 +7,15 @@ //! let client = mira::Client::new("YOUR_API_KEY"); //! //! ``` +use crate::{ + agent::AgentBuilder, + completion::{self, CompletionError, CompletionRequest}, + extractor::ExtractorBuilder, + message::{self, Message}, + OneOrMany, +}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::string::FromUtf8Error; @@ -25,6 +33,18 @@ pub enum MiraError { Utf8Error(#[from] FromUtf8Error), } +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err, +} + #[derive(Debug, Serialize)] pub struct AiRequest { pub model: String, @@ -44,8 +64,9 @@ pub struct ChatMessage { } #[derive(Debug, Deserialize)] -pub struct ChatResponse { +pub struct CompletionResponse { pub choices: Vec, + pub usage: Option, } #[derive(Debug, Deserialize)] @@ -63,6 +84,7 @@ struct ModelInfo { id: String, } +#[derive(Clone)] /// Client for interacting with the Mira API pub struct Client { base_url: String, @@ -99,7 +121,7 @@ impl Client { } /// Generate a chat completion - pub async fn generate(&self, request: AiRequest) -> Result { + pub async fn generate(&self, request: AiRequest) -> Result { let response = self .client .post(format!("{}/v1/chat/completions", self.base_url)) @@ -163,19 +185,196 @@ impl Client { Ok(response.json().await?) } + + /// Create a completion model with the given name. + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.to_owned(), model) + } + + /// Create an agent builder with the given completion model. + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> { + // Convert messages to Mira format + let mut messages = Vec::new(); + + // Add preamble as system message if available + if let Some(preamble) = &completion_request.preamble { + messages.push(ChatMessage { + role: "system".to_string(), + content: preamble.to_string(), + }); + } + + // Add prompt first + let prompt = completion_request.prompt_with_context(); + let prompt_str = match prompt { + Message::User { content } => content + .into_iter() + .filter_map(|c| match c { + message::UserContent::Text(text) => Some(text.text), + _ => None, + }) + .collect::>() + .join(" "), + _ => String::new(), + }; + + if !prompt_str.is_empty() { + messages.push(ChatMessage { + role: "user".to_string(), + content: prompt_str, + }); + } + + // Add chat history + for message in completion_request.chat_history { + match message { + Message::User { content } => { + // Convert user content to string + let content_str = content + .into_iter() + .filter_map(|c| match c { + message::UserContent::Text(text) => Some(text.text), + _ => None, // Skip other content types + }) + .collect::>() + .join(" "); + + if !content_str.is_empty() { + messages.push(ChatMessage { + role: "user".to_string(), + content: content_str, + }); + } + } + Message::Assistant { content } => { + // Convert assistant content to string + let content_str = content + .into_iter() + .filter_map(|c| match c { + message::AssistantContent::Text(text) => Some(text.text), + _ => None, // Skip tool calls + }) + .collect::>() + .join(" "); + + if !content_str.is_empty() { + messages.push(ChatMessage { + role: "assistant".to_string(), + content: content_str, + }); + } + } + } + } + + let request = AiRequest { + model: self.model.clone(), + messages, + temperature: Some(completion_request.temperature.unwrap_or(0.7) as f32), + max_tokens: None, + stream: None, + }; + + let response = self + .client + .generate(request) + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + + response.try_into() + } +} + +impl From for CompletionError { + fn from(err: ApiErrorResponse) -> Self { + CompletionError::ProviderError(err.message) + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: CompletionResponse) -> Result { + let choice = response.choices.first().ok_or_else(|| { + CompletionError::ResponseError("Response contained no choices".to_owned()) + })?; + + let content = vec![completion::AssistantContent::text(&choice.message.content)]; + + let choice = OneOrMany::many(content).map_err(|_| { + CompletionError::ResponseError( + "Response contained no message or tool call (empty)".to_owned(), + ) + })?; + + Ok(completion::CompletionResponse { + choice, + raw_response: response, + }) + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub total_tokens: usize, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {} Total tokens: {}", + self.prompt_tokens, self.total_tokens + ) + } } #[cfg(test)] mod tests { use super::*; + use crate::message::{Text, UserContent}; #[tokio::test] async fn test_generate() { - let client = Client::new("mira-api-key").unwrap(); - - // First get available models to ensure we use a valid one - let _models = client.list_models().await.unwrap(); - // println!("Available models: {:?}", models); + let client = + Client::new("mira-api-key").unwrap(); let request = AiRequest { model: "deepseek-r1".to_string(), @@ -188,33 +387,53 @@ mod tests { stream: None, }; - let response = client.generate(request).await.unwrap(); - println!("Response: {:?}", response); - assert!(!response.choices.is_empty()); + let _response = client.generate(request).await.unwrap(); + } + + #[tokio::test] + async fn test_completion_model() { + let client = + Client::new("mira-api-key").unwrap(); + let model = client.completion_model("deepseek-r1"); + + let request = CompletionRequest { + prompt: Message::User { + content: OneOrMany::one(UserContent::Text(Text { + text: "Hello, what can you do?".to_string(), + })), + }, + temperature: Some(0.7), + preamble: None, + chat_history: Vec::new(), + additional_params: None, + documents: Vec::new(), + tools: Vec::new(), + max_tokens: None, + }; + + let _response = completion::CompletionModel::completion(&model, request) + .await + .unwrap(); } #[tokio::test] async fn test_list_models() { let client = Client::new("mira-api-key").unwrap(); let models = client.list_models().await.unwrap(); - println!("Models: {:?}", models); assert!(!models.is_empty()); - assert!(models.iter().any(|model| model == "gpt-4o" - || model == "deepseek-r1" - || model == "claude-3.5-sonnet")); } #[tokio::test] async fn test_get_user_credits() { - let client = Client::new("mira-api-key").unwrap(); - let credits = client.get_user_credits().await.unwrap(); - println!("Credits: {:?}", credits); + let client = + Client::new("mira-api-key").unwrap(); + let _credits = client.get_user_credits().await.unwrap(); } #[tokio::test] async fn test_get_credits_history() { - let client = Client::new("mira-api-key").unwrap(); - let history = client.get_credits_history().await.unwrap(); - println!("History: {:?}", history); + let client = + Client::new("mira-api-key").unwrap(); + let _history = client.get_credits_history().await.unwrap(); } } From a4a52667a3cacd0ac7e5cf7fcaf8641e47dd9290 Mon Sep 17 00:00:00 2001 From: Vaibhav Satija Date: Sun, 2 Mar 2025 13:33:27 +0530 Subject: [PATCH 4/6] refactor: Mira Integration - Addressed reviewer comments - updated tests --- rig-core/examples/agent_with_mira.rs | 152 +++++- rig-core/src/providers/mira.rs | 733 ++++++++++++++++++++------- 2 files changed, 675 insertions(+), 210 deletions(-) diff --git a/rig-core/examples/agent_with_mira.rs b/rig-core/examples/agent_with_mira.rs index 6f19b5ff..1c97aac2 100644 --- a/rig-core/examples/agent_with_mira.rs +++ b/rig-core/examples/agent_with_mira.rs @@ -1,36 +1,146 @@ -use rig::completion::Prompt; -use std::error::Error; +use rig::{ + completion::{Prompt, ToolDefinition}, + providers, + tool::Tool, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; #[tokio::main] -async fn main() -> Result<(), Box> { - // Initialize the Mira client with your API key - let client = rig::providers::mira::Client::new( - "mira-api-key", - )?; +async fn main() -> Result<(), anyhow::Error> { + // Initialize logging + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .with_target(false) + .init(); - // Create an agent with the Mira model + // Initialize the Mira client using environment variables + let client = providers::mira::Client::from_env() + .map_err(|e| anyhow::anyhow!("Failed to initialize Mira client: {}", e))?; + + // Test API connection first by listing models + println!("\nTesting API connection by listing models..."); + match client.list_models().await { + Ok(models) => { + println!("Successfully connected to Mira API!"); + println!("Available models:"); + for model in models { + println!("- {}", model); + } + println!("\nProceeding with chat completion...\n"); + } + Err(e) => { + return Err(anyhow::anyhow!("Failed to connect to Mira API: {}. Please verify your API key and network connection.", e)); + } + } + + // Create a basic agent for general conversation let agent = client - .agent("claude-3.5-sonnet") + .agent("gpt-4o") .preamble("You are a helpful AI assistant.") .temperature(0.7) .build(); - // Send a message to the agent + // Send a message and get response let response = agent.prompt("What are the 7 wonders of the world?").await?; + println!("Basic Agent Response: {}", response); + + // Create a calculator agent with tools + let calculator_agent = client + .agent("claude-3.5-sonnet") + .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + // Test the calculator agent + println!("\nTesting Calculator Agent:"); + println!( + "Mira Calculator Agent: {}", + calculator_agent.prompt("Calculate 15 - 7").await? + ); + + Ok(()) +} - // Print the response - println!("Assistant: {}", response); +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} - // List available models - println!("\nAvailable models:"); - let models = client.list_models().await?; - for model in models { - println!("- {}", model); +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + }), + } } - // Get user credits - let credits = client.get_user_credits().await?; - println!("\nUser credits: {:?}", credits); + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} - Ok(()) +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } } diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 410e6f95..979f4dd7 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -11,15 +11,15 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, extractor::ExtractorBuilder, - message::{self, Message}, + message::{self, AssistantContent, Message, UserContent}, OneOrMany, }; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use serde_json::Value; use std::string::FromUtf8Error; use thiserror::Error; +use tracing; #[derive(Debug, Error)] pub enum MiraError { @@ -31,6 +31,8 @@ pub enum MiraError { RequestError(#[from] reqwest::Error), #[error("UTF-8 error: {0}")] Utf8Error(#[from] FromUtf8Error), + #[error("JSON error: {0}")] + JsonError(#[from] serde_json::Error), } #[derive(Debug, Deserialize)] @@ -39,39 +41,65 @@ struct ApiErrorResponse { } #[derive(Debug, Deserialize)] -#[serde(untagged)] -enum ApiResponse { - Ok(T), - Err, +struct RawMessage { + role: String, + content: String, } -#[derive(Debug, Serialize)] -pub struct AiRequest { - pub model: String, - pub messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, -} +const MIRA_API_BASE_URL: &str = "https://api.mira.network"; + +impl TryFrom for message::Message { + type Error = CompletionError; -#[derive(Debug, Serialize, Deserialize)] -pub struct ChatMessage { - pub role: String, - pub content: String, + fn try_from(raw: RawMessage) -> Result { + match raw.role.as_str() { + "user" => Ok(message::Message::User { + content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })), + }), + "assistant" => Ok(message::Message::Assistant { + content: OneOrMany::one(AssistantContent::Text(message::Text { + text: raw.content, + })), + }), + _ => Err(CompletionError::ResponseError(format!( + "Unsupported message role: {}", + raw.role + ))), + } + } } #[derive(Debug, Deserialize)] -pub struct CompletionResponse { - pub choices: Vec, - pub usage: Option, +#[serde(untagged)] +pub enum CompletionResponse { + Structured { + id: String, + object: String, + created: u64, + model: String, + choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + usage: Option, + }, + Simple(String), } #[derive(Debug, Deserialize)] pub struct ChatChoice { - pub message: ChatMessage, + #[serde(deserialize_with = "deserialize_message")] + pub message: message::Message, + #[serde(default)] + pub finish_reason: Option, + #[serde(default)] + pub index: Option, +} + +fn deserialize_message<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let raw = RawMessage::deserialize(deserializer)?; + message::Message::try_from(raw).map_err(serde::de::Error::custom) } #[derive(Debug, Deserialize)] @@ -94,25 +122,44 @@ pub struct Client { impl Client { /// Create a new Mira client with the given API key - pub fn new(api_key: impl AsRef) -> Result { + pub fn new(api_key: &str) -> Result { let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); headers.insert( AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {}", api_key.as_ref())) + HeaderValue::from_str(&format!("Bearer {}", api_key)) .map_err(|_| MiraError::InvalidApiKey)?, ); + headers.insert( + reqwest::header::ACCEPT, + HeaderValue::from_static("application/json"), + ); + headers.insert( + reqwest::header::USER_AGENT, + HeaderValue::from_static("rig-client/1.0"), + ); Ok(Self { - base_url: "https://apis.mira.network".to_string(), - client: reqwest::Client::new(), + base_url: MIRA_API_BASE_URL.to_string(), + client: reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .connect_timeout(std::time::Duration::from_secs(30)) + .build() + .expect("Failed to build HTTP client"), headers, }) } + /// Create a new Mira client from the `MIRA_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Result { + let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set"); + Self::new(&api_key) + } + /// Create a new Mira client with a custom base URL and API key pub fn new_with_base_url( - api_key: impl AsRef, + api_key: &str, base_url: impl Into, ) -> Result { let mut client = Self::new(api_key)?; @@ -121,69 +168,120 @@ impl Client { } /// Generate a chat completion - pub async fn generate(&self, request: AiRequest) -> Result { - let response = self - .client - .post(format!("{}/v1/chat/completions", self.base_url)) - .headers(self.headers.clone()) - .json(&request) - .send() - .await?; + pub async fn generate( + &self, + model: &str, + request: CompletionRequest, + ) -> Result { + let mut messages = Vec::new(); - if !response.status().is_success() { - return Err(MiraError::ApiError(response.status().as_u16())); - } + // Add prompt first + let prompt_text = match &request.prompt { + Message::User { content } => content + .iter() + .map(|c| match c { + UserContent::Text(text) => &text.text, + _ => "", + }) + .collect::>() + .join(" "), + _ => return Err(MiraError::ApiError(422)), + }; - Ok(response.json().await?) - } + messages.push(serde_json::json!({ + "role": "user", + "content": prompt_text + })); - /// List available models - pub async fn list_models(&self) -> Result, MiraError> { - let response = self - .client - .get(format!("{}/v1/models", self.base_url)) - .headers(self.headers.clone()) - .send() - .await?; - - if !response.status().is_success() { - return Err(MiraError::ApiError(response.status().as_u16())); + // Then add chat history + for msg in request.chat_history { + let (role, content) = match msg { + Message::User { content } => { + let text = content + .iter() + .map(|c| match c { + UserContent::Text(text) => &text.text, + _ => "", + }) + .collect::>() + .join(" "); + ("user", text) + } + Message::Assistant { content } => { + let text = content + .iter() + .map(|c| match c { + AssistantContent::Text(text) => &text.text, + _ => "", + }) + .collect::>() + .join(" "); + ("assistant", text) + } + }; + messages.push(serde_json::json!({ + "role": role, + "content": content + })); } - let models: ModelsResponse = response.json().await?; - Ok(models.data.into_iter().map(|model| model.id).collect()) - } + let mira_request = serde_json::json!({ + "model": model, + "messages": messages, + "temperature": request.temperature.map(|t| t as f32), + "max_tokens": request.max_tokens.map(|t| t as u32), + "stream": false + }); - /// Get user credits information - pub async fn get_user_credits(&self) -> Result { let response = self .client - .get(format!("{}/user-credits", self.base_url)) + .post(format!("{}/v1/chat/completions", self.base_url)) .headers(self.headers.clone()) + .json(&mira_request) .send() .await?; if !response.status().is_success() { - return Err(MiraError::ApiError(response.status().as_u16())); + let status = response.status(); + return Err(MiraError::ApiError(status.as_u16())); } - Ok(response.json().await?) + // Parse the response + let response_text = response.text().await?; + let parsed_response: CompletionResponse = serde_json::from_str(&response_text)?; + Ok(parsed_response) } - /// Get credits history - pub async fn get_credits_history(&self) -> Result, MiraError> { + /// List available models + pub async fn list_models(&self) -> Result, MiraError> { + let url = format!("{}/v1/models", self.base_url); + tracing::debug!("Requesting models from: {}", url); + tracing::debug!("Headers: {:?}", self.headers); + let response = self .client - .get(format!("{}/user-credits-history", self.base_url)) + .get(&url) .headers(self.headers.clone()) .send() .await?; - if !response.status().is_success() { - return Err(MiraError::ApiError(response.status().as_u16())); + let status = response.status(); + + if !status.is_success() { + // Log the error text but don't store it in an unused variable + let _error_text = response.text().await.unwrap_or_default(); + tracing::error!("Error response: {}", _error_text); + return Err(MiraError::ApiError(status.as_u16())); } - Ok(response.json().await?) + let response_text = response.text().await?; + + let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| { + tracing::error!("Failed to parse response: {}", e); + MiraError::JsonError(e) + })?; + + Ok(models.data.into_iter().map(|model| model.id).collect()) } /// Create a completion model with the given name. @@ -229,91 +327,103 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - // Convert messages to Mira format + if !completion_request.tools.is_empty() { + tracing::warn!(target: "rig", + "Tool calls are not supported by the Mira provider. {} tools will be ignored.", + completion_request.tools.len() + ); + } + let mut messages = Vec::new(); - // Add preamble as system message if available + // Add preamble as user message if available if let Some(preamble) = &completion_request.preamble { - messages.push(ChatMessage { - role: "system".to_string(), - content: preamble.to_string(), - }); + messages.push(serde_json::json!({ + "role": "user", + "content": preamble.to_string() + })); } - // Add prompt first - let prompt = completion_request.prompt_with_context(); - let prompt_str = match prompt { - Message::User { content } => content - .into_iter() - .filter_map(|c| match c { - message::UserContent::Text(text) => Some(text.text), - _ => None, + // Add prompt + messages.push(match &completion_request.prompt { + Message::User { content } => { + let text = content + .iter() + .map(|c| match c { + UserContent::Text(text) => &text.text, + _ => "", + }) + .collect::>() + .join(" "); + serde_json::json!({ + "role": "user", + "content": text }) - .collect::>() - .join(" "), - _ => String::new(), - }; - - if !prompt_str.is_empty() { - messages.push(ChatMessage { - role: "user".to_string(), - content: prompt_str, - }); - } + } + _ => unreachable!(), + }); // Add chat history - for message in completion_request.chat_history { - match message { + for msg in completion_request.chat_history { + let (role, content) = match msg { Message::User { content } => { - // Convert user content to string - let content_str = content - .into_iter() - .filter_map(|c| match c { - message::UserContent::Text(text) => Some(text.text), - _ => None, // Skip other content types + let text = content + .iter() + .map(|c| match c { + UserContent::Text(text) => &text.text, + _ => "", }) .collect::>() .join(" "); - - if !content_str.is_empty() { - messages.push(ChatMessage { - role: "user".to_string(), - content: content_str, - }); - } + ("user", text) } Message::Assistant { content } => { - // Convert assistant content to string - let content_str = content - .into_iter() - .filter_map(|c| match c { - message::AssistantContent::Text(text) => Some(text.text), - _ => None, // Skip tool calls + let text = content + .iter() + .map(|c| match c { + AssistantContent::Text(text) => &text.text, + _ => "", }) .collect::>() .join(" "); - - if !content_str.is_empty() { - messages.push(ChatMessage { - role: "assistant".to_string(), - content: content_str, - }); - } + ("assistant", text) } - } + }; + messages.push(serde_json::json!({ + "role": role, + "content": content + })); } - let request = AiRequest { - model: self.model.clone(), - messages, - temperature: Some(completion_request.temperature.unwrap_or(0.7) as f32), - max_tokens: None, - stream: None, - }; + let mira_request = serde_json::json!({ + "model": self.model, + "messages": messages, + "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7), + "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100), + "stream": false + }); let response = self .client - .generate(request) + .client + .post(format!("{}/v1/chat/completions", self.client.base_url)) + .headers(self.client.headers.clone()) + .json(&mira_request) + .send() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status().as_u16(); + let error_text = response.text().await.unwrap_or_default(); + return Err(CompletionError::ProviderError(format!( + "API error: {} - {}", + status, error_text + ))); + } + + let response: CompletionResponse = response + .json() .await .map_err(|e| CompletionError::ProviderError(e.to_string()))?; @@ -331,11 +441,50 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; - - let content = vec![completion::AssistantContent::text(&choice.message.content)]; + let content = match &response { + CompletionResponse::Structured { choices, .. } => { + let choice = choices.first().ok_or_else(|| { + CompletionError::ResponseError("Response contained no choices".to_owned()) + })?; + + match &choice.message { + Message::Assistant { content } => { + if content.is_empty() { + return Err(CompletionError::ResponseError( + "Response contained empty content".to_owned(), + )); + } + + // Log warning for unsupported content types + for c in content.iter() { + if !matches!(c, AssistantContent::Text(_)) { + tracing::warn!(target: "rig", + "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c + ); + } + } + + content.iter().map(|c| { + match c { + AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)), + other => Err(CompletionError::ResponseError( + format!("Unsupported content type: {:?}. The Mira provider currently only supports text content", other) + )) + } + }).collect::, _>>()? + } + Message::User { .. } => { + tracing::warn!(target: "rig", "Received user message in response where assistant message was expected"); + return Err(CompletionError::ResponseError( + "Received user message in response where assistant message was expected".to_owned() + )); + } + } + } + CompletionResponse::Simple(text) => { + vec![completion::AssistantContent::text(text)] + } + }; let choice = OneOrMany::many(content).map_err(|_| { CompletionError::ResponseError( @@ -366,74 +515,280 @@ impl std::fmt::Display for Usage { } } +impl From for serde_json::Value { + fn from(msg: Message) -> Self { + match msg { + Message::User { content } => { + let text = content + .iter() + .map(|c| match c { + UserContent::Text(text) => &text.text, + _ => "", + }) + .collect::>() + .join(" "); + serde_json::json!({ + "role": "user", + "content": text + }) + } + Message::Assistant { content } => { + let text = content + .iter() + .map(|c| match c { + AssistantContent::Text(text) => &text.text, + _ => "", + }) + .collect::>() + .join(" "); + serde_json::json!({ + "role": "assistant", + "content": text + }) + } + } + } +} + +impl TryFrom for Message { + type Error = CompletionError; + + fn try_from(value: serde_json::Value) -> Result { + let role = value["role"].as_str().ok_or_else(|| { + CompletionError::ResponseError("Message missing role field".to_owned()) + })?; + + // Handle both string and array content formats + let content = match value.get("content") { + Some(content) => match content { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Array(arr) => arr + .iter() + .filter_map(|c| { + c.get("text") + .and_then(|t| t.as_str()) + .map(|text| text.to_string()) + }) + .collect::>() + .join(" "), + _ => { + return Err(CompletionError::ResponseError( + "Message content must be string or array".to_owned(), + )) + } + }, + None => { + return Err(CompletionError::ResponseError( + "Message missing content field".to_owned(), + )) + } + }; + + match role { + "user" => Ok(Message::User { + content: OneOrMany::one(UserContent::Text(message::Text { text: content })), + }), + "assistant" => Ok(Message::Assistant { + content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })), + }), + _ => Err(CompletionError::ResponseError(format!( + "Unsupported message role: {}", + role + ))), + } + } +} + #[cfg(test)] mod tests { use super::*; - use crate::message::{Text, UserContent}; + use crate::message::UserContent; + use serde_json::json; + + #[test] + fn test_deserialize_message() { + // Test string content format + let assistant_message_json = json!({ + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }); + + let user_message_json = json!({ + "role": "user", + "content": "What can you help me with?" + }); + + // Test array content format + let assistant_message_array_json = json!({ + "role": "assistant", + "content": [{ + "type": "text", + "text": "Hello there, how may I assist you today?" + }] + }); + + let assistant_message = Message::try_from(assistant_message_json).unwrap(); + let user_message = Message::try_from(user_message_json).unwrap(); + let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap(); + + // Test string content format + match assistant_message { + Message::Assistant { content } => { + assert_eq!( + content.first(), + AssistantContent::Text(message::Text { + text: "Hello there, how may I assist you today?".to_string() + }) + ); + } + _ => panic!("Expected assistant message"), + } - #[tokio::test] - async fn test_generate() { - let client = - Client::new("mira-api-key").unwrap(); + match user_message { + Message::User { content } => { + assert_eq!( + content.first(), + UserContent::Text(message::Text { + text: "What can you help me with?".to_string() + }) + ); + } + _ => panic!("Expected user message"), + } - let request = AiRequest { - model: "deepseek-r1".to_string(), - messages: vec![ChatMessage { - role: "user".to_string(), - content: "Hello, What can you do?".to_string(), - }], - temperature: Some(0.7), - max_tokens: Some(100), - stream: None, + // Test array content format + match assistant_message_array { + Message::Assistant { content } => { + assert_eq!( + content.first(), + AssistantContent::Text(message::Text { + text: "Hello there, how may I assist you today?".to_string() + }) + ); + } + _ => panic!("Expected assistant message"), + } + } + + #[test] + fn test_message_conversion() { + // Test converting from our Message type to Mira's format and back + let original_message = message::Message::User { + content: OneOrMany::one(message::UserContent::text("Hello")), }; - let _response = client.generate(request).await.unwrap(); - } + // Convert to Mira format + let mira_value: serde_json::Value = original_message.clone().try_into().unwrap(); - #[tokio::test] - async fn test_completion_model() { - let client = - Client::new("mira-api-key").unwrap(); - let model = client.completion_model("deepseek-r1"); + // Convert back to our Message type + let converted_message: Message = mira_value.try_into().unwrap(); - let request = CompletionRequest { - prompt: Message::User { - content: OneOrMany::one(UserContent::Text(Text { - text: "Hello, what can you do?".to_string(), - })), - }, - temperature: Some(0.7), - preamble: None, - chat_history: Vec::new(), - additional_params: None, - documents: Vec::new(), - tools: Vec::new(), - max_tokens: None, - }; + // Convert back to original format + let final_message: message::Message = converted_message.try_into().unwrap(); - let _response = completion::CompletionModel::completion(&model, request) - .await - .unwrap(); + assert_eq!(original_message, final_message); } - #[tokio::test] - async fn test_list_models() { - let client = Client::new("mira-api-key").unwrap(); - let models = client.list_models().await.unwrap(); - assert!(!models.is_empty()); - } + #[test] + fn test_completion_response_deserialization() { + // Test structured response + let structured_json = json!({ + "id": "resp_123", + "object": "chat.completion", + "created": 1234567890, + "model": "deepseek-r1", + "choices": [{ + "message": { + "role": "assistant", + "content": "I can help you with various tasks." + }, + "finish_reason": "stop", + "index": 0 + }], + "usage": { + "prompt_tokens": 10, + "total_tokens": 20 + } + }); + + // Test simple response + let simple_json = json!("Simple response text"); + + // Try both formats + let structured: CompletionResponse = serde_json::from_value(structured_json).unwrap(); + let simple: CompletionResponse = serde_json::from_value(simple_json).unwrap(); + + match structured { + CompletionResponse::Structured { + id, + object, + created, + model, + choices, + usage, + } => { + assert_eq!(id, "resp_123"); + assert_eq!(object, "chat.completion"); + assert_eq!(created, 1234567890); + assert_eq!(model, "deepseek-r1"); + assert!(!choices.is_empty()); + assert!(usage.is_some()); + + let choice = &choices[0]; + match &choice.message { + Message::Assistant { content } => { + assert_eq!( + content.first(), + AssistantContent::Text(message::Text { + text: "I can help you with various tasks.".to_string() + }) + ); + } + _ => panic!("Expected assistant message"), + } + + assert_eq!(choice.finish_reason.as_deref(), Some("stop")); + assert_eq!(choice.index, Some(0)); + } + CompletionResponse::Simple(_) => panic!("Expected structured response"), + } - #[tokio::test] - async fn test_get_user_credits() { - let client = - Client::new("mira-api-key").unwrap(); - let _credits = client.get_user_credits().await.unwrap(); + match simple { + CompletionResponse::Simple(text) => { + assert_eq!(text, "Simple response text"); + } + CompletionResponse::Structured { .. } => panic!("Expected simple response"), + } } - #[tokio::test] - async fn test_get_credits_history() { - let client = - Client::new("mira-api-key").unwrap(); - let _history = client.get_credits_history().await.unwrap(); + #[test] + fn test_completion_response_conversion() { + let mira_response = CompletionResponse::Structured { + id: "resp_123".to_string(), + object: "chat.completion".to_string(), + created: 1234567890, + model: "deepseek-r1".to_string(), + choices: vec![ChatChoice { + message: Message::Assistant { + content: OneOrMany::one(AssistantContent::Text(message::Text { + text: "Test response".to_string(), + })), + }, + finish_reason: Some("stop".to_string()), + index: Some(0), + }], + usage: Some(Usage { + prompt_tokens: 10, + total_tokens: 20, + }), + }; + + let completion_response: completion::CompletionResponse = + mira_response.try_into().unwrap(); + + assert_eq!( + completion_response.choice.first(), + completion::AssistantContent::text("Test response") + ); } } From ebf15537e38e060e1aa5716aca70942081be0142 Mon Sep 17 00:00:00 2001 From: Vaibhav Satija Date: Wed, 5 Mar 2025 12:46:42 +0530 Subject: [PATCH 5/6] refactor: Remove redundant timeout setting in Mira client --- rig-core/src/providers/mira.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 979f4dd7..0364c36b 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -142,7 +142,6 @@ impl Client { Ok(Self { base_url: MIRA_API_BASE_URL.to_string(), client: reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(60)) .connect_timeout(std::time::Duration::from_secs(30)) .build() .expect("Failed to build HTTP client"), From 8d4cafcdf02ba9ba297ba770dfbae3085a569293 Mon Sep 17 00:00:00 2001 From: Vaibhav Satija Date: Fri, 7 Mar 2025 16:44:01 +0530 Subject: [PATCH 6/6] refactor: Improve Mira provider message handling, serialization and resolve comments --- rig-core/src/providers/mira.rs | 203 +++------------------------------ 1 file changed, 18 insertions(+), 185 deletions(-) diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 0364c36b..dc97e914 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -40,10 +40,10 @@ struct ApiErrorResponse { message: String, } -#[derive(Debug, Deserialize)] -struct RawMessage { - role: String, - content: String, +#[derive(Debug, Deserialize, Clone)] +pub struct RawMessage { + pub role: String, + pub content: String, } const MIRA_API_BASE_URL: &str = "https://api.mira.network"; @@ -86,22 +86,13 @@ pub enum CompletionResponse { #[derive(Debug, Deserialize)] pub struct ChatChoice { - #[serde(deserialize_with = "deserialize_message")] - pub message: message::Message, + pub message: RawMessage, #[serde(default)] pub finish_reason: Option, #[serde(default)] pub index: Option, } -fn deserialize_message<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let raw = RawMessage::deserialize(deserializer)?; - message::Message::try_from(raw).map_err(serde::de::Error::custom) -} - #[derive(Debug, Deserialize)] struct ModelsResponse { data: Vec, @@ -142,7 +133,6 @@ impl Client { Ok(Self { base_url: MIRA_API_BASE_URL.to_string(), client: reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(30)) .build() .expect("Failed to build HTTP client"), headers, @@ -166,96 +156,9 @@ impl Client { Ok(client) } - /// Generate a chat completion - pub async fn generate( - &self, - model: &str, - request: CompletionRequest, - ) -> Result { - let mut messages = Vec::new(); - - // Add prompt first - let prompt_text = match &request.prompt { - Message::User { content } => content - .iter() - .map(|c| match c { - UserContent::Text(text) => &text.text, - _ => "", - }) - .collect::>() - .join(" "), - _ => return Err(MiraError::ApiError(422)), - }; - - messages.push(serde_json::json!({ - "role": "user", - "content": prompt_text - })); - - // Then add chat history - for msg in request.chat_history { - let (role, content) = match msg { - Message::User { content } => { - let text = content - .iter() - .map(|c| match c { - UserContent::Text(text) => &text.text, - _ => "", - }) - .collect::>() - .join(" "); - ("user", text) - } - Message::Assistant { content } => { - let text = content - .iter() - .map(|c| match c { - AssistantContent::Text(text) => &text.text, - _ => "", - }) - .collect::>() - .join(" "); - ("assistant", text) - } - }; - messages.push(serde_json::json!({ - "role": role, - "content": content - })); - } - - let mira_request = serde_json::json!({ - "model": model, - "messages": messages, - "temperature": request.temperature.map(|t| t as f32), - "max_tokens": request.max_tokens.map(|t| t as u32), - "stream": false - }); - - let response = self - .client - .post(format!("{}/v1/chat/completions", self.base_url)) - .headers(self.headers.clone()) - .json(&mira_request) - .send() - .await?; - - if !response.status().is_success() { - let status = response.status(); - return Err(MiraError::ApiError(status.as_u16())); - } - - // Parse the response - let response_text = response.text().await?; - let parsed_response: CompletionResponse = serde_json::from_str(&response_text)?; - Ok(parsed_response) - } - /// List available models pub async fn list_models(&self) -> Result, MiraError> { let url = format!("{}/v1/models", self.base_url); - tracing::debug!("Requesting models from: {}", url); - tracing::debug!("Headers: {:?}", self.headers); let response = self .client @@ -353,7 +256,7 @@ impl completion::CompletionModel for CompletionModel { _ => "", }) .collect::>() - .join(" "); + .join("\n"); serde_json::json!({ "role": "user", "content": text @@ -373,7 +276,7 @@ impl completion::CompletionModel for CompletionModel { _ => "", }) .collect::>() - .join(" "); + .join("\n"); ("user", text) } Message::Assistant { content } => { @@ -384,7 +287,7 @@ impl completion::CompletionModel for CompletionModel { _ => "", }) .collect::>() - .join(" "); + .join("\n"); ("assistant", text) } }; @@ -446,7 +349,10 @@ impl TryFrom for completion::CompletionResponse { if content.is_empty() { return Err(CompletionError::ResponseError( @@ -525,7 +431,7 @@ impl From for serde_json::Value { _ => "", }) .collect::>() - .join(" "); + .join("\n"); serde_json::json!({ "role": "user", "content": text @@ -539,7 +445,7 @@ impl From for serde_json::Value { _ => "", }) .collect::>() - .join(" "); + .join("\n"); serde_json::json!({ "role": "assistant", "content": text @@ -569,7 +475,7 @@ impl TryFrom for Message { .map(|text| text.to_string()) }) .collect::>() - .join(" "), + .join("\n"), _ => { return Err(CompletionError::ResponseError( "Message content must be string or array".to_owned(), @@ -688,78 +594,6 @@ mod tests { assert_eq!(original_message, final_message); } - #[test] - fn test_completion_response_deserialization() { - // Test structured response - let structured_json = json!({ - "id": "resp_123", - "object": "chat.completion", - "created": 1234567890, - "model": "deepseek-r1", - "choices": [{ - "message": { - "role": "assistant", - "content": "I can help you with various tasks." - }, - "finish_reason": "stop", - "index": 0 - }], - "usage": { - "prompt_tokens": 10, - "total_tokens": 20 - } - }); - - // Test simple response - let simple_json = json!("Simple response text"); - - // Try both formats - let structured: CompletionResponse = serde_json::from_value(structured_json).unwrap(); - let simple: CompletionResponse = serde_json::from_value(simple_json).unwrap(); - - match structured { - CompletionResponse::Structured { - id, - object, - created, - model, - choices, - usage, - } => { - assert_eq!(id, "resp_123"); - assert_eq!(object, "chat.completion"); - assert_eq!(created, 1234567890); - assert_eq!(model, "deepseek-r1"); - assert!(!choices.is_empty()); - assert!(usage.is_some()); - - let choice = &choices[0]; - match &choice.message { - Message::Assistant { content } => { - assert_eq!( - content.first(), - AssistantContent::Text(message::Text { - text: "I can help you with various tasks.".to_string() - }) - ); - } - _ => panic!("Expected assistant message"), - } - - assert_eq!(choice.finish_reason.as_deref(), Some("stop")); - assert_eq!(choice.index, Some(0)); - } - CompletionResponse::Simple(_) => panic!("Expected structured response"), - } - - match simple { - CompletionResponse::Simple(text) => { - assert_eq!(text, "Simple response text"); - } - CompletionResponse::Structured { .. } => panic!("Expected simple response"), - } - } - #[test] fn test_completion_response_conversion() { let mira_response = CompletionResponse::Structured { @@ -768,10 +602,9 @@ mod tests { created: 1234567890, model: "deepseek-r1".to_string(), choices: vec![ChatChoice { - message: Message::Assistant { - content: OneOrMany::one(AssistantContent::Text(message::Text { - text: "Test response".to_string(), - })), + message: RawMessage { + role: "assistant".to_string(), + content: "Test response".to_string(), }, finish_reason: Some("stop".to_string()), index: Some(0),