Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: permission before tool call #1313

Merged
merged 19 commits into from
Feb 24, 2025
65 changes: 57 additions & 8 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ use anyhow::Result;
use etcetera::choose_app_strategy;
use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::Agent;
use goose::config::Config;
use goose::message::{Message, MessageContent};
use mcp_core::handler::ToolError;
use rand::{distributions::Alphanumeric, Rng};
use rustyline::Editor;
use std::path::PathBuf;
use tokio;

Expand Down Expand Up @@ -104,7 +106,7 @@ impl Session {
}

pub async fn start(&mut self) -> Result<()> {
let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?;
let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?;

// Load history from messages
for msg in self
Expand All @@ -120,16 +122,18 @@ impl Session {
}
}
}

let config = Config::global();
output::display_greeting();
loop {
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string());
match input::get_input(&mut editor)? {
input::InputResult::Message(content) => {
self.messages.push(Message::user().with_text(&content));
storage::persist_messages(&self.session_file, &self.messages)?;

output::show_thinking();
self.process_agent_response().await?;
self.process_agent_response(&mut editor, Some(goose_mode.clone()))
.await?;
output::hide_thinking();
}
input::InputResult::Exit => break,
Expand Down Expand Up @@ -185,23 +189,68 @@ impl Session {
}

pub async fn headless_start(&mut self, initial_message: String) -> Result<()> {
// Load settings from config
let config = Config::global();
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string());

self.messages
.push(Message::user().with_text(&initial_message));
storage::persist_messages(&self.session_file, &self.messages)?;
self.process_agent_response().await?;
let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?;
self.process_agent_response(&mut editor, Some(goose_mode.clone()))
.await?;
Ok(())
}

async fn process_agent_response(&mut self) -> Result<()> {
let mut stream = self.agent.reply(&self.messages).await?;
async fn process_agent_response(
&mut self,
editor: &mut Editor<(), rustyline::history::DefaultHistory>,
goose_mode: Option<String>,
) -> Result<()> {
let mut stream = self.agent.reply(&self.messages, goose_mode).await?;

use futures::StreamExt;
loop {
tokio::select! {
result = stream.next() => {
match result {
Some(Ok(message)) => {
self.messages.push(message.clone());
Some(Ok(mut message)) => {

// Handle tool confirmation requests before rendering
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
output::hide_thinking();

// Format the confirmation prompt
let prompt = format!(
"Goose would like to call the tool: {}\nWith arguments: {}\nAllow? (y/n): ",
confirmation.tool_name,
serde_json::to_string_pretty(&confirmation.arguments).unwrap_or_default()
);
output::render_message(&Message::assistant().with_text(&prompt));

// Get confirmation from user
let confirmed = match input::get_input(editor)? {
input::InputResult::Message(content) => {
content.trim().to_lowercase().starts_with('y')
}
_ => false,
};
let confirmation_request = Message::user().with_tool_confirmation_request(
confirmation.id.clone(),
confirmation.tool_name.clone(),
confirmation.arguments.clone(),
);

self.agent.handle_confirmation(confirmation.id.clone(), confirmed).await;

message = confirmation_request;
}

// Only push the message if it's not a tool confirmation request
if !message.content.iter().any(|content| matches!(content, MessageContent::ToolConfirmationRequest(_))) {
self.messages.push(message.clone());
}

storage::persist_messages(&self.session_file, &self.messages)?;
output::hide_thinking();
output::render_message(&message);
Expand Down
1 change: 1 addition & 0 deletions crates/goose-cli/src/session/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub fn render_message(message: &Message) {
MessageContent::Text(text) => print_markdown(&text.text, theme),
MessageContent::ToolRequest(req) => render_tool_request(req, theme),
MessageContent::ToolResponse(resp) => render_tool_response(resp, theme),
MessageContent::ToolConfirmationRequest(_) => {}
MessageContent::Image(image) => {
println!("Image: [data: {}, type: {}]", image.data, image.mime_type);
}
Expand Down
12 changes: 7 additions & 5 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,11 @@ async fn stream_message(
.await?;
}
}
MessageContent::ToolConfirmationRequest(_) => {
// TODO
}
MessageContent::Image(_) => {
// TODO
continue;
}
MessageContent::ToolResponse(_) => {
// Tool responses should only come from the user
Expand Down Expand Up @@ -308,10 +310,10 @@ async fn handler(
}
};

let mut stream = match agent.reply(&messages).await {
let mut stream = match agent.reply(&messages, Some("auto".to_string())).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
tracing::error!("Failed to start reply stream: {:?}", e);
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
Expand Down Expand Up @@ -395,10 +397,10 @@ async fn ask_handler(

// Get response from agent
let mut response_text = String::new();
let mut stream = match agent.reply(&messages).await {
let mut stream = match agent.reply(&messages, Some("auto".to_string())).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
tracing::error!("Failed to start reply stream: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
Expand Down
5 changes: 4 additions & 1 deletion crates/goose/examples/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ async fn main() {
let messages = vec![Message::user()
.with_text("can you summarize the readme.md in this dir using just a haiku?")];

let mut stream = agent.reply(&messages).await.unwrap();
let mut stream = agent
.reply(&messages, Some("auto".to_string()))
.await
.unwrap();
while let Some(message) = stream.next().await {
println!(
"{}",
Expand Down
9 changes: 8 additions & 1 deletion crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ use crate::providers::base::ProviderUsage;
#[async_trait]
pub trait Agent: Send + Sync {
/// Create a stream that yields each message as it's generated by the agent
async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>>;
async fn reply(
&self,
messages: &[Message],
goose_mode: Option<String>,
) -> Result<BoxStream<'_, Result<Message>>>;

/// Add a new MCP client to the agent
async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()>;
Expand All @@ -32,6 +36,9 @@ pub trait Agent: Send + Sync {
/// Add custom text to be included in the system prompt
async fn extend_system_prompt(&mut self, extension: String);

/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool);

/// Override the system prompt with custom text
async fn override_system_prompt(&mut self, template: String);
}
5 changes: 5 additions & 0 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ impl Agent for ReferenceAgent {
Ok(Value::Null)
}

async fn handle_confirmation(&self, _request_id: String, _confirmed: bool) {
// TODO implement
}

#[instrument(skip(self, messages), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
_goose_mode: Option<String>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
Expand Down
106 changes: 87 additions & 19 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};

Expand All @@ -16,7 +17,7 @@ use crate::register_agent;
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use indoc::indoc;
use mcp_core::tool::Tool;
use mcp_core::{tool::Tool, Content};
use serde_json::{json, Value};

const MAX_TRUNCATION_ATTEMPTS: usize = 3;
Expand All @@ -26,14 +27,21 @@ const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
pub struct TruncateAgent {
capabilities: Mutex<Capabilities>,
token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
}

impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channel with buffer size 32 (adjust if needed)
let (tx, rx) = mpsc::channel(32);

Self {
capabilities: Mutex::new(Capabilities::new(provider)),
token_counter,
confirmation_tx: tx,
confirmation_rx: Mutex::new(rx),
}
}

Expand Down Expand Up @@ -121,10 +129,18 @@ impl Agent for TruncateAgent {
Ok(Value::Null)
}

/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmed)).await {
error!("Failed to send confirmation: {}", e);
}
}

#[instrument(skip(self, messages), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
goose_mode: Option<String>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
Expand Down Expand Up @@ -191,7 +207,6 @@ impl Agent for TruncateAgent {
Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
// Attempt to get completion from provider
match capabilities.provider().complete(
&system_prompt,
&messages,
Expand All @@ -218,24 +233,77 @@ impl Agent for TruncateAgent {
break;
}

// Then dispatch each in parallel
let futures: Vec<_> = tool_requests
.iter()
.filter_map(|request| request.tool_call.clone().ok())
.map(|tool_call| capabilities.dispatch_tool_call(tool_call))
.collect();

// Process all the futures in parallel but wait until all are finished
let outputs = futures::future::join_all(futures).await;

// Create a message with the responses
// Process each tool request sequentially, asking for confirmation
let mut message_tool_response = Message::user();
// Now combine these into MessageContent::ToolResponse using the original ID
for (request, output) in tool_requests.iter().zip(outputs.into_iter()) {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone().unwrap_or_else(|| "auto".to_string());
match mode.as_str() {
"approve" => {
// Process each tool request sequentially with confirmation
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
);
yield confirmation;

// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
if let Some((req_id, confirmed)) = rx.recv().await {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add any waiting deadline?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for cli, we shouldn't add a timeout imo, but for GUI, it should be considered

if req_id == request.id {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, will we have the case that req_id != request.id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, the identical message id should be passed back.

if confirmed {
// User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text("User declined to run this tool.")]),
);
}
}
}
}
}
},
"chat" => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of skipping tool calls in the loop, we really want to not advertise the tools to the agent at all? e.g. remove them from the call to provider.complete - otherwise i could see this leading to confusing dialogue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subtle detail, but the ToolConfirmationRequest is not added to the agent loop and is not visible to the agent. Added some context in the chat mode to allow the agent to frame the tool call requests as plans.

// Skip all tool calls in chat mode
for request in &tool_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text("Tool call skipped in Goose chat mode")]),
);
}
},
_ => {
if mode != "auto" {
warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode.");
}
// Process tool requests in parallel
let mut tool_futures = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
tool_futures.push(async {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request.id.clone(), output)
});
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
}

yield message_tool_response.clone();
Expand Down
Loading
Loading