From 2a74528447d06f015b9f7fa4fb159ae161341ccd Mon Sep 17 00:00:00 2001 From: Uddhav Kambli Date: Sun, 23 Feb 2025 12:57:08 -0500 Subject: [PATCH] feat: add GCP Vertex AI platform as provider Signed-off-by: Uddhav Kambli --- Cargo.lock | 129 +- .../src/routes/providers_and_keys.json | 6 + crates/goose/Cargo.toml | 6 + crates/goose/src/providers/factory.rs | 3 + .../src/providers/formats/gcpvertexai.rs | 388 ++++++ crates/goose/src/providers/formats/mod.rs | 1 + crates/goose/src/providers/gcpauth.rs | 1048 +++++++++++++++++ crates/goose/src/providers/gcpvertexai.rs | 342 ++++++ crates/goose/src/providers/mod.rs | 2 + crates/goose/tests/truncate_agent.rs | 25 +- .../docs/getting-started/providers.md | 23 +- .../components/settings/api_keys/utils.tsx | 2 + .../settings/models/hardcoded_stuff.tsx | 19 + 13 files changed, 1970 insertions(+), 24 deletions(-) create mode 100644 crates/goose/src/providers/formats/gcpvertexai.rs create mode 100644 crates/goose/src/providers/gcpauth.rs create mode 100644 crates/goose/src/providers/gcpvertexai.rs diff --git a/Cargo.lock b/Cargo.lock index d2f5f8096..aecdcb947 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,9 +230,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.85" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", @@ -1699,6 +1699,12 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dyn-clone" version = "1.0.17" @@ -1883,6 +1889,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "futures" version = "0.3.31" @@ -2149,10 +2161,12 @@ dependencies = [ "futures", "include_dir", "indoc", + "jsonwebtoken", "keyring", "lazy_static", "mcp-client", "mcp-core", + "mockall", "nanoid", "once_cell", "paste", @@ -2568,7 +2582,7 @@ dependencies = [ "http 1.2.0", "hyper 1.6.0", "hyper-util", - "rustls 0.23.21", + "rustls 0.23.23", "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", @@ -3025,6 +3039,21 @@ dependencies = [ "serde", ] +[[package]] +name = "jsonwebtoken" +version = "9.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +dependencies = [ + "base64 0.22.1", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "keyring" version = "3.6.1" @@ -3364,6 +3393,32 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "monostate" version = "0.1.13" @@ -3736,6 +3791,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" +[[package]] +name = "pem" +version = "3.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -3938,6 +4003,32 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "prettyplease" version = "0.2.29" @@ -4060,7 +4151,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.1.0", - "rustls 0.23.21", + "rustls 0.23.23", "socket2", "thiserror 2.0.11", "tokio", @@ -4078,7 +4169,7 @@ dependencies = [ "rand", "ring", "rustc-hash 2.1.0", - "rustls 0.23.21", + "rustls 0.23.23", "rustls-pki-types", "slab", "thiserror 2.0.11", @@ -4390,7 +4481,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.21", + "rustls 0.23.23", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", @@ -4511,9 +4602,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.21" +version = "0.23.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" +checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ "once_cell", "ring", @@ -4974,6 +5065,18 @@ dependencies = [ "quote", ] +[[package]] +name = "simple_asn1" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.11", + "time", +] + [[package]] name = "siphasher" version = "1.0.1" @@ -5265,6 +5368,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "test-case" version = "3.3.1" @@ -5534,7 +5643,7 @@ version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ - "rustls 0.23.21", + "rustls 0.23.23", "tokio", ] @@ -6764,7 +6873,7 @@ dependencies = [ "hyper-util", "log", "percent-encoding", - "rustls 0.23.21", + "rustls 0.23.23", "rustls-pemfile 2.2.0", "seahash", "serde", diff --git a/crates/goose-server/src/routes/providers_and_keys.json b/crates/goose-server/src/routes/providers_and_keys.json index 34589cc08..a754c20a9 100644 --- a/crates/goose-server/src/routes/providers_and_keys.json +++ b/crates/goose-server/src/routes/providers_and_keys.json @@ -17,6 +17,12 @@ "models": ["goose"], "required_keys": ["DATABRICKS_HOST"] }, + "gcp_vertex_ai": { + "name": "GCP Vertex AI", + "description": "Use Vertex AI platform models", + "models": ["claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05"], + "required_keys": ["GCP_PROJECT_ID", "GCP_LOCATION"] + }, "google": { "name": "Google", "description": "Lorem ipsum", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 3d0d4a7b2..07a525c6f 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -66,6 +66,9 @@ aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-smithy-types = "1.2.12" aws-sdk-bedrockruntime = "1.72.0" +# For GCP Vertex AI provider auth +jsonwebtoken = "9.3.1" + [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } @@ -73,6 +76,9 @@ winapi = { version = "0.3", features = ["wincred"] } criterion = "0.5" tempfile = "3.15.0" serial_test = "3.2.0" +mockall = "0.13.1" +wiremock = "0.6.0" +tokio = { version = "1.0", features = ["full"] } [[example]] name = "agent" diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index d17fb8893..d2ba7bfb4 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -4,6 +4,7 @@ use super::{ base::{Provider, ProviderMetadata}, bedrock::BedrockProvider, databricks::DatabricksProvider, + gcpvertexai::GcpVertexAIProvider, google::GoogleProvider, groq::GroqProvider, ollama::OllamaProvider, @@ -19,6 +20,7 @@ pub fn providers() -> Vec { AzureProvider::metadata(), BedrockProvider::metadata(), DatabricksProvider::metadata(), + GcpVertexAIProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), OllamaProvider::metadata(), @@ -37,6 +39,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result Ok(Box::new(GroqProvider::from_env(model)?)), "ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)), "openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)), + "gcp_vertex_ai" => Ok(Box::new(GcpVertexAIProvider::from_env(model)?)), "google" => Ok(Box::new(GoogleProvider::from_env(model)?)), _ => Err(anyhow::anyhow!("Unknown provider: {}", name)), } diff --git a/crates/goose/src/providers/formats/gcpvertexai.rs b/crates/goose/src/providers/formats/gcpvertexai.rs new file mode 100644 index 000000000..a61f24ffd --- /dev/null +++ b/crates/goose/src/providers/formats/gcpvertexai.rs @@ -0,0 +1,388 @@ +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::base::Usage; +use anyhow::{Context, Result}; +use mcp_core::tool::Tool; +use std::fmt; +use serde_json::Value; +use super::{anthropic, google}; + +/// Sensible default values of Google Cloud Platform (GCP) locations for model deployment. +/// +/// Each variant corresponds to a specific GCP region where models can be hosted. +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum GcpLocation { + /// Represents the us-central1 region in Iowa + Iowa, + /// Represents the us-east5 region in Ohio + Ohio, +} + +impl fmt::Display for GcpLocation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Iowa => write!(f, "us-central1"), + Self::Ohio => write!(f, "us-east5"), + } + } +} + +impl TryFrom<&str> for GcpLocation { + type Error = ModelError; + + fn try_from(s: &str) -> Result { + match s { + "us-central1" => Ok(Self::Iowa), + "us-east5" => Ok(Self::Ohio), + _ => Err(ModelError::UnsupportedLocation(s.to_string())), + } + } +} + +/// Represents errors that can occur during model operations. +/// +/// This enum encompasses various error conditions that might arise when working +/// with GCP Vertex AI models, including unsupported models, invalid requests, +/// and unsupported locations. +#[derive(Debug, thiserror::Error)] +pub enum ModelError { + /// Error when an unsupported Vertex AI model is specified + #[error("Unsupported Vertex AI model: {0}")] + UnsupportedModel(String), + /// Error when the request structure is invalid + #[error("Invalid request structure: {0}")] + InvalidRequest(String), + /// Error when an unsupported GCP location is specified + #[error("Unsupported GCP location: {0}")] + UnsupportedLocation(String), +} + +/// Represents available GCP Vertex AI models for Goose. +/// +/// This enum encompasses different model families and their versions +/// that are supported in the GCP Vertex AI platform. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum GcpVertexAIModel { + /// Claude model family with specific versions + Claude(ClaudeVersion), + /// Gemini model family with specific versions + Gemini(GeminiVersion), +} + +/// Represents available versions of the Claude model for Goose. +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum ClaudeVersion { + /// Claude 3.5 Sonnet initial version + Sonnet35, + /// Claude 3.5 Sonnet version 2 + Sonnet35V2, +} + +/// Represents available versions of the Gemini model for Goose. +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum GeminiVersion { + /// Gemini 1.5 Pro version + Pro15, + /// Gemini 2.0 Flash version + Flash20, + /// Gemini 2.0 Pro Experimental version + Pro20Exp, +} + +impl fmt::Display for GcpVertexAIModel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let model_id = match self { + Self::Claude(version) => match version { + ClaudeVersion::Sonnet35 => "claude-3-5-sonnet@20240620", + ClaudeVersion::Sonnet35V2 => "claude-3-5-sonnet-v2@20241022", + }, + Self::Gemini(version) => match version { + GeminiVersion::Pro15 => "gemini-1.5-pro-002", + GeminiVersion::Flash20 => "gemini-2.0-flash-001", + GeminiVersion::Pro20Exp => "gemini-2.0-pro-exp-02-05", + }, + }; + write!(f, "{}", model_id) + } +} + +impl GcpVertexAIModel { + /// Returns the default GCP location for the model. + /// + /// Each model family has a well-known location: + /// - Claude models default to Ohio (us-east5) + /// - Gemini models default to Iowa (us-central1) + pub fn default_location(&self) -> GcpLocation { + match self { + Self::Claude(_) => GcpLocation::Ohio, + Self::Gemini(_) => GcpLocation::Iowa, + } + } +} + +impl TryFrom<&str> for GcpVertexAIModel { + type Error = ModelError; + + fn try_from(s: &str) -> Result { + match s { + "claude-3-5-sonnet@20240620" => Ok(Self::Claude(ClaudeVersion::Sonnet35)), + "claude-3-5-sonnet-v2@20241022" => Ok(Self::Claude(ClaudeVersion::Sonnet35V2)), + "gemini-1.5-pro-002" => Ok(Self::Gemini(GeminiVersion::Pro15)), + "gemini-2.0-flash-001" => Ok(Self::Gemini(GeminiVersion::Flash20)), + "gemini-2.0-pro-exp-02-05" => Ok(Self::Gemini(GeminiVersion::Pro20Exp)), + _ => Err(ModelError::UnsupportedModel(s.to_string())), + } + } +} + +/// Holds context information for a model request since the Vertex AI platform +/// supports multiple model families. +/// +/// This structure maintains information about the model being used +/// and provides utility methods for handling model-specific operations. +#[derive(Debug, Clone)] +pub struct RequestContext { + /// The GCP Vertex AI model being used + pub model: GcpVertexAIModel, +} + +impl RequestContext { + /// Creates a new RequestContext from a model ID string. + /// + /// # Arguments + /// * `model_id` - The string identifier of the model + /// + /// # Returns + /// * `Result` - A new RequestContext if the model ID is valid + pub fn new(model_id: &str) -> Result { + Ok(Self { + model: GcpVertexAIModel::try_from(model_id) + .with_context(|| format!("Failed to parse model ID: {}", model_id))?, + }) + } + + /// Returns the provider associated with the model. + pub fn provider(&self) -> ModelProvider { + match self.model { + GcpVertexAIModel::Claude(_) => ModelProvider::Anthropic, + GcpVertexAIModel::Gemini(_) => ModelProvider::Google, + } + } +} + +/// Represents available model providers. +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum ModelProvider { + /// Anthropic provider (Claude models) + Anthropic, + /// Google provider (Gemini models) + Google, +} + +impl ModelProvider { + /// Returns the string representation of the provider. + pub fn as_str(&self) -> &'static str { + match self { + Self::Anthropic => "anthropic", + Self::Google => "google", + } + } +} + +/// Creates an Anthropic-specific Vertex AI request payload. +/// +/// # Arguments +/// * `model_config` - Configuration for the model +/// * `system` - System prompt +/// * `messages` - Array of messages +/// * `tools` - Array of available tools +/// +/// # Returns +/// * `Result` - JSON request payload for Anthropic API +fn create_anthropic_request( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> Result { + let mut request = anthropic::create_request(model_config, system, messages, tools)?; + + let obj = request + .as_object_mut() + .ok_or_else(|| ModelError::InvalidRequest("Request is not a JSON object".to_string()))?; + + obj.remove("model"); + obj.insert( + "anthropic_version".to_string(), + Value::String("vertex-2023-10-16".to_string()), + ); + + Ok(request) +} + +/// Creates a Gemini-specific Vertex AI request payload. +/// +/// # Arguments +/// * `model_config` - Configuration for the model +/// * `system` - System prompt +/// * `messages` - Array of messages +/// * `tools` - Array of available tools +/// +/// # Returns +/// * `Result` - JSON request payload for Google API +fn create_google_request( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> Result { + google::create_request(model_config, system, messages, tools) +} + +/// Creates a provider-specific request payload and context. +/// +/// # Arguments +/// * `model_config` - Configuration for the model +/// * `system` - System prompt +/// * `messages` - Array of messages +/// * `tools` - Array of available tools +/// +/// # Returns +/// * `Result<(Value, RequestContext)>` - Tuple of request payload and context +pub fn create_request( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> Result<(Value, RequestContext)> { + let context = RequestContext::new(&model_config.model_name)?; + + let request = match context.model { + GcpVertexAIModel::Claude(_) => create_anthropic_request(model_config, system, messages, tools)?, + GcpVertexAIModel::Gemini(_) => create_google_request(model_config, system, messages, tools)?, + }; + + Ok((request, context)) +} + +/// Converts a provider response to a Message. +/// +/// # Arguments +/// * `response` - The raw response from the provider +/// * `request_context` - Context information about the request +/// +/// # Returns +/// * `Result` - Converted message +pub fn response_to_message(response: Value, request_context: RequestContext) -> Result { + match request_context.provider() { + ModelProvider::Anthropic => anthropic::response_to_message(response), + ModelProvider::Google => google::response_to_message(response), + } +} + +/// Extracts token usage information from the response data. +/// +/// # Arguments +/// * `data` - The response data containing usage information +/// * `request_context` - Context information about the request +/// +/// # Returns +/// * `Result` - Usage statistics +pub fn get_usage(data: &Value, request_context: &RequestContext) -> Result { + match request_context.provider() { + ModelProvider::Anthropic => anthropic::get_usage(data), + ModelProvider::Google => google::get_usage(data), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + + #[test] + fn test_model_parsing() -> Result<()> { + let valid_models = [ + "claude-3-5-sonnet@20240620", + "claude-3-5-sonnet-v2@20241022", + "gemini-1.5-pro-002", + "gemini-2.0-flash-001", + "gemini-2.0-pro-exp-02-05", + ]; + + for model_id in valid_models { + let model = GcpVertexAIModel::try_from(model_id)?; + assert_eq!(model.to_string(), model_id); + } + + assert!(GcpVertexAIModel::try_from("unsupported-model").is_err()); + Ok(()) + } + + #[test] + fn test_request_context() -> Result<()> { + let context = RequestContext::new("claude-3-5-sonnet@20240620")?; + assert!(matches!(context.provider(), ModelProvider::Anthropic)); + + let context = RequestContext::new("gemini-1.5-pro-002")?; + assert!(matches!(context.provider(), ModelProvider::Google)); + + assert!(RequestContext::new("unsupported-model").is_err()); + Ok(()) + } + + #[test] + fn test_create_request() -> Result<()> { + let test_cases = [ + ("claude-3-5-sonnet@20240620", ModelProvider::Anthropic), + ("gemini-1.5-pro-002", ModelProvider::Google), + ]; + + for (model_id, expected_provider) in test_cases { + let model_config = ModelConfig::new(model_id.to_string()); + let system = "You are a helpful assistant."; + let messages = vec![Message::user().with_text("Hello")]; + let tools = vec![]; + + let (request, context) = create_request(&model_config, system, &messages, &tools)?; + + assert!(request.is_object()); + assert_eq!(context.provider(), expected_provider); + } + + Ok(()) + } + + #[test] + fn test_default_locations() -> Result<()> { + let test_cases = [ + ("claude-3-5-sonnet@20240620", GcpLocation::Ohio), + ("claude-3-5-sonnet-v2@20241022", GcpLocation::Ohio), + ("gemini-1.5-pro-002", GcpLocation::Iowa), + ("gemini-2.0-flash-001", GcpLocation::Iowa), + ("gemini-2.0-pro-exp-02-05", GcpLocation::Iowa), + ]; + + for (model_id, expected_location) in test_cases { + let model = GcpVertexAIModel::try_from(model_id)?; + assert_eq!( + model.default_location(), + expected_location, + "Model {} should have default location {:?}", + model_id, + expected_location + ); + + let context = RequestContext::new(model_id)?; + assert_eq!( + context.model.default_location(), + expected_location, + "RequestContext for {} should have default location {:?}", + model_id, + expected_location + ); + } + + Ok(()) + } +} \ No newline at end of file diff --git a/crates/goose/src/providers/formats/mod.rs b/crates/goose/src/providers/formats/mod.rs index 780f38488..a429147f9 100644 --- a/crates/goose/src/providers/formats/mod.rs +++ b/crates/goose/src/providers/formats/mod.rs @@ -1,4 +1,5 @@ pub mod anthropic; pub mod bedrock; +pub mod gcpvertexai; pub mod google; pub mod openai; diff --git a/crates/goose/src/providers/gcpauth.rs b/crates/goose/src/providers/gcpauth.rs new file mode 100644 index 000000000..726561ce7 --- /dev/null +++ b/crates/goose/src/providers/gcpauth.rs @@ -0,0 +1,1048 @@ +use async_trait::async_trait; +use jsonwebtoken::{encode, EncodingKey, Header}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use std::{env, fmt, io}; +use tokio::sync::RwLock; + + +/// Represents errors that can occur during GCP authentication. +/// +/// This enum encompasses various error conditions that might arise during +/// the authentication process, including credential loading, token creation, +/// and token exchange operations. +#[derive(Debug, thiserror::Error)] +pub enum AuthError { + /// Error when loading credentials from the filesystem or environment + #[error("Failed to load credentials: {0}")] + CredentialsError(String), + + /// Error during JWT token creation + #[error("Token creation failed: {0}")] + TokenCreationError(String), + + /// Error during OAuth token exchange + #[error("Token exchange failed: {0}")] + TokenExchangeError(String), +} + +/// Represents an authentication token with its type and value. +/// +/// This structure holds both the token type (e.g., "Bearer") and its +/// actual value, typically used for authentication with GCP services. +/// The token is obtained either through service account or user credentials. +#[derive(Debug, Clone)] +pub struct AuthToken { + /// The type of the token (e.g., "Bearer") + pub token_type: String, + /// The actual token value + pub token_value: String, +} + +impl fmt::Display for AuthToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} {}", self.token_type, self.token_value) + } +} + +/// Represents the types of Application Default Credentials (ADC) supported. +/// +/// GCP supports multiple credential types for authentication. This enum +/// represents the two main types: authorized user and service account. +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum AdcCredentials { + /// Credentials for an authorized user (typically from gcloud auth) + AuthorizedUser(AuthorizedUserCredentials), + /// Credentials for a service account + ServiceAccount(ServiceAccountCredentials), + /// Credentials for the GCP native default account + DefaultAccount(TokenResponse), +} + +/// Credentials for an authorized user account. +/// +/// These credentials are typically obtained through interactive login +/// with the gcloud CLI tool. +#[derive(Debug, Deserialize)] +struct AuthorizedUserCredentials { + /// OAuth 2.0 client ID + client_id: String, + /// OAuth 2.0 client secret + client_secret: String, + /// OAuth 2.0 refresh token + refresh_token: String, + /// URI for token refresh requests + #[serde(default = "default_token_uri")] + token_uri: String, +} + +/// Credentials for a service account. +/// +/// These credentials are typically obtained from a JSON key file +/// downloaded from the Google Cloud Console. +#[derive(Debug, Deserialize)] +struct ServiceAccountCredentials { + /// Service account email address + client_email: String, + /// The private key from JSON credential for signing JWT tokens + private_key: String, + /// URI for token exchange requests + token_uri: String, +} + +/// Returns the default OAuth 2.0 token endpoint. +fn default_token_uri() -> String { + "https://oauth2.googleapis.com/token".to_string() +} + +/// A trait that defines operations for interacting with the filesystem. +/// +/// This trait provides an abstraction over filesystem operations, primarily +/// for reading credential files. It enables testing through mock implementations. +#[async_trait] +pub trait FilesystemOps { + /// Reads the contents of a file into a string. + /// + /// # Arguments + /// * `path` - The path to the file to read + /// + /// # Returns + /// * `Result` - The contents of the file or an error + async fn read_to_string(&self, path: String) -> Result; +} + +//// A trait that defines operations for accessing environment variables. +/// +/// This trait provides an abstraction over environment variable access, +/// enabling testing through mock implementations. +pub trait EnvOps { + /// Retrieves the value of an environment variable. + /// + /// # Arguments + /// * `key` - The name of the environment variable + /// + /// # Returns + /// * `Result` - The value of the variable or an error if not found + fn get_var(&self, key: &str) -> Result; +} + +/// A concrete implementation of FilesystemOps using the actual filesystem. +/// +/// This implementation uses tokio's async filesystem operations for +/// reading files in an asynchronous manner. +pub struct RealFilesystemOps; + +/// A concrete implementation of EnvOps using the actual environment. +/// +/// This implementation directly accesses system environment variables +/// through the standard library. +pub struct RealEnvOps; + +#[async_trait] +impl FilesystemOps for RealFilesystemOps { + async fn read_to_string(&self, path: String) -> Result { + tokio::fs::read_to_string(path).await + } +} + +impl EnvOps for RealEnvOps { + fn get_var(&self, key: &str) -> Result { + env::var(key) + } +} + +impl AdcCredentials { + /// Loads credentials from the default locations. + /// https://cloud.google.com/docs/authentication/application-default-credentials#personal + /// + /// Attempts to load credentials in the following order: + /// 1. GOOGLE_APPLICATION_CREDENTIALS environment variable + /// 2. Default gcloud credentials path (~/.config/gcloud/application_default_credentials.json) + /// 3. Metadata server if running in GCP + async fn load() -> Result { + Self::load_impl(&RealFilesystemOps, &RealEnvOps, "http://metadata.google.internal").await + } + + async fn load_impl( + fs_ops: &impl FilesystemOps, + env_ops: &impl EnvOps, + metadata_base_url: &str, + ) -> Result { + // Try GOOGLE_APPLICATION_CREDENTIALS first + if let Ok(cred_path) = Self::get_env_credentials_path(env_ops) { + if let Ok(creds) = Self::load_from_file(fs_ops, &cred_path).await { + return Ok(creds); + } + } + + // Try default gcloud credentials path + if let Ok(cred_path) = Self::get_default_credentials_path(env_ops) { + if let Ok(creds) = Self::load_from_file(fs_ops, &cred_path).await { + return Ok(creds); + } + } + + // Try metadata server if running on GCP + if let Ok(creds) = Self::load_from_metadata_server(metadata_base_url).await { + return Ok(creds); + } + + Err(AuthError::CredentialsError("No valid credentials found in any location".to_string())) + } + + async fn load_from_file(fs_ops: &impl FilesystemOps, path: &str) -> Result { + let content = fs_ops.read_to_string(path.to_string()) + .await + .map_err(|e| AuthError::CredentialsError( + format!("Failed to read credentials from {}: {}", path, e) + ))?; + + serde_json::from_str(&content) + .map_err(|e| AuthError::CredentialsError(format!("Invalid credentials format: {}", e))) + } + + fn get_env_credentials_path(env_ops: &impl EnvOps) -> Result { + env_ops.get_var("GOOGLE_APPLICATION_CREDENTIALS") + .map_err(|_| AuthError::CredentialsError("GOOGLE_APPLICATION_CREDENTIALS not set".to_string())) + } + + fn get_default_credentials_path(env_ops: &impl EnvOps) -> Result { + let (env_var, subpath) = if cfg!(windows) { + ("APPDATA", "gcloud\\application_default_credentials.json") + } else { + ("HOME", ".config/gcloud/application_default_credentials.json") + }; + + env_ops.get_var(env_var) + .map(|dir| PathBuf::from(dir).join(subpath).to_string_lossy().into_owned()) + .map_err(|_| AuthError::CredentialsError("Could not determine user home directory".to_string())) + } + + async fn load_from_metadata_server(base_url: &str) -> Result { + let client = reqwest::Client::new(); + let metadata_path = "/computeMetadata/v1/instance/service-accounts/default/token"; + + let response = client + .get(format!("{}{}", base_url, metadata_path)) + .header("Metadata-Flavor", "Google") + .send() + .await + .map_err(|e| AuthError::CredentialsError(format!("Metadata server request failed: {}", e)))?; + + if !response.status().is_success() { + return Err(AuthError::CredentialsError( + "Not running on GCP or metadata server unavailable".to_string() + )); + } + + // Get the identity token and credentials from metadata server + let token_response = response + .json::() + .await + .map_err(|e| AuthError::CredentialsError(format!("Invalid metadata response: {}", e)))?; + + // Note: When using metadata server, we have access to the OAuth2 access token + // that can be used to authenticate applications. + Ok(AdcCredentials::DefaultAccount(TokenResponse { + token_type: token_response.token_type, + access_token: token_response.access_token, + expires_in: token_response.expires_in, + })) + } +} + +/// Claims structure for JWT tokens. +/// +/// These claims are included in the JWT token used for service account +/// authentication. +#[derive(Debug, Serialize)] +struct JwtClaims { + /// Token issuer (service account email) + iss: String, + /// Token subject (service account email) + sub: String, + /// Service account scope within role + scope: String, + /// Token audience (OAuth endpoint) + aud: String, + /// Token issued at timestamp + iat: u64, + /// Token expiration timestamp + exp: u64, +} + +/// Holds a cached token and its expiration time. +/// +/// Used internally to implement token caching and automatic refresh. +#[derive(Debug, Clone)] +struct CachedToken { + /// The cached authentication token + token: AuthToken, + /// When the token will expire + expires_at: Instant, +} + +/// Response structure for token exchange requests. +#[derive(Debug, Deserialize, Clone)] +struct TokenResponse { + /// The access token string + access_token: String, + /// Token lifetime in seconds + expires_in: u64, + /// Token type (e.g., "Bearer") + #[serde(default)] + token_type: String, +} + +/// Handles authentication with Google Cloud Platform services. +/// +/// This struct manages the complete authentication lifecycle including: +/// - Loading and validating credentials +/// - Creating and refreshing tokens +/// - Caching tokens for efficient reuse +/// - Managing concurrent access through atomic operations +/// +/// It supports both service account and authorized user authentication methods, +/// automatically selecting the appropriate method based on available credentials. +/// ``` +#[derive(Debug)] +pub struct GcpAuth { + /// The loaded credentials (service account or authorized user) + credentials: AdcCredentials, + /// HTTP client for making token exchange requests + client: reqwest::Client, + /// Thread-safe cache for the current token + cached_token: Arc>>, +} + +impl GcpAuth { + /// Creates a new GCP authentication handler. + /// + /// Initializes the authentication handler by: + /// 1. Loading credentials from default locations + /// 2. Setting up an HTTP client for token requests + /// 3. Initializing the token cache + /// + /// The credentials are loaded in the following order: + /// 1. GOOGLE_APPLICATION_CREDENTIALS environment variable + /// 2. Default gcloud credentials path + /// 3. GCP metadata server (when running on GCP) + /// + /// # Returns + /// * `Result` - A new GcpAuth instance or an error if initialization fails + pub async fn new() -> Result { + Ok(Self { + credentials: AdcCredentials::load().await?, + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(None)), + }) + } + + /// Retrieves a valid authentication token. + /// + /// This method implements an efficient token management strategy: + /// 1. Checks the cache for a valid token + /// 2. Returns the cached token if not expired + /// 3. Obtains a new token if needed or expired + /// 4. Uses double-checked locking for thread safety + /// + /// The returned token includes a type (usually "Bearer") and the actual + /// token value used for authentication with GCP services. + /// + /// # Returns + /// * `Result` - A valid authentication token or an error + pub async fn get_token(&self) -> Result { + // Try read lock first for better concurrency + if let Some(cached) = self.cached_token.read().await.as_ref() { + if cached.expires_at > Instant::now() { + return Ok(cached.token.clone()); + } + } + + // Take write lock only if needed + let mut token_guard = self.cached_token.write().await; + + // Double-check expiration after acquiring write lock + if let Some(cached) = token_guard.as_ref() { + if cached.expires_at > Instant::now() { + return Ok(cached.token.clone()); + } + } + + // Get new token + let token_response = match &self.credentials { + AdcCredentials::ServiceAccount(creds) => self.get_service_account_token(creds).await?, + AdcCredentials::AuthorizedUser(creds) => self.get_authorized_user_token(creds).await?, + AdcCredentials::DefaultAccount(creds ) => self.get_default_access_token(creds).await?, + }; + + let auth_token = AuthToken { + token_type: if token_response.token_type.is_empty() { + "Bearer".to_string() + } else { + token_response.token_type + }, + token_value: token_response.access_token, + }; + + let expires_at = Instant::now() + Duration::from_secs( + token_response.expires_in.saturating_sub(30) // 30 second buffer + ); + + *token_guard = Some(CachedToken { + token: auth_token.clone(), + expires_at, + }); + + Ok(auth_token) + } + + /// Creates a JWT token for service account authentication. + /// + /// # Arguments + /// * `creds` - Service account credentials for signing the token + /// + /// # Returns + /// * `Result` - A signed JWT token + fn create_jwt_token(&self, creds: &ServiceAccountCredentials) -> Result { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| AuthError::TokenCreationError(e.to_string()))? + .as_secs(); + + let claims = JwtClaims { + iss: creds.client_email.clone(), + sub: creds.client_email.clone(), + scope: "https://www.googleapis.com/auth/cloud-platform".to_string(), + aud: creds.token_uri.clone(), + iat: now, + exp: now + 3600, // 1 hours validity + }; + + let encoding_key = EncodingKey::from_rsa_pem(creds.private_key.as_bytes()) + .map_err(|e| AuthError::TokenCreationError(format!("Invalid private key: {}", e)))?; + + encode(&Header::new(jsonwebtoken::Algorithm::RS256), &claims, &encoding_key) + .map_err(|e| AuthError::TokenCreationError(format!("Failed to create JWT: {}", e))) + } + + /// Exchanges a token or assertion for an access token. + /// + /// # Arguments + /// * `token_uri` - The token exchange endpoint + /// * `params` - Parameters for the token exchange request + /// + /// # Returns + /// * `Result` - The token exchange response + async fn exchange_token( + &self, + token_uri: &str, + params: &[(&str, &str)], + ) -> Result { + let response = self.client + .post(token_uri) + .form(params) + .send() + .await + .map_err(|e| AuthError::TokenExchangeError(e.to_string()))?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(AuthError::TokenExchangeError(format!( + "Status {}: {}", + status, + error_text + ))); + } + + response + .json::() + .await + .map_err(|e| AuthError::TokenExchangeError(format!("Invalid response: {}", e))) + } + + /// Gets a token using service account credentials. + /// + /// # Arguments + /// * `creds` - Service account credentials + /// + /// # Returns + /// * `Result` - The token response + async fn get_service_account_token( + &self, + creds: &ServiceAccountCredentials, + ) -> Result { + let jwt = self.create_jwt_token(creds)?; + let params = [ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", &jwt), + ("scope", "https://www.googleapis.com/auth/cloud-platform"), + ]; + + self.exchange_token(&creds.token_uri, ¶ms).await + } + + /// Gets a token using authorized user credentials. + /// + /// # Arguments + /// * `creds` - Authorized user credentials + /// + /// # Returns + /// * `Result` - The token response + async fn get_authorized_user_token( + &self, + creds: &AuthorizedUserCredentials, + ) -> Result { + let params = [ + ("client_id", creds.client_id.as_str()), + ("client_secret", creds.client_secret.as_str()), + ("refresh_token", creds.refresh_token.as_str()), + ("grant_type", "refresh_token"), + ("scope", "https://www.googleapis.com/auth/cloud-platform"), + ]; + + self.exchange_token(&creds.token_uri, ¶ms).await + } + + /// Gets a token directly from the GCP metadata endpoint. + /// + /// # Arguments + /// * `creds` - Default Access Token Response + /// + /// # Returns + /// * `Result` - The token response + async fn get_default_access_token( + &self, + creds: &TokenResponse, + ) -> Result { + Ok(creds.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockall::predicate::eq; + use tokio::time::sleep; + use wiremock::matchers::{header, method, path}; + // Only import what we need + use wiremock::{Mock, MockServer, ResponseTemplate}; + + mockall::mock! { + #[derive(Debug)] + FilesystemOpsMock {} + + #[async_trait] + impl FilesystemOps for FilesystemOpsMock { + async fn read_to_string(&self, path: String) -> Result; + } + } + + mockall::mock! { + #[derive(Debug)] + EnvOpsMock {} + + impl EnvOps for EnvOpsMock { + fn get_var(&self, key: &str) -> Result; + } + } + + struct TestContext { + fs_mock: MockFilesystemOpsMock, + env_mock: MockEnvOpsMock, + mock_server: Option, + } + + impl TestContext { + fn new() -> Self { + Self { + fs_mock: MockFilesystemOpsMock::new(), + env_mock: MockEnvOpsMock::new(), + mock_server: None, + } + } + + async fn with_metadata_server(mut self) -> Self { + self.mock_server = Some(MockServer::start().await); + self + } + } + + // Test fixtures for credentials + fn mock_service_account() -> ServiceAccountCredentials { + ServiceAccountCredentials { + client_email: "test@test.com".to_string(), + // This is a generated test credential + private_key: "-----BEGIN RSA PRIVATE KEY----- +MIIJJwIBAAKCAgEA1AjOgxm0Op/DDVhMK1ifZatszNsKvuFSK12uuJ5oWkOIO+kt +GW/bgN3E+naX9Zsq6yeVG+uJsw9XQbLGKvHAV+H1QIarIGQCsyLUTX06AUdf9Hg7 +bhMK2u6LQm2vnyF+pNu9Xu9zRRS7BIVrtn3ECNIpj+AuTXuZvI2bsfu6W2c54tIa +KuDY68zonesmyfukbMpXiTOPWk6il7Uuj51EcgjDOT1y1fgA6UEIcUb3znq8pqQf +ebnF22rgGH4zFHkJa2j1cCVmJcCyBi74phdupeF80Y6NxNrxcehQzSePrb6PoDwa +VeA7I+9Voi8gCCExztydi1rhMgELvBDbWySLgKPLy3I7apHP6M2FOh8aYUoojX7+ +h7wD+ecMYLUxeZaTtgCKj4igAO14c1c6OVR5UWUlbGFTVxRCZ/+5JsfSzO6DRpql +YcJudtqg1hqAvHEmneSA+/mtFKfRYd86jgHlHFZVIdCdo5CFRBMniYJiJj8/MIKW +TQsmjxLTNTQfsJ92X2sMizJWvlg6d+oP6biYWEhKvkuiKG60PYf/17IMddk16pkM +aYWfVIuDxYzduXDmaX03NV8TfeZIXA9C3SdINePju8U0V3ElK6ipQ6zcb/wSFCcj +v1MmDZ8M7t2F8uhQk+k38BRco9tDlsgZ/yC8n9XZDGi7gUgd0IbRVRPUDt0CAwEA +AQKCAgBRWW+h7OKw+0qifBX9K2s8XqDHl+JviZM1ACRgwKXYu8Aw/C1JbRkSQAOq +9IUovfehcPZMV/nksSYRFr3hDA93qEGoGALf0n8Wq244rKrsgq3V5asneDbZ+FuF +iP+wVfF43rWxDr1y65k1CttgkK/9kmRPxvr8z0cUiGAL0UCWgOw8kc9oVAvlrCAz +Nl0TcXCMLLWY9icxxqmq+uB6SSRRe/sqouDEJvpyg3jxvQCmP4DRjnZlBVlb7Y08 +2G5QlH+Ariw8cpzWLzAeHzdWwfa5veFdpQvPUxD/WtplW6BMUKhaGbUg7X7DMrfw +GZR4igPKEep/5MYxoSUXaoA+X68FYP753HHnQl10r6NsDymAmsAmWMxwUb/Ip6u/ +n19DI8ZXMdgb7aNwDAFdTOYmRVR+UVmJBMKyFKkVDsmqZabYB0yTECHh7Apunro/ +oJEK4E8JHjtLt/+7hhytZNS7e2Je1fw8DeRLoa6cMBraJS3CKEKaabgwmc0yY5ME +fRvt9kqn8XnJON4zV+I80d9S77ihcTr8xlFI+9PAutlmYe5ZgTls4fKpcl8WWxsU +kuQzL+u5I7TBvGZ3XL2uZKc2CPYLho8MGHbh4t5qF3zwjLFWZoQSPywBo7cN0kMP +e5NhjEOY81LvPHTuAup8hnJ8JjR2qHTD7/qZ7e1tOrH7IrhyIQKCAQEA7pqIhffw +O95e/ZshBLynFXVgvTEBzvnsBm7q9ItR2ytcGb15yJl+JNtv3Jcg5uMmfmd2tXxr +68MaJ5/V2j2PQGLcPVlIhCW0b9NH8/c2NA15o78QClbh4x0eqz4qCfwmGsktPC6Q +YUVaFKng+ECTWwjFTApKFUZFE/Jrg2N8RdMjYFIvLEMal8Co1AIn62eHPwC8xlW7 +69F+80KvxxEVmkDxEhG1p/BMQ+dimWdrtxyB+20LWK1N7zpg/Cmzo50gyLxvvJ6W +ekXdJpG1LcwVZxqvUK1NMvbxpLFFUY4ZCmotlw9M8i/3W+Hfs4HSqKI3lUOYDYQd +8xRQw6N8BSOHFwKCAQEA435dxFB46FgYN8NfCv8qUgO38maO0pETQjrUh5A4J3pS +UyNIWqAmlkMo9tCDQZMyvhl8fV/uQoeDW9FiCijaffE7POkyRRTt+0mz/xuxjoeT +Dc5IREE6xcLOd/nH6EsWZu3B0HWoLcK+63Dt2psGFUdqMRAuwr9XGfI3uqr8slTQ +uqTpEc+/i80/hyWSu4+dDTwt+sU4+3dYiY719GHOXy5/j54jz0LwjiH4G7Di5teT +yAWRX9SD06dSHy1qgqY7LZ3cxtLmQEGmFtTEPL5h/tPKx/tyX3baEiH6MmyuS1FK +o30TYQMb16taN4wC1ztDjJ/BCOJqVOF5fU1kNYFSKwKCAQB4CgDDPXB7/izV89SR +uINqtUm9BMm/IlcPCYBlFS5SUCcewAdj12zyB//n/5RK9F5qW40KUxVMYDRpWO1S +xYOrRdE9gAyOhxWW6LmbUHTRjTH0Imxkdz9fbkf+qOCnc1aMRUffriFu/mAKY0jO +PFamBuyTi92nhFm+ZkiWqldcHZP/onkfEIdxbzjAqHEC6mvNU4alVX6cbiIrKhKa +2MqAd0mQ6J32ZltIEkG1oaU8UzhFkJ+TtmSuBTXDxwscNjHHK54fS72yuDFBdS6s +Yq8l1vP6Z6WeDUSWsaSJGi8Y4UAcblMsyNruO926Rob/1dSW4JG/wwb6Qu867aW4 +RB5zAoIBABsXyJkBsHSTUUcK2H3Zx7N+x+BxgF7pci64DOmcLmPdOIK4N/y7B/1r +QCysxoT/v9JN/Lp9u0VnGCjONevZ07OeEBz/9MGvbWw46dve83VzBftl7staLWKy +AZ7eO4WZs7BMboGiEYZppA0sJNedEMtl9uqi7763xOrNIv/zLycZ3MXtr+g0Iq7G +oeM5gVEfGGgkG6G67T9dhkjTos0Y/NfvFLgI8GDVqwpyVzcNCOjPEcWHjDmqeIyz +Z59Y7E9k9rVHEK0JHuzWJK6hZkGJtuf/Vy4b7xIZeH0iWMa6lMNZihcQZUdvdFhq +CtOEtC3n2/KacAXb2SgEtlBK8D1DCoMCggEAVypafwslJIId0hyNJmX0QesXSfbT +AqNSNifeQTby0fqyJUJbslxS6AauQnPwUNEZHiFnRGVJ3FgMNnm7hdDaguVdjS6S +tgBJmh9PW84RqJm8BNMguUBzUWId4Nh1xDJtI+Klhx8YA2Sfx7nHkabQLAkolmAW +g/kWgQ+sZowHm8h9KJ84ojqC1LeZKjnvhINPGCXM8JhzPOABsDfl5fNFeK5+xOSG +erYuWN1BB3Dl3Pal75Ryu7vqk/0uumdRWfqOkf4wgUIZvD+mRdngT9QmK9doT8z7 +iXVBc2YmAuU8hiOFUPxtyQfNzG5fQ0rhJSewdtyWxIadJSLj6fsK+AEsNQ== +-----END RSA PRIVATE KEY-----".to_string(), + token_uri: "https://oauth2.googleapis.com/token".to_string(), + } + } + + fn mock_authorized_user() -> AuthorizedUserCredentials { + AuthorizedUserCredentials { + client_id: "test_client".to_string(), + client_secret: "test_secret".to_string(), + refresh_token: "test_refresh".to_string(), + token_uri: "https://oauth2.googleapis.com/token".to_string(), + } + } + + // Helper function to create a test GcpAuth instance with credentials + async fn create_test_auth_with_creds(creds: AdcCredentials) -> GcpAuth { + GcpAuth { + credentials: creds, + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(None)), + } + } + + #[tokio::test] + async fn test_token_caching() { + let auth = GcpAuth { + credentials: AdcCredentials::ServiceAccount(mock_service_account()), + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(Some(CachedToken { + token: AuthToken { + token_type: "Bearer".to_string(), + token_value: "cached_token".to_string(), + }, + expires_at: Instant::now() + Duration::from_secs(3600), + }))), + }; + + // First call should return cached token + let token1 = auth.get_token().await.unwrap(); + assert_eq!(token1.token_value, "cached_token"); + + // Second call should return same cached token + let token2 = auth.get_token().await.unwrap(); + assert_eq!(token2.token_value, "cached_token"); + } + + #[tokio::test] + async fn test_token_expiration() { + let auth = GcpAuth { + credentials: AdcCredentials::ServiceAccount(mock_service_account()), + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(Some(CachedToken { + token: AuthToken { + token_type: "Bearer".to_string(), + token_value: "expired_token".to_string(), + }, + expires_at: Instant::now() - Duration::from_secs(1), + }))), + }; + + // Should fail as token is expired and real credentials aren't available + let result = auth.get_token().await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_invalid_credentials() { + let auth = create_test_auth_with_creds( + AdcCredentials::ServiceAccount(ServiceAccountCredentials { + client_email: "".to_string(), + private_key: "invalid".to_string(), + token_uri: "https://invalid.example.com".to_string(), + }) + ).await; + + let result = auth.get_token().await; + assert!(result.is_err()); + match result { + Err(AuthError::TokenCreationError(_)) => (), + _ => panic!("Expected TokenCreationError"), + } + } + + #[tokio::test] + async fn test_concurrent_token_access() { + let auth = Arc::new(GcpAuth { + credentials: AdcCredentials::ServiceAccount(mock_service_account()), + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(Some(CachedToken { + token: AuthToken { + token_type: "Bearer".to_string(), + token_value: "concurrent_token".to_string(), + }, + expires_at: Instant::now() + Duration::from_secs(3600), + }))), + }); + + let mut handles = vec![]; + + // Spawn multiple concurrent token requests + for _ in 0..10 { + let auth_clone = Arc::clone(&auth); + handles.push(tokio::spawn(async move { + auth_clone.get_token().await.unwrap() + })); + } + + // All requests should return the same cached token + for handle in handles { + let token = handle.await.unwrap(); + assert_eq!(token.token_value, "concurrent_token"); + } + } + + #[tokio::test] + async fn test_token_refresh_race_condition() { + let auth = Arc::new(GcpAuth { + credentials: AdcCredentials::ServiceAccount(mock_service_account()), + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(Some(CachedToken { + token: AuthToken { + token_type: "Bearer".to_string(), + token_value: "about_to_expire".to_string(), + }, + expires_at: Instant::now() + Duration::from_millis(100), + }))), + }); + + let mut handles = vec![]; + + for i in 0..5 { + let auth_clone = Arc::clone(&auth); + handles.push(tokio::spawn(async move { + sleep(Duration::from_millis(i * 50)).await; + let result = auth_clone.get_token().await; + match result { + Ok(token) => { + // Should be the cached token since we can't actually exchange tokens in tests + assert_eq!(token.token_value, "about_to_expire", + "Expected cached token, got: {}", token.token_value); + }, + Err(e) => { + match e { + AuthError::TokenExchangeError(err) => { + // This is expected - we can't actually exchange tokens in tests + assert!(err.contains("invalid_scope") || err.contains("400"), + "Unexpected error message: {}", err); + }, + other => panic!("Unexpected error type: {:?}", other), + } + } + } + })); + } + + // Wait for all handles + for handle in handles { + handle.await.unwrap(); + } + } + + #[tokio::test] + async fn test_authorized_user_token() { + let auth = GcpAuth { + credentials: AdcCredentials::AuthorizedUser(mock_authorized_user()), + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(None)), + }; + + // This should fail since we can't actually make the token exchange request + let result = auth.get_token().await; + assert!(result.is_err()); + match result { + Err(AuthError::TokenExchangeError(_)) => (), + _ => panic!("Expected TokenExchangeError"), + } + } + + #[tokio::test] + async fn test_service_account_jwt_creation() { + let auth = GcpAuth { + credentials: AdcCredentials::ServiceAccount(mock_service_account()), + client: reqwest::Client::new(), + cached_token: Arc::new(RwLock::new(None)), + }; + + let jwt = auth.create_jwt_token(&mock_service_account()); + assert!(jwt.is_ok(), "JWT creation failed: {:?}", jwt.err()); + let jwt_str = jwt.unwrap(); + assert!(jwt_str.starts_with("ey"), "JWT should start with 'ey'"); + assert_eq!(jwt_str.matches('.').count(), 2, "JWT should have exactly 2 dots"); + } + + #[tokio::test] + async fn test_load_from_env_credentials() { + let mut context = TestContext::new(); + + // Mock environment variable + context.env_mock + .expect_get_var() + .with(eq("GOOGLE_APPLICATION_CREDENTIALS")) + .times(1) + .return_once(|_| Ok("/path/to/credentials.json".to_string())); + + // Mock file content - convert &str to String for comparison + let creds_content = r#"{ + "type": "service_account", + "client_email": "test@example.com", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIE...test...key\n-----END PRIVATE KEY-----\n", + "token_uri": "https://oauth2.googleapis.com/token" + }"#; + + context.fs_mock + .expect_read_to_string() + .with(eq("/path/to/credentials.json".to_string())) // Convert to String + .times(1) + .return_once(move |_| Ok(creds_content.to_string())); + + let result = AdcCredentials::load_impl( + &context.fs_mock, + &context.env_mock, + "http://metadata.example.com" + ).await; + + assert!(result.is_ok()); + if let Ok(AdcCredentials::ServiceAccount(sa)) = result { + assert_eq!(sa.client_email, "test@example.com"); + assert!(sa.private_key.contains("test...key")); + } else { + panic!("Expected ServiceAccount credentials"); + } + } + + #[tokio::test] + async fn test_load_from_default_path() { + let mut context = TestContext::new(); + + // Mock environment variables + context.env_mock + .expect_get_var() + .with(eq("GOOGLE_APPLICATION_CREDENTIALS")) + .times(1) + .return_once(|_| Err(env::VarError::NotPresent)); + + let home_var = if cfg!(windows) { "APPDATA" } else { "HOME" }; + context.env_mock + .expect_get_var() + .with(eq(home_var)) + .times(1) + .return_once(|_| Ok("/home/testuser".to_string())); + + // Mock file content + let creds_content = r#"{ + "type": "authorized_user", + "client_id": "test_client", + "client_secret": "test_secret", + "refresh_token": "test_refresh" + }"#; + + let expected_path = if cfg!(windows) { + "/home/testuser/gcloud/application_default_credentials.json".to_string() + } else { + "/home/testuser/.config/gcloud/application_default_credentials.json".to_string() + }; + + context.fs_mock + .expect_read_to_string() + .with(eq(expected_path.clone())) // Use clone() to avoid borrowing issues + .times(1) + .return_once(move |_| Ok(creds_content.to_string())); + + let result = AdcCredentials::load_impl( + &context.fs_mock, + &context.env_mock, + "http://metadata.example.com" + ).await; + + assert!(result.is_ok()); + if let Ok(AdcCredentials::AuthorizedUser(au)) = result { + assert_eq!(au.client_id, "test_client"); + assert_eq!(au.client_secret, "test_secret"); + assert_eq!(au.refresh_token, "test_refresh"); + } else { + panic!("Expected AuthorizedUser credentials"); + } + } + + #[tokio::test] + async fn test_load_from_metadata_server() { + let mut context = TestContext::new(); + + // Mock environment variable lookups to fail + context.env_mock + .expect_get_var() + .with(eq("GOOGLE_APPLICATION_CREDENTIALS")) + .times(1) + .return_once(|_| Err(env::VarError::NotPresent)); + + let home_var = if cfg!(windows) { "APPDATA" } else { "HOME" }; + context.env_mock + .expect_get_var() + .with(eq(home_var)) + .times(1) + .return_once(|_| Err(env::VarError::NotPresent)); + + // Initialize mock server + let context = context.with_metadata_server().await; + let mock_server = context.mock_server.as_ref().expect("Mock server should be initialized"); + + // Define expected token values + let expected_token = "test_token"; + let expected_type = "Bearer"; + let expected_expires = 3600; + + // Configure mock response + Mock::given(method("GET")) + .and(path("/computeMetadata/v1/instance/service-accounts/default/token")) + .and(header("Metadata-Flavor", "Google")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({ + "access_token": expected_token, + "expires_in": expected_expires, + "token_type": expected_type, + })) + ) + .mount(mock_server) + .await; + + // Execute the code under test + let result = AdcCredentials::load_impl( + &context.fs_mock, + &context.env_mock, + &mock_server.uri() + ).await; + + // Assertions + assert!(result.is_ok(), "Expected successful result, got {:?}", result); + + if let Ok(AdcCredentials::DefaultAccount(token_response)) = result { + assert_eq!(token_response.access_token, expected_token); + assert_eq!(token_response.token_type, expected_type); + assert_eq!(token_response.expires_in, expected_expires); + } else { + panic!("Expected DefaultAccount credentials, got {:?}", result); + } + } + + #[tokio::test] + async fn test_invalid_credentials_file() { + let mut context = TestContext::new(); + + // Mock GOOGLE_APPLICATION_CREDENTIALS environment variable + context.env_mock + .expect_get_var() + .with(eq("GOOGLE_APPLICATION_CREDENTIALS")) + .times(1) + .return_once(|_| Ok("/path/to/credentials.json".to_string())); + + // Mock filesystem read for the invalid credentials file + context.fs_mock + .expect_read_to_string() + .with(eq("/path/to/credentials.json".to_string())) + .times(1) + .return_once(|_| Ok("invalid json".to_string())); + + // Mock HOME/APPDATA environment variable + let home_var = if cfg!(windows) { "APPDATA" } else { "HOME" }; + context.env_mock + .expect_get_var() + .with(eq(home_var)) + .times(1) + .return_once(|_| Ok("/home/user".to_string())); + + // Mock filesystem read for the default credentials path + let default_creds_path = if cfg!(windows) { + "/home/user/gcloud/application_default_credentials.json" + } else { + "/home/user/.config/gcloud/application_default_credentials.json" + }; + context.fs_mock + .expect_read_to_string() + .with(eq(default_creds_path.to_string())) + .times(1) + .return_once(|_| Err(std::io::Error::new(std::io::ErrorKind::NotFound, "File not found"))); + + let result = AdcCredentials::load_impl( + &context.fs_mock, + &context.env_mock, + "http://metadata.example.com" + ).await; + + assert!(matches!(result, Err(AuthError::CredentialsError(_)))); + } + + #[tokio::test] + async fn test_no_credentials_found() { + let mut context = TestContext::new(); + + // Mock all credential sources to fail + context.env_mock + .expect_get_var() + .with(eq("GOOGLE_APPLICATION_CREDENTIALS")) + .times(1) + .return_once(|_| Err(env::VarError::NotPresent)); + + context.env_mock + .expect_get_var() + .with(eq(if cfg!(windows) { "APPDATA" } else { "HOME" })) + .times(1) + .return_once(|_| Err(env::VarError::NotPresent)); + + let result = AdcCredentials::load_impl(&context.fs_mock, &context.env_mock, "http://metadata.example.com").await; + assert!(matches!(result, Err(AuthError::CredentialsError(_)))); + } +} diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs new file mode 100644 index 000000000..fda50646d --- /dev/null +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -0,0 +1,342 @@ +use std::format; +use std::time::Duration; +use std::vec; + +use anyhow::Result; +use async_trait::async_trait; +use reqwest::{Client, StatusCode}; +use serde_json::Value; +use url::Url; + +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; + +use crate::providers::errors::ProviderError; +use crate::providers::formats::gcpvertexai::{ + create_request, + get_usage, + response_to_message, + ClaudeVersion, + GcpVertexAIModel, + GeminiVersion, + ModelProvider, + RequestContext, +}; + +use crate::providers::gcpauth::GcpAuth; +use crate::providers::utils::emit_debug_trace; +use mcp_core::tool::Tool; + +/// Base URL for GCP Vertex AI documentation +const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai"; +/// Fallback default GCP region for model deployment +const GCP_DEFAULT_LOCATION: &str = "us-central1"; +/// Default timeout for API requests in seconds +const DEFAULT_TIMEOUT_SECS: u64 = 600; + +/// Represents errors specific to GCP Vertex AI operations. +/// +/// This enum encompasses various error conditions that might arise when working +/// with the GCP Vertex AI provider, particularly around URL construction and authentication. +#[derive(Debug, thiserror::Error)] +enum GcpVertexAIError { + /// Error when URL construction fails + #[error("Invalid URL configuration: {0}")] + InvalidUrl(String), + + /// Error during GCP authentication + #[error("Authentication error: {0}")] + AuthError(String), +} + +/// Provider implementation for Google Cloud Platform's Vertex AI service. +/// +/// This provider enables interaction with various AI models hosted on GCP Vertex AI, +/// including Claude and Gemini model families. It handles authentication, request routing, +/// and response processing for the Vertex AI API endpoints. +#[derive(Debug, serde::Serialize)] +pub struct GcpVertexAIProvider { + /// HTTP client for making API requests + #[serde(skip)] + client: Client, + /// GCP authentication handler + #[serde(skip)] + auth: GcpAuth, + /// Base URL for the Vertex AI API + host: String, + /// GCP project identifier + project_id: String, + /// GCP region for model deployment + location: String, + /// Configuration for the specific model being used + model: ModelConfig, +} + +impl GcpVertexAIProvider { + /// Creates a new provider instance from environment configuration. + /// + /// This is a convenience method that initializes the provider using + /// environment variables and default settings. + /// + /// # Arguments + /// * `model` - Configuration for the model to be used + pub fn from_env(model: ModelConfig) -> Result { + futures::executor::block_on(Self::new(model)) + } + + /// Creates a new provider instance with the specified model configuration. + /// + /// Initializes the provider with custom settings and establishes necessary + /// client connections and authentication. + /// + /// # Arguments + /// * `model` - Configuration for the model to be used + pub async fn new(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + let project_id = config.get("GCP_PROJECT_ID")?; + let location = Self::determine_location(&config, &model)?; + let host = format!("https://{}-aiplatform.googleapis.com", location); + + let client = Client::builder() + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) + .build()?; + + let auth = GcpAuth::new().await?; + + Ok(Self { + client, + auth, + host, + project_id, + location, + model, + }) + } + + /// Determines the appropriate GCP location for model deployment. + /// + /// Location is determined in the following order: + /// 1. Custom location from GCP_LOCATION environment variable + /// 2. Model's default location + /// 3. Global default location (us-central1) + fn determine_location(config: &crate::config::Config, model: &ModelConfig) -> Result { + Ok(config + .get("GCP_LOCATION") + .ok() + .filter(|loc: &String| !loc.trim().is_empty() && loc != "default") + .unwrap_or_else(|| { + GcpVertexAIModel::try_from(model.model_name.as_str()) + .map(|m| m.default_location().to_string()) + .unwrap_or_else(|_| GCP_DEFAULT_LOCATION.to_string()) + })) + } + + /// Retrieves an authentication token for API requests. + /// + /// # Returns + /// * `Result` - Bearer token for authentication + async fn get_auth_header(&self) -> Result { + self.auth + .get_token() + .await + .map(|token| format!("Bearer {}", token.token_value)) + .map_err(|e| GcpVertexAIError::AuthError(e.to_string())) + } + + /// Constructs the appropriate API endpoint URL for a given provider. + /// + /// # Arguments + /// * `provider` - The model provider (Anthropic or Google) + /// + /// # Returns + /// * `Result` - Fully qualified API endpoint URL + fn build_request_url(&self, provider: ModelProvider) -> Result { + let base_url = Url::parse(&self.host) + .map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?; + + let path = format!( + "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}", + self.project_id, + self.location, + provider.as_str(), + self.model.model_name, + match provider { + ModelProvider::Anthropic => "streamRawPredict", + ModelProvider::Google => "generateContent", + } + ); + + base_url + .join(&path) + .map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string())) + } + + /// Makes an authenticated POST request to the Vertex AI API. + /// + /// # Arguments + /// * `payload` - The request payload to send + /// * `context` - Request context containing model information + /// + /// # Returns + /// * `Result` - JSON response from the API + async fn post(&self, payload: Value, context: RequestContext) -> Result { + let url = self.build_request_url(context.provider()) + .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; + + let auth_header = self.get_auth_header() + .await + .map_err(|e| ProviderError::Authentication(e.to_string()))?; + + let response = self.client + .post(url) + .json(&payload) + .header("Authorization", auth_header) + .send() + .await + .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; + + let status = response.status(); + let response_json = response + .json::() + .await + .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?; + + match status { + StatusCode::OK => Ok(response_json), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + tracing::debug!("Authentication failed. Status: {status}, Payload: {payload:?}"); + Err(ProviderError::Authentication(format!( + "Authentication failed: {response_json:?}" + ))) + } + _ => { + tracing::debug!("Request failed. Status: {status}, Response: {response_json:?}"); + Err(ProviderError::RequestFailed(format!( + "Request failed with status {status}: {response_json:?}" + ))) + } + } + } +} + +impl Default for GcpVertexAIProvider { + fn default() -> Self { + let model = ModelConfig::new(Self::metadata().default_model); + futures::executor::block_on(Self::new(model)) + .expect("Failed to initialize VertexAI provider") + } +} + +#[async_trait] +impl Provider for GcpVertexAIProvider { + /// Returns metadata about the GCP Vertex AI provider. + /// + /// This includes information about supported models, configuration requirements, + /// and documentation links. + fn metadata() -> ProviderMetadata + where + Self: Sized, + { + let known_models = vec![ + GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35), + GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35V2), + GcpVertexAIModel::Gemini(GeminiVersion::Pro15), + GcpVertexAIModel::Gemini(GeminiVersion::Flash20), + GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp), + ] + .into_iter() + .map(|model| model.to_string()) + .collect(); + + ProviderMetadata::new( + "gcp_vertex_ai", + "GCP Vertex AI", + "Access variety of AI models such as Claude, Gemini through Vertex AI", + GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35V2).to_string().as_str(), + known_models, + GCP_VERTEX_AI_DOC_URL, + vec![ + ConfigKey::new("GCP_PROJECT_ID", true, false, None), + ConfigKey::new("GCP_LOCATION", false, false, None), + ], + ) + } + + /// Completes a model interaction by sending a request and processing the response. + /// + /// # Arguments + /// * `system` - System prompt or context + /// * `messages` - Array of previous messages in the conversation + /// * `tools` - Array of available tools for the model + /// + /// # Returns + /// * `Result<(Message, ProviderUsage)>` - Tuple of response message and usage statistics + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let (request, context) = create_request(&self.model, system, messages, tools)?; + let response = self.post(request.clone(), context.clone()).await?; + let usage = get_usage(&response, &context)?; + + emit_debug_trace(self, &request, &response, &usage); + + let message = response_to_message(response.clone(), context.clone())?; + let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage); + + Ok((message, provider_usage)) + } + + /// Returns the current model configuration. + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_provider_conversion() { + assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic"); + assert_eq!(ModelProvider::Google.as_str(), "google"); + } + + #[tokio::test] + async fn test_url_construction() { + let model = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string()); + let provider = GcpVertexAIProvider { + client: Client::new(), + auth: GcpAuth::new().await.expect("Failed to create GcpAuth"), + host: "https://us-east5-aiplatform.googleapis.com".to_string(), + project_id: "test-project".to_string(), + location: "us-east5".to_string(), + model, + }; + + let url = provider + .build_request_url(ModelProvider::Anthropic) + .unwrap() + .to_string(); + + assert!(url.contains("publishers/anthropic")); + assert!(url.contains("projects/test-project")); + assert!(url.contains("locations/us-east5")); + } + + #[test] + fn test_provider_metadata() { + let metadata = GcpVertexAIProvider::metadata(); + assert!(metadata.known_models.contains(&"claude-3-5-sonnet-v2@20241022".to_string())); + assert!(metadata.known_models.contains(&"gemini-1.5-pro-002".to_string())); + assert_eq!(metadata.config_keys.len(), 2); + } +} \ No newline at end of file diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 634224fd7..74b19848f 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -6,6 +6,8 @@ pub mod databricks; pub mod errors; mod factory; pub mod formats; +mod gcpauth; +pub mod gcpvertexai; pub mod google; pub mod groq; pub mod oauth; diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index 4225797f9..cbaf296be 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -6,12 +6,18 @@ use goose::agents::AgentFactory; use goose::message::Message; use goose::model::ModelConfig; use goose::providers::base::Provider; -use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider}; use goose::providers::{ - azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider, + anthropic::AnthropicProvider, + azure::AzureProvider, + bedrock::BedrockProvider, + databricks::DatabricksProvider, + gcpvertexai::GcpVertexAIProvider, + google::GoogleProvider, + groq::GroqProvider, + ollama::OllamaProvider, + openai::OpenAiProvider, openrouter::OpenRouterProvider, }; -use goose::providers::{google::GoogleProvider, groq::GroqProvider}; #[derive(Debug, PartialEq)] enum ProviderType { @@ -20,6 +26,7 @@ enum ProviderType { Anthropic, Bedrock, Databricks, + GcpVertexAI, Google, Groq, Ollama, @@ -42,6 +49,7 @@ impl ProviderType { ProviderType::Groq => &["GROQ_API_KEY"], ProviderType::Ollama => &[], ProviderType::OpenRouter => &["OPENROUTER_API_KEY"], + ProviderType::GcpVertexAI => &["GCP_PROJECT_ID", "GCP_LOCATION"], } } @@ -70,6 +78,7 @@ impl ProviderType { ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?), ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?), ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?), + ProviderType::GcpVertexAI => Box::new(GcpVertexAIProvider::from_env(model_config)?), ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?), ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?), ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?), @@ -290,4 +299,14 @@ mod tests { }) .await } + + #[tokio::test] + async fn test_truncate_agent_with_gcpvertexai() -> Result<()> { + run_test_with_config(TestConfig { + provider_type: ProviderType::GcpVertexAI, + model: "claude-3-5-sonnet-v2@20241022", + context_window: 200_000, + }) + .await + } } diff --git a/documentation/docs/getting-started/providers.md b/documentation/docs/getting-started/providers.md index 0384dc59d..666a5b8c9 100644 --- a/documentation/docs/getting-started/providers.md +++ b/documentation/docs/getting-started/providers.md @@ -17,17 +17,18 @@ Goose relies heavily on tool calling capabilities and currently works best with ## Available Providers -| Provider | Description | Parameters | -|-----------------------------------------------|-----------------------------------------------------|---------------------------------------| -|[Amazon Bedrock](https://aws.amazon.com/bedrock/)| Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`| -| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` | -|[Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5.| `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` | -| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` | -| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` | -| [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` | -| [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` | -| [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) | -| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` | +| Provider | Description | Parameters | +|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------| +| [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION` | +| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` | +| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` | +| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` | +| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` | +| [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be configured in advance. Follow the instructions at https://cloud.google.com/vertex-ai/docs/authentication.** | `GCP_PROJECT_ID`, `GCP_LOCATION` (optional - `default` picks up known supported locations). | +| [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` | +| [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` | +| [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) | +| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` | diff --git a/ui/desktop/src/components/settings/api_keys/utils.tsx b/ui/desktop/src/components/settings/api_keys/utils.tsx index b29014db8..75811459b 100644 --- a/ui/desktop/src/components/settings/api_keys/utils.tsx +++ b/ui/desktop/src/components/settings/api_keys/utils.tsx @@ -10,6 +10,8 @@ export function isSecretKey(keyName: string): boolean { 'OPENAI_HOST', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME', + 'GCP_PROJECT_ID', + 'GCP_LOCATION', ]; return !nonSecretKeys.includes(keyName); } diff --git a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx index b38840153..9479570ae 100644 --- a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx +++ b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx @@ -19,6 +19,11 @@ export const goose_models: Model[] = [ { id: 17, name: 'qwen2.5', provider: 'Ollama' }, { id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' }, { id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' }, + { id: 20, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' }, + { id: 21, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' }, + { id: 22, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' }, + { id: 23, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' }, + { id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' }, ]; export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1']; @@ -47,6 +52,14 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet']; export const azure_openai_models = ['gpt-4o']; +export const gcp_vertex_ai_models = [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-3-5-sonnet@20240620', + 'gemini-1.5-pro-002', + 'gemini-2.0-flash-001', + 'gemini-2.0-pro-exp-02-05', +]; + export const default_models = { openai: 'gpt-4o', anthropic: 'claude-3-5-sonnet-latest', @@ -56,6 +69,7 @@ export const default_models = { openrouter: 'anthropic/claude-3.5-sonnet', ollama: 'qwen2.5', azure_openai: 'gpt-4o', + gcp_vertex_ai: 'claude-3-5-sonnet-v2@20241022', }; export function getDefaultModel(key: string): string | undefined { @@ -73,11 +87,13 @@ export const required_keys = { Google: ['GOOGLE_API_KEY'], OpenRouter: ['OPENROUTER_API_KEY'], 'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'], + 'GCP Vertex AI': ['GCP_PROJECT_ID', 'GCP_LOCATION'], }; export const default_key_value = { OPENAI_HOST: 'https://api.openai.com', OLLAMA_HOST: 'localhost', + GCP_LOCATION: 'default', }; export const supported_providers = [ @@ -89,6 +105,7 @@ export const supported_providers = [ 'Ollama', 'OpenRouter', 'Azure OpenAI', + 'GCP Vertex AI', ]; export const model_docs_link = [ @@ -102,6 +119,7 @@ export const model_docs_link = [ }, { name: 'OpenRouter', href: 'https://openrouter.ai/models' }, { name: 'Ollama', href: 'https://ollama.com/library' }, + { name: 'GCP Vertex AI', href: 'https://cloud.google.com/vertex-ai' }, ]; export const provider_aliases = [ @@ -113,4 +131,5 @@ export const provider_aliases = [ { provider: 'OpenRouter', alias: 'openrouter' }, { provider: 'Google', alias: 'google' }, { provider: 'Azure OpenAI', alias: 'azure_openai' }, + { provider: 'GCP Vertex AI', alias: 'gcp_vertex_ai' }, ];