diff --git a/Cargo.lock b/Cargo.lock index 959019b4..85e35fa4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12,6 +12,189 @@ dependencies = [ "regex", ] +[[package]] +name = "actix-codec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f7b0a21988c1bf877cf4759ef5ddaac04c1c9fe808c9142ecb78ba97d97a28a" +dependencies = [ + "bitflags 2.8.0", + "bytes", + "futures-core", + "futures-sink", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "actix-http" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d48f96fc3003717aeb9856ca3d02a8c7de502667ad76eeacd830b48d2e91fac4" +dependencies = [ + "actix-codec", + "actix-rt", + "actix-service", + "actix-utils", + "ahash 0.8.11", + "base64 0.22.1", + "bitflags 2.8.0", + "brotli", + "bytes", + "bytestring", + "derive_more 0.99.19", + "encoding_rs", + "flate2", + "futures-core", + "h2 0.3.26", + "http 0.2.12", + "httparse", + "httpdate", + "itoa", + "language-tags", + "local-channel", + "mime", + "percent-encoding", + "pin-project-lite", + "rand", + "sha1", + "smallvec", + "tokio", + "tokio-util", + "tracing", + "zstd 0.13.2", +] + +[[package]] +name = "actix-macros" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" +dependencies = [ + "quote", + "syn 2.0.98", +] + +[[package]] +name = "actix-router" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13d324164c51f63867b57e73ba5936ea151b8a41a1d23d1031eeb9f70d0236f8" +dependencies = [ + "bytestring", + "cfg-if", + "http 0.2.12", + "regex", + "regex-lite", + "serde", + "tracing", +] + +[[package]] +name = "actix-rt" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24eda4e2a6e042aa4e55ac438a2ae052d3b5da0ecf83d7411e1a368946925208" +dependencies = [ + "futures-core", + "tokio", +] + +[[package]] +name = "actix-server" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ca2549781d8dd6d75c40cf6b6051260a2cc2f3c62343d761a969a0640646894" +dependencies = [ + "actix-rt", + "actix-service", + "actix-utils", + "futures-core", + "futures-util", + "mio", + "socket2 0.5.8", + "tokio", + "tracing", +] + +[[package]] +name = "actix-service" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b894941f818cfdc7ccc4b9e60fa7e53b5042a2e8567270f9147d5591893373a" +dependencies = [ + "futures-core", + "paste", + "pin-project-lite", +] + +[[package]] +name = "actix-utils" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a1dcdff1466e3c2488e1cb5c36a71822750ad43839937f85d2f4d9f8b705d8" +dependencies = [ + "local-waker", + "pin-project-lite", +] + +[[package]] +name = "actix-web" +version = "4.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9180d76e5cc7ccbc4d60a506f2c727730b154010262df5b910eb17dbe4b8cb38" +dependencies = [ + "actix-codec", + "actix-http", + "actix-macros", + "actix-router", + "actix-rt", + "actix-server", + "actix-service", + "actix-utils", + "actix-web-codegen", + "ahash 0.8.11", + "bytes", + "bytestring", + "cfg-if", + "cookie", + "derive_more 0.99.19", + "encoding_rs", + "futures-core", + "futures-util", + "impl-more", + "itoa", + "language-tags", + "log", + "mime", + "once_cell", + "pin-project-lite", + "regex", + "regex-lite", + "serde", + "serde_json", + "serde_urlencoded", + "smallvec", + "socket2 0.5.8", + "time", + "url", +] + +[[package]] +name = "actix-web-codegen" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f591380e2e68490b5dfaf1dd1aa0ebe78d84ba7067078512b4ea6e4492d622b8" +dependencies = [ + "actix-router", + "proc-macro2", + "quote", + "syn 2.0.98", +] + [[package]] name = "addr" version = "0.15.6" @@ -87,6 +270,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -1535,6 +1733,27 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "brotli" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bs58" version = "0.5.1" @@ -1659,6 +1878,15 @@ dependencies = [ "either", ] +[[package]] +name = "bytestring" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e465647ae23b2823b0753f50decb2d5a86d2bb2cac04788fafd1f80e45378e5f" +dependencies = [ + "bytes", +] + [[package]] name = "bzip2" version = "0.4.4" @@ -2083,6 +2311,17 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "cookie" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e859cd57d0710d9e06c381b550c06e76992472a8c6d527aecd2fc673dcc231fb" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -3408,6 +3647,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.73.0" @@ -4489,6 +4739,22 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.10" @@ -4764,6 +5030,12 @@ dependencies = [ "parity-scale-codec", ] +[[package]] +name = "impl-more" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a5a9a0ff0086c7a148acb942baaabeadf9504d10400b5a05645853729b9cd2" + [[package]] name = "impl-rlp" version = "0.3.0" @@ -5486,6 +5758,12 @@ dependencies = [ "url", ] +[[package]] +name = "language-tags" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4345964bb142484797b161f473a503a434de77149dd8c7427788c6e13379388" + [[package]] name = "lazy_static" version = "1.5.0" @@ -5588,9 +5866,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.169" +version = "0.2.170" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" [[package]] name = "libfuzzer-sys" @@ -5666,6 +5944,23 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +[[package]] +name = "local-channel" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6cbc85e69b8df4b8bb8b89ec634e7189099cea8927a276b7384ce5488e53ec8" +dependencies = [ + "futures-core", + "futures-sink", + "local-waker", +] + +[[package]] +name = "local-waker" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d873d7c67ce09b42110d801813efbc9364414e356be9935700d368351657487" + [[package]] name = "lock_api" version = "0.4.12" @@ -5885,6 +6180,27 @@ dependencies = [ "rayon", ] +[[package]] +name = "mcp-core" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5df4144ca06f51aaeb770f5b09ec04715b12e031540ed08f382d3e36b20a13" +dependencies = [ + "actix-web", + "anyhow", + "async-trait", + "futures", + "libc", + "reqwest 0.12.12", + "reqwest-eventsource", + "serde", + "serde_json", + "tokio", + "tracing", + "url", + "uuid 1.13.1", +] + [[package]] name = "md-5" version = "0.10.6" @@ -5982,6 +6298,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", + "log", "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -7792,7 +8109,7 @@ dependencies = [ "http-body 0.4.6", "hyper 0.14.32", "hyper-rustls 0.24.2", - "hyper-tls", + "hyper-tls 0.5.0", "ipnet", "js-sys", "log", @@ -7808,7 +8125,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", - "system-configuration", + "system-configuration 0.5.1", "tokio", "tokio-native-tls", "tokio-rustls 0.24.1", @@ -7831,6 +8148,7 @@ checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", "h2 0.4.7", @@ -7839,12 +8157,14 @@ dependencies = [ "http-body-util", "hyper 1.6.0", "hyper-rustls 0.27.5", + "hyper-tls 0.6.0", "hyper-util", "ipnet", "js-sys", "log", "mime", "mime_guess", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -7857,7 +8177,9 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 1.0.2", + "system-configuration 0.6.1", "tokio", + "tokio-native-tls", "tokio-rustls 0.26.1", "tokio-util", "tower 0.5.2", @@ -7871,6 +8193,22 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest 0.12.12", + "thiserror 1.0.69", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -7942,6 +8280,7 @@ dependencies = [ "futures", "glob", "lopdf", + "mcp-core", "ordered-float", "quick-xml 0.37.2", "rayon", @@ -9816,7 +10155,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation 0.9.4", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.8.0", + "core-foundation 0.9.4", + "system-configuration-sys 0.6.0", ] [[package]] @@ -9829,6 +10179,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "system-deps" version = "6.2.2" diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 7b080b65..d141deb9 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -30,6 +30,7 @@ epub = { version = "2.1.2", optional = true } quick-xml = { version = "0.37.2", optional = true } rayon = { version = "1.10.0", optional = true } worker = { version = "0.5", optional = true } +mcp-core = { version = "0.1.33", optional = true } bytes = "1.9.0" async-stream = "0.3.6" @@ -42,6 +43,7 @@ tracing-subscriber = "0.3.18" tokio-test = "0.4.4" serde_path_to_error = "0.1.16" base64 = "0.22.1" +mcp-core = { version = "0.1.33", features = ["sse_server"] } [features] all = ["derive", "pdf", "rayon"] @@ -50,6 +52,8 @@ pdf = ["dep:lopdf"] epub = ["dep:epub", "dep:quick-xml"] rayon = ["dep:rayon"] worker = ["dep:worker"] +mcp = ["dep:mcp-core"] + [[test]] name = "embed_macro" @@ -94,3 +98,7 @@ required-features = ["derive"] [[example]] name = "together_embeddings" required-features = ["derive"] + +[[example]] +name = "mcp_tool" +required-features = ["mcp"] diff --git a/rig-core/examples/mcp_tool.rs b/rig-core/examples/mcp_tool.rs new file mode 100644 index 00000000..93874e92 --- /dev/null +++ b/rig-core/examples/mcp_tool.rs @@ -0,0 +1,135 @@ +use serde_json::json; + +use rig::{ + completion::Prompt, + providers::{self}, +}; + +use mcp_core::{ + client::ClientBuilder, + server::Server, + tool_error_response, tool_text_response, + tools::ToolHandlerFn, + transport::{ClientSseTransportBuilder, ServerSseTransport}, + types::{ + CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, ServerCapabilities, + Tool, ToolResponseContent, + }, +}; + +pub struct AddTool; + +impl AddTool { + pub fn tool() -> Tool { + Tool { + name: "Add".to_string(), + description: Some("Adds two numbers together.".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number to add" + }, + "b": { + "type": "number", + "description": "The second number to add" + } + }, + "required": [ + "a", + "b" + ] + }), + } + } + + pub async fn call() -> ToolHandlerFn { + move |req: CallToolRequest| { + Box::pin(async move { + let args = req.arguments.unwrap_or_default(); + + let a = match args["a"].as_f64() { + Some(val) => val, + None => { + return tool_error_response!(anyhow::anyhow!( + "Missing or invalid 'a' parameter" + )) + } + }; + let b = match args["b"].as_f64() { + Some(val) => val, + None => { + return tool_error_response!(anyhow::anyhow!( + "Missing or invalid 'b' parameter" + )) + } + }; + + tool_text_response!((a + b).to_string()) + }) + } + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt().init(); + + // Create the MCP server + let mcp_server_protocol = Server::builder("add".to_string(), "1.0".to_string()) + .capabilities(ServerCapabilities { + tools: Some(json!({ + "listChanged": false, + })), + ..Default::default() + }) + .register_tool(AddTool::tool(), AddTool::call().await) + .build(); + let mcp_server_transport = + ServerSseTransport::new("127.0.0.1".to_string(), 3000, mcp_server_protocol); + + // Start the MCP server in the background + tokio::spawn(async move { Server::start(mcp_server_transport).await }); + + // Create the MCP client + let mcp_client = ClientBuilder::new( + ClientSseTransportBuilder::new("http://localhost:3000".to_string()).build(), + ) + .build(); + // Start the MCP client + mcp_client.open().await?; + + let init_res = mcp_client + .initialize( + Implementation { + name: "mcp-client".to_string(), + version: "0.1.0".to_string(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Initialized: {:?}", init_res); + + let tools_list_res = mcp_client.list_tools(None, None).await?; + println!("Tools: {:?}", tools_list_res); + + tracing::info!("Building RIG agent"); + let completion_model = providers::openai::Client::from_env(); + let mut agent_builder = completion_model.agent("gpt-4o"); + + // Add MCP tools to the agent + agent_builder = tools_list_res + .tools + .into_iter() + .fold(agent_builder, |builder, tool| { + builder.mcp_tool(tool, mcp_client.clone().into()) + }); + let agent = agent_builder.build(); + + tracing::info!("Prompting RIG agent"); + let response = agent.prompt("Add 10 + 10").await?; + tracing::info!("Agent response: {:?}", response); + + Ok(()) +} diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index b2a299e6..630bb5eb 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -124,6 +124,9 @@ use crate::{ vector_store::{VectorStoreError, VectorStoreIndexDyn}, }; +#[cfg(feature = "mcp")] +use crate::tool::McpTool; + /// Struct representing an LLM agent. An agent is an LLM model combined with a preamble /// (i.e.: system prompt) and a static set of context documents and tools. /// All context documents and tools are always provided to the agent when prompted. @@ -412,6 +415,19 @@ impl AgentBuilder { self } + // Add an MCP tool to the agent + #[cfg(feature = "mcp")] + pub fn mcp_tool( + mut self, + tool: mcp_core::types::Tool, + client: mcp_core::client::Client, + ) -> Self { + let toolname = tool.name.clone(); + self.tools.add_tool(McpTool::from_mcp_server(tool, client)); + self.static_tools.push(toolname); + self + } + /// Add some dynamic context to the agent. On each prompt, `sample` documents from the /// dynamic context will be inserted in the request. pub fn dynamic_context( diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index e540c918..8c4b8ccb 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -185,6 +185,115 @@ impl ToolDyn for T { } } +#[cfg(feature = "mcp")] +pub struct McpTool { + definition: mcp_core::types::Tool, + client: mcp_core::client::Client, +} + +#[cfg(feature = "mcp")] +impl McpTool +where + T: mcp_core::transport::Transport, +{ + pub fn from_mcp_server( + definition: mcp_core::types::Tool, + client: mcp_core::client::Client, + ) -> Self { + Self { definition, client } + } +} + +#[cfg(feature = "mcp")] +#[derive(Debug, thiserror::Error)] +#[error("MCP tool error: {0}")] +pub struct McpToolError(String); + +#[cfg(feature = "mcp")] +impl From for ToolError { + fn from(e: McpToolError) -> Self { + ToolError::ToolCallError(Box::new(e)) + } +} + +#[cfg(feature = "mcp")] +impl ToolDyn for McpTool +where + T: mcp_core::transport::Transport, +{ + fn name(&self) -> String { + self.definition.name.clone() + } + + fn definition( + &self, + _prompt: String, + ) -> Pin + Send + Sync + '_>> { + Box::pin(async move { + ToolDefinition { + name: self.definition.name.clone(), + description: match &self.definition.description { + Some(desc) => desc.clone(), + None => String::new(), + }, + parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(), + } + }) + } + + fn call( + &self, + args: String, + ) -> Pin> + Send + Sync + '_>> { + let name = self.definition.name.clone(); + let args_clone = args.clone(); + let args: serde_json::Value = serde_json::from_str(&args_clone).unwrap_or_default(); + Box::pin(async move { + let result = self + .client + .call_tool(&name, Some(args)) + .await + .map_err(|e| McpToolError(format!("Tool returned an error: {}", e)))?; + + if result.is_error.unwrap_or(false) { + if let Some(error) = result.content.first() { + match error { + mcp_core::types::ToolResponseContent::Text { text } => { + return Err(McpToolError(text.clone()).into()); + } + _ => return Err(McpToolError("Unsuppported error type".to_string()).into()), + } + } else { + return Err(McpToolError("No error message returned".to_string()).into()); + } + } + + Ok(result + .content + .into_iter() + .map(|c| match c { + mcp_core::types::ToolResponseContent::Text { text } => text, + mcp_core::types::ToolResponseContent::Image { data, mime_type } => { + format!("data:{};base64,{}", mime_type, data) + } + mcp_core::types::ToolResponseContent::Resource { + resource: mcp_core::types::ResourceContents { uri, mime_type }, + } => { + format!( + "{}{}", + mime_type + .map(|m| format!("data:{};", m)) + .unwrap_or_default(), + uri + ) + } + }) + .collect::>() + .join("")) + }) + } +} + /// Wrapper trait to allow for dynamic dispatch of raggable tools pub trait ToolEmbeddingDyn: ToolDyn { fn context(&self) -> serde_json::Result;