Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom openai host endpoint #1349

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/goose-server/src/routes/providers_and_keys.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,11 @@
"description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
},
"custom_openai": {
"name": "Custom OpenAI",
"description": "Access OpenAI-compatible models through a custom endpoint",
"models": ["gpt-4"],
"required_keys": ["CUSTOM_OPENAI_API_KEY", "CUSTOM_OPENAI_HOST"]
}
}
105 changes: 105 additions & 0 deletions crates/goose/src/providers/custom_openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;

use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
use super::utils::{ImageFormat, handle_response_openai_compat};
use crate::message::Message;
use crate::model::ModelConfig;
use mcp_core::tool::Tool;

pub const CUSTOM_OPENAI_DEFAULT_MODEL: &str = "gpt-4";

#[derive(Debug, serde::Serialize)]
pub struct CustomOpenAiProvider {
#[serde(skip)]
client: Client,
host: String,
api_key: String,
model: ModelConfig,
}

impl Default for CustomOpenAiProvider {
fn default() -> Self {
let model = ModelConfig::new(CustomOpenAiProvider::metadata().default_model);
CustomOpenAiProvider::from_env(model).expect("Failed to initialize Custom OpenAI provider")
}
}

impl CustomOpenAiProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("CUSTOM_OPENAI_API_KEY")?;
let host: String = config.get("CUSTOM_OPENAI_HOST")?;

let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;

Ok(Self {
client,
host,
api_key,
model,
})
}

async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("v1/chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;

let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await?;

handle_response_openai_compat(response).await
}
}

#[async_trait]
impl Provider for CustomOpenAiProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"custom_openai",
"Custom OpenAI",
"OpenAI-compatible API with custom host",
CUSTOM_OPENAI_DEFAULT_MODEL,
vec![CUSTOM_OPENAI_DEFAULT_MODEL.to_string()],
"",
vec![
ConfigKey::new("CUSTOM_OPENAI_API_KEY", true, true, None),
ConfigKey::new("CUSTOM_OPENAI_HOST", true, false, Some("")),
],
)
}

fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}

async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;

let response = self.post(payload).await?;
let message = response_to_message(response.clone())?;
let usage = get_usage(&response)?;

Ok((message, ProviderUsage::new(self.model.model_name.clone(), usage)))
}
}
3 changes: 3 additions & 0 deletions crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::{
azure::AzureProvider,
base::{Provider, ProviderMetadata},
bedrock::BedrockProvider,
custom_openai::CustomOpenAiProvider,
databricks::DatabricksProvider,
google::GoogleProvider,
groq::GroqProvider,
Expand All @@ -18,6 +19,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
AnthropicProvider::metadata(),
AzureProvider::metadata(),
BedrockProvider::metadata(),
CustomOpenAiProvider::metadata(),
DatabricksProvider::metadata(),
GoogleProvider::metadata(),
GroqProvider::metadata(),
Expand All @@ -33,6 +35,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
"bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)),
"custom_openai" => Ok(Box::new(CustomOpenAiProvider::from_env(model)?)),
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod anthropic;
pub mod azure;
pub mod base;
pub mod bedrock;
pub mod custom_openai;
pub mod databricks;
pub mod errors;
mod factory;
Expand Down
1 change: 1 addition & 0 deletions ui/desktop/src/components/settings/api_keys/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export function isSecretKey(keyName: string): boolean {
'OPENAI_HOST',
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME',
'CUSTOM_OPENAI_HOST',
];
return !nonSecretKeys.includes(keyName);
}
Expand Down
8 changes: 8 additions & 0 deletions ui/desktop/src/components/settings/models/hardcoded_stuff.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export const goose_models: Model[] = [
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'gpt-4', provider: 'Custom OpenAI' },
];

export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
Expand Down Expand Up @@ -47,6 +48,8 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet'];

export const azure_openai_models = ['gpt-4o'];

export const custom_openai_models = ['gpt-4'];

export const default_models = {
openai: 'gpt-4o',
anthropic: 'claude-3-5-sonnet-latest',
Expand All @@ -56,6 +59,7 @@ export const default_models = {
openrouter: 'anthropic/claude-3.5-sonnet',
ollama: 'qwen2.5',
azure_openai: 'gpt-4o',
custom_openai: 'gpt-4',
};

export function getDefaultModel(key: string): string | undefined {
Expand All @@ -73,11 +77,13 @@ export const required_keys = {
Google: ['GOOGLE_API_KEY'],
OpenRouter: ['OPENROUTER_API_KEY'],
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
'Custom OpenAI': ['CUSTOM_OPENAI_API_KEY', 'CUSTOM_OPENAI_HOST'],
};

export const default_key_value = {
OPENAI_HOST: 'https://api.openai.com',
OLLAMA_HOST: 'localhost',
CUSTOM_OPENAI_HOST: '',
};

export const supported_providers = [
Expand All @@ -89,6 +95,7 @@ export const supported_providers = [
'Ollama',
'OpenRouter',
'Azure OpenAI',
'Custom OpenAI',
];

export const model_docs_link = [
Expand All @@ -113,4 +120,5 @@ export const provider_aliases = [
{ provider: 'OpenRouter', alias: 'openrouter' },
{ provider: 'Google', alias: 'google' },
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
{ provider: 'Custom OpenAI', alias: 'custom_openai' },
];