diff --git a/crates/goose-server/src/routes/providers_and_keys.json b/crates/goose-server/src/routes/providers_and_keys.json index 34589cc08..b9ca26a87 100644 --- a/crates/goose-server/src/routes/providers_and_keys.json +++ b/crates/goose-server/src/routes/providers_and_keys.json @@ -46,5 +46,11 @@ "description": "Connect to Azure OpenAI Service", "models": ["gpt-4o", "gpt-4o-mini"], "required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"] + }, + "custom_openai": { + "name": "Custom OpenAI", + "description": "Access OpenAI-compatible models through a custom endpoint", + "models": ["gpt-4"], + "required_keys": ["CUSTOM_OPENAI_API_KEY", "CUSTOM_OPENAI_HOST"] } } diff --git a/crates/goose/src/providers/custom_openai.rs b/crates/goose/src/providers/custom_openai.rs new file mode 100644 index 000000000..8d782618d --- /dev/null +++ b/crates/goose/src/providers/custom_openai.rs @@ -0,0 +1,105 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::Value; +use std::time::Duration; + +use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use super::errors::ProviderError; +use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::utils::{ImageFormat, handle_response_openai_compat}; +use crate::message::Message; +use crate::model::ModelConfig; +use mcp_core::tool::Tool; + +pub const CUSTOM_OPENAI_DEFAULT_MODEL: &str = "gpt-4"; + +#[derive(Debug, serde::Serialize)] +pub struct CustomOpenAiProvider { + #[serde(skip)] + client: Client, + host: String, + api_key: String, + model: ModelConfig, +} + +impl Default for CustomOpenAiProvider { + fn default() -> Self { + let model = ModelConfig::new(CustomOpenAiProvider::metadata().default_model); + CustomOpenAiProvider::from_env(model).expect("Failed to initialize Custom OpenAI provider") + } +} + +impl CustomOpenAiProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + let api_key: String = config.get_secret("CUSTOM_OPENAI_API_KEY")?; + let host: String = config.get("CUSTOM_OPENAI_HOST")?; + + let client = Client::builder() + .timeout(Duration::from_secs(600)) + .build()?; + + Ok(Self { + client, + host, + api_key, + model, + }) + } + + async fn post(&self, payload: Value) -> Result { + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("v1/chat/completions").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + let response = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&payload) + .send() + .await?; + + handle_response_openai_compat(response).await + } +} + +#[async_trait] +impl Provider for CustomOpenAiProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "custom_openai", + "Custom OpenAI", + "OpenAI-compatible API with custom host", + CUSTOM_OPENAI_DEFAULT_MODEL, + vec![CUSTOM_OPENAI_DEFAULT_MODEL.to_string()], + "", + vec![ + ConfigKey::new("CUSTOM_OPENAI_API_KEY", true, true, None), + ConfigKey::new("CUSTOM_OPENAI_HOST", true, false, Some("")), + ], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + + let response = self.post(payload).await?; + let message = response_to_message(response.clone())?; + let usage = get_usage(&response)?; + + Ok((message, ProviderUsage::new(self.model.model_name.clone(), usage))) + } +} diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index d17fb8893..d5a89d9ce 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -3,6 +3,7 @@ use super::{ azure::AzureProvider, base::{Provider, ProviderMetadata}, bedrock::BedrockProvider, + custom_openai::CustomOpenAiProvider, databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider, @@ -18,6 +19,7 @@ pub fn providers() -> Vec { AnthropicProvider::metadata(), AzureProvider::metadata(), BedrockProvider::metadata(), + CustomOpenAiProvider::metadata(), DatabricksProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), @@ -33,6 +35,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result Ok(Box::new(AnthropicProvider::from_env(model)?)), "azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)), "bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)), + "custom_openai" => Ok(Box::new(CustomOpenAiProvider::from_env(model)?)), "databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)), "groq" => Ok(Box::new(GroqProvider::from_env(model)?)), "ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)), diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 634224fd7..e1ba41731 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -2,6 +2,7 @@ pub mod anthropic; pub mod azure; pub mod base; pub mod bedrock; +pub mod custom_openai; pub mod databricks; pub mod errors; mod factory; diff --git a/ui/desktop/src/components/settings/api_keys/utils.tsx b/ui/desktop/src/components/settings/api_keys/utils.tsx index b29014db8..979a497c8 100644 --- a/ui/desktop/src/components/settings/api_keys/utils.tsx +++ b/ui/desktop/src/components/settings/api_keys/utils.tsx @@ -10,6 +10,7 @@ export function isSecretKey(keyName: string): boolean { 'OPENAI_HOST', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME', + 'CUSTOM_OPENAI_HOST', ]; return !nonSecretKeys.includes(keyName); } diff --git a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx index b38840153..0604a0a26 100644 --- a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx +++ b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx @@ -19,6 +19,7 @@ export const goose_models: Model[] = [ { id: 17, name: 'qwen2.5', provider: 'Ollama' }, { id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' }, { id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' }, + { id: 20, name: 'gpt-4', provider: 'Custom OpenAI' }, ]; export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1']; @@ -47,6 +48,8 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet']; export const azure_openai_models = ['gpt-4o']; +export const custom_openai_models = ['gpt-4']; + export const default_models = { openai: 'gpt-4o', anthropic: 'claude-3-5-sonnet-latest', @@ -56,6 +59,7 @@ export const default_models = { openrouter: 'anthropic/claude-3.5-sonnet', ollama: 'qwen2.5', azure_openai: 'gpt-4o', + custom_openai: 'gpt-4', }; export function getDefaultModel(key: string): string | undefined { @@ -73,11 +77,13 @@ export const required_keys = { Google: ['GOOGLE_API_KEY'], OpenRouter: ['OPENROUTER_API_KEY'], 'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'], + 'Custom OpenAI': ['CUSTOM_OPENAI_API_KEY', 'CUSTOM_OPENAI_HOST'], }; export const default_key_value = { OPENAI_HOST: 'https://api.openai.com', OLLAMA_HOST: 'localhost', + CUSTOM_OPENAI_HOST: '', }; export const supported_providers = [ @@ -89,6 +95,7 @@ export const supported_providers = [ 'Ollama', 'OpenRouter', 'Azure OpenAI', + 'Custom OpenAI', ]; export const model_docs_link = [ @@ -113,4 +120,5 @@ export const provider_aliases = [ { provider: 'OpenRouter', alias: 'openrouter' }, { provider: 'Google', alias: 'google' }, { provider: 'Azure OpenAI', alias: 'azure_openai' }, + { provider: 'Custom OpenAI', alias: 'custom_openai' }, ];