From 34a4bac93131a170b898f0b61d289ad58f574b54 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 19 Feb 2025 15:04:26 -0800 Subject: [PATCH 01/16] feat: add prompts support to mcp-client, ahere to MCP spec for prompts - add new endpoints `list_prompts` and `get_prompt` in the MCP client - update prompt model in mcp-core to make `description` and `arguments` optional, following MCP spec --- crates/goose-mcp/src/developer/mod.rs | 14 ++------ crates/mcp-client/src/client.rs | 48 +++++++++++++++++++++++++-- crates/mcp-core/src/prompt.rs | 28 ++++++++++------ crates/mcp-server/src/router.rs | 27 ++++++++------- 4 files changed, 81 insertions(+), 36 deletions(-) diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index ee326ddd4..5cbdef982 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -70,9 +70,9 @@ pub fn load_prompt_files() -> HashMap { description: arg.description, required: arg.required, }) - .collect(); + .collect::>(); - let prompt = Prompt::new(&template.id, &template.template, arguments); + let prompt = Prompt::new(&template.id, Some(&template.template), Some(arguments)); if prompts.contains_key(&prompt.name) { eprintln!("Duplicate prompt name '{}' found. Skipping.", prompt.name); @@ -854,15 +854,7 @@ impl Router for DeveloperRouter { Some(Box::pin(async move { match prompts.get(&prompt_name) { - Some(prompt) => { - if prompt.description.trim().is_empty() { - Err(PromptError::InternalError(format!( - "Prompt '{prompt_name}' has an empty description" - ))) - } else { - Ok(prompt.description.clone()) - } - } + Some(prompt) => Ok(prompt.description.clone().unwrap_or_default()), None => Err(PromptError::NotFound(format!( "Prompt '{prompt_name}' not found" ))), diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 0a00e8c77..0d722e558 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,7 +1,7 @@ use mcp_core::protocol::{ - CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage, - JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, - ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, + CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, + ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -93,6 +93,10 @@ pub trait McpClientTrait: Send + Sync { async fn list_tools(&self, next_cursor: Option) -> Result; async fn call_tool(&self, name: &str, arguments: Value) -> Result; + + async fn list_prompts(&self, next_cursor: Option) -> Result; + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result; } /// The MCP client is the interface for MCP operations. @@ -346,4 +350,42 @@ where // https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2 self.send_request("tools/call", params).await } + + async fn list_prompts(&self, next_cursor: Option) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + + // If prompts is not supported, return an error + if self.server_capabilities.as_ref().unwrap().prompts.is_none() { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'prompts' capability".to_string(), + }); + } + + let payload = next_cursor + .map(|cursor| serde_json::json!({"cursor": cursor})) + .unwrap_or_else(|| serde_json::json!({})); + + self.send_request("prompts/list", payload).await + } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + + // If prompts is not supported, return an error + if self.server_capabilities.as_ref().unwrap().prompts.is_none() { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'prompts' capability".to_string(), + }); + } + + let params = serde_json::json!({ "name": name, "arguments": arguments }); + + self.send_request("prompts/get", params).await + } } diff --git a/crates/mcp-core/src/prompt.rs b/crates/mcp-core/src/prompt.rs index 7b814fd44..4a0106e34 100644 --- a/crates/mcp-core/src/prompt.rs +++ b/crates/mcp-core/src/prompt.rs @@ -10,22 +10,28 @@ use serde::{Deserialize, Serialize}; pub struct Prompt { /// The name of the prompt pub name: String, - /// A description of what the prompt does - pub description: String, - /// The arguments that can be passed to customize the prompt - pub arguments: Vec, + /// Optional description of what the prompt does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Optional arguments that can be passed to customize the prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, } impl Prompt { /// Create a new prompt with the given name, description and arguments - pub fn new(name: N, description: D, arguments: Vec) -> Self + pub fn new( + name: N, + description: Option, + arguments: Option>, + ) -> Self where N: Into, D: Into, { Prompt { name: name.into(), - description: description.into(), + description: description.map(Into::into), arguments, } } @@ -37,9 +43,11 @@ pub struct PromptArgument { /// The name of the argument pub name: String, /// A description of what the argument is used for - pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, /// Whether this argument is required - pub required: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, } /// Represents the role of a message sender in a prompt conversation @@ -151,6 +159,6 @@ pub struct PromptTemplate { #[derive(Debug, Serialize, Deserialize)] pub struct PromptArgumentTemplate { pub name: String, - pub description: String, - pub required: bool, + pub description: Option, + pub required: Option, } diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index d2918311c..0060ffd92 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -305,18 +305,21 @@ pub trait Router: Send + Sync + 'static { }; // Validate required arguments - for arg in &prompt.arguments { - if arg.required - && (!arguments.contains_key(&arg.name) - || arguments - .get(&arg.name) - .and_then(Value::as_str) - .is_none_or(str::is_empty)) - { - return Err(RouterError::InvalidParams(format!( - "Missing required argument: '{}'", - arg.name - ))); + if let Some(args) = &prompt.arguments { + for arg in args { + if arg.required.is_some() + && arg.required.unwrap() + && (!arguments.contains_key(&arg.name) + || arguments + .get(&arg.name) + .and_then(Value::as_str) + .is_none_or(str::is_empty)) + { + return Err(RouterError::InvalidParams(format!( + "Missing required argument: '{}'", + arg.name + ))); + } } } From 5c4258b4e876a6a427e823a6401c899bc0285abe Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 20 Feb 2025 10:37:16 -0800 Subject: [PATCH 02/16] feat: handle JsonRpcMessage::Error messages to propagate to the user --- crates/mcp-client/src/transport/sse.rs | 20 +++++++++++++++----- crates/mcp-client/src/transport/stdio.rs | 14 +++++++++++--- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index ed08e4800..90dc5f2f2 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -111,13 +111,23 @@ impl SseActor { // Attempt to parse the SSE data as a JsonRpcMessage match serde_json::from_str::(&e.data) { Ok(message) => { - // If it's a response, complete the pending request - if let JsonRpcMessage::Response(resp) = &message { - if let Some(id) = &resp.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; + match &message { + JsonRpcMessage::Response(response) => { + if let Some(id) = &response.id { + pending_requests + .respond(&id.to_string(), Ok(message)) + .await; + } } + JsonRpcMessage::Error(error) => { + if let Some(id) = &error.id { + pending_requests + .respond(&id.to_string(), Ok(message)) + .await; + } + } + _ => {} // TODO: Handle other variants (Request, etc.) } - // If it's something else (notification, etc.), handle as needed } Err(err) => { warn!("Failed to parse SSE message: {err}"); diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 59d900540..7980816bf 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -87,10 +87,18 @@ impl StdioActor { "Received incoming message" ); - if let JsonRpcMessage::Response(response) = &message { - if let Some(id) = &response.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; + match &message { + JsonRpcMessage::Response(response) => { + if let Some(id) = &response.id { + pending_requests.respond(&id.to_string(), Ok(message)).await; + } } + JsonRpcMessage::Error(error) => { + if let Some(id) = &error.id { + pending_requests.respond(&id.to_string(), Ok(message)).await; + } + } + _ => {} // TODO: Handle other variants (Request, etc.) } } line.clear(); From 2523e61ce6fafa74f35ad8d664e7b66fa6049b9c Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 20 Feb 2025 16:07:48 -0800 Subject: [PATCH 03/16] test: update MockClient in test with list_prompts and get_prompt --- crates/goose/src/agents/capabilities.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 783f15def..fb487979a 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -556,7 +556,8 @@ mod tests { use mcp_client::client::Error; use mcp_client::client::McpClientTrait; use mcp_core::protocol::{ - CallToolResult, InitializeResult, ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult, + ListToolsResult, ReadResourceResult, }; use serde_json::json; @@ -625,6 +626,20 @@ mod tests { _ => Err(Error::NotInitialized), } } + async fn list_prompts( + &self, + _next_cursor: Option, + ) -> Result { + Err(Error::NotInitialized) + } + + async fn get_prompt( + &self, + _name: &str, + _arguments: Value, + ) -> Result { + Err(Error::NotInitialized) + } } #[test] From 529b7fd50578ebc315f202f688930f7cb2e18a67 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 09:14:32 -0800 Subject: [PATCH 04/16] feat: remove concrete impl of get_prompt and list_prompts, and require implementing types to define them, similar to other methods --- crates/mcp-server/src/router.rs | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index 0060ffd92..2c277d1c4 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -97,12 +97,8 @@ pub trait Router: Send + Sync + 'static { &self, uri: &str, ) -> Pin> + Send + 'static>>; - fn list_prompts(&self) -> Option> { - None - } - fn get_prompt(&self, _prompt_name: &str) -> Option { - None - } + fn list_prompts(&self) -> Vec; + fn get_prompt(&self, prompt_name: &str) -> PromptFuture; // Helper method to create base response fn create_response(&self, id: Option) -> JsonRpcResponse { @@ -257,7 +253,7 @@ pub trait Router: Send + Sync + 'static { req: JsonRpcRequest, ) -> impl Future> + Send { async move { - let prompts = self.list_prompts().unwrap_or_default(); + let prompts = self.list_prompts(); let result = ListPromptsResult { prompts }; @@ -294,15 +290,13 @@ pub trait Router: Send + Sync + 'static { .ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?; // Fetch the prompt definition first - let prompt = match self.list_prompts() { - Some(prompts) => prompts - .into_iter() - .find(|p| p.name == prompt_name) - .ok_or_else(|| { - RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name)) - })?, - None => return Err(RouterError::PromptNotFound("No prompts available".into())), - }; + let prompt = self + .list_prompts() + .into_iter() + .find(|p| p.name == prompt_name) + .ok_or_else(|| { + RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name)) + })?; // Validate required arguments if let Some(args) = &prompt.arguments { @@ -326,7 +320,6 @@ pub trait Router: Send + Sync + 'static { // Now get the prompt content let description = self .get_prompt(prompt_name) - .ok_or_else(|| RouterError::PromptNotFound("Prompt not found".into()))? .await .map_err(|e| RouterError::Internal(e.to_string()))?; From a0110ac819c275a5ec5f05e75fff7d39a3446a21 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 09:16:34 -0800 Subject: [PATCH 05/16] test: add impl of list/get prompt to main.rs and stdio_integration to test both new methods --- .../mcp-client/examples/stdio_integration.rs | 11 ++++++ crates/mcp-server/src/main.rs | 35 ++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index 9acd2086d..ffdcc10c3 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -82,5 +82,16 @@ async fn main() -> Result<(), ClientError> { let resource = client.read_resource("memo://insights").await?; println!("Resource: {resource:?}\n"); + let prompts = client.list_prompts(None).await?; + println!("Prompts: {prompts:?}\n"); + + let prompt = client + .get_prompt( + "example_prompt", + serde_json::json!({"message": "hello there!"}), + ) + .await?; + println!("Prompt: {prompt:?}\n"); + Ok(()) } diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index eee250025..907cc1b1c 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -1,6 +1,7 @@ use anyhow::Result; use mcp_core::content::Content; -use mcp_core::handler::ResourceError; +use mcp_core::handler::{PromptError, ResourceError}; +use mcp_core::prompt::{Prompt, PromptArgument}; use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; use mcp_server::router::{CapabilitiesBuilder, RouterService}; use mcp_server::{ByteTransport, Router, Server}; @@ -61,6 +62,7 @@ impl Router for CounterRouter { CapabilitiesBuilder::new() .with_tools(false) .with_resources(false, false) + .with_prompts(false) .build() } @@ -153,6 +155,37 @@ impl Router for CounterRouter { } }) } + + fn list_prompts(&self) -> Vec { + vec![Prompt::new( + "example_prompt", + Some("This is an example prompt that takes one required agrument, message"), + Some(vec![PromptArgument { + name: "message".to_string(), + description: Some("A message to put in the prompt".to_string()), + required: Some(true), + }]), + )] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + match prompt_name.as_str() { + "example_prompt" => { + let prompt = "This is an example prompt with your message here: '{message}'"; + Ok(prompt.to_string()) + } + _ => Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))), + } + }) + } } #[tokio::main] From 0f776c07de63a6ed9960b90a932a6d2e9d157777 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 10:59:19 -0800 Subject: [PATCH 06/16] refactor: implement list_prompts and get_prompts for mcp servers --- .../goose-mcp/src/computercontroller/mod.rs | 20 ++++++++++++++++++- crates/goose-mcp/src/developer/mod.rs | 18 +++++++---------- crates/goose-mcp/src/google_drive/mod.rs | 20 ++++++++++++++++++- crates/goose-mcp/src/jetbrains/mod.rs | 20 ++++++++++++++++++- crates/goose-mcp/src/memory/mod.rs | 19 +++++++++++++++++- crates/goose-mcp/src/tutorial/mod.rs | 20 ++++++++++++++++++- 6 files changed, 101 insertions(+), 16 deletions(-) diff --git a/crates/goose-mcp/src/computercontroller/mod.rs b/crates/goose-mcp/src/computercontroller/mod.rs index be74395b7..77d372db0 100644 --- a/crates/goose-mcp/src/computercontroller/mod.rs +++ b/crates/goose-mcp/src/computercontroller/mod.rs @@ -9,7 +9,8 @@ use std::{ use tokio::process::Command; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::Tool, @@ -807,4 +808,21 @@ impl Router for ComputerControllerRouter { } }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 5cbdef982..a0a58b2a6 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -827,39 +827,35 @@ impl Router for DeveloperRouter { Box::pin(async move { Ok("".to_string()) }) } - fn list_prompts(&self) -> Option> { - if self.prompts.is_empty() { - None - } else { - Some(self.prompts.values().cloned().collect()) - } + fn list_prompts(&self) -> Vec { + self.prompts.values().cloned().collect() } fn get_prompt( &self, prompt_name: &str, - ) -> Option> + Send + 'static>>> { + ) -> Pin> + Send + 'static>> { let prompt_name = prompt_name.trim().to_owned(); // Validate prompt name is not empty if prompt_name.is_empty() { - return Some(Box::pin(async move { + return Box::pin(async move { Err(PromptError::InvalidParameters( "Prompt name cannot be empty".to_string(), )) - })); + }); } let prompts = Arc::clone(&self.prompts); - Some(Box::pin(async move { + Box::pin(async move { match prompts.get(&prompt_name) { Some(prompt) => Ok(prompt.description.clone().unwrap_or_default()), None => Err(PromptError::NotFound(format!( "Prompt '{prompt_name}' not found" ))), } - })) + }) } } diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 2ba36a574..2ed1e7f14 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -5,7 +5,8 @@ use serde_json::{json, Value}; use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin}; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::Tool, @@ -618,6 +619,23 @@ impl Router for GoogleDriveRouter { let uri_clone = uri.to_string(); Box::pin(async move { this.read_google_resource(uri_clone).await }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for GoogleDriveRouter { diff --git a/crates/goose-mcp/src/jetbrains/mod.rs b/crates/goose-mcp/src/jetbrains/mod.rs index 319cdcd36..0cdf80189 100644 --- a/crates/goose-mcp/src/jetbrains/mod.rs +++ b/crates/goose-mcp/src/jetbrains/mod.rs @@ -3,7 +3,8 @@ mod proxy; use anyhow::Result; use mcp_core::{ content::Content, - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, role::Role, @@ -176,6 +177,23 @@ impl Router for JetBrainsRouter { ) -> Pin> + Send + 'static>> { Box::pin(async { Err(ResourceError::NotFound("Resource not found".into())) }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for JetBrainsRouter { diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index 4a7411a54..a9fd1fa39 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -12,7 +12,8 @@ use std::{ }; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::{Tool, ToolCall}, @@ -493,6 +494,22 @@ impl Router for MemoryRouter { ) -> Pin> + Send + 'static>> { Box::pin(async move { Ok("".to_string()) }) } + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } #[derive(Debug)] diff --git a/crates/goose-mcp/src/tutorial/mod.rs b/crates/goose-mcp/src/tutorial/mod.rs index 9d6ba3d7c..2f32b03ac 100644 --- a/crates/goose-mcp/src/tutorial/mod.rs +++ b/crates/goose-mcp/src/tutorial/mod.rs @@ -5,7 +5,8 @@ use serde_json::{json, Value}; use std::{future::Future, pin::Pin}; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, role::Role, @@ -156,6 +157,23 @@ impl Router for TutorialRouter { ) -> Pin> + Send + 'static>> { Box::pin(async move { Ok("".to_string()) }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for TutorialRouter { From 7d0cb115b1dd365205cff45ec8023e6335adf42f Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 19 Feb 2025 16:01:29 -0800 Subject: [PATCH 07/16] feat: add list prompts command - extend CLI input handling to support `/prompts` for listing available prompts - add `ListPrompts` variant in the input enum and update help documentation - implement prompt rendering in the session output module - update agent traits and capabilities to aggregate and list prompts from all extensions --- crates/goose-cli/src/session/input.rs | 4 ++ crates/goose-cli/src/session/mod.rs | 8 +++ crates/goose-cli/src/session/output.rs | 12 ++++ crates/goose/src/agents/agent.rs | 5 ++ crates/goose/src/agents/capabilities.rs | 81 ++++++++++++++++++++++++- crates/goose/src/agents/reference.rs | 18 ++++++ crates/goose/src/agents/truncate.rs | 18 ++++++ 7 files changed, 145 insertions(+), 1 deletion(-) diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 7cfa94d35..0245a4563 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -9,6 +9,8 @@ pub enum InputResult { AddBuiltin(String), ToggleTheme, Retry, + ListPrompts, + //UsePrompt(String), } pub fn get_input( @@ -59,6 +61,7 @@ fn handle_slash_command(input: &str) -> Option { Some(InputResult::Retry) } "/t" => Some(InputResult::ToggleTheme), + "/prompts" => Some(InputResult::ListPrompts), s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, @@ -72,6 +75,7 @@ fn print_help() { /t - Toggle Light/Dark/Ansi theme /extension - Add a stdio extension (format: ENV1=val1 command args...) /builtin - Add builtin extensions by name (comma-separated) +/prompts - List all available prompts by name /? or /help - Display this help message Navigation: diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index d359a297a..d6db761c7 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -14,6 +14,7 @@ use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; use rand::{distributions::Alphanumeric, Rng}; +use std::collections::HashMap; use std::path::PathBuf; use tokio; @@ -103,6 +104,10 @@ impl Session { Ok(()) } + pub async fn list_prompts(&mut self) -> HashMap> { + self.agent.list_extension_prompts().await + } + pub async fn start(&mut self) -> Result<()> { let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?; @@ -165,6 +170,9 @@ impl Session { continue; } input::InputResult::Retry => continue, + input::InputResult::ListPrompts => { + output::render_prompts(&self.list_prompts().await) + } } } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index f6ccbdfcb..47ad3248a 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -5,6 +5,7 @@ use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; +use std::collections::HashMap; use std::path::Path; // Re-export theme for use in main @@ -151,6 +152,17 @@ pub fn render_error(message: &str) { println!("\n {} {}\n", style("error:").red().bold(), message); } +pub fn render_prompts(prompts: &HashMap>) { + println!(); + for (extension, prompts) in prompts { + println!(" {}", style(extension).green()); + for prompt in prompts { + println!(" - {}", style(prompt).cyan()); + } + } + println!(); +} + pub fn render_extension_success(name: &str) { println!(); println!( diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 4500f95d1..589c98f75 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; @@ -34,4 +36,7 @@ pub trait Agent: Send + Sync { /// Override the system prompt with custom text async fn override_system_prompt(&mut self, template: String); + + /// Lists all prompts from all extensions + async fn list_extension_prompts(&self) -> HashMap>; } diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index fb487979a..caf3b24ed 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use chrono::{DateTime, TimeZone, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use mcp_client::McpService; @@ -13,7 +14,7 @@ use crate::prompt_template::{load_prompt, load_prompt_file}; use crate::providers::base::{Provider, ProviderUsage}; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, StdioTransport, Transport}; -use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp @@ -544,6 +545,69 @@ impl Capabilities { result } + + pub async fn list_prompts_from_extension( + &self, + extension_name: &str, + ) -> Result, ToolError> { + let client = self.clients.get(extension_name).ok_or_else(|| { + ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name)) + })?; + + let client_guard = client.lock().await; + client_guard + .list_prompts(None) + .await + .map_err(|e| { + ToolError::ExecutionError(format!( + "Unable to list prompts for {}, {:?}", + extension_name, e + )) + }) + .map(|lp| lp.prompts) + } + + pub async fn list_prompts(&self) -> Result>, ToolError> { + let mut futures = FuturesUnordered::new(); + + for extension_name in self.clients.keys() { + futures.push(async move { + ( + extension_name, + self.list_prompts_from_extension(extension_name).await, + ) + }); + } + + let mut all_prompts = HashMap::new(); + let mut errors = Vec::new(); + + // Process results as they complete + while let Some(result) = futures.next().await { + let (name, prompts) = result; + match prompts { + Ok(content) => { + all_prompts.insert(name.to_string(), content); + } + Err(tool_error) => { + errors.push(tool_error); + } + } + } + + // Log any errors that occurred + if !errors.is_empty() { + tracing::error!( + errors = ?errors + .into_iter() + .map(|e| format!("{:?}", e)) + .collect::>(), + "errors from listing prompts" + ); + } + + Ok(all_prompts) + } } #[cfg(test)] @@ -617,6 +681,21 @@ mod tests { Err(Error::NotInitialized) } + async fn list_prompts( + &self, + _next_cursor: Option, + ) -> Result { + Err(Error::NotInitialized) + } + + async fn get_prompt( + &self, + _name: &str, + _arguments: Value, + ) -> Result { + Err(Error::NotInitialized) + } + async fn call_tool(&self, name: &str, _arguments: Value) -> Result { match name { "tool" | "test__tool" => Ok(CallToolResult { diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index 6c30435d9..7f9074a7d 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -2,6 +2,7 @@ /// It makes no attempt to handle context limits, and cannot read resources use async_trait::async_trait; use futures::stream::BoxStream; +use std::collections::HashMap; use tokio::sync::Mutex; use tracing::{debug, instrument}; @@ -194,6 +195,23 @@ impl Agent for ReferenceAgent { let mut capabilities = self.capabilities.lock().await; capabilities.set_system_prompt_override(template); } + + async fn list_extension_prompts(&self) -> HashMap> { + let capabilities = self.capabilities.lock().await; + capabilities + .list_prompts() + .await + .map(|prompts| { + prompts + .into_iter() + .map(|(extension, prompt_list)| { + let names = prompt_list.into_iter().map(|p| p.name).collect(); + (extension, names) + }) + .collect() + }) + .expect("Failed to list prompts") + } } register_agent!("reference", ReferenceAgent); diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 685524d4b..c19bd53c1 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -2,6 +2,7 @@ /// It makes no attempt to handle context limits, and cannot read resources use async_trait::async_trait; use futures::stream::BoxStream; +use std::collections::HashMap; use tokio::sync::Mutex; use tracing::{debug, error, instrument, warn}; @@ -302,6 +303,23 @@ impl Agent for TruncateAgent { let mut capabilities = self.capabilities.lock().await; capabilities.set_system_prompt_override(template); } + + async fn list_extension_prompts(&self) -> HashMap> { + let capabilities = self.capabilities.lock().await; + capabilities + .list_prompts() + .await + .map(|prompts| { + prompts + .into_iter() + .map(|(extension, prompt_list)| { + let names = prompt_list.into_iter().map(|p| p.name).collect(); + (extension, names) + }) + .collect() + }) + .expect("Failed to list prompts") + } } register_agent!("truncate", TruncateAgent); From ead9c1da43c04e431b4f36f897c0a1fc37954bc3 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 19 Feb 2025 17:09:03 -0800 Subject: [PATCH 08/16] feat: add /prompt $NAME --info and placeholder for exeuction --- crates/goose-cli/src/session/input.rs | 89 +++++++++++++++++++++++--- crates/goose-cli/src/session/mod.rs | 40 +++++++++++- crates/goose-cli/src/session/output.rs | 43 +++++++++++++ crates/goose/src/agents/agent.rs | 3 +- crates/goose/src/agents/reference.rs | 12 +--- crates/goose/src/agents/truncate.rs | 12 +--- 6 files changed, 167 insertions(+), 32 deletions(-) diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 0245a4563..32545e30f 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -1,5 +1,6 @@ use anyhow::Result; use rustyline::Editor; +use std::collections::HashMap; #[derive(Debug)] pub enum InputResult { @@ -10,7 +11,14 @@ pub enum InputResult { ToggleTheme, Retry, ListPrompts, - //UsePrompt(String), + PromptCommand(PromptCommandOptions), +} + +#[derive(Debug)] +pub struct PromptCommandOptions { + pub name: String, + pub info: bool, + pub arguments: HashMap, } pub fn get_input( @@ -52,22 +60,55 @@ pub fn get_input( } fn handle_slash_command(input: &str) -> Option { - let input = input.trim(); - - match input { - "/exit" | "/quit" => Some(InputResult::Exit), - "/?" | "/help" => { + let parts: Vec<&str> = input.trim().split_whitespace().collect(); + match parts.get(0).map(|s| *s) { + Some("/exit") | Some("/quit") => Some(InputResult::Exit), + Some("/?") | Some("/help") => { print_help(); Some(InputResult::Retry) } - "/t" => Some(InputResult::ToggleTheme), - "/prompts" => Some(InputResult::ListPrompts), - s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), - s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), + Some("/t") => Some(InputResult::ToggleTheme), + Some("/prompts") => Some(InputResult::ListPrompts), + Some("/prompt") => parse_prompt_command(&parts[1..]), + Some(s) if s.starts_with("/extension ") => { + Some(InputResult::AddExtension(s[11..].to_string())) + } + Some(s) if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, } } +fn parse_prompt_command(args: &[&str]) -> Option { + if args.is_empty() { + return None; + } + + let mut options = PromptCommandOptions { + name: args[0].to_string(), + info: false, + arguments: HashMap::new(), + }; + + // Parse remaining arguments + let mut i = 1; + while i < args.len() { + match args[i] { + "--info" => { + options.info = true; + } + arg if arg.contains('=') => { + if let Some((key, value)) = arg.split_once('=') { + options.arguments.insert(key.to_string(), value.to_string()); + } + } + _ => return None, // Invalid format + } + i += 1; + } + + Some(InputResult::PromptCommand(options)) +} + fn print_help() { println!( "Available commands: @@ -76,6 +117,7 @@ fn print_help() { /extension - Add a stdio extension (format: ENV1=val1 command args...) /builtin - Add builtin extensions by name (comma-separated) /prompts - List all available prompts by name +/prompt [--info] [key=value...] - Get prompt info or execute a prompt /? or /help - Display this help message Navigation: @@ -135,6 +177,33 @@ mod tests { assert!(handle_slash_command("/unknown").is_none()); } + #[test] + fn test_prompt_command() { + // Test basic prompt info command + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command("/prompt test-prompt --info") + { + assert_eq!(opts.name, "test-prompt"); + assert!(opts.info); + assert!(opts.arguments.is_empty()); + } else { + panic!("Expected PromptCommand"); + } + + // Test prompt with arguments + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command("/prompt test-prompt arg1=val1 arg2=val2") + { + assert_eq!(opts.name, "test-prompt"); + assert!(!opts.info); + assert_eq!(opts.arguments.len(), 2); + assert_eq!(opts.arguments.get("arg1"), Some(&"val1".to_string())); + assert_eq!(opts.arguments.get("arg2"), Some(&"val2".to_string())); + } else { + panic!("Expected PromptCommand"); + } + } + // Test whitespace handling #[test] fn test_whitespace_handling() { diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index d6db761c7..475e10c4d 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -13,6 +13,7 @@ use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; + use rand::{distributions::Alphanumeric, Rng}; use std::collections::HashMap; use std::path::PathBuf; @@ -105,7 +106,32 @@ impl Session { } pub async fn list_prompts(&mut self) -> HashMap> { - self.agent.list_extension_prompts().await + let prompts = self.agent.list_extension_prompts().await; + prompts + .into_iter() + .map(|(extension, prompt_list)| { + let names = prompt_list.into_iter().map(|p| p.name).collect(); + (extension, names) + }) + .collect() + } + + pub async fn get_prompt_info(&mut self, name: &str) -> Result> { + let prompts = self.agent.list_extension_prompts().await; + + // Find which extension has this prompt + for (extension, prompt_list) in prompts { + if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) { + return Ok(Some(output::PromptInfo { + name: prompt.name.clone(), + description: prompt.description.clone(), + arguments: prompt.arguments.clone(), + extension: Some(extension), + })); + } + } + + Ok(None) } pub async fn start(&mut self) -> Result<()> { @@ -173,6 +199,18 @@ impl Session { input::InputResult::ListPrompts => { output::render_prompts(&self.list_prompts().await) } + input::InputResult::PromptCommand(opts) => { + if opts.info { + match self.get_prompt_info(&opts.name).await? { + Some(info) => output::render_prompt_info(&info), + None => { + output::render_error(&format!("Prompt '{}' not found", opts.name)) + } + } + } else { + output::render_error("Prompt execution not yet implemented"); + } + } } } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 47ad3248a..93af711db 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -2,6 +2,7 @@ use bat::WrappingMode; use console::style; use goose::config::Config; use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use mcp_core::prompt::PromptArgument; use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; @@ -74,6 +75,14 @@ impl ThinkingIndicator { } } +#[derive(Debug)] +pub struct PromptInfo { + pub name: String, + pub description: Option, + pub arguments: Option>, + pub extension: Option, +} + // Global thinking indicator thread_local! { static THINKING: RefCell = RefCell::new(ThinkingIndicator::default()); @@ -163,6 +172,40 @@ pub fn render_prompts(prompts: &HashMap>) { println!(); } +pub fn render_prompt_info(info: &PromptInfo) { + println!(); + + if let Some(ext) = &info.extension { + println!(" {}: {}", style("Extension").green(), ext); + } + + println!("Prompt: {}", style(&info.name).cyan().bold()); + + if let Some(desc) = &info.description { + println!("\n {}", desc); + } + + if let Some(args) = &info.arguments { + println!("\n Arguments:"); + for arg in args { + let required = arg.required.unwrap_or(false); + let req_str = if required { + style("(required)").red() + } else { + style("(optional)").dim() + }; + + println!( + " {} {} {}", + style(&arg.name).yellow(), + req_str, + arg.description.as_deref().unwrap_or("") + ); + } + } + println!(); +} + pub fn render_extension_success(name: &str) { println!(); println!( diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 589c98f75..1ef94d355 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -8,6 +8,7 @@ use serde_json::Value; use super::extension::{ExtensionConfig, ExtensionResult}; use crate::message::Message; use crate::providers::base::ProviderUsage; +use mcp_core::prompt::Prompt; /// Core trait defining the behavior of an Agent #[async_trait] @@ -38,5 +39,5 @@ pub trait Agent: Send + Sync { async fn override_system_prompt(&mut self, template: String); /// Lists all prompts from all extensions - async fn list_extension_prompts(&self) -> HashMap>; + async fn list_extension_prompts(&self) -> HashMap>; } diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index 7f9074a7d..c11bbff4a 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -15,6 +15,7 @@ use crate::providers::base::ProviderUsage; use crate::register_agent; use crate::token_counter::TokenCounter; use indoc::indoc; +use mcp_core::prompt::Prompt; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -196,20 +197,11 @@ impl Agent for ReferenceAgent { capabilities.set_system_prompt_override(template); } - async fn list_extension_prompts(&self) -> HashMap> { + async fn list_extension_prompts(&self) -> HashMap> { let capabilities = self.capabilities.lock().await; capabilities .list_prompts() .await - .map(|prompts| { - prompts - .into_iter() - .map(|(extension, prompt_list)| { - let names = prompt_list.into_iter().map(|p| p.name).collect(); - (extension, names) - }) - .collect() - }) .expect("Failed to list prompts") } } diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index c19bd53c1..ca9644776 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -17,6 +17,7 @@ use crate::register_agent; use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; use indoc::indoc; +use mcp_core::prompt::Prompt; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -304,20 +305,11 @@ impl Agent for TruncateAgent { capabilities.set_system_prompt_override(template); } - async fn list_extension_prompts(&self) -> HashMap> { + async fn list_extension_prompts(&self) -> HashMap> { let capabilities = self.capabilities.lock().await; capabilities .list_prompts() .await - .map(|prompts| { - prompts - .into_iter() - .map(|(extension, prompt_list)| { - let names = prompt_list.into_iter().map(|p| p.name).collect(); - (extension, names) - }) - .collect() - }) .expect("Failed to list prompts") } } From 8025a936187fa29965feca308dee32ea7fc9f9dc Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 11:20:57 -0800 Subject: [PATCH 09/16] refactor: revert handle_slash_command, match existing patterns fix: cherry-pick conflicts resolved --- crates/goose-cli/src/session/input.rs | 33 +++++++++++++------------ crates/goose-cli/src/session/output.rs | 2 +- crates/goose/src/agents/capabilities.rs | 16 +----------- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 32545e30f..b865f55a1 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -60,39 +60,40 @@ pub fn get_input( } fn handle_slash_command(input: &str) -> Option { - let parts: Vec<&str> = input.trim().split_whitespace().collect(); - match parts.get(0).map(|s| *s) { - Some("/exit") | Some("/quit") => Some(InputResult::Exit), - Some("/?") | Some("/help") => { + let input = input.trim(); + + match input { + "/exit" | "/quit" => Some(InputResult::Exit), + "/?" | "/help" => { print_help(); Some(InputResult::Retry) } - Some("/t") => Some(InputResult::ToggleTheme), - Some("/prompts") => Some(InputResult::ListPrompts), - Some("/prompt") => parse_prompt_command(&parts[1..]), - Some(s) if s.starts_with("/extension ") => { - Some(InputResult::AddExtension(s[11..].to_string())) - } - Some(s) if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), + "/t" => Some(InputResult::ToggleTheme), + "/prompts" => Some(InputResult::ListPrompts), + s if s.starts_with("/prompt ") => parse_prompt_command(&s[8..]), + s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), + s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, } } -fn parse_prompt_command(args: &[&str]) -> Option { - if args.is_empty() { +fn parse_prompt_command(args: &str) -> Option { + let parts: Vec<&str> = args.split_whitespace().collect(); + + if parts.is_empty() { return None; } let mut options = PromptCommandOptions { - name: args[0].to_string(), + name: parts[0].to_string(), info: false, arguments: HashMap::new(), }; // Parse remaining arguments let mut i = 1; - while i < args.len() { - match args[i] { + while i < parts.len() { + match parts[i] { "--info" => { options.info = true; } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 93af711db..525c48575 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -179,7 +179,7 @@ pub fn render_prompt_info(info: &PromptInfo) { println!(" {}: {}", style("Extension").green(), ext); } - println!("Prompt: {}", style(&info.name).cyan().bold()); + println!(" Prompt: {}", style(&info.name).cyan().bold()); if let Some(desc) = &info.description { println!("\n {}", desc); diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index caf3b24ed..fc4762ea4 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -681,21 +681,6 @@ mod tests { Err(Error::NotInitialized) } - async fn list_prompts( - &self, - _next_cursor: Option, - ) -> Result { - Err(Error::NotInitialized) - } - - async fn get_prompt( - &self, - _name: &str, - _arguments: Value, - ) -> Result { - Err(Error::NotInitialized) - } - async fn call_tool(&self, name: &str, _arguments: Value) -> Result { match name { "tool" | "test__tool" => Ok(CallToolResult { @@ -705,6 +690,7 @@ mod tests { _ => Err(Error::NotInitialized), } } + async fn list_prompts( &self, _next_cursor: Option, From 2407051ef25e1cadfa02f0bb429fe688f896bc0b Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 25 Feb 2025 13:57:34 -0800 Subject: [PATCH 10/16] feat: first pass at enabling /prompt support, just rendering the output from the mcp server --- Cargo.lock | 1 + crates/goose-cli/Cargo.toml | 1 + crates/goose-cli/src/session/input.rs | 124 ++++++++++++++++++++---- crates/goose-cli/src/session/mod.rs | 29 +++++- crates/goose/src/agents/agent.rs | 5 + crates/goose/src/agents/capabilities.rs | 19 ++++ crates/goose/src/agents/reference.rs | 25 +++++ crates/goose/src/agents/truncate.rs | 25 +++++ 8 files changed, 211 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f9675c094..6457a36d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2203,6 +2203,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "shlex", "temp-env", "tempfile", "test-case", diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 41cb85a60..7addc964d 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -47,6 +47,7 @@ chrono = "0.4" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } tracing-appender = "0.2" once_cell = "1.20.2" +shlex = "1.3.0" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index b865f55a1..38fd2ea58 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -1,5 +1,6 @@ use anyhow::Result; use rustyline::Editor; +use shlex; use std::collections::HashMap; #[derive(Debug)] @@ -70,7 +71,22 @@ fn handle_slash_command(input: &str) -> Option { } "/t" => Some(InputResult::ToggleTheme), "/prompts" => Some(InputResult::ListPrompts), - s if s.starts_with("/prompt ") => parse_prompt_command(&s[8..]), + s if s.starts_with("/prompt") => { + if s == "/prompt" { + // No arguments case + Some(InputResult::PromptCommand(PromptCommandOptions { + name: String::new(), // Empty name will trigger the error message in the rendering + info: false, + arguments: HashMap::new(), + })) + } else if let Some(stripped) = s.strip_prefix("/prompt ") { + // Has arguments case + parse_prompt_command(stripped) + } else { + // Handle invalid cases like "/promptxyz" + None + } + } s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, @@ -78,32 +94,37 @@ fn handle_slash_command(input: &str) -> Option { } fn parse_prompt_command(args: &str) -> Option { - let parts: Vec<&str> = args.split_whitespace().collect(); - - if parts.is_empty() { - return None; - } + let parts: Vec = shlex::split(args).unwrap_or_default(); + // set name to empty and error out in the rendering let mut options = PromptCommandOptions { - name: parts[0].to_string(), + name: parts.first().cloned().unwrap_or_default(), info: false, arguments: HashMap::new(), }; + // handle info at any point in the command + if parts.iter().any(|part| part == "--info") { + options.info = true; + } + // Parse remaining arguments let mut i = 1; + while i < parts.len() { - match parts[i] { - "--info" => { - options.info = true; - } - arg if arg.contains('=') => { - if let Some((key, value)) = arg.split_once('=') { - options.arguments.insert(key.to_string(), value.to_string()); - } - } - _ => return None, // Invalid format + let part = &parts[i]; + + // Skip flag arguments + if part == "--info" { + i += 1; + continue; + } + + // Process key=value pairs - removed redundant contains check + if let Some((key, value)) = part.split_once('=') { + options.arguments.insert(key.to_string(), value.to_string()); } + i += 1; } @@ -223,4 +244,73 @@ mod tests { panic!("Expected AddBuiltin"); } } + + // Test prompt with no arguments + #[test] + fn test_prompt_no_args() { + // Test just "/prompt" with no arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command("/prompt") { + assert_eq!(opts.name, ""); + assert!(!opts.info); + assert!(opts.arguments.is_empty()); + } else { + panic!("Expected PromptCommand"); + } + + // Test invalid prompt command + assert!(handle_slash_command("/promptxyz").is_none()); + } + + // Test quoted arguments + #[test] + fn test_quoted_arguments() { + // Test prompt with quoted arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command( + r#"/prompt test-prompt arg1="value with spaces" arg2="another value""#, + ) { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 2); + assert_eq!( + opts.arguments.get("arg1"), + Some(&"value with spaces".to_string()) + ); + assert_eq!( + opts.arguments.get("arg2"), + Some(&"another value".to_string()) + ); + } else { + panic!("Expected PromptCommand"); + } + + // Test prompt with mixed quoted and unquoted arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command( + r#"/prompt test-prompt simple=value quoted="value with \"nested\" quotes""#, + ) { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 2); + assert_eq!(opts.arguments.get("simple"), Some(&"value".to_string())); + assert_eq!( + opts.arguments.get("quoted"), + Some(&r#"value with "nested" quotes"#.to_string()) + ); + } else { + panic!("Expected PromptCommand"); + } + } + + // Test invalid arguments + #[test] + fn test_invalid_arguments() { + // Test prompt with invalid arguments + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command(r#"/prompt test-prompt valid=value invalid_arg another_invalid"#) + { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 1); + assert_eq!(opts.arguments.get("valid"), Some(&"value".to_string())); + // Invalid arguments are ignored but logged + } else { + panic!("Expected PromptCommand"); + } + } } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 475e10c4d..3e142b203 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -13,8 +13,10 @@ use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; +use mcp_core::prompt::PromptMessage; use rand::{distributions::Alphanumeric, Rng}; +use serde_json::Value; use std::collections::HashMap; use std::path::PathBuf; use tokio; @@ -134,6 +136,11 @@ impl Session { Ok(None) } + pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result> { + let result = self.agent.get_prompt(name, arguments).await?; + Ok(result.messages) + } + pub async fn start(&mut self) -> Result<()> { let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?; @@ -200,6 +207,12 @@ impl Session { output::render_prompts(&self.list_prompts().await) } input::InputResult::PromptCommand(opts) => { + // name is required + if opts.name.is_empty() { + output::render_error("Prompt name argument is required"); + continue; + } + if opts.info { match self.get_prompt_info(&opts.name).await? { Some(info) => output::render_prompt_info(&info), @@ -208,7 +221,21 @@ impl Session { } } } else { - output::render_error("Prompt execution not yet implemented"); + // Convert the arguments HashMap to a Value + let arguments = serde_json::to_value(opts.arguments) + .map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?; + + match self.get_prompt(&opts.name, arguments).await { + Ok(messages) => { + println!( + "{:?}", + serde_json::to_string(&messages) + .unwrap_or("failed to get prompt".to_string()) + ); + continue; + } + Err(e) => output::render_error(&e.to_string()), + } } } } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 1ef94d355..9ee53522c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,6 +9,7 @@ use super::extension::{ExtensionConfig, ExtensionResult}; use crate::message::Message; use crate::providers::base::ProviderUsage; use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; /// Core trait defining the behavior of an Agent #[async_trait] @@ -40,4 +41,8 @@ pub trait Agent: Send + Sync { /// Lists all prompts from all extensions async fn list_extension_prompts(&self) -> HashMap>; + + /// Get a prompt result with the given name and arguments + /// Returns the prompt text that would be used as user input + async fn get_prompt(&self, name: &str, arguments: Value) -> Result; } diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index fc4762ea4..2e95ec706 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -2,6 +2,7 @@ use anyhow::Result; use chrono::{DateTime, TimeZone, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use mcp_client::McpService; +use mcp_core::protocol::GetPromptResult; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::sync::LazyLock; @@ -608,6 +609,24 @@ impl Capabilities { Ok(all_prompts) } + + pub async fn get_prompt( + &self, + extension_name: &str, + name: &str, + arguments: Value, + ) -> Result { + let client = self + .clients + .get(extension_name) + .ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?; + + let client_guard = client.lock().await; + client_guard + .get_prompt(name, arguments) + .await + .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e)) + } } #[cfg(test)] diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index c11bbff4a..87118bef3 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -14,8 +14,10 @@ use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; use crate::register_agent; use crate::token_counter::TokenCounter; +use anyhow::{anyhow, Result}; use indoc::indoc; use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -204,6 +206,29 @@ impl Agent for ReferenceAgent { .await .expect("Failed to list prompts") } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + let capabilities = self.capabilities.lock().await; + + // First find which extension has this prompt + let prompts = capabilities + .list_prompts() + .await + .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; + + if let Some(extension) = prompts + .iter() + .find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name)) + .map(|(extension, _)| extension) + { + return capabilities + .get_prompt(extension, name, arguments) + .await + .map_err(|e| anyhow!("Failed to get prompt: {}", e)); + } + + Err(anyhow!("Prompt '{}' not found", name)) + } } register_agent!("reference", ReferenceAgent); diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index ca9644776..0d6e28fb5 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -16,8 +16,10 @@ use crate::providers::errors::ProviderError; use crate::register_agent; use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; +use anyhow::{anyhow, Result}; use indoc::indoc; use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -312,6 +314,29 @@ impl Agent for TruncateAgent { .await .expect("Failed to list prompts") } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + let capabilities = self.capabilities.lock().await; + + // First find which extension has this prompt + let prompts = capabilities + .list_prompts() + .await + .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; + + if let Some(extension) = prompts + .iter() + .find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name)) + .map(|(extension, _)| extension) + { + return capabilities + .get_prompt(extension, name, arguments) + .await + .map_err(|e| anyhow!("Failed to get prompt: {}", e)); + } + + Err(anyhow!("Prompt '{}' not found", name)) + } } register_agent!("truncate", TruncateAgent); From dd4cddbcb02d0e31633979c550ae9881e5171d28 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 25 Feb 2025 14:58:39 -0800 Subject: [PATCH 11/16] feat: convert prompt messages to agent messages, and handle prompt in agent loop Add functionality to transform PromptMessageContent to MessageContent with proper handling in the session module and add test coverage. Add the results of GetPrompt to the message conversation and run the agent loop with prompt response. --- crates/goose-cli/src/session/mod.rs | 29 ++++-- crates/goose/src/message.rs | 135 ++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 8 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 3e142b203..2b203c6c6 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -11,9 +11,9 @@ use anyhow::Result; use etcetera::choose_app_strategy; use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; -use goose::message::{Message, MessageContent}; +use goose::message::{prompt_content_to_message_content, Message, MessageContent}; use mcp_core::handler::ToolError; -use mcp_core::prompt::PromptMessage; +use mcp_core::prompt::{PromptMessage, PromptMessageRole}; use rand::{distributions::Alphanumeric, Rng}; use serde_json::Value; @@ -227,12 +227,25 @@ impl Session { match self.get_prompt(&opts.name, arguments).await { Ok(messages) => { - println!( - "{:?}", - serde_json::to_string(&messages) - .unwrap_or("failed to get prompt".to_string()) - ); - continue; + // convert the PromptMessages to Messages + for message in messages { + let msg_content = + prompt_content_to_message_content(message.content); + match message.role { + PromptMessageRole::User => { + self.messages + .push(Message::user().with_content(msg_content)); + } + PromptMessageRole::Assistant => { + self.messages.push( + Message::assistant().with_content(msg_content), + ); + } + } + } + output::show_thinking(); + self.process_agent_response().await?; + output::hide_thinking(); } Err(e) => output::render_error(&e.to_string()), } diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index 30de253ff..67877b9e0 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -10,9 +10,31 @@ use std::collections::HashSet; use chrono::Utc; use mcp_core::content::{Content, ImageContent, TextContent}; use mcp_core::handler::ToolResult; +use mcp_core::prompt::PromptMessageContent; +use mcp_core::resource::ResourceContents; use mcp_core::role::Role; use mcp_core::tool::ToolCall; +/// Convert PromptMessageContent to MessageContent +/// +/// This function allows converting from the prompt message content type +/// to the message content type used in the agent. +pub fn prompt_content_to_message_content(content: PromptMessageContent) -> MessageContent { + match content { + PromptMessageContent::Text { text } => MessageContent::text(text), + PromptMessageContent::Image { image } => MessageContent::image(image.data, image.mime_type), + PromptMessageContent::Resource { resource } => { + // For resources, convert to text content with the resource text + match resource.resource { + ResourceContents::TextResourceContents { text, .. } => MessageContent::text(text), + ResourceContents::BlobResourceContents { blob, .. } => { + MessageContent::text(format!("[Binary content: {}]", blob)) + } + } + } + } +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct ToolRequest { pub id: String, @@ -248,3 +270,116 @@ impl Message { .all(|c| matches!(c, MessageContent::Text(_))) } } + +#[cfg(test)] +mod tests { + use super::*; + use mcp_core::content::EmbeddedResource; + use mcp_core::prompt::PromptMessageContent; + use mcp_core::resource::ResourceContents; + + #[test] + fn test_prompt_content_to_message_content_text() { + let prompt_content = PromptMessageContent::Text { + text: "Hello, world!".to_string(), + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Text(text_content) = message_content { + assert_eq!(text_content.text, "Hello, world!"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_prompt_content_to_message_content_image() { + let prompt_content = PromptMessageContent::Image { + image: ImageContent { + data: "base64data".to_string(), + mime_type: "image/jpeg".to_string(), + annotations: None, + }, + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Image(image_content) = message_content { + assert_eq!(image_content.data, "base64data"); + assert_eq!(image_content.mime_type, "image/jpeg"); + } else { + panic!("Expected MessageContent::Image"); + } + } + + #[test] + fn test_prompt_content_to_message_content_text_resource() { + let resource = ResourceContents::TextResourceContents { + uri: "file:///test.txt".to_string(), + mime_type: Some("text/plain".to_string()), + text: "Resource content".to_string(), + }; + + let prompt_content = PromptMessageContent::Resource { + resource: EmbeddedResource { + resource, + annotations: None, + }, + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Text(text_content) = message_content { + assert_eq!(text_content.text, "Resource content"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_prompt_content_to_message_content_blob_resource() { + let resource = ResourceContents::BlobResourceContents { + uri: "file:///test.bin".to_string(), + mime_type: Some("application/octet-stream".to_string()), + blob: "binary_data".to_string(), + }; + + let prompt_content = PromptMessageContent::Resource { + resource: EmbeddedResource { + resource, + annotations: None, + }, + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Text(text_content) = message_content { + assert_eq!(text_content.text, "[Binary content: binary_data]"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_message_with_text() { + let message = Message::user().with_text("Hello"); + assert_eq!(message.as_concat_text(), "Hello"); + } + + #[test] + fn test_message_with_tool_request() { + let tool_call = Ok(ToolCall { + name: "test_tool".to_string(), + arguments: serde_json::json!({}), + }); + + let message = Message::assistant().with_tool_request("req1", tool_call); + assert!(message.is_tool_call()); + assert!(!message.is_tool_response()); + + let ids = message.get_tool_ids(); + assert_eq!(ids.len(), 1); + assert!(ids.contains("req1")); + } +} From a1ef0b9e33d01cee3bca83e6d0c8240d3cb960c6 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 26 Feb 2025 09:13:28 -0800 Subject: [PATCH 12/16] style: cargo fmt after merge fix: update process_agent_response call --- crates/goose-cli/src/session/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 3acb43a7e..9c26c6290 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -107,7 +107,6 @@ impl Session { Ok(()) } - pub async fn list_prompts(&mut self) -> HashMap> { let prompts = self.agent.list_extension_prompts().await; prompts @@ -258,7 +257,7 @@ impl Session { } } output::show_thinking(); - self.process_agent_response().await?; + self.process_agent_response(true).await?; output::hide_thinking(); } Err(e) => output::render_error(&e.to_string()), From 2b45c20c234bdc71c3fa53ebaaac17fb6444025a Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 26 Feb 2025 13:17:19 -0800 Subject: [PATCH 13/16] feat: add /prompts --extension $extension to filter for prompts --- crates/goose-cli/src/session/input.rs | 47 ++++++++++++++++++++++++--- crates/goose-cli/src/session/mod.rs | 26 ++++++++++++--- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 38fd2ea58..133068e45 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -11,7 +11,7 @@ pub enum InputResult { AddBuiltin(String), ToggleTheme, Retry, - ListPrompts, + ListPrompts(Option), PromptCommand(PromptCommandOptions), } @@ -70,7 +70,12 @@ fn handle_slash_command(input: &str) -> Option { Some(InputResult::Retry) } "/t" => Some(InputResult::ToggleTheme), - "/prompts" => Some(InputResult::ListPrompts), + "/prompts" => Some(InputResult::ListPrompts(None)), + s if s.starts_with("/prompts ") => { + // Parse arguments for /prompts command + let args = s.strip_prefix("/prompts ").unwrap_or_default(); + parse_prompts_command(args) + } s if s.starts_with("/prompt") => { if s == "/prompt" { // No arguments case @@ -93,6 +98,21 @@ fn handle_slash_command(input: &str) -> Option { } } +fn parse_prompts_command(args: &str) -> Option { + let parts: Vec = shlex::split(args).unwrap_or_default(); + + // Look for --extension flag + for i in 0..parts.len() { + if parts[i] == "--extension" && i + 1 < parts.len() { + // Return the extension name that follows the flag + return Some(InputResult::ListPrompts(Some(parts[i + 1].clone()))); + } + } + + // If we got here, there was no valid --extension flag + Some(InputResult::ListPrompts(None)) +} + fn parse_prompt_command(args: &str) -> Option { let parts: Vec = shlex::split(args).unwrap_or_default(); @@ -138,8 +158,8 @@ fn print_help() { /t - Toggle Light/Dark/Ansi theme /extension - Add a stdio extension (format: ENV1=val1 command args...) /builtin - Add builtin extensions by name (comma-separated) -/prompts - List all available prompts by name -/prompt [--info] [key=value...] - Get prompt info or execute a prompt +/prompts [--extension ] - List all available prompts, optionally filtered by extension +/prompt [--info] [key=value...] - Get prompt info or execute a prompt /? or /help - Display this help message Navigation: @@ -199,6 +219,25 @@ mod tests { assert!(handle_slash_command("/unknown").is_none()); } + #[test] + fn test_prompts_command() { + // Test basic prompts command + if let Some(InputResult::ListPrompts(extension)) = handle_slash_command("/prompts") { + assert!(extension.is_none()); + } else { + panic!("Expected ListPrompts"); + } + + // Test prompts with extension filter + if let Some(InputResult::ListPrompts(extension)) = + handle_slash_command("/prompts --extension test") + { + assert_eq!(extension, Some("test".to_string())); + } else { + panic!("Expected ListPrompts with extension"); + } + } + #[test] fn test_prompt_command() { // Test basic prompt info command diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 9c26c6290..940fca267 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -107,15 +107,28 @@ impl Session { Ok(()) } - pub async fn list_prompts(&mut self) -> HashMap> { + pub async fn list_prompts( + &mut self, + extension: Option, + ) -> Result>> { let prompts = self.agent.list_extension_prompts().await; - prompts + + // Early validation if filtering by extension + if let Some(filter) = &extension { + if !prompts.contains_key(filter) { + return Err(anyhow::anyhow!("Extension '{}' not found", filter)); + } + } + + // Convert prompts into filtered map of extension names to prompt names + Ok(prompts .into_iter() + .filter(|(ext, _)| extension.as_ref().is_none_or(|f| f == ext)) .map(|(extension, prompt_list)| { let names = prompt_list.into_iter().map(|p| p.name).collect(); (extension, names) }) - .collect() + .collect()) } pub async fn get_prompt_info(&mut self, name: &str) -> Result> { @@ -216,8 +229,11 @@ impl Session { continue; } input::InputResult::Retry => continue, - input::InputResult::ListPrompts => { - output::render_prompts(&self.list_prompts().await) + input::InputResult::ListPrompts(extension) => { + match self.list_prompts(extension).await { + Ok(prompts) => output::render_prompts(&prompts), + Err(e) => output::render_error(&e.to_string()), + } } input::InputResult::PromptCommand(opts) => { // name is required From 25746ab0763dfe86599d103614cd8afaee9bb73d Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 26 Feb 2025 16:14:31 -0800 Subject: [PATCH 14/16] feat(cli): add custom command completion with cache support - introduce `GooseCompleter` in `completion.rs` implementing rustyline traits (Completer, Hinter, Highlighter, Validator, Helper) for CLI command completions - add helper methods for completing slash commands, prompt names, flags, and argument keys - update `Session` in `mod.rs` to include a thread-safe completion cache with update and invalidate functionality - derive `Clone` for `PromptInfo` in `output.rs` to support caching --- crates/goose-cli/src/session/completion.rs | 333 +++++++++++++++++++++ crates/goose-cli/src/session/input.rs | 4 +- crates/goose-cli/src/session/mod.rs | 92 +++++- crates/goose-cli/src/session/output.rs | 2 +- 4 files changed, 427 insertions(+), 4 deletions(-) create mode 100644 crates/goose-cli/src/session/completion.rs diff --git a/crates/goose-cli/src/session/completion.rs b/crates/goose-cli/src/session/completion.rs new file mode 100644 index 000000000..b6ea008bf --- /dev/null +++ b/crates/goose-cli/src/session/completion.rs @@ -0,0 +1,333 @@ +use rustyline::completion::{Completer, Pair}; +use rustyline::highlight::{CmdKind, Highlighter}; +use rustyline::hint::Hinter; +use rustyline::validate::Validator; +use rustyline::{Helper, Result}; +use std::borrow::Cow; +use std::sync::Arc; + +use super::CompletionCache; + +/// Completer for Goose CLI commands +pub struct GooseCompleter { + completion_cache: Arc>, +} + +impl GooseCompleter { + /// Create a new GooseCompleter with a reference to the Session's completion cache + pub fn new(completion_cache: Arc>) -> Self { + Self { completion_cache } + } + + /// Complete prompt names for the /prompt command + fn complete_prompt_names(&self, line: &str) -> Result<(usize, Vec)> { + // Get the prefix of the prompt name being typed + let prefix = if line.len() > 8 { &line[8..] } else { "" }; + + // Get available prompts from cache + let cache = self.completion_cache.read().unwrap(); + + // Create completion candidates that match the prefix + let candidates: Vec = cache + .prompts + .iter() + .flat_map(|(_, names)| names) + .filter(|name| name.starts_with(prefix.trim())) + .map(|name| Pair { + display: name.clone(), + replacement: name.clone(), + }) + .collect(); + + Ok((8, candidates)) + } + + /// Complete flags for the /prompt command + fn complete_prompt_flags(&self, line: &str) -> Result<(usize, Vec)> { + // Get the last part of the line + let parts: Vec<&str> = line.split_whitespace().collect(); + if let Some(last_part) = parts.last() { + // If the last part starts with '-', it might be a partial flag + if last_part.starts_with('-') { + // Define available flags + let flags = ["--info"]; + + // Find flags that match the prefix + let matching_flags: Vec = flags + .iter() + .filter(|flag| flag.starts_with(last_part)) + .map(|flag| Pair { + display: flag.to_string(), + replacement: flag.to_string(), + }) + .collect(); + + if !matching_flags.is_empty() { + // Return matches for the partial flag + // The position is the start of the last word + let pos = line.len() - last_part.len(); + return Ok((pos, matching_flags)); + } + } + } + + // No flag completions available + Ok((line.len(), vec![])) + } + + /// Complete slash commands + fn complete_slash_commands(&self, line: &str) -> Result<(usize, Vec)> { + // Define available slash commands + let commands = [ + "/exit", + "/quit", + "/help", + "/?", + "/t", + "/extension", + "/builtin", + "/prompts", + "/prompt", + ]; + + // Find commands that match the prefix + let matching_commands: Vec = commands + .iter() + .filter(|cmd| cmd.starts_with(line)) + .map(|cmd| Pair { + display: cmd.to_string(), + replacement: format!("{} ", cmd), // Add a space after the command + }) + .collect(); + + if !matching_commands.is_empty() { + return Ok((0, matching_commands)); + } + + // No command completions available + Ok((line.len(), vec![])) + } + + /// Complete argument keys for a specific prompt + fn complete_argument_keys(&self, line: &str) -> Result<(usize, Vec)> { + let parts: Vec<&str> = line[8..].split_whitespace().collect(); + + // We need at least the prompt name + if parts.is_empty() { + return Ok((line.len(), vec![])); + } + + let prompt_name = parts[0]; + + // Get prompt info from cache + let cache = self.completion_cache.read().unwrap(); + let prompt_info = cache.prompt_info.get(prompt_name).cloned(); + + if let Some(info) = prompt_info { + if let Some(args) = info.arguments { + // Find required arguments that haven't been provided yet + let existing_args: Vec<&str> = parts + .iter() + .skip(1) + .filter_map(|part| { + if part.contains('=') { + Some(part.split('=').next().unwrap()) + } else { + None + } + }) + .collect(); + + // Check if we're trying to complete a partial argument name + if let Some(last_part) = parts.last() { + // If the last part doesn't contain '=', it might be a partial argument name + if !last_part.contains('=') { + // Find arguments that match the prefix + let matching_args: Vec = args + .iter() + .filter(|arg| { + arg.name.starts_with(last_part) + && !existing_args.contains(&arg.name.as_str()) + }) + .map(|arg| Pair { + display: format!("{}=", arg.name), + replacement: format!("{}=", arg.name), + }) + .collect(); + + if !matching_args.is_empty() { + // Return matches for the partial argument name + // The position is the start of the last word + let pos = line.len() - last_part.len(); + return Ok((pos, matching_args)); + } + + // If we have a partial argument that doesn't match anything, + // return an empty list rather than suggesting unrelated arguments + if !last_part.is_empty() { + return Ok((line.len(), vec![])); + } + } + } + + // If no partial match or no last part, suggest the first required argument + // Use a reference to avoid moving args + for arg in &args { + if arg.required.unwrap_or(false) && !existing_args.contains(&arg.name.as_str()) + { + let candidates = vec![Pair { + display: format!("{}=", arg.name), + replacement: format!("{}=", arg.name), + }]; + return Ok((line.len(), candidates)); + } + } + + // If no required arguments left, suggest optional ones + // Use a reference to avoid moving args + for arg in &args { + if !arg.required.unwrap_or(true) && !existing_args.contains(&arg.name.as_str()) + { + let candidates = vec![Pair { + display: format!("{}=", arg.name), + replacement: format!("{}=", arg.name), + }]; + return Ok((line.len(), candidates)); + } + } + } + } + + // No completions available + Ok((line.len(), vec![])) + } +} + +impl Completer for GooseCompleter { + type Candidate = Pair; + + fn complete( + &self, + line: &str, + pos: usize, + _ctx: &rustyline::Context<'_>, + ) -> Result<(usize, Vec)> { + // If the line starts with '/', it might be a slash command + if line.starts_with('/') { + // If it's just a partial slash command (no space yet) + if !line.contains(' ') { + return self.complete_slash_commands(line); + } + + // Handle /prompt command + if line.starts_with("/prompt") { + // If we're just after "/prompt" with or without a space + if line == "/prompt" || line == "/prompt " { + return self.complete_prompt_names(line); + } + + // Get the parts of the command + let parts: Vec<&str> = line.split_whitespace().collect(); + + // If we're typing a prompt name (only one part after /prompt) + if parts.len() == 2 && !line.ends_with(' ') { + return self.complete_prompt_names(line); + } + + // Check if we might be typing a flag + if let Some(last_part) = parts.last() { + if last_part.starts_with('-') { + return self.complete_prompt_flags(line); + } + } + + // If we have a prompt name and need argument completion + if parts.len() >= 2 { + return self.complete_argument_keys(line); + } + } + + // Handle /prompts command + if line.starts_with("/prompts") { + // If we're just after "/prompts" with a space + if line == "/prompts " { + // Suggest the --extension flag + return Ok(( + line.len(), + vec![Pair { + display: "--extension".to_string(), + replacement: "--extension ".to_string(), + }], + )); + } + + // Check if we might be typing the --extension flag + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() == 2 + && parts[1].starts_with('-') + && "--extension".starts_with(parts[1]) + { + return Ok(( + line.len() - parts[1].len(), + vec![Pair { + display: "--extension".to_string(), + replacement: "--extension ".to_string(), + }], + )); + } + } + } + + // Default: no completions + Ok((pos, vec![])) + } +} + +// Implement the Helper trait which is required by rustyline +impl Helper for GooseCompleter {} + +// Implement required traits with default implementations +impl Hinter for GooseCompleter { + type Hint = String; + + fn hint(&self, _line: &str, _pos: usize, _ctx: &rustyline::Context<'_>) -> Option { + None + } +} + +impl Highlighter for GooseCompleter { + fn highlight_prompt<'b, 's: 'b, 'p: 'b>( + &'s self, + prompt: &'p str, + _default: bool, + ) -> Cow<'b, str> { + Cow::Borrowed(prompt) + } + + fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { + Cow::Borrowed(hint) + } + + fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { + Cow::Borrowed(line) + } + + fn highlight_char(&self, _line: &str, _pos: usize, _cmd_kind: CmdKind) -> bool { + false + } +} + +impl Validator for GooseCompleter { + fn validate( + &self, + _ctx: &mut rustyline::validate::ValidationContext, + ) -> rustyline::Result { + Ok(rustyline::validate::ValidationResult::Valid(None)) + } +} + +#[cfg(test)] +mod tests { + // Tests are disabled for now due to mismatch between mock and real types + // We've manually tested the completion functionality +} diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 133068e45..aa7e20b5e 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -3,6 +3,8 @@ use rustyline::Editor; use shlex; use std::collections::HashMap; +use super::completion::GooseCompleter; + #[derive(Debug)] pub enum InputResult { Message(String), @@ -23,7 +25,7 @@ pub struct PromptCommandOptions { } pub fn get_input( - editor: &mut Editor<(), rustyline::history::DefaultHistory>, + editor: &mut Editor, ) -> Result { // Ensure Ctrl-J binding is set for newlines editor.bind_sequence( diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 940fca267..5f5a88c62 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1,4 +1,5 @@ mod builder; +mod completion; mod input; mod output; mod prompt; @@ -6,6 +7,7 @@ mod storage; mod thinking; pub use builder::build_session; +use completion::GooseCompleter; use anyhow::Result; use etcetera::choose_app_strategy; @@ -19,6 +21,8 @@ use rand::{distributions::Alphanumeric, Rng}; use serde_json::Value; use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; +use std::time::Instant; use tokio; use crate::log_usage::log_usage; @@ -27,6 +31,25 @@ pub struct Session { agent: Box, messages: Vec, session_file: PathBuf, + // Cache for completion data - using std::sync for thread safety without async + completion_cache: Arc>, +} + +// Cache structure for completion data +struct CompletionCache { + prompts: HashMap>, + prompt_info: HashMap, + last_updated: Instant, +} + +impl CompletionCache { + fn new() -> Self { + Self { + prompts: HashMap::new(), + prompt_info: HashMap::new(), + last_updated: Instant::now(), + } + } } impl Session { @@ -43,6 +66,7 @@ impl Session { agent, messages, session_file, + completion_cache: Arc::new(std::sync::RwLock::new(CompletionCache::new())), } } @@ -87,7 +111,12 @@ impl Session { self.agent .add_extension(config) .await - .map_err(|e| anyhow::anyhow!("Failed to start extension: {}", e)) + .map_err(|e| anyhow::anyhow!("Failed to start extension: {}", e))?; + + // Invalidate the completion cache when a new extension is added + self.invalidate_completion_cache().await; + + Ok(()) } /// Add a builtin extension to the session @@ -104,6 +133,10 @@ impl Session { .await .map_err(|e| anyhow::anyhow!("Failed to start builtin extension: {}", e))?; } + + // Invalidate the completion cache when a new extension is added + self.invalidate_completion_cache().await; + Ok(()) } @@ -169,7 +202,21 @@ impl Session { self.process_message(msg).await?; } - let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?; + // Initialize the completion cache + self.update_completion_cache().await?; + + // Create a new editor with our custom completer + let config = rustyline::Config::builder() + .completion_type(rustyline::CompletionType::Circular) + .build(); + let mut editor = + rustyline::Editor::::with_config( + config, + )?; + + // Set up the completer with a reference to the completion cache + let completer = GooseCompleter::new(self.completion_cache.clone()); + editor.set_helper(Some(completer)); // Load history from messages for msg in self @@ -432,4 +479,45 @@ impl Session { pub fn session_file(&self) -> PathBuf { self.session_file.clone() } + + /// Update the completion cache with fresh data + /// This should be called before the interactive session starts + pub async fn update_completion_cache(&mut self) -> Result<()> { + // Get fresh data + let prompts = self.agent.list_extension_prompts().await; + + // Update the cache with write lock + let mut cache = self.completion_cache.write().unwrap(); + cache.prompts.clear(); + cache.prompt_info.clear(); + + for (extension, prompt_list) in prompts { + let names: Vec = prompt_list.iter().map(|p| p.name.clone()).collect(); + cache.prompts.insert(extension.clone(), names); + + for prompt in prompt_list { + cache.prompt_info.insert( + prompt.name.clone(), + output::PromptInfo { + name: prompt.name.clone(), + description: prompt.description.clone(), + arguments: prompt.arguments.clone(), + extension: Some(extension.clone()), + }, + ); + } + } + + cache.last_updated = Instant::now(); + Ok(()) + } + + /// Invalidate the completion cache + /// This should be called when extensions are added or removed + async fn invalidate_completion_cache(&self) { + let mut cache = self.completion_cache.write().unwrap(); + cache.prompts.clear(); + cache.prompt_info.clear(); + cache.last_updated = Instant::now(); + } } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index d1377eb68..66021695a 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -75,7 +75,7 @@ impl ThinkingIndicator { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PromptInfo { pub name: String, pub description: Option, From cedbac8723448796572602a4a67c4b4d18086014 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 26 Feb 2025 16:29:59 -0800 Subject: [PATCH 15/16] test(cli): add unit tests for command completion - Test slash command, prompt name, flag, and argument completion - Match assertions to actual implementation behavior - Document edge cases in test comments --- crates/goose-cli/src/session/completion.rs | 207 ++++++++++++++++++++- 1 file changed, 205 insertions(+), 2 deletions(-) diff --git a/crates/goose-cli/src/session/completion.rs b/crates/goose-cli/src/session/completion.rs index b6ea008bf..b3fc61c53 100644 --- a/crates/goose-cli/src/session/completion.rs +++ b/crates/goose-cli/src/session/completion.rs @@ -328,6 +328,209 @@ impl Validator for GooseCompleter { #[cfg(test)] mod tests { - // Tests are disabled for now due to mismatch between mock and real types - // We've manually tested the completion functionality + use super::*; + use crate::session::output; + use mcp_core::prompt::PromptArgument; + use std::sync::{Arc, RwLock}; + + // Helper function to create a test completion cache + fn create_test_cache() -> Arc> { + let mut cache = CompletionCache::new(); + + // Add some test prompts + let mut extension1_prompts = Vec::new(); + extension1_prompts.push("test_prompt1".to_string()); + extension1_prompts.push("test_prompt2".to_string()); + cache + .prompts + .insert("extension1".to_string(), extension1_prompts); + + let mut extension2_prompts = Vec::new(); + extension2_prompts.push("other_prompt".to_string()); + cache + .prompts + .insert("extension2".to_string(), extension2_prompts); + + // Add prompt info with arguments + let test_prompt1_args = vec![ + PromptArgument { + name: "required_arg".to_string(), + description: Some("A required argument".to_string()), + required: Some(true), + }, + PromptArgument { + name: "optional_arg".to_string(), + description: Some("An optional argument".to_string()), + required: Some(false), + }, + ]; + + let test_prompt1_info = output::PromptInfo { + name: "test_prompt1".to_string(), + description: Some("Test prompt 1 description".to_string()), + arguments: Some(test_prompt1_args), + extension: Some("extension1".to_string()), + }; + cache + .prompt_info + .insert("test_prompt1".to_string(), test_prompt1_info); + + let test_prompt2_info = output::PromptInfo { + name: "test_prompt2".to_string(), + description: Some("Test prompt 2 description".to_string()), + arguments: None, + extension: Some("extension1".to_string()), + }; + cache + .prompt_info + .insert("test_prompt2".to_string(), test_prompt2_info); + + let other_prompt_info = output::PromptInfo { + name: "other_prompt".to_string(), + description: Some("Other prompt description".to_string()), + arguments: None, + extension: Some("extension2".to_string()), + }; + cache + .prompt_info + .insert("other_prompt".to_string(), other_prompt_info); + + Arc::new(RwLock::new(cache)) + } + + #[test] + fn test_complete_slash_commands() { + let cache = create_test_cache(); + let completer = GooseCompleter::new(cache); + + // Test complete match + let (pos, candidates) = completer.complete_slash_commands("/exit").unwrap(); + assert_eq!(pos, 0); + assert_eq!(candidates.len(), 1); + assert_eq!(candidates[0].display, "/exit"); + assert_eq!(candidates[0].replacement, "/exit "); + + // Test partial match + let (pos, candidates) = completer.complete_slash_commands("/e").unwrap(); + assert_eq!(pos, 0); + // There might be multiple commands starting with "e" like "/exit" and "/extension" + assert!(candidates.len() >= 1); + + // Test multiple matches + let (pos, candidates) = completer.complete_slash_commands("/").unwrap(); + assert_eq!(pos, 0); + assert!(candidates.len() > 1); + + // Test no match + let (_pos, candidates) = completer.complete_slash_commands("/nonexistent").unwrap(); + assert_eq!(candidates.len(), 0); + } + + #[test] + fn test_complete_prompt_names() { + let cache = create_test_cache(); + let completer = GooseCompleter::new(cache); + + // Test with just "/prompt " + let (pos, candidates) = completer.complete_prompt_names("/prompt ").unwrap(); + assert_eq!(pos, 8); + assert_eq!(candidates.len(), 3); // All prompts + + // Test with partial prompt name + let (pos, candidates) = completer.complete_prompt_names("/prompt test").unwrap(); + assert_eq!(pos, 8); + assert_eq!(candidates.len(), 2); // test_prompt1 and test_prompt2 + + // Test with specific prompt name + let (pos, candidates) = completer + .complete_prompt_names("/prompt test_prompt1") + .unwrap(); + assert_eq!(pos, 8); + assert_eq!(candidates.len(), 1); + assert_eq!(candidates[0].display, "test_prompt1"); + + // Test with no match + let (pos, candidates) = completer + .complete_prompt_names("/prompt nonexistent") + .unwrap(); + assert_eq!(pos, 8); + assert_eq!(candidates.len(), 0); + } + + #[test] + fn test_complete_prompt_flags() { + let cache = create_test_cache(); + let completer = GooseCompleter::new(cache); + + // Test with partial flag + let (_pos, candidates) = completer + .complete_prompt_flags("/prompt test_prompt1 --") + .unwrap(); + assert_eq!(candidates.len(), 1); + assert_eq!(candidates[0].display, "--info"); + + // Test with exact flag + let (_pos, candidates) = completer + .complete_prompt_flags("/prompt test_prompt1 --info") + .unwrap(); + assert_eq!(candidates.len(), 1); + assert_eq!(candidates[0].display, "--info"); + + // Test with no match + let (_pos, candidates) = completer + .complete_prompt_flags("/prompt test_prompt1 --nonexistent") + .unwrap(); + assert_eq!(candidates.len(), 0); + + // Test with no flag + let (_pos, candidates) = completer + .complete_prompt_flags("/prompt test_prompt1") + .unwrap(); + assert_eq!(candidates.len(), 0); + } + + #[test] + fn test_complete_argument_keys() { + let cache = create_test_cache(); + let completer = GooseCompleter::new(cache); + + // Test with just a prompt name (no space after) + // This case doesn't return any candidates in the current implementation + let (_pos, candidates) = completer + .complete_argument_keys("/prompt test_prompt1") + .unwrap(); + assert_eq!(candidates.len(), 0); + + // Test with partial argument + let (_pos, candidates) = completer + .complete_argument_keys("/prompt test_prompt1 req") + .unwrap(); + assert_eq!(candidates.len(), 1); + assert_eq!(candidates[0].display, "required_arg="); + + // Test with one argument already provided + let (_pos, candidates) = completer + .complete_argument_keys("/prompt test_prompt1 required_arg=value") + .unwrap(); + assert_eq!(candidates.len(), 1); + assert_eq!(candidates[0].display, "optional_arg="); + + // Test with all arguments provided + let (_pos, candidates) = completer + .complete_argument_keys("/prompt test_prompt1 required_arg=value optional_arg=value") + .unwrap(); + assert_eq!(candidates.len(), 0); + + // Test with prompt that has no arguments + let (_pos, candidates) = completer + .complete_argument_keys("/prompt test_prompt2") + .unwrap(); + assert_eq!(candidates.len(), 0); + + // Test with nonexistent prompt + let (_pos, candidates) = completer + .complete_argument_keys("/prompt nonexistent") + .unwrap(); + assert_eq!(candidates.len(), 0); + } } From eb42a64eca18adcfda048f31ff8a51b283018e5b Mon Sep 17 00:00:00 2001 From: kalvinnchau Date: Wed, 26 Feb 2025 21:01:39 -0800 Subject: [PATCH 16/16] style: short-circut out of = \ /, cleanup rust code --- crates/goose-cli/src/session/completion.rs | 53 ++++++++++------------ 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/crates/goose-cli/src/session/completion.rs b/crates/goose-cli/src/session/completion.rs index b3fc61c53..378a444b1 100644 --- a/crates/goose-cli/src/session/completion.rs +++ b/crates/goose-cli/src/session/completion.rs @@ -140,6 +140,13 @@ impl GooseCompleter { // Check if we're trying to complete a partial argument name if let Some(last_part) = parts.last() { + // ignore if last_part starts with = / \ for suggestions + if let Some(c) = last_part.chars().next() { + if matches!(c, '=' | '/' | '\\') { + return Ok((line.len(), vec![])); + } + } + // If the last part doesn't contain '=', it might be a partial argument name if !last_part.contains('=') { // Find arguments that match the prefix @@ -170,30 +177,21 @@ impl GooseCompleter { } } - // If no partial match or no last part, suggest the first required argument - // Use a reference to avoid moving args - for arg in &args { - if arg.required.unwrap_or(false) && !existing_args.contains(&arg.name.as_str()) - { - let candidates = vec![Pair { - display: format!("{}=", arg.name), - replacement: format!("{}=", arg.name), - }]; - return Ok((line.len(), candidates)); - } - } + // Partition the arguments into required and optional, then try to find a candidate. + let (required_args, optional_args): (Vec<_>, Vec<_>) = + args.iter().partition(|arg| arg.required.unwrap_or(false)); - // If no required arguments left, suggest optional ones - // Use a reference to avoid moving args - for arg in &args { - if !arg.required.unwrap_or(true) && !existing_args.contains(&arg.name.as_str()) - { - let candidates = vec![Pair { - display: format!("{}=", arg.name), - replacement: format!("{}=", arg.name), - }]; - return Ok((line.len(), candidates)); - } + let candidate = required_args + .iter() + .chain(optional_args.iter()) // chain optional_args after required_args + .find(|arg| !existing_args.contains(&arg.name.as_str())) + .map(|arg| Pair { + display: format!("{}=", arg.name), + replacement: format!("{}=", arg.name), + }); + + if let Some(candidate) = candidate { + return Ok((line.len(), vec![candidate])); } } } @@ -338,15 +336,12 @@ mod tests { let mut cache = CompletionCache::new(); // Add some test prompts - let mut extension1_prompts = Vec::new(); - extension1_prompts.push("test_prompt1".to_string()); - extension1_prompts.push("test_prompt2".to_string()); + let extension1_prompts = vec!["test_prompt1".to_string(), "test_prompt2".to_string()]; cache .prompts .insert("extension1".to_string(), extension1_prompts); - let mut extension2_prompts = Vec::new(); - extension2_prompts.push("other_prompt".to_string()); + let extension2_prompts = vec!["other_prompt".to_string()]; cache .prompts .insert("extension2".to_string(), extension2_prompts); @@ -414,7 +409,7 @@ mod tests { let (pos, candidates) = completer.complete_slash_commands("/e").unwrap(); assert_eq!(pos, 0); // There might be multiple commands starting with "e" like "/exit" and "/extension" - assert!(candidates.len() >= 1); + assert!(!candidates.is_empty()); // Test multiple matches let (pos, candidates) = completer.complete_slash_commands("/").unwrap();