diff --git a/crates/goose-bench/Cargo.toml b/crates/goose-bench/Cargo.toml new file mode 100644 index 000000000..78fac6b22 --- /dev/null +++ b/crates/goose-bench/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "goose-bench" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +description.workspace = true + + +[dependencies] +anyhow = "1.0" +paste = "1.0" +ctor = "0.2.7" +goose = { path = "../goose" } +async-trait = "0.1.86" + +[target.'cfg(target_os = "windows")'.dependencies] +winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-bench/src/eval_suites/core/complex_tasks/flappy_bird.rs b/crates/goose-bench/src/eval_suites/core/complex_tasks/flappy_bird.rs new file mode 100644 index 000000000..61d345355 --- /dev/null +++ b/crates/goose-bench/src/eval_suites/core/complex_tasks/flappy_bird.rs @@ -0,0 +1,22 @@ +use crate::eval_suites::{BenchAgent, Evaluation, EvaluationMetric}; +use crate::register_evaluation; +use async_trait::async_trait; + +pub struct FlappyBird {} + +impl FlappyBird { + pub fn new() -> Self { + FlappyBird {} + } +} + +#[async_trait] +impl Evaluation for FlappyBird { + async fn run(&self, mut agent: Box) -> anyhow::Result> { + let metrics = Vec::new(); + let _ = agent.prompt("What can you do?".to_string()).await; + Ok(metrics) + } +} + +register_evaluation!("core", FlappyBird); diff --git a/crates/goose-bench/src/eval_suites/core/complex_tasks/mod.rs b/crates/goose-bench/src/eval_suites/core/complex_tasks/mod.rs new file mode 100644 index 000000000..c09f5510f --- /dev/null +++ b/crates/goose-bench/src/eval_suites/core/complex_tasks/mod.rs @@ -0,0 +1 @@ +mod flappy_bird; diff --git a/crates/goose-bench/src/eval_suites/core/mod.rs b/crates/goose-bench/src/eval_suites/core/mod.rs new file mode 100644 index 000000000..a1efebf95 --- /dev/null +++ b/crates/goose-bench/src/eval_suites/core/mod.rs @@ -0,0 +1 @@ +mod complex_tasks; diff --git a/crates/goose-bench/src/eval_suites/evaluation.rs b/crates/goose-bench/src/eval_suites/evaluation.rs new file mode 100644 index 000000000..87890b772 --- /dev/null +++ b/crates/goose-bench/src/eval_suites/evaluation.rs @@ -0,0 +1,24 @@ +use anyhow::Result; +use async_trait::async_trait; +use goose::message::Message; + +pub type Model = (String, String); +pub type Extension = String; + +#[derive(Debug)] +pub enum EvaluationMetric { + Integer(i64), + Float(f64), + String(String), + Boolean(bool), +} + +#[async_trait] +pub trait BenchAgent: Send + Sync { + async fn prompt(&mut self, p: String) -> Result>; +} + +#[async_trait] +pub trait Evaluation: Send + Sync { + async fn run(&self, agent: Box) -> Result>; +} diff --git a/crates/goose-bench/src/eval_suites/factory.rs b/crates/goose-bench/src/eval_suites/factory.rs new file mode 100644 index 000000000..0f361ac44 --- /dev/null +++ b/crates/goose-bench/src/eval_suites/factory.rs @@ -0,0 +1,65 @@ +pub use super::Evaluation; +use std::collections::HashMap; +use std::sync::{OnceLock, RwLock}; + +type EvaluationConstructor = fn() -> Box; + +// Use std::sync::RwLock for interior mutability +static EVALUATION_REGISTRY: OnceLock>>> = + OnceLock::new(); + +/// Initialize the registry if it hasn't been initialized +fn registry() -> &'static RwLock>> { + EVALUATION_REGISTRY.get_or_init(|| RwLock::new(HashMap::new())) +} + +/// Register a new evaluation version +pub fn register_evaluation(suite_name: &'static str, constructor: fn() -> Box) { + let registry = registry(); + if let Ok(mut map) = registry.write() { + map.entry(suite_name) + .or_insert_with(Vec::new) + .push(constructor); + } +} + +pub struct EvaluationSuiteFactory; + +impl EvaluationSuiteFactory { + pub fn create(suite_name: &str) -> Option>> { + let registry = registry(); + let map = registry + .read() + .expect("Failed to read the benchmark evaluation registry."); + + let constructors = map.get(suite_name)?; + let instances = constructors + .iter() + .map(|&constructor| constructor()) + .collect::>(); + + Some(instances) + } + + pub fn available_evaluations() -> Vec<&'static str> { + registry() + .read() + .map(|map| map.keys().copied().collect()) + .unwrap_or_default() + } +} + +#[macro_export] +macro_rules! register_evaluation { + ($suite_name:expr, $evaluation_type:ty) => { + paste::paste! { + #[ctor::ctor] + #[allow(non_snake_case)] + fn [<__register_evaluation_ $suite_name>]() { + $crate::eval_suites::factory::register_evaluation($suite_name, || { + Box::new(<$evaluation_type>::new()) + }); + } + } + }; +} diff --git a/crates/goose-bench/src/eval_suites/mod.rs b/crates/goose-bench/src/eval_suites/mod.rs new file mode 100644 index 000000000..82404e34b --- /dev/null +++ b/crates/goose-bench/src/eval_suites/mod.rs @@ -0,0 +1,6 @@ +mod core; +mod evaluation; +mod factory; + +pub use evaluation::*; +pub use factory::{register_evaluation, EvaluationSuiteFactory}; diff --git a/crates/goose-bench/src/lib.rs b/crates/goose-bench/src/lib.rs new file mode 100644 index 000000000..2881661ea --- /dev/null +++ b/crates/goose-bench/src/lib.rs @@ -0,0 +1 @@ +pub mod eval_suites; diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 41cb85a60..a4554fba2 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -13,6 +13,7 @@ path = "src/main.rs" [dependencies] goose = { path = "../goose" } +goose-bench = { path = "../goose-bench" } goose-mcp = { path = "../goose-mcp" } mcp-client = { path = "../mcp-client" } mcp-server = { path = "../mcp-server" } @@ -47,6 +48,7 @@ chrono = "0.4" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } tracing-appender = "0.2" once_cell = "1.20.2" +async-trait = "0.1.86" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/commands/bench.rs b/crates/goose-cli/src/commands/bench.rs new file mode 100644 index 000000000..59be398f9 --- /dev/null +++ b/crates/goose-cli/src/commands/bench.rs @@ -0,0 +1,34 @@ +use crate::session::build_session; +use crate::Session; +use async_trait::async_trait; +use goose::message::Message; +use goose_bench::eval_suites::{BenchAgent, EvaluationSuiteFactory}; + +#[async_trait] +impl BenchAgent for Session { + async fn prompt(&mut self, p: String) -> anyhow::Result> { + self.headless_start(p).await?; + Ok(self.message_history()) + } +} + +pub async fn run_benchmark(suites: Vec) { + let suites = EvaluationSuiteFactory::available_evaluations() + .into_iter() + .filter(|&s| suites.contains(&s.to_string())) + .collect::>(); + + for suite in suites { + let evaluations = match EvaluationSuiteFactory::create(suite) { + Some(evaluations) => evaluations, + None => continue, + }; + for evaluation in evaluations { + let session = build_session(None, false, Vec::new(), Vec::new()).await; + let _ = match evaluation.run(Box::new(session)).await { + Ok(report) => report, + _ => continue, + }; + } + } +} diff --git a/crates/goose-cli/src/commands/mod.rs b/crates/goose-cli/src/commands/mod.rs index e9ed50ce5..b702cddea 100644 --- a/crates/goose-cli/src/commands/mod.rs +++ b/crates/goose-cli/src/commands/mod.rs @@ -1,3 +1,4 @@ pub mod agent_version; +pub mod bench; pub mod configure; pub mod mcp; diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index 5443e142c..e53883b83 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -4,6 +4,7 @@ use clap::{CommandFactory, Parser, Subcommand}; use console::style; use goose::config::Config; use goose_cli::commands::agent_version::AgentCommand; +use goose_cli::commands::bench::run_benchmark; use goose_cli::commands::configure::handle_configure; use goose_cli::commands::mcp::run_server; use goose_cli::logging::setup_logging; @@ -140,6 +141,19 @@ enum Command { /// List available agent versions Agents(AgentCommand), + + /// Run benchmark suite + Bench { + #[arg( + short = 's', + long = "suites", + value_name = "BENCH_SUITE_NAME", + help = "Run this list of bench-suites.", + long_help = "Specify a comma-separated list of evaluation-suite names to be run.", + value_delimiter = ',' + )] + suites: Vec, + }, } #[derive(clap::ValueEnum, Clone, Debug)] @@ -207,6 +221,15 @@ async fn main() -> Result<()> { cmd.run()?; return Ok(()); } + Some(Command::Bench { suites }) => { + let suites = if suites.is_empty() { + vec!["core".to_string()] + } else { + suites + }; + run_benchmark(suites).await; + return Ok(()); + } None => { Cli::command().print_help()?; println!(); diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 2a65ef058..37acf739f 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -299,4 +299,8 @@ impl Session { pub fn session_file(&self) -> PathBuf { self.session_file.clone() } + + pub fn message_history(&self) -> Vec { + self.messages.clone() + } }