Skip to content

Commit

Permalink
[feat] bedrock anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
brunoalho99 committed Nov 11, 2024
1 parent 49a402d commit 20e7abe
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 44 deletions.
44 changes: 42 additions & 2 deletions libs/core/llmstudio_core/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,51 @@ providers:
- BEDROCK_ACCESS_KEY
- BEDROCK_REGION
models:
anthropic.claude-3-sonnet-20240229-v1:0:
mode: chat
max_tokens: 200000
input_token_cost: 0.000003
output_token_cost: 0.000015
anthropic.claude-3-5-sonnet-20240620-v1:0:
mode: chat
max_tokens: 200000
input_token_cost: 0.000003
output_token_cost: 0.000015
anthropic.claude-3-5-sonnet-20241022-v2:0:
mode: chat
max_tokens: 200000
input_token_cost: 0.000003
output_token_cost: 0.000015
anthropic.claude-3-haiku-20240307-v1:0:
mode: chat
max_tokens: 200000
input_token_cost: 0.00000025
output_token_cost: 0.00000125
anthropic.claude-3-5-haiku-20241022-v1:0:
mode: chat
max_tokens: 200000
input_token_cost: 0.000001
output_token_cost: 0.000005
anthropic.claude-3-opus-20240229-v1:0:
mode: chat
max_tokens: 200000
input_token_cost: 0.000015
output_token_cost: 0.000075
anthropic.claude-instant-v1:
mode: chat
max_tokens: 100000
input_token_cost: 0.00000163
output_token_cost: 0.00000551
input_token_cost: 0.0000008
output_token_cost: 0.000024
anthropic.claude-v2:
mode: chat
max_tokens: 100000
input_token_cost: 0.000008
output_token_cost: 0.000024
anthropic.claude-v2:1:
mode: chat
max_tokens: 100000
input_token_cost: 0.000008
output_token_cost: 0.000024
parameters:
temperature:
name: "Temperature"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)

import boto3
from fastapi import HTTPException
from llmstudio_core.exceptions import ProviderError
from llmstudio_core.providers.provider import ChatRequest, ProviderCore, provider

Expand All @@ -30,18 +29,23 @@
)
from pydantic import ValidationError

SERVICE = "bedrock-runtime"


@provider
class BedrockAntropicProvider(ProviderCore):
class BedrockAnthropicProvider(ProviderCore):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.access_key = (
self.access_key if self.access_key else os.getenv("BEDROCK_ACCESS_KEY")
)
self.secret_key = (
self.secret_key if self.secret_key else os.getenv("BEDROCK_SECRET_KEY")
self._client = boto3.client(
SERVICE,
region_name=self.region if self.region else os.getenv("BEDROCK_REGION"),
aws_access_key_id=self.access_key
if self.access_key
else os.getenv("BEDROCK_ACCESS_KEY"),
aws_secret_access_key=self.secret_key
if self.secret_key
else os.getenv("BEDROCK_SECRET_KEY"),
)
self.region = self.region if self.region else os.getenv("BEDROCK_REGION")

@staticmethod
def _provider_config_name():
Expand All @@ -57,26 +61,6 @@ async def agenerate_client(self, request: ChatRequest) -> Coroutine[Any, Any, An
def generate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Generator]:
"""Generate an AWS Bedrock client"""
try:

service = "bedrock-runtime"

if (
self.access_key is None
or self.secret_key is None
or self.region is None
):
raise HTTPException(
status_code=400,
detail="AWS credentials were not given or not set in environment variables.",
)

client = boto3.client(
service,
region_name=self.region,
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
)

messages, system_prompt = self._process_messages(request.chat_input)
tools = self._process_tools(request.parameters)

Expand All @@ -95,7 +79,7 @@ def generate_client(self, request: ChatRequest) -> Coroutine[Any, Any, Generator
if tools:
client_params["toolConfig"] = tools

return client.converse_stream(**client_params)
return self._client.converse_stream(**client_params)
except Exception as e:
raise ProviderError(str(e))

Expand Down Expand Up @@ -249,18 +233,7 @@ def parse_response(self, response: AsyncGenerator[Any, None], **kwargs) -> Any:
def _process_messages(
chat_input: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, Union[List[Dict[str, str]], str]]]:
"""
Generate input text for the Bedrock API based on the provided chat input.
Args:
chat_input (Union[str, List[Dict[str, str]]]): The input text or a list of message dictionaries.
Returns:
List[Dict[str, Union[List[Dict[str, str]], str]]]: A list of formatted messages for the Bedrock API.

Raises:
HTTPException: If the input is invalid.
"""
if isinstance(chat_input, str):
return [
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, AsyncGenerator, Coroutine, Generator

from llmstudio_core.providers.bedrock_providers.antropic import BedrockAntropicProvider
from llmstudio_core.providers.bedrock.anthropic import BedrockAnthropicProvider
from llmstudio_core.providers.provider import ChatRequest, ProviderCore, provider


Expand All @@ -13,7 +13,7 @@ def __init__(self, config, **kwargs):

def _get_provider(self, model):
if "anthropic." in model:
return BedrockAntropicProvider(config=self.config, **self.kwargs)
return BedrockAnthropicProvider(config=self.config, **self.kwargs)

raise ValueError(f" provider is not yet supported.")

Expand Down

0 comments on commit 20e7abe

Please sign in to comment.