Skip to content

Commit

Permalink
feat: add custom openai host endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
AnthonyRonning committed Feb 22, 2025
1 parent 9693b40 commit 697d17c
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 0 deletions.
6 changes: 6 additions & 0 deletions crates/goose-server/src/routes/providers_and_keys.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,11 @@
"description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
},
"custom_openai": {
"name": "Custom OpenAI",
"description": "Access OpenAI-compatible models through a custom endpoint",
"models": ["gpt-4"],
"required_keys": ["CUSTOM_OPENAI_API_KEY", "CUSTOM_OPENAI_HOST"]
}
}
105 changes: 105 additions & 0 deletions crates/goose/src/providers/custom_openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;

use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
use super::utils::{ImageFormat, handle_response_openai_compat};
use crate::message::Message;
use crate::model::ModelConfig;
use mcp_core::tool::Tool;

pub const CUSTOM_OPENAI_DEFAULT_MODEL: &str = "gpt-4";

#[derive(Debug, serde::Serialize)]
pub struct CustomOpenAiProvider {
#[serde(skip)]
client: Client,
host: String,
api_key: String,
model: ModelConfig,
}

impl Default for CustomOpenAiProvider {
fn default() -> Self {
let model = ModelConfig::new(CustomOpenAiProvider::metadata().default_model);
CustomOpenAiProvider::from_env(model).expect("Failed to initialize Custom OpenAI provider")
}
}

impl CustomOpenAiProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("CUSTOM_OPENAI_API_KEY")?;
let host: String = config.get("CUSTOM_OPENAI_HOST")?;

let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;

Ok(Self {
client,
host,
api_key,
model,
})
}

async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("v1/chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;

let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await?;

handle_response_openai_compat(response).await
}
}

#[async_trait]
impl Provider for CustomOpenAiProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"custom_openai",
"Custom OpenAI",
"OpenAI-compatible API with custom host",
CUSTOM_OPENAI_DEFAULT_MODEL,
vec![CUSTOM_OPENAI_DEFAULT_MODEL.to_string()],
"",
vec![
ConfigKey::new("CUSTOM_OPENAI_API_KEY", true, true, None),
ConfigKey::new("CUSTOM_OPENAI_HOST", true, false, Some("")),
],
)
}

fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}

async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;

let response = self.post(payload).await?;
let message = response_to_message(response.clone())?;
let usage = get_usage(&response)?;

Ok((message, ProviderUsage::new(self.model.model_name.clone(), usage)))
}
}
3 changes: 3 additions & 0 deletions crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::{
azure::AzureProvider,
base::{Provider, ProviderMetadata},
bedrock::BedrockProvider,
custom_openai::CustomOpenAiProvider,
databricks::DatabricksProvider,
google::GoogleProvider,
groq::GroqProvider,
Expand All @@ -18,6 +19,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
AnthropicProvider::metadata(),
AzureProvider::metadata(),
BedrockProvider::metadata(),
CustomOpenAiProvider::metadata(),
DatabricksProvider::metadata(),
GoogleProvider::metadata(),
GroqProvider::metadata(),
Expand All @@ -33,6 +35,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
"bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)),
"custom_openai" => Ok(Box::new(CustomOpenAiProvider::from_env(model)?)),
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod anthropic;
pub mod azure;
pub mod base;
pub mod bedrock;
pub mod custom_openai;
pub mod databricks;
pub mod errors;
mod factory;
Expand Down
1 change: 1 addition & 0 deletions ui/desktop/src/components/settings/api_keys/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export function isSecretKey(keyName: string): boolean {
'OPENAI_HOST',
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME',
'CUSTOM_OPENAI_HOST',
];
return !nonSecretKeys.includes(keyName);
}
Expand Down
8 changes: 8 additions & 0 deletions ui/desktop/src/components/settings/models/hardcoded_stuff.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export const goose_models: Model[] = [
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'gpt-4', provider: 'Custom OpenAI' },
];

export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
Expand Down Expand Up @@ -47,6 +48,8 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet'];

export const azure_openai_models = ['gpt-4o'];

export const custom_openai_models = ['gpt-4'];

export const default_models = {
openai: 'gpt-4o',
anthropic: 'claude-3-5-sonnet-latest',
Expand All @@ -56,6 +59,7 @@ export const default_models = {
openrouter: 'anthropic/claude-3.5-sonnet',
ollama: 'qwen2.5',
azure_openai: 'gpt-4o',
custom_openai: 'gpt-4',
};

export function getDefaultModel(key: string): string | undefined {
Expand All @@ -73,11 +77,13 @@ export const required_keys = {
Google: ['GOOGLE_API_KEY'],
OpenRouter: ['OPENROUTER_API_KEY'],
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
'Custom OpenAI': ['CUSTOM_OPENAI_API_KEY', 'CUSTOM_OPENAI_HOST'],
};

export const default_key_value = {
OPENAI_HOST: 'https://api.openai.com',
OLLAMA_HOST: 'localhost',
CUSTOM_OPENAI_HOST: '',
};

export const supported_providers = [
Expand All @@ -89,6 +95,7 @@ export const supported_providers = [
'Ollama',
'OpenRouter',
'Azure OpenAI',
'Custom OpenAI',
];

export const model_docs_link = [
Expand All @@ -113,4 +120,5 @@ export const provider_aliases = [
{ provider: 'OpenRouter', alias: 'openrouter' },
{ provider: 'Google', alias: 'google' },
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
{ provider: 'Custom OpenAI', alias: 'custom_openai' },
];

0 comments on commit 697d17c

Please sign in to comment.