From 7d7257240fcf36af298d899ceb2ffd7b9baed988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Fri, 31 Jan 2025 22:36:29 +0000 Subject: [PATCH 1/9] Add Bedrock provider --- crates/goose/Cargo.toml | 5 + crates/goose/src/providers/bedrock.rs | 332 ++++++++++++++++++++++++++ crates/goose/src/providers/factory.rs | 3 + crates/goose/src/providers/mod.rs | 1 + 4 files changed, 341 insertions(+) create mode 100644 crates/goose/src/providers/bedrock.rs diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 053f55ab3..a706d6a8e 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -61,6 +61,11 @@ once_cell = "1.20.2" dirs = "6.0.0" rand = "0.8.5" +# For Bedrock provider +aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } +aws-smithy-types = "1.2.12" +aws-sdk-bedrockruntime = "1.72.0" + [dev-dependencies] criterion = "0.5" tempfile = "3.15.0" diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs new file mode 100644 index 000000000..c0b9f1823 --- /dev/null +++ b/crates/goose/src/providers/bedrock.rs @@ -0,0 +1,332 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, bail, Result}; +use async_trait::async_trait; +use aws_sdk_bedrockruntime::{types as bedrock, Client}; +use aws_smithy_types::{Document, Number}; +use chrono::Utc; +use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; +use serde_json::Value; + +use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::errors::ProviderError; +use crate::config::Config; +use crate::message::{Message, MessageContent}; +use crate::model::ModelConfig; + +pub const BEDROCK_DOC_LINK: &str = + "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html"; + +pub const BEDROCK_DEFAULT_MODEL: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0"; +pub const BEDROCK_KNOWN_MODELS: &[&str] = &[ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", +]; + +#[derive(Debug, serde::Serialize)] +pub struct BedrockProvider { + #[serde(skip)] + client: Client, + model: ModelConfig, +} + +impl BedrockProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = Config::global(); + let sdk_config = tokio::task::block_in_place(|| { + let mut aws_config = aws_config::from_env(); + + if let Some(region) = config.get::("AWS_REGION").ok() { + aws_config = aws_config.region(aws_config::Region::new(region)); + } + + tokio::runtime::Handle::current().block_on(aws_config.load()) + }); + let client = Client::new(&sdk_config); + + Ok(Self { client, model }) + } +} + +#[async_trait] +impl Provider for BedrockProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "bedrock", + "Amazon Bedrock", + "Run models through Amazon Bedrock", + BEDROCK_DEFAULT_MODEL, + BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(), + BEDROCK_DOC_LINK, + vec![ConfigKey::new("AWS_REGION", false, false, None)], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let model_name = &self.model.model_name; + + let response = self + .client + .converse() + .tool_config(to_bedrock_tool_config(tools)?) + .model_id(model_name.to_string()) + .system(bedrock::SystemContentBlock::Text(system.to_string())) + .set_messages(Some( + messages + .iter() + .map(to_bedrock_message) + .collect::>()?, + )) + .send() + .await + .or_else(|err| Err(anyhow!("Failed to call Bedrock: {}", err)))?; + + let message = match response.output { + Some(bedrock::ConverseOutput::Message(message)) => message, + _ => { + return Err(ProviderError::RequestFailed( + "No output from Bedrock".to_string(), + )) + } + }; + + let usage = response + .usage + .as_ref() + .map(from_bedrock_usage) + .unwrap_or_default(); + + let message = from_bedrock_message(&message)?; + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + + Ok((message, provider_usage)) + } +} + +fn to_bedrock_message(message: &Message) -> Result { + bedrock::Message::builder() + .role(to_bedrock_role(&message.role)) + .set_content(Some( + message + .content + .iter() + .map(to_bedrock_message_content) + .collect::>()?, + )) + .build() + .map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err)) +} + +fn to_bedrock_message_content(content: &MessageContent) -> Result { + Ok(match content { + MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()), + MessageContent::Image(_) => { + bail!("Image content is not supported by Bedrock provider yet") + } + MessageContent::ToolRequest(tool_req) => { + let tool_use_id = tool_req.id.to_string(); + let tool_use = if let Some(call) = tool_req.tool_call.as_ref().ok() { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .name(call.name.to_string()) + .input(to_bedrock_json(&call.arguments)) + .build() + } else { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .build() + }?; + bedrock::ContentBlock::ToolUse(tool_use) + } + MessageContent::ToolResponse(tool_res) => { + let content = match &tool_res.tool_result { + Ok(content) => Some( + content + .iter() + .map(to_bedrock_tool_result_content_block) + .collect::>()?, + ), + Err(_) => None, + }; + bedrock::ContentBlock::ToolResult( + bedrock::ToolResultBlock::builder() + .tool_use_id(tool_res.id.to_string()) + .status(if content.is_some() { + bedrock::ToolResultStatus::Success + } else { + bedrock::ToolResultStatus::Error + }) + .set_content(content) + .build()?, + ) + } + }) +} + +fn to_bedrock_tool_result_content_block( + content: &Content, +) -> Result { + Ok(match content { + Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), + Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), + Content::Resource(_) => bail!("Resource content is not supported by Bedrock provider yet"), + }) +} + +fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole { + match role { + Role::User => bedrock::ConversationRole::User, + Role::Assistant => bedrock::ConversationRole::Assistant, + } +} + +fn to_bedrock_tool_config(tools: &[Tool]) -> Result { + Ok(bedrock::ToolConfiguration::builder() + .set_tools(Some( + tools.iter().map(to_bedrock_tool).collect::>()?, + )) + .build()?) +} + +fn to_bedrock_tool(tool: &Tool) -> Result { + Ok(bedrock::Tool::ToolSpec( + bedrock::ToolSpecification::builder() + .name(tool.name.to_string()) + .description(tool.description.to_string()) + .input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json( + &tool.input_schema, + ))) + .build()?, + )) +} + +fn to_bedrock_json(value: &Value) -> Document { + match value { + Value::Null => Document::Null, + Value::Bool(bool) => Document::Bool(*bool), + Value::Number(num) => { + if let Some(n) = num.as_u64() { + Document::Number(Number::PosInt(n)) + } else if let Some(n) = num.as_i64() { + Document::Number(Number::NegInt(n)) + } else if let Some(n) = num.as_f64() { + Document::Number(Number::Float(n)) + } else { + unreachable!() + } + } + Value::String(str) => Document::String(str.to_string()), + Value::Array(arr) => Document::Array(arr.into_iter().map(to_bedrock_json).collect()), + Value::Object(obj) => Document::Object(HashMap::from_iter( + obj.into_iter() + .map(|(key, val)| (key.to_string(), to_bedrock_json(val))), + )), + } +} + +fn from_bedrock_message(message: &bedrock::Message) -> Result { + let role = from_bedrock_role(message.role())?; + let content = message + .content() + .iter() + .map(from_bedrock_content_block) + .collect::>>()?; + let created = Utc::now().timestamp(); + + Ok(Message { + role, + content, + created, + }) +} + +fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result { + Ok(match block { + bedrock::ContentBlock::Text(text) => MessageContent::text(text), + bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request( + tool_use.tool_use_id.to_string(), + Ok(ToolCall::new( + tool_use.name.to_string(), + from_bedrock_json(&tool_use.input), + )), + ), + bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response( + tool_res.tool_use_id.to_string(), + if tool_res.content.is_empty() { + Err(ToolError::ExecutionError( + "Empty content for tool use from Bedrock".to_string(), + )) + } else { + tool_res + .content + .iter() + .map(from_bedrock_tool_result_content_block) + .collect::>>() + }, + ), + _ => bail!("Unsupported content block type from Bedrock"), + }) +} + +fn from_bedrock_tool_result_content_block( + content: &bedrock::ToolResultContentBlock, +) -> ToolResult { + Ok(match content { + bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()), + _ => { + return Err(ToolError::ExecutionError( + "Unsupported tool result from Bedrock".to_string(), + )) + } + }) +} + +fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result { + Ok(match role { + bedrock::ConversationRole::User => Role::User, + bedrock::ConversationRole::Assistant => Role::Assistant, + _ => bail!("Unknown role from Bedrock"), + }) +} + +fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { + Usage { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + } +} + +fn from_bedrock_json(document: &Document) -> Value { + match document { + Document::Null => Value::Null, + Document::Bool(bool) => Value::Bool(*bool), + Document::Number(num) => match num { + Number::PosInt(i) => Value::Number((*i).into()), + Number::NegInt(i) => Value::Number((*i).into()), + Number::Float(f) => { + Value::Number(serde_json::Number::from_f64(*f).expect("Expected a valid f64")) + } + }, + Document::String(str) => Value::String(str.clone()), + Document::Array(arr) => Value::Array(arr.iter().map(from_bedrock_json).collect()), + Document::Object(obj) => Value::Object( + obj.iter() + .map(|(key, val)| (key.clone(), from_bedrock_json(val))) + .collect(), + ), + } +} diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index ed169aa7e..d17fb8893 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -2,6 +2,7 @@ use super::{ anthropic::AnthropicProvider, azure::AzureProvider, base::{Provider, ProviderMetadata}, + bedrock::BedrockProvider, databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider, @@ -16,6 +17,7 @@ pub fn providers() -> Vec { vec![ AnthropicProvider::metadata(), AzureProvider::metadata(), + BedrockProvider::metadata(), DatabricksProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), @@ -30,6 +32,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result Ok(Box::new(OpenAiProvider::from_env(model)?)), "anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)), "azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)), + "bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)), "databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)), "groq" => Ok(Box::new(GroqProvider::from_env(model)?)), "ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)), diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index de6225767..634224fd7 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,6 +1,7 @@ pub mod anthropic; pub mod azure; pub mod base; +pub mod bedrock; pub mod databricks; pub mod errors; mod factory; From 01954b0bd1a9c8d0428cdca78971ec6bc7e17df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:13:57 +0000 Subject: [PATCH 2/9] Fix Clippy errors --- crates/goose/src/providers/bedrock.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index c0b9f1823..f4e5e72f3 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -36,7 +36,7 @@ impl BedrockProvider { let sdk_config = tokio::task::block_in_place(|| { let mut aws_config = aws_config::from_env(); - if let Some(region) = config.get::("AWS_REGION").ok() { + if let Ok(region) = config.get::("AWS_REGION") { aws_config = aws_config.region(aws_config::Region::new(region)); } @@ -92,7 +92,7 @@ impl Provider for BedrockProvider { )) .send() .await - .or_else(|err| Err(anyhow!("Failed to call Bedrock: {}", err)))?; + .map_err(|err| anyhow!("Failed to call Bedrock: {}", err))?; let message = match response.output { Some(bedrock::ConverseOutput::Message(message)) => message, @@ -138,7 +138,7 @@ fn to_bedrock_message_content(content: &MessageContent) -> Result { let tool_use_id = tool_req.id.to_string(); - let tool_use = if let Some(call) = tool_req.tool_call.as_ref().ok() { + let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() { bedrock::ToolUseBlock::builder() .tool_use_id(tool_use_id) .name(call.name.to_string()) @@ -229,7 +229,7 @@ fn to_bedrock_json(value: &Value) -> Document { } } Value::String(str) => Document::String(str.to_string()), - Value::Array(arr) => Document::Array(arr.into_iter().map(to_bedrock_json).collect()), + Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()), Value::Object(obj) => Document::Object(HashMap::from_iter( obj.into_iter() .map(|(key, val)| (key.to_string(), to_bedrock_json(val))), From b48b9be38fb4e9401e871e642d55478ca99a85c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:17:17 +0000 Subject: [PATCH 3/9] Return a `Result<>` from `from_bedrock_json` instead of using `expect` --- crates/goose/src/providers/bedrock.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index f4e5e72f3..c882b5086 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -260,7 +260,7 @@ fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result MessageContent::tool_response( @@ -310,23 +310,25 @@ fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { } } -fn from_bedrock_json(document: &Document) -> Value { - match document { +fn from_bedrock_json(document: &Document) -> Result { + Ok(match document { Document::Null => Value::Null, Document::Bool(bool) => Value::Bool(*bool), Document::Number(num) => match num { Number::PosInt(i) => Value::Number((*i).into()), Number::NegInt(i) => Value::Number((*i).into()), - Number::Float(f) => { - Value::Number(serde_json::Number::from_f64(*f).expect("Expected a valid f64")) - } + Number::Float(f) => Value::Number( + serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?, + ), }, Document::String(str) => Value::String(str.clone()), - Document::Array(arr) => Value::Array(arr.iter().map(from_bedrock_json).collect()), + Document::Array(arr) => { + Value::Array(arr.iter().map(from_bedrock_json).collect::>()?) + } Document::Object(obj) => Value::Object( obj.iter() - .map(|(key, val)| (key.clone(), from_bedrock_json(val))) - .collect(), + .map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?))) + .collect::>()?, ), - } + }) } From 722cac09935426db630e76bf0f7bd5d3b6eb4c7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:23:14 +0000 Subject: [PATCH 4/9] Remove `AWS_REGION` configuration and just rely on `aws_config::from_env()` --- crates/goose/src/providers/bedrock.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index c882b5086..0a086fc3a 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -8,9 +8,8 @@ use chrono::Utc; use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; -use crate::config::Config; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; @@ -32,15 +31,8 @@ pub struct BedrockProvider { impl BedrockProvider { pub fn from_env(model: ModelConfig) -> Result { - let config = Config::global(); let sdk_config = tokio::task::block_in_place(|| { - let mut aws_config = aws_config::from_env(); - - if let Ok(region) = config.get::("AWS_REGION") { - aws_config = aws_config.region(aws_config::Region::new(region)); - } - - tokio::runtime::Handle::current().block_on(aws_config.load()) + tokio::runtime::Handle::current().block_on(aws_config::from_env().load()) }); let client = Client::new(&sdk_config); @@ -58,7 +50,7 @@ impl Provider for BedrockProvider { BEDROCK_DEFAULT_MODEL, BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(), BEDROCK_DOC_LINK, - vec![ConfigKey::new("AWS_REGION", false, false, None)], + vec![], ) } From 4975a7778d9b9d69a78f6cb88260ee565d4cb361 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:40:45 +0000 Subject: [PATCH 5/9] Add Bedrock provider tests --- crates/goose/src/providers/bedrock.rs | 7 ++++++ crates/goose/tests/providers.rs | 32 ++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 0a086fc3a..e08dd7abd 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -40,6 +40,13 @@ impl BedrockProvider { } } +impl Default for BedrockProvider { + fn default() -> Self { + let model = ModelConfig::new(BedrockProvider::metadata().default_model); + BedrockProvider::from_env(model).expect("Failed to initialize Bedrock provider") + } +} + #[async_trait] impl Provider for BedrockProvider { fn metadata() -> ProviderMetadata { diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 6a5f4b9da..332f3ee76 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -3,7 +3,9 @@ use dotenv::dotenv; use goose::message::{Message, MessageContent}; use goose::providers::base::Provider; use goose::providers::errors::ProviderError; -use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter}; +use goose::providers::{ + anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, +}; use mcp_core::content::Content; use mcp_core::tool::Tool; use std::collections::HashMap; @@ -374,6 +376,34 @@ async fn test_azure_provider() -> Result<()> { .await } +#[tokio::test] +async fn test_bedrock_provider_long_term_credentials() -> Result<()> { + test_provider( + "Bedrock", + &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + None, + bedrock::BedrockProvider::default, + ) + .await +} + +#[tokio::test] +async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { + let env_mods = HashMap::from_iter([ + // Ensure to unset long-term credentials to use AWS Profile provider + ("AWS_ACCESS_KEY_ID", None), + ("AWS_SECRET_ACCESS_KEY", None), + ]); + + test_provider( + "Bedrock AWS Profile Credentials", + &["AWS_PROFILE"], + Some(env_mods), + bedrock::BedrockProvider::default, + ) + .await +} + #[tokio::test] async fn test_databricks_provider() -> Result<()> { test_provider( From 19d08d694fa35107c4941d6cf9d567594ba320d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 17:43:22 +0000 Subject: [PATCH 6/9] Add truncate agent tests for Bedrock provider --- crates/goose/tests/truncate_agent.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index d3702d5f8..fbd496741 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -8,7 +8,7 @@ use goose::model::ModelConfig; use goose::providers::base::Provider; use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider}; use goose::providers::{ - azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider, + azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, }; use goose::providers::{google::GoogleProvider, groq::GroqProvider}; @@ -18,6 +18,7 @@ enum ProviderType { Azure, OpenAi, Anthropic, + Bedrock, Databricks, Google, Groq, @@ -35,6 +36,7 @@ impl ProviderType { ], ProviderType::OpenAi => &["OPENAI_API_KEY"], ProviderType::Anthropic => &["ANTHROPIC_API_KEY"], + ProviderType::Bedrock => &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], ProviderType::Databricks => &["DATABRICKS_HOST"], ProviderType::Google => &["GOOGLE_API_KEY"], ProviderType::Groq => &["GROQ_API_KEY"], @@ -66,6 +68,7 @@ impl ProviderType { ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?), ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?), ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?), + ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?), ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?), ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?), ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?), @@ -200,6 +203,16 @@ mod tests { .await } + #[tokio::test] + async fn test_truncate_agent_with_bedrock() -> Result<()> { + run_test_with_config(TestConfig { + provider_type: ProviderType::Bedrock, + model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + context_window: 200_000, + }) + .await + } + #[tokio::test] async fn test_truncate_agent_with_databricks() -> Result<()> { run_test_with_config(TestConfig { From f712ed8085af87fd5e0167653dcd2be696f5fa70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 17:44:35 +0000 Subject: [PATCH 7/9] Use `futures::executor::block_on` to load AWS Config Tokio's runtime panics on single threaded tests. --- crates/goose/src/providers/bedrock.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index e08dd7abd..f72c5068d 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -31,9 +31,7 @@ pub struct BedrockProvider { impl BedrockProvider { pub fn from_env(model: ModelConfig) -> Result { - let sdk_config = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(aws_config::from_env().load()) - }); + let sdk_config = futures::executor::block_on(aws_config::load_from_env()); let client = Client::new(&sdk_config); Ok(Self { client, model }) From b700083164e92646f6ba5e31f0d2a38c575b7b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 17:45:54 +0000 Subject: [PATCH 8/9] Properly map Bedrock errors to `ProviderError`s --- crates/goose/src/providers/bedrock.rs | 46 +++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index f72c5068d..849a377b2 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; +use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use aws_smithy_types::{Document, Number}; use chrono::Utc; @@ -75,21 +76,52 @@ impl Provider for BedrockProvider { ) -> Result<(Message, ProviderUsage), ProviderError> { let model_name = &self.model.model_name; - let response = self + let mut request = self .client .converse() - .tool_config(to_bedrock_tool_config(tools)?) - .model_id(model_name.to_string()) .system(bedrock::SystemContentBlock::Text(system.to_string())) + .model_id(model_name.to_string()) .set_messages(Some( messages .iter() .map(to_bedrock_message) .collect::>()?, - )) - .send() - .await - .map_err(|err| anyhow!("Failed to call Bedrock: {}", err))?; + )); + + if !tools.is_empty() { + request = request.tool_config(to_bedrock_tool_config(tools)?); + } + + let response = request.send().await; + + let response = match response { + Ok(response) => response, + Err(err) => { + return Err(match err.into_service_error() { + ConverseError::AccessDeniedException(err) => { + ProviderError::Authentication(format!("Failed to call Bedrock: {}", err)) + } + ConverseError::ThrottlingException(err) => { + ProviderError::RateLimitExceeded(format!("Failed to call Bedrock: {}", err)) + } + ConverseError::ValidationException(err) + if err + .message() + .unwrap_or_default() + .contains("Input is too long for requested model.") => + { + ProviderError::ContextLengthExceeded(format!( + "Failed to call Bedrock: {}", + err + )) + } + ConverseError::ModelErrorException(err) => { + ProviderError::ExecutionError(format!("Failed to call Bedrock: {}", err)) + } + err => ProviderError::ServerError(format!("Failed to call Bedrock: {}", err,)), + }); + } + }; let message = match response.output { Some(bedrock::ConverseOutput::Message(message)) => message, From 06d396f21660a1d1b20ab9b618a9e91e4df7665e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 4 Feb 2025 17:52:26 +0000 Subject: [PATCH 9/9] Add support for text resources --- crates/goose/src/providers/bedrock.rs | 48 +++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 849a377b2..b57c4ceee 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::path::Path; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -6,7 +7,7 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use aws_smithy_types::{Document, Number}; use chrono::Utc; -use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; @@ -185,7 +186,7 @@ fn to_bedrock_message_content(content: &MessageContent) -> Result Some( content .iter() - .map(to_bedrock_tool_result_content_block) + .map(|c| to_bedrock_tool_result_content_block(&tool_res.id, c)) .collect::>()?, ), Err(_) => None, @@ -206,12 +207,15 @@ fn to_bedrock_message_content(content: &MessageContent) -> Result Result { Ok(match content { Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), - Content::Resource(_) => bail!("Resource content is not supported by Bedrock provider yet"), + Content::Resource(resource) => bedrock::ToolResultContentBlock::Document( + to_bedrock_document(tool_use_id, &resource.resource)?, + ), }) } @@ -266,6 +270,44 @@ fn to_bedrock_json(value: &Value) -> Document { } } +fn to_bedrock_document( + tool_use_id: &str, + content: &ResourceContents, +) -> Result { + let (uri, text) = match content { + ResourceContents::TextResourceContents { uri, text, .. } => (uri, text), + ResourceContents::BlobResourceContents { .. } => { + bail!("Blob resource content is not supported by Bedrock provider yet") + } + }; + + let filename = Path::new(uri) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(uri); + + let (name, format) = match filename.split_once('.') { + Some((name, "txt")) => (name, bedrock::DocumentFormat::Txt), + Some((name, "csv")) => (name, bedrock::DocumentFormat::Csv), + Some((name, "md")) => (name, bedrock::DocumentFormat::Md), + Some((name, "html")) => (name, bedrock::DocumentFormat::Html), + Some((name, _)) => (name, bedrock::DocumentFormat::Txt), + _ => (filename, bedrock::DocumentFormat::Txt), + }; + + // Since we can't use the full path (due to character limit and also Bedrock does not accept `/` etc.), + // and Bedrock wants document names to be unique, we're adding `tool_use_id` as a prefix to make + // document names unique. + let name = format!("{tool_use_id}-{name}"); + + bedrock::DocumentBlock::builder() + .format(format) + .name(name) + .source(bedrock::DocumentSource::Bytes(text.as_bytes().into())) + .build() + .map_err(|err| anyhow!("Failed to construct Bedrock document: {}", err)) +} + fn from_bedrock_message(message: &bedrock::Message) -> Result { let role = from_bedrock_role(message.role())?; let content = message