Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data model #1276

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions crates/goose-bench/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "goose-bench"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
description.workspace = true


[dependencies]
anyhow = "1.0"
paste = "1.0"
ctor = "0.2.7"
goose = { path = "../goose" }
async-trait = "0.1.86"

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::eval_suites::{BenchAgent, Evaluation, EvaluationMetric};
use crate::register_evaluation;
use async_trait::async_trait;

pub struct FlappyBird {}

impl FlappyBird {
pub fn new() -> Self {
FlappyBird {}
}
}

#[async_trait]
impl Evaluation for FlappyBird {
async fn run(&self, mut agent: Box<dyn BenchAgent>) -> anyhow::Result<Vec<EvaluationMetric>> {
let metrics = Vec::new();
let _ = agent.prompt("What can you do?".to_string()).await;
Ok(metrics)
}
}

register_evaluation!("core", FlappyBird);
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod flappy_bird;
1 change: 1 addition & 0 deletions crates/goose-bench/src/eval_suites/core/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod complex_tasks;
24 changes: 24 additions & 0 deletions crates/goose-bench/src/eval_suites/evaluation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use anyhow::Result;
use async_trait::async_trait;
use goose::message::Message;

pub type Model = (String, String);
pub type Extension = String;

#[derive(Debug)]
pub enum EvaluationMetric {
Integer(i64),
Float(f64),
String(String),
Boolean(bool),
}

#[async_trait]
pub trait BenchAgent: Send + Sync {
async fn prompt(&mut self, p: String) -> Result<Vec<Message>>;
}

#[async_trait]
pub trait Evaluation: Send + Sync {
async fn run(&self, agent: Box<dyn BenchAgent>) -> Result<Vec<EvaluationMetric>>;
}
65 changes: 65 additions & 0 deletions crates/goose-bench/src/eval_suites/factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
pub use super::Evaluation;
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};

type EvaluationConstructor = fn() -> Box<dyn Evaluation>;

// Use std::sync::RwLock for interior mutability
static EVALUATION_REGISTRY: OnceLock<RwLock<HashMap<&'static str, Vec<EvaluationConstructor>>>> =
OnceLock::new();

/// Initialize the registry if it hasn't been initialized
fn registry() -> &'static RwLock<HashMap<&'static str, Vec<EvaluationConstructor>>> {
EVALUATION_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
}

/// Register a new evaluation version
pub fn register_evaluation(suite_name: &'static str, constructor: fn() -> Box<dyn Evaluation>) {
let registry = registry();
if let Ok(mut map) = registry.write() {
map.entry(suite_name)
.or_insert_with(Vec::new)
.push(constructor);
}
}

pub struct EvaluationSuiteFactory;

impl EvaluationSuiteFactory {
pub fn create(suite_name: &str) -> Option<Vec<Box<dyn Evaluation>>> {
let registry = registry();
let map = registry
.read()
.expect("Failed to read the benchmark evaluation registry.");

let constructors = map.get(suite_name)?;
let instances = constructors
.iter()
.map(|&constructor| constructor())
.collect::<Vec<_>>();

Some(instances)
}

pub fn available_evaluations() -> Vec<&'static str> {
registry()
.read()
.map(|map| map.keys().copied().collect())
.unwrap_or_default()
}
}

#[macro_export]
macro_rules! register_evaluation {
($suite_name:expr, $evaluation_type:ty) => {
paste::paste! {
#[ctor::ctor]
#[allow(non_snake_case)]
fn [<__register_evaluation_ $suite_name>]() {
$crate::eval_suites::factory::register_evaluation($suite_name, || {
Box::new(<$evaluation_type>::new())
});
}
}
};
}
6 changes: 6 additions & 0 deletions crates/goose-bench/src/eval_suites/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
mod core;
mod evaluation;
mod factory;

pub use evaluation::*;
pub use factory::{register_evaluation, EvaluationSuiteFactory};
1 change: 1 addition & 0 deletions crates/goose-bench/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod eval_suites;
2 changes: 2 additions & 0 deletions crates/goose-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ path = "src/main.rs"

[dependencies]
goose = { path = "../goose" }
goose-bench = { path = "../goose-bench" }
goose-mcp = { path = "../goose-mcp" }
mcp-client = { path = "../mcp-client" }
mcp-server = { path = "../mcp-server" }
Expand Down Expand Up @@ -47,6 +48,7 @@ chrono = "0.4"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] }
tracing-appender = "0.2"
once_cell = "1.20.2"
async-trait = "0.1.86"

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }
Expand Down
34 changes: 34 additions & 0 deletions crates/goose-cli/src/commands/bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use crate::session::build_session;
use crate::Session;
use async_trait::async_trait;
use goose::message::Message;
use goose_bench::eval_suites::{BenchAgent, EvaluationSuiteFactory};

#[async_trait]
impl BenchAgent for Session {
async fn prompt(&mut self, p: String) -> anyhow::Result<Vec<Message>> {
self.headless_start(p).await?;
Ok(self.message_history())
}
}

pub async fn run_benchmark(suites: Vec<String>) {
let suites = EvaluationSuiteFactory::available_evaluations()
.into_iter()
.filter(|&s| suites.contains(&s.to_string()))
.collect::<Vec<_>>();

for suite in suites {
let evaluations = match EvaluationSuiteFactory::create(suite) {
Some(evaluations) => evaluations,
None => continue,
};
for evaluation in evaluations {
let session = build_session(None, false, Vec::new(), Vec::new()).await;
let _ = match evaluation.run(Box::new(session)).await {
Ok(report) => report,
_ => continue,
};
}
}
}
1 change: 1 addition & 0 deletions crates/goose-cli/src/commands/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod agent_version;
pub mod bench;
pub mod configure;
pub mod mcp;
23 changes: 23 additions & 0 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use clap::{CommandFactory, Parser, Subcommand};
use console::style;
use goose::config::Config;
use goose_cli::commands::agent_version::AgentCommand;
use goose_cli::commands::bench::run_benchmark;
use goose_cli::commands::configure::handle_configure;
use goose_cli::commands::mcp::run_server;
use goose_cli::logging::setup_logging;
Expand Down Expand Up @@ -140,6 +141,19 @@ enum Command {

/// List available agent versions
Agents(AgentCommand),

/// Run benchmark suite
Bench {
#[arg(
short = 's',
long = "suites",
value_name = "BENCH_SUITE_NAME",
help = "Run this list of bench-suites.",
long_help = "Specify a comma-separated list of evaluation-suite names to be run.",
value_delimiter = ','
)]
suites: Vec<String>,
},
}

#[derive(clap::ValueEnum, Clone, Debug)]
Expand Down Expand Up @@ -207,6 +221,15 @@ async fn main() -> Result<()> {
cmd.run()?;
return Ok(());
}
Some(Command::Bench { suites }) => {
let suites = if suites.is_empty() {
vec!["core".to_string()]
} else {
suites
};
run_benchmark(suites).await;
return Ok(());
}
None => {
Cli::command().print_help()?;
println!();
Expand Down
4 changes: 4 additions & 0 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,8 @@ impl Session {
pub fn session_file(&self) -> PathBuf {
self.session_file.clone()
}

pub fn message_history(&self) -> Vec<Message> {
self.messages.clone()
}
}
Loading