Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: provider settings alpha version #625

Merged
merged 32 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
80ccc17
list provider secret status endpoint
lily-de Jan 16, 2025
6c60d80
add tests
lily-de Jan 16, 2025
f8fc4a8
passing tests
lily-de Jan 16, 2025
3875e49
add back in store route
lily-de Jan 16, 2025
c8e8162
resolve conflicts
lily-de Jan 17, 2025
265bde2
ability to delete keys
lily-de Jan 16, 2025
4620c94
show selected provider
lily-de Jan 16, 2025
074bb20
selected provider displayed
lily-de Jan 16, 2025
896ed09
allow for key editing
lily-de Jan 16, 2025
226604f
broken attempt to change providers
lily-de Jan 16, 2025
0c2ca40
stable keys page
lily-de Jan 16, 2025
36660bb
resolve conflicts
lily-de Jan 17, 2025
4d24342
refresh current window provider state
lily-de Jan 17, 2025
da317cf
update list of supported providers
lily-de Jan 17, 2025
111ef59
resolve conflicts
lily-de Jan 17, 2025
7774f59
fetch all supported providers to show in settings
lily-de Jan 17, 2025
0cafc5c
resolve conflicts
lily-de Jan 17, 2025
a0f36d1
modularize the provider keys code
lily-de Jan 17, 2025
9ff3f9c
test
lily-de Jan 17, 2025
929fee8
remove unused import
lily-de Jan 17, 2025
5b4afb0
hide provider settings
lily-de Jan 17, 2025
585094d
lint
lily-de Jan 17, 2025
e886cb6
lint desktop
lily-de Jan 17, 2025
dd7fd36
lint rust
lily-de Jan 17, 2025
60f6588
lint desktop
lily-de Jan 17, 2025
7ef8530
lint
lily-de Jan 17, 2025
b64852a
lint
lily-de Jan 17, 2025
107ee2e
try to fix tests
lily-de Jan 17, 2025
46ccb4d
fix tests
lily-de Jan 17, 2025
40d6e6b
fix whitespace
lily-de Jan 17, 2025
a1cc92a
change anthropic to what is on v1.0
lily-de Jan 17, 2025
6e61270
remove whitespace
lily-de Jan 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ http = "1.0"
config = { version = "0.14.1", features = ["toml"] }
thiserror = "1.0"
clap = { version = "4.4", features = ["derive"] }
once_cell = "1.18"

[[bin]]
name = "goosed"
Expand Down
47 changes: 47 additions & 0 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use axum::{
};
use goose::{agents::AgentFactory, providers::factory};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Serialize)]
struct VersionsResponse {
Expand All @@ -25,6 +26,28 @@ struct CreateAgentResponse {
version: String,
}

#[derive(Deserialize)]
struct ProviderFile {
name: String,
description: String,
models: Vec<String>,
required_keys: Vec<String>,
}

#[derive(Serialize)]
struct ProviderDetails {
name: String,
description: String,
models: Vec<String>,
required_keys: Vec<String>,
}

#[derive(Serialize)]
struct ProviderList {
id: String,
details: ProviderDetails,
}

async fn get_versions() -> Json<VersionsResponse> {
let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version().to_string();
Expand Down Expand Up @@ -64,9 +87,33 @@ async fn create_agent(
Ok(Json(CreateAgentResponse { version }))
}

async fn list_providers() -> Json<Vec<ProviderList>> {
let contents = include_str!("providers_and_keys.json");

let providers: HashMap<String, ProviderFile> =
serde_json::from_str(contents).expect("Failed to parse providers_and_keys.json");

let response: Vec<ProviderList> = providers
.into_iter()
.map(|(id, provider)| ProviderList {
id,
details: ProviderDetails {
name: provider.name,
description: provider.description,
models: provider.models,
required_keys: provider.required_keys,
},
})
.collect();

// Return the response as JSON.
Json(response)
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
.route("/agent/providers", get(list_providers))
.route("/agent", post(create_agent))
.with_state(state)
}
38 changes: 38 additions & 0 deletions crates/goose-server/src/routes/providers_and_keys.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"openai": {
"name": "OpenAI",
"description": "Use GPT-4 and other OpenAI models",
"models": ["gpt-4o", "gpt-4-turbo","o1"],
"required_keys": ["OPENAI_API_KEY"]
},
"anthropic": {
"name": "Anthropic",
"description": "Use Claude and other Anthropic models",
"models": ["claude-3.5-sonnet-2"],
"required_keys": ["ANTHROPIC_API_KEY"]
},
"databricks": {
"name": "Databricks",
"description": "Connect to LLMs via Databricks",
"models": ["claude-3-5-sonnet-2"],
"required_keys": ["DATABRICKS_HOST"]
},
"google": {
"name": "Google",
"description": "Lorem ipsum",
"models": ["gemini-1.5-flash"],
"required_keys": ["GOOGLE_API_KEY"]
},
"grok": {
"name": "Grok",
"description": "Lorem ipsum",
"models": ["llama-3.3-70b-versatile"],
"required_keys": ["GROK_API_KEY"]
},
"ollama": {
"name": "Ollama",
"description": "Lorem ipsum",
"models": ["qwen2.5"],
"required_keys": []
}
}
153 changes: 151 additions & 2 deletions crates/goose-server/src/routes/secrets.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::state::AppState;
use axum::{extract::State, routing::post, Json, Router};
use goose::key_manager::save_to_keyring;
use axum::{extract::State, routing::delete, routing::post, Json, Router};
use goose::key_manager::{
delete_from_keyring, get_keyring_secret, save_to_keyring, KeyRetrievalStrategy,
};
use http::{HeaderMap, StatusCode};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Serialize)]
struct SecretResponse {
Expand Down Expand Up @@ -36,8 +40,153 @@ async fn store_secret(
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderSecretRequest {
pub providers: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct SecretStatus {
pub is_set: bool,
pub location: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderResponse {
pub supported: bool,
pub name: Option<String>,
pub description: Option<String>,
pub models: Option<Vec<String>>,
pub secret_status: HashMap<String, SecretStatus>,
}

#[derive(Debug, Serialize, Deserialize)]
struct ProviderConfig {
name: String,
description: String,
models: Vec<String>,
required_keys: Vec<String>,
}

static PROVIDER_ENV_REQUIREMENTS: Lazy<HashMap<String, ProviderConfig>> = Lazy::new(|| {
let contents = include_str!("providers_and_keys.json");
serde_json::from_str(contents).expect("Failed to parse providers_and_keys.json")
});

fn check_key_status(key: &str) -> (bool, Option<String>) {
if let Ok(_value) = std::env::var(key) {
(true, Some("env".to_string()))
} else if let Ok(_) = get_keyring_secret(key, KeyRetrievalStrategy::KeyringOnly) {
(true, Some("keyring".to_string()))
} else {
(false, None)
}
}

async fn check_provider_secrets(
Json(request): Json<ProviderSecretRequest>,
) -> Result<Json<HashMap<String, ProviderResponse>>, StatusCode> {
let mut response = HashMap::new();

for provider_name in request.providers {
if let Some(provider_config) = PROVIDER_ENV_REQUIREMENTS.get(&provider_name) {
let mut secret_status = HashMap::new();

for key in &provider_config.required_keys {
let (key_set, key_location) = check_key_status(key);
secret_status.insert(
key.to_string(),
SecretStatus {
is_set: key_set,
location: key_location,
},
);
}

response.insert(
provider_name,
ProviderResponse {
supported: true,
name: Some(provider_config.name.clone()),
description: Some(provider_config.description.clone()),
models: Some(provider_config.models.clone()),
secret_status,
},
);
} else {
response.insert(
provider_name,
ProviderResponse {
supported: false,
name: None,
description: None,
models: None,
secret_status: HashMap::new(),
},
);
}
}

Ok(Json(response))
}

#[derive(Deserialize)]
struct DeleteSecretRequest {
key: String,
}

async fn delete_secret(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<DeleteSecretRequest>,
) -> Result<StatusCode, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

// Attempt to delete the key
match delete_from_keyring(&request.key) {
Ok(_) => Ok(StatusCode::NO_CONTENT),
Err(_) => Err(StatusCode::NOT_FOUND),
}
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/secrets/providers", post(check_provider_secrets))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will need to check for application secret as will delete

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure i know what you're saying

.route("/secrets/store", post(store_secret))
.route("/secrets/delete", delete(delete_secret))
.with_state(state)
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_unsupported_provider() {
// Setup
let request = ProviderSecretRequest {
providers: vec!["unsupported_provider".to_string()],
};

// Execute
let result = check_provider_secrets(Json(request)).await;

// Assert
assert!(result.is_ok());
let Json(response) = result.unwrap();

let provider_status = response
.get("unsupported_provider")
.expect("Provider should exist");
assert!(!provider_status.supported);
assert!(provider_status.secret_status.is_empty());
}
}
26 changes: 26 additions & 0 deletions crates/goose/src/key_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ pub fn save_to_keyring(key_name: &str, key_val: &str) -> std::result::Result<(),
kr.set_password(key_val).map_err(KeyManagerError::from)
}

pub fn delete_from_keyring(key_name: &str) -> std::result::Result<(), KeyManagerError> {
let kr = Entry::new("goose", key_name)?;
kr.delete_credential().map_err(KeyManagerError::from)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -100,6 +105,27 @@ mod tests {
kr.delete_credential().map_err(KeyManagerError::from)
}

#[test]
fn test_delete_from_keyring() {
let key_name = format!("{}{}", TEST_ENV_PREFIX, "DELETE_KEY");

// Save a value to the keyring
save_to_keyring(&key_name, "test_value").unwrap();

// Verify it exists
let kr = Entry::new("goose", &key_name).unwrap();
assert_eq!(kr.get_password().unwrap(), "test_value");

// Delete the keyring entry
let result = delete_from_keyring(&key_name);
assert!(result.is_ok());

// Verify deletion
let kr = Entry::new("goose", &key_name).unwrap();
let password_result = kr.get_password();
assert!(password_result.is_err());
}

#[test]
fn test_get_key_environment_only() {
let key_name = format!("{}{}", TEST_ENV_PREFIX, "ENV_KEY");
Expand Down
Loading
Loading