diff --git a/src/models/chat.rs b/src/models/chat.rs index 70fed08..59c4246 100644 --- a/src/models/chat.rs +++ b/src/models/chat.rs @@ -29,10 +29,24 @@ pub struct ChatCompletionRequest { pub user: Option, } +#[derive(Deserialize, Serialize, Clone)] +#[serde(untagged)] +pub enum ChatMessageContent { + String(String), + Array(Vec), +} + +#[derive(Deserialize, Serialize, Clone)] +pub struct ChatMessageContentPart { + #[serde(rename = "type")] + pub r#type: String, + pub text: String, +} + #[derive(Deserialize, Serialize, Clone)] pub struct ChatCompletionMessage { pub role: String, - pub content: String, + pub content: ChatMessageContent, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, } @@ -40,8 +54,10 @@ pub struct ChatCompletionMessage { #[derive(Deserialize, Serialize, Clone)] pub struct ChatCompletionResponse { pub id: String, - pub object: String, - pub created: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created: Option, pub model: String, pub choices: Vec, pub usage: Usage, diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index b19678c..8ba9d6d 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,14 +1,16 @@ use axum::async_trait; use axum::http::StatusCode; +use serde::{Deserialize, Serialize}; use super::provider::Provider; use crate::config::models::{ModelConfig, Provider as ProviderConfig}; -use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; -use crate::models::common::Usage; -use crate::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse}; -use crate::models::embeddings::{ - Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse, +use crate::models::chat::{ + ChatCompletionChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, + ChatMessageContentPart, }; +use crate::models::common::Usage; +use crate::models::completion::{CompletionRequest, CompletionResponse}; +use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; use reqwest::Client; pub struct AnthropicProvider { @@ -17,6 +19,27 @@ pub struct AnthropicProvider { http_client: Client, } +#[derive(Deserialize, Serialize, Clone)] +struct AnthropicContent { + pub text: String, + #[serde(rename = "type")] + pub r#type: String, +} + +#[derive(Deserialize, Serialize, Clone)] +struct AnthropicChatCompletionResponse { + pub id: String, + pub model: String, + pub content: Vec, + pub usage: AnthropicUsage, +} + +#[derive(Deserialize, Serialize, Clone)] +struct AnthropicUsage { + pub input_tokens: u32, + pub output_tokens: u32, +} + #[async_trait] impl Provider for AnthropicProvider { fn new(config: &ProviderConfig) -> Self { @@ -43,7 +66,7 @@ impl Provider for AnthropicProvider { let response = self .http_client .post("https://api.anthropic.com/v1/messages") - .header("Authorization", format!("Bearer {}", self.api_key)) + .header("x-api-key", &self.api_key) .header("anthropic-version", "2023-06-01") .json(&payload) .send() @@ -52,10 +75,42 @@ impl Provider for AnthropicProvider { let status = response.status(); if status.is_success() { - response + let anthropic_response: AnthropicChatCompletionResponse = response .json() .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + .expect("Failed to parse Anthropic response"); + + Ok(ChatCompletionResponse { + id: anthropic_response.id, + object: None, + created: None, + model: anthropic_response.model, + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatCompletionMessage { + name: None, + role: "assistant".to_string(), + content: crate::models::chat::ChatMessageContent::Array( + anthropic_response + .content + .into_iter() + .map(|content| ChatMessageContentPart { + r#type: content.r#type, + text: content.text, + }) + .collect(), + ), + }, + finish_reason: Some("stop".to_string()), + logprobs: None, + }], + usage: Usage { + prompt_tokens: anthropic_response.usage.input_tokens, + completion_tokens: anthropic_response.usage.output_tokens, + total_tokens: anthropic_response.usage.input_tokens + + anthropic_response.usage.output_tokens, + }, + }) } else { Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) } @@ -63,123 +118,17 @@ impl Provider for AnthropicProvider { async fn completions( &self, - payload: CompletionRequest, + _payload: CompletionRequest, _model_config: &ModelConfig, ) -> Result { - let anthropic_payload = serde_json::json!({ - "model": payload.model, - "prompt": format!("\n\nHuman: {}\n\nAssistant:", payload.prompt), - "max_tokens_to_sample": payload.max_tokens.unwrap_or(100), - "temperature": payload.temperature.unwrap_or(0.7), - "top_p": payload.top_p.unwrap_or(1.0), - "stop_sequences": payload.stop.unwrap_or_default(), - }); - - let response = self - .http_client - .post("https://api.anthropic.com/v1/complete") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("anthropic-version", "2023-06-01") - .json(&anthropic_payload) - .send() - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - let status = response.status(); - if !status.is_success() { - return Err( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) - ); - } - - let anthropic_response: serde_json::Value = response - .json() - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - Ok(CompletionResponse { - id: anthropic_response["completion_id"] - .as_str() - .unwrap_or("") - .to_string(), - object: "text_completion".to_string(), - created: chrono::Utc::now().timestamp() as u64, - model: payload.model, - choices: vec![CompletionChoice { - text: anthropic_response["completion"] - .as_str() - .unwrap_or("") - .to_string(), - index: 0, - logprobs: None, - finish_reason: Some("stop".to_string()), - }], - usage: Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - }) + unimplemented!() } async fn embeddings( &self, - payload: EmbeddingsRequest, + _payload: EmbeddingsRequest, _model_config: &ModelConfig, ) -> Result { - let anthropic_payload = match &payload.input { - EmbeddingsInput::Single(text) => serde_json::json!({ - "model": payload.model, - "text": text, - }), - EmbeddingsInput::Multiple(texts) => serde_json::json!({ - "model": payload.model, - "text": texts, - }), - }; - - let response = self - .http_client - .post("https://api.anthropic.com/v1/embeddings") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("anthropic-version", "2023-06-01") - .json(&anthropic_payload) - .send() - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - let status = response.status(); - if !status.is_success() { - return Err( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) - ); - } - - let anthropic_response: serde_json::Value = response - .json() - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - let embedding = anthropic_response["embedding"] - .as_array() - .unwrap_or(&Vec::new()) - .iter() - .filter_map(|v| v.as_f64().map(|f| f as f32)) - .collect(); - - Ok(EmbeddingsResponse { - object: "list".to_string(), - model: payload.model, - data: vec![Embeddings { - object: "embedding".to_string(), - embedding, - index: 0, - }], - usage: Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - }) + unimplemented!() } }