Skip to content

Commit

Permalink
fix: make anthropic work (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirga authored Nov 15, 2024
1 parent ee4f524 commit 2b96274
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 121 deletions.
22 changes: 19 additions & 3 deletions src/models/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,35 @@ pub struct ChatCompletionRequest {
pub user: Option<String>,
}

#[derive(Deserialize, Serialize, Clone)]
#[serde(untagged)]
pub enum ChatMessageContent {
String(String),
Array(Vec<ChatMessageContentPart>),
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatMessageContentPart {
#[serde(rename = "type")]
pub r#type: String,
pub text: String,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionMessage {
pub role: String,
pub content: String,
pub content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub object: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<u64>,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
Expand Down
185 changes: 67 additions & 118 deletions src/providers/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use axum::async_trait;
use axum::http::StatusCode;
use serde::{Deserialize, Serialize};

use super::provider::Provider;
use crate::config::models::{ModelConfig, Provider as ProviderConfig};
use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::models::common::Usage;
use crate::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse};
use crate::models::embeddings::{
Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse,
use crate::models::chat::{
ChatCompletionChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ChatMessageContentPart,
};
use crate::models::common::Usage;
use crate::models::completion::{CompletionRequest, CompletionResponse};
use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
use reqwest::Client;

pub struct AnthropicProvider {
Expand All @@ -17,6 +19,27 @@ pub struct AnthropicProvider {
http_client: Client,
}

#[derive(Deserialize, Serialize, Clone)]
struct AnthropicContent {
pub text: String,
#[serde(rename = "type")]
pub r#type: String,
}

#[derive(Deserialize, Serialize, Clone)]
struct AnthropicChatCompletionResponse {
pub id: String,
pub model: String,
pub content: Vec<AnthropicContent>,
pub usage: AnthropicUsage,
}

#[derive(Deserialize, Serialize, Clone)]
struct AnthropicUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}

#[async_trait]
impl Provider for AnthropicProvider {
fn new(config: &ProviderConfig) -> Self {
Expand All @@ -43,7 +66,7 @@ impl Provider for AnthropicProvider {
let response = self
.http_client
.post("https://api.anthropic.com/v1/messages")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&payload)
.send()
Expand All @@ -52,134 +75,60 @@ impl Provider for AnthropicProvider {

let status = response.status();
if status.is_success() {
response
let anthropic_response: AnthropicChatCompletionResponse = response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.expect("Failed to parse Anthropic response");

Ok(ChatCompletionResponse {
id: anthropic_response.id,
object: None,
created: None,
model: anthropic_response.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatCompletionMessage {
name: None,
role: "assistant".to_string(),
content: crate::models::chat::ChatMessageContent::Array(
anthropic_response
.content
.into_iter()
.map(|content| ChatMessageContentPart {
r#type: content.r#type,
text: content.text,
})
.collect(),
),
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Usage {
prompt_tokens: anthropic_response.usage.input_tokens,
completion_tokens: anthropic_response.usage.output_tokens,
total_tokens: anthropic_response.usage.input_tokens
+ anthropic_response.usage.output_tokens,
},
})
} else {
Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
}
}

async fn completions(
&self,
payload: CompletionRequest,
_payload: CompletionRequest,
_model_config: &ModelConfig,
) -> Result<CompletionResponse, StatusCode> {
let anthropic_payload = serde_json::json!({
"model": payload.model,
"prompt": format!("\n\nHuman: {}\n\nAssistant:", payload.prompt),
"max_tokens_to_sample": payload.max_tokens.unwrap_or(100),
"temperature": payload.temperature.unwrap_or(0.7),
"top_p": payload.top_p.unwrap_or(1.0),
"stop_sequences": payload.stop.unwrap_or_default(),
});

let response = self
.http_client
.post("https://api.anthropic.com/v1/complete")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("anthropic-version", "2023-06-01")
.json(&anthropic_payload)
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let status = response.status();
if !status.is_success() {
return Err(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
);
}

let anthropic_response: serde_json::Value = response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(CompletionResponse {
id: anthropic_response["completion_id"]
.as_str()
.unwrap_or("")
.to_string(),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: payload.model,
choices: vec![CompletionChoice {
text: anthropic_response["completion"]
.as_str()
.unwrap_or("")
.to_string(),
index: 0,
logprobs: None,
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
})
unimplemented!()
}

async fn embeddings(
&self,
payload: EmbeddingsRequest,
_payload: EmbeddingsRequest,
_model_config: &ModelConfig,
) -> Result<EmbeddingsResponse, StatusCode> {
let anthropic_payload = match &payload.input {
EmbeddingsInput::Single(text) => serde_json::json!({
"model": payload.model,
"text": text,
}),
EmbeddingsInput::Multiple(texts) => serde_json::json!({
"model": payload.model,
"text": texts,
}),
};

let response = self
.http_client
.post("https://api.anthropic.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("anthropic-version", "2023-06-01")
.json(&anthropic_payload)
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let status = response.status();
if !status.is_success() {
return Err(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
);
}

let anthropic_response: serde_json::Value = response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let embedding = anthropic_response["embedding"]
.as_array()
.unwrap_or(&Vec::new())
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();

Ok(EmbeddingsResponse {
object: "list".to_string(),
model: payload.model,
data: vec![Embeddings {
object: "embedding".to_string(),
embedding,
index: 0,
}],
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
})
unimplemented!()
}
}

0 comments on commit 2b96274

Please sign in to comment.