From 79fa2af60b21cc0dfa4cfef2f69a3bf8a15ad1c5 Mon Sep 17 00:00:00 2001 From: Max Novich Date: Thu, 23 Jan 2025 16:00:13 -0800 Subject: [PATCH] maintain extension name consistent (#721) Co-authored-by: Bradley Axen --- crates/goose-cli/src/commands/configure.rs | 70 +++++++++------------ crates/goose-cli/src/commands/session.rs | 19 ++++-- crates/goose-server/src/routes/extension.rs | 13 +++- crates/goose/examples/agent.rs | 2 +- crates/goose/src/agents/capabilities.rs | 26 ++++---- crates/goose/src/agents/extension.rs | 37 ++++++++--- crates/goose/src/config/extensions.rs | 26 +++++--- 7 files changed, 118 insertions(+), 75 deletions(-) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index a9c96bd4a..5a0ce341b 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -31,15 +31,12 @@ pub async fn handle_configure() -> Result<(), Box> { style("goose configure").cyan() ); // Since we are setting up for the first time, we'll also enable the developer system - ExtensionManager::set( - "developer", - ExtensionEntry { - enabled: true, - config: ExtensionConfig::Builtin { - name: "developer".to_string(), - }, + ExtensionManager::set(ExtensionEntry { + enabled: true, + config: ExtensionConfig::Builtin { + name: "developer".to_string(), }, - )?; + })?; } else { let _ = config.clear(); println!( @@ -267,7 +264,7 @@ pub fn toggle_extensions_dialog() -> Result<(), Box> { // Create a list of extension names and their enabled status let extension_status: Vec<(String, bool)> = extensions .iter() - .map(|(name, entry)| (name.clone(), entry.enabled)) + .map(|entry| (entry.config.name().to_string(), entry.enabled)) .collect(); // Get currently enabled extensions for the selection @@ -347,26 +344,23 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { .interact()? .to_string(); - ExtensionManager::set( - &extension, - ExtensionEntry { - enabled: true, - config: ExtensionConfig::Builtin { - name: extension.clone(), - }, + ExtensionManager::set(ExtensionEntry { + enabled: true, + config: ExtensionConfig::Builtin { + name: extension.clone(), }, - )?; + })?; cliclack::outro(format!("Enabled {} extension", style(extension).green()))?; } "stdio" => { - let extensions = ExtensionManager::get_all()?; + let extensions = ExtensionManager::get_all_names()?; let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-extension") .validate(move |input: &String| { if input.is_empty() { Err("Please enter a name") - } else if extensions.contains_key(input) { + } else if extensions.contains(input) { Err("An extension with this name already exists") } else { Ok(()) @@ -412,28 +406,26 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionManager::set( - &name, - ExtensionEntry { - enabled: true, - config: ExtensionConfig::Stdio { - cmd, - args, - envs: Envs::new(envs), - }, + ExtensionManager::set(ExtensionEntry { + enabled: true, + config: ExtensionConfig::Stdio { + name: name.clone(), + cmd, + args, + envs: Envs::new(envs), }, - )?; + })?; cliclack::outro(format!("Added {} extension", style(name).green()))?; } "sse" => { - let extensions = ExtensionManager::get_all()?; + let extensions = ExtensionManager::get_all_names()?; let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-remote-extension") .validate(move |input: &String| { if input.is_empty() { Err("Please enter a name") - } else if extensions.contains_key(input) { + } else if extensions.contains(input) { Err("An extension with this name already exists") } else { Ok(()) @@ -476,16 +468,14 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionManager::set( - &name, - ExtensionEntry { - enabled: true, - config: ExtensionConfig::Sse { - uri, - envs: Envs::new(envs), - }, + ExtensionManager::set(ExtensionEntry { + enabled: true, + config: ExtensionConfig::Sse { + name: name.clone(), + uri, + envs: Envs::new(envs), }, - )?; + })?; cliclack::outro(format!("Added {} extension", style(name).green()))?; } diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 2148ca2ad..bb40be017 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -41,10 +41,11 @@ pub async fn build_session( .expect("Failed to create agent"); // Setup extensions for the agent - for (name, extension) in ExtensionManager::get_all().expect("should load extensions") { + for extension in ExtensionManager::get_all().expect("should load extensions") { if extension.enabled { + let config = extension.config.clone(); agent - .add_extension(extension.config.clone()) + .add_extension(config.clone()) .await .unwrap_or_else(|e| { let err = match e { @@ -53,8 +54,11 @@ pub async fn build_session( } _ => e.to_string(), }; - println!("Failed to start extension: {}, {:?}", name, err); - println!("Please check extension configuration for {}.", name); + println!("Failed to start extension: {}, {:?}", config.name(), err); + println!( + "Please check extension configuration for {}.", + config.name() + ); process::exit(1); }); } @@ -81,7 +85,14 @@ pub async fn build_session( } let cmd = parts.remove(0).to_string(); + //this is an ephemeral extension so name does not matter + let name = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(8) + .map(char::from) + .collect(); let config = ExtensionConfig::Stdio { + name, cmd, args: parts.iter().map(|s| s.to_string()).collect(), envs: Envs::new(envs), diff --git a/crates/goose-server/src/routes/extension.rs b/crates/goose-server/src/routes/extension.rs index d0f04bb10..1ab492343 100644 --- a/crates/goose-server/src/routes/extension.rs +++ b/crates/goose-server/src/routes/extension.rs @@ -16,6 +16,8 @@ enum ExtensionConfigRequest { /// Server-Sent Events (SSE) extension. #[serde(rename = "sse")] Sse { + /// The name to identify this extension + name: String, /// The URI endpoint for the SSE extension. uri: String, /// List of environment variable keys. The server will fetch their values from the keyring. @@ -24,6 +26,8 @@ enum ExtensionConfigRequest { /// Standard I/O (stdio) extension. #[serde(rename = "stdio")] Stdio { + /// The name to identify this extension + name: String, /// The command to execute. cmd: String, /// Arguments for the command. @@ -73,7 +77,11 @@ async fn add_extension( // Construct ExtensionConfig with Envs populated from keyring based on provided env_keys. let extension_config: ExtensionConfig = match request { - ExtensionConfigRequest::Sse { uri, env_keys } => { + ExtensionConfigRequest::Sse { + name, + uri, + env_keys, + } => { let mut env_map = HashMap::new(); for key in env_keys { match config.get_secret(&key) { @@ -97,11 +105,13 @@ async fn add_extension( } ExtensionConfig::Sse { + name, uri, envs: Envs::new(env_map), } } ExtensionConfigRequest::Stdio { + name, cmd, args, env_keys, @@ -129,6 +139,7 @@ async fn add_extension( } ExtensionConfig::Stdio { + name, cmd, args, envs: Envs::new(env_map), diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index a6c9cc27c..6ecd44775 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -14,7 +14,7 @@ async fn main() { // Setup an agent with the developer extension let mut agent = AgentFactory::create("reference", provider).expect("default should exist"); - let config = ExtensionConfig::stdio("./target/debug/developer"); + let config = ExtensionConfig::stdio("developer", "./target/debug/developer"); agent.add_extension(config).await.unwrap(); println!("Extensions:"); diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 4daa8d986..b73d37508 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -98,24 +98,22 @@ impl Capabilities { /// Add a new MCP extension based on the provided client type // TODO IMPORTANT need to ensure this times out if the extension command is broken! pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { - let mut client: Box = match config { - ExtensionConfig::Sse { ref uri, ref envs } => { + let mut client: Box = match &config { + ExtensionConfig::Sse { uri, envs, .. } => { let transport = SseTransport::new(uri, envs.get_env()); let handle = transport.start().await?; let service = McpService::with_timeout(handle, Duration::from_secs(300)); Box::new(McpClient::new(service)) } ExtensionConfig::Stdio { - ref cmd, - ref args, - ref envs, + cmd, args, envs, .. } => { let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env()); let handle = transport.start().await?; let service = McpService::with_timeout(handle, Duration::from_secs(300)); Box::new(McpClient::new(service)) } - ExtensionConfig::Builtin { ref name } => { + ExtensionConfig::Builtin { name } => { // For builtin extensions, we run the current executable with mcp and extension name let cmd = std::env::current_exe() .expect("should find the current executable") @@ -148,18 +146,18 @@ impl Capabilities { // Store instructions if provided if let Some(instructions) = init_result.instructions { self.instructions - .insert(init_result.server_info.name.clone(), instructions); + .insert(config.name().to_string(), instructions); } // if the server is capable if resources we track it if init_result.capabilities.resources.is_some() { self.resource_capable_extensions - .insert(sanitize(init_result.server_info.name.clone())); + .insert(sanitize(config.name().to_string())); } - // Store the client + // Store the client using the provided name self.clients.insert( - sanitize(init_result.server_info.name.clone()), + sanitize(config.name().to_string()), Arc::new(Mutex::new(client)), ); @@ -180,15 +178,13 @@ impl Capabilities { /// Get aggregated usage statistics pub async fn remove_extension(&mut self, name: &str) -> ExtensionResult<()> { self.clients.remove(name); + self.instructions.remove(name); + self.resource_capable_extensions.remove(name); Ok(()) } pub async fn list_extensions(&self) -> ExtensionResult> { - let mut extensions = Vec::new(); - for name in self.clients.keys() { - extensions.push(name.clone()); - } - Ok(extensions) + Ok(self.clients.keys().cloned().collect()) } pub async fn get_usage(&self) -> Vec { diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index cbcc48910..db5d5f3bc 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -47,6 +47,8 @@ pub enum ExtensionConfig { /// Server-sent events client with a URI endpoint #[serde(rename = "sse")] Sse { + /// The name used to identify this extension + name: String, uri: String, #[serde(default)] envs: Envs, @@ -54,6 +56,8 @@ pub enum ExtensionConfig { /// Standard I/O client with command and arguments #[serde(rename = "stdio")] Stdio { + /// The name used to identify this extension + name: String, cmd: String, args: Vec, #[serde(default)] @@ -61,7 +65,10 @@ pub enum ExtensionConfig { }, /// Built-in extension that is part of the goose binary #[serde(rename = "builtin")] - Builtin { name: String }, + Builtin { + /// The name used to identify this extension + name: String, + }, } impl Default for ExtensionConfig { @@ -73,15 +80,17 @@ impl Default for ExtensionConfig { } impl ExtensionConfig { - pub fn sse>(uri: S) -> Self { + pub fn sse>(name: S, uri: S) -> Self { Self::Sse { + name: name.into(), uri: uri.into(), envs: Envs::default(), } } - pub fn stdio>(cmd: S) -> Self { + pub fn stdio>(name: S, cmd: S) -> Self { Self::Stdio { + name: name.into(), cmd: cmd.into(), args: vec![], envs: Envs::default(), @@ -94,7 +103,10 @@ impl ExtensionConfig { S: Into, { match self { - Self::Stdio { cmd, envs, .. } => Self::Stdio { + Self::Stdio { + name, cmd, envs, .. + } => Self::Stdio { + name, cmd, envs, args: args.into_iter().map(Into::into).collect(), @@ -102,14 +114,25 @@ impl ExtensionConfig { other => other, } } + + /// Get the extension name regardless of variant + pub fn name(&self) -> &str { + match self { + Self::Sse { name, .. } => name, + Self::Stdio { name, .. } => name, + Self::Builtin { name } => name, + } + } } impl std::fmt::Display for ExtensionConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ExtensionConfig::Sse { uri, .. } => write!(f, "SSE({})", uri), - ExtensionConfig::Stdio { cmd, args, .. } => { - write!(f, "Stdio({} {})", cmd, args.join(" ")) + ExtensionConfig::Sse { name, uri, .. } => write!(f, "SSE({}: {})", name, uri), + ExtensionConfig::Stdio { + name, cmd, args, .. + } => { + write!(f, "Stdio({}: {} {})", name, cmd, args.join(" ")) } ExtensionConfig::Builtin { name } => write!(f, "Builtin({})", name), } diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index 50e018177..e5d566d3e 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -1,10 +1,9 @@ +use super::base::Config; +use crate::agents::ExtensionConfig; use anyhow::Result; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use super::base::Config; -use crate::agents::ExtensionConfig; - const DEFAULT_EXTENSION: &str = "developer"; #[derive(Debug, Deserialize, Serialize, Clone)] @@ -52,13 +51,13 @@ impl ExtensionManager { } /// Set or update an extension configuration - pub fn set(name: &str, entry: ExtensionEntry) -> Result<()> { + pub fn set(entry: ExtensionEntry) -> Result<()> { let config = Config::global(); let mut extensions: HashMap = config.get("extensions").unwrap_or_else(|_| HashMap::new()); - extensions.insert(name.to_string(), entry); + extensions.insert(entry.config.name().parse()?, entry); config.set("extensions", serde_json::to_value(extensions)?)?; Ok(()) } @@ -90,9 +89,19 @@ impl ExtensionManager { } /// Get all extensions and their configurations - pub fn get_all() -> Result> { + pub fn get_all() -> Result> { let config = Config::global(); - Ok(config.get("extensions").unwrap_or_else(|_| HashMap::new())) + let extensions: HashMap = + config.get("extensions").unwrap_or(HashMap::new()); + Ok(Vec::from_iter(extensions.values().cloned())) + } + + /// Get all extension names + pub fn get_all_names() -> Result> { + let config = Config::global(); + Ok(config + .get("extensions") + .unwrap_or_else(|_| get_keys(Default::default()))) } /// Check if an extension is enabled @@ -104,3 +113,6 @@ impl ExtensionManager { Ok(extensions.get(name).map(|e| e.enabled).unwrap_or(false)) } } +fn get_keys(entries: HashMap) -> Vec { + entries.into_keys().collect() +}