Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: permission before tool call #1313

Merged
merged 19 commits into from
Feb 24, 2025
51 changes: 44 additions & 7 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use goose::agents::Agent;
use goose::message::{Message, MessageContent};
use mcp_core::handler::ToolError;
use rand::{distributions::Alphanumeric, Rng};
use rustyline::Editor;
use std::path::PathBuf;
use tokio;

Expand Down Expand Up @@ -104,7 +105,7 @@ impl Session {
}

pub async fn start(&mut self) -> Result<()> {
let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?;
let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?;

// Load history from messages
for msg in self
Expand All @@ -120,7 +121,6 @@ impl Session {
}
}
}

output::display_greeting();
loop {
match input::get_input(&mut editor)? {
Expand All @@ -129,7 +129,7 @@ impl Session {
storage::persist_messages(&self.session_file, &self.messages)?;

output::show_thinking();
self.process_agent_response().await?;
self.process_agent_response(&mut editor).await?;
output::hide_thinking();
}
input::InputResult::Exit => break,
Expand Down Expand Up @@ -188,20 +188,57 @@ impl Session {
self.messages
.push(Message::user().with_text(&initial_message));
storage::persist_messages(&self.session_file, &self.messages)?;
self.process_agent_response().await?;
let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?;
self.process_agent_response(&mut editor).await?;
Ok(())
}

async fn process_agent_response(&mut self) -> Result<()> {
async fn process_agent_response(
&mut self,
editor: &mut Editor<(), rustyline::history::DefaultHistory>,
) -> Result<()> {
let mut stream = self.agent.reply(&self.messages).await?;

use futures::StreamExt;
loop {
tokio::select! {
result = stream.next() => {
match result {
Some(Ok(message)) => {
self.messages.push(message.clone());
Some(Ok(mut message)) => {

// Handle tool confirmation requests before rendering
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
output::hide_thinking();

// Format the confirmation prompt
let prompt = "Goose would like to call the above tool. Allow? (y/n):".to_string();

let confirmation_request = Message::user().with_tool_confirmation_request(
confirmation.id.clone(),
confirmation.tool_name.clone(),
confirmation.arguments.clone(),
Some(prompt)
);
output::render_message(&confirmation_request);

// Get confirmation from user
let confirmed = match input::get_input(editor)? {
input::InputResult::Message(content) => {
content.trim().to_lowercase().starts_with('y')
}
_ => false,
};

self.agent.handle_confirmation(confirmation.id.clone(), confirmed).await;

message = confirmation_request;
}

// Only push the message if it's not a tool confirmation request
if !message.content.iter().any(|content| matches!(content, MessageContent::ToolConfirmationRequest(_))) {
self.messages.push(message.clone());
}

storage::persist_messages(&self.session_file, &self.messages)?;
output::hide_thinking();
output::render_message(&message);
Expand Down
16 changes: 15 additions & 1 deletion crates/goose-cli/src/session/output.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bat::WrappingMode;
use console::style;
use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use goose::message::{Message, MessageContent, ToolConfirmationRequest, ToolRequest, ToolResponse};
use mcp_core::tool::ToolCall;
use serde_json::Value;
use std::cell::RefCell;
Expand Down Expand Up @@ -94,6 +94,9 @@ pub fn render_message(message: &Message) {
MessageContent::Text(text) => print_markdown(&text.text, theme),
MessageContent::ToolRequest(req) => render_tool_request(req, theme),
MessageContent::ToolResponse(resp) => render_tool_response(resp, theme),
MessageContent::ToolConfirmationRequest(req) => {
render_tool_confirmation_request(req, theme)
}
MessageContent::Image(image) => {
println!("Image: [data: {}, type: {}]", image.data, image.mime_type);
}
Expand Down Expand Up @@ -147,6 +150,17 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme) {
}
}

fn render_tool_confirmation_request(req: &ToolConfirmationRequest, theme: Theme) {
match &req.prompt {
Some(prompt) => {
let colored_prompt =
prompt.replace("Allow? (y/n)", &format!("{}", style("Allow? (y/n)").cyan()));
println!("{}", colored_prompt);
}
None => print_markdown("No prompt provided", theme),
}
}

pub fn render_error(message: &str) {
println!("\n {} {}\n", style("error:").red().bold(), message);
}
Expand Down
13 changes: 7 additions & 6 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,14 @@ async fn stream_message(
.await?;
}
}
MessageContent::ToolConfirmationRequest(_) => {
// skip tool confirmation requests
}
MessageContent::Image(_) => {
// TODO
continue;
// skip images
}
MessageContent::ToolResponse(_) => {
// Tool responses should only come from the user
continue;
// skip tool responses
}
}
}
Expand Down Expand Up @@ -311,7 +312,7 @@ async fn handler(
let mut stream = match agent.reply(&messages).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
tracing::error!("Failed to start reply stream: {:?}", e);
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
Expand Down Expand Up @@ -398,7 +399,7 @@ async fn ask_handler(
let mut stream = match agent.reply(&messages).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
tracing::error!("Failed to start reply stream: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ pub trait Agent: Send + Sync {
/// Add custom text to be included in the system prompt
async fn extend_system_prompt(&mut self, extension: String);

/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool);

/// Override the system prompt with custom text
async fn override_system_prompt(&mut self, template: String);
}
4 changes: 4 additions & 0 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ impl Agent for ReferenceAgent {
Ok(Value::Null)
}

async fn handle_confirmation(&self, _request_id: String, _confirmed: bool) {
// TODO implement
}

#[instrument(skip(self, messages), fields(user_message))]
async fn reply(
&self,
Expand Down
119 changes: 100 additions & 19 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};

use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::config::Config;
use crate::message::{Message, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand All @@ -16,7 +18,7 @@ use crate::register_agent;
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use indoc::indoc;
use mcp_core::tool::Tool;
use mcp_core::{tool::Tool, Content};
use serde_json::{json, Value};

const MAX_TRUNCATION_ATTEMPTS: usize = 3;
Expand All @@ -26,14 +28,21 @@ const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
pub struct TruncateAgent {
capabilities: Mutex<Capabilities>,
token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
}

impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channel with buffer size 32 (adjust if needed)
let (tx, rx) = mpsc::channel(32);

Self {
capabilities: Mutex::new(Capabilities::new(provider)),
token_counter,
confirmation_tx: tx,
confirmation_rx: Mutex::new(rx),
}
}

Expand Down Expand Up @@ -121,6 +130,13 @@ impl Agent for TruncateAgent {
Ok(Value::Null)
}

/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmed)).await {
error!("Failed to send confirmation: {}", e);
}
}

#[instrument(skip(self, messages), fields(user_message))]
async fn reply(
&self,
Expand All @@ -132,6 +148,10 @@ impl Agent for TruncateAgent {
let mut tools = capabilities.get_prefixed_tools().await?;
let mut truncation_attempt: usize = 0;

// Load settings from config
let config = Config::global();
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string());

// we add in the 2 resource tools if any extensions support resources
// TODO: make sure there is no collision with another extension's tool name
let read_resource_tool = Tool::new(
Expand Down Expand Up @@ -191,7 +211,6 @@ impl Agent for TruncateAgent {
Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
// Attempt to get completion from provider
match capabilities.provider().complete(
&system_prompt,
&messages,
Expand All @@ -218,24 +237,86 @@ impl Agent for TruncateAgent {
break;
}

// Then dispatch each in parallel
let futures: Vec<_> = tool_requests
.iter()
.filter_map(|request| request.tool_call.clone().ok())
.map(|tool_call| capabilities.dispatch_tool_call(tool_call))
.collect();

// Process all the futures in parallel but wait until all are finished
let outputs = futures::future::join_all(futures).await;

// Create a message with the responses
// Process tool requests depending on goose_mode
let mut message_tool_response = Message::user();
// Now combine these into MessageContent::ToolResponse using the original ID
for (request, output) in tool_requests.iter().zip(outputs.into_iter()) {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone();
match mode.as_str() {
"approve" => {
// Process each tool request sequentially with confirmation
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the tool: {}\nAllow? (y/n): ".to_string()),
);
yield confirmation;

// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
if let Some((req_id, confirmed)) = rx.recv().await {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add any waiting deadline?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for cli, we shouldn't add a timeout imo, but for GUI, it should be considered

if req_id == request.id {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, will we have the case that req_id != request.id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, the identical message id should be passed back.

if confirmed {
// User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text("User declined to run this tool.")]),
);
}
}
}
}
}
},
"chat" => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of skipping tool calls in the loop, we really want to not advertise the tools to the agent at all? e.g. remove them from the call to provider.complete - otherwise i could see this leading to confusing dialogue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subtle detail, but the ToolConfirmationRequest is not added to the agent loop and is not visible to the agent. Added some context in the chat mode to allow the agent to frame the tool call requests as plans.

// Skip all tool calls in chat mode
for request in &tool_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"The following tool call was skipped in Goose chat mode. \
In chat mode, you cannot run tool calls, instead, you can \
only provide a detailed plan to the user. Provide an \
explanation of the proposed tool call as if it were a plan. \
Only if the user asks, provide a short explanation to the \
user that they could consider running the tool above on \
their own or with a different goose mode."
)]),
);
}
},
_ => {
if mode != "auto" {
warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode.");
}
// Process tool requests in parallel
let mut tool_futures = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
tool_futures.push(async {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request.id.clone(), output)
});
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
}

yield message_tool_response.clone();
Expand Down
Loading
Loading