From 22ab93318e6c4a185cc2dc820de86e133163138c Mon Sep 17 00:00:00 2001 From: "Michael B. Klein" Date: Fri, 8 Mar 2024 21:38:24 +0000 Subject: [PATCH] Use the BedrockChat LLM instead of AzureOpenAI --- chat/src/event_config.py | 52 ++++++++++------------------------ chat/src/helpers/response.py | 10 +++---- chat/src/requirements.txt | 2 +- chat/src/setup.py | 14 ++++----- chat/template.yaml | 31 ++++++++++---------- chat/test/test_event_config.py | 28 ++---------------- 6 files changed, 43 insertions(+), 94 deletions(-) diff --git a/chat/src/event_config.py b/chat/src/event_config.py index b9da4881..c6cdc648 100644 --- a/chat/src/event_config.py +++ b/chat/src/event_config.py @@ -7,7 +7,7 @@ from setup import ( opensearch_client, opensearch_vector_store, - openai_chat_client, + bedrock_chat_client, ) from typing import List from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler @@ -38,15 +38,13 @@ class EventConfig: api_token: ApiToken = field(init=False) attributes: List[str] = field(init=False) - azure_endpoint: str = field(init=False) - azure_resource_name: str = field(init=False) debug_mode: bool = field(init=False) - deployment_name: str = field(init=False) + model_id: str = field(init=False) document_prompt: PromptTemplate = field(init=False) event: dict = field(default_factory=dict) + index_name: str = field(init=False) is_logged_in: bool = field(init=False) k: int = field(init=False) - openai_api_version: str = field(init=False) payload: dict = field(default_factory=dict) prompt_text: str = field(init=False) prompt: PromptTemplate = field(init=False) @@ -61,13 +59,11 @@ def __post_init__(self): self.payload = json.loads(self.event.get("body", "{}")) self.api_token = ApiToken(signed_token=self.payload.get("auth")) self.attributes = self._get_attributes() - self.azure_endpoint = self._get_azure_endpoint() - self.azure_resource_name = self._get_azure_resource_name() self.debug_mode = self._is_debug_mode_enabled() - self.deployment_name = self._get_deployment_name() + self.index_name = self._get_opensearch_index() + self.model_id = self._get_model_id() self.is_logged_in = self.api_token.is_logged_in() self.k = self._get_k() - self.openai_api_version = self._get_openai_api_version() self.prompt_text = self._get_prompt_text() self.request_context = self.event.get("requestContext", {}) self.question = self.payload.get("question") @@ -88,7 +84,7 @@ def _get_payload_value_with_superuser_check(self, key, default): def _get_attributes_function(self): try: opensearch = opensearch_client() - mapping = opensearch.indices.get_mapping(index="dc-v2-work") + mapping = opensearch.indices.get_mapping(index=self._get_opensearch_index()) return list(next(iter(mapping.values()))['mappings']['properties'].keys()) except StopIteration: return [] @@ -96,34 +92,20 @@ def _get_attributes_function(self): def _get_attributes(self): return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES) - def _get_azure_endpoint(self): - default = f"https://{self._get_azure_resource_name()}.openai.azure.com/" - return self._get_payload_value_with_superuser_check("azure_endpoint", default) - - def _get_azure_resource_name(self): - azure_resource_name = self._get_payload_value_with_superuser_check( - "azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") - ) - if not azure_resource_name: - raise EnvironmentError( - "Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" - ) - return azure_resource_name - - def _get_deployment_name(self): + def _get_model_id(self): return self._get_payload_value_with_superuser_check( - "deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") + "model_id", os.getenv("AI_MODEL_ID") ) def _get_k(self): value = self._get_payload_value_with_superuser_check("k", K_VALUE) return min(value, MAX_K) - def _get_openai_api_version(self): + def _get_opensearch_index(self): return self._get_payload_value_with_superuser_check( - "openai_api_version", VERSION + "index", os.getenv("INDEX_NAME") ) - + def _get_prompt_text(self): return self._get_payload_value_with_superuser_check("prompt", prompt_template()) @@ -144,10 +126,8 @@ def debug_message(self): "type": "debug", "message": { "attributes": self.attributes, - "azure_endpoint": self.azure_endpoint, - "deployment_name": self.deployment_name, + "model_id": self.model_id, "k": self.k, - "openai_api_version": self.openai_api_version, "prompt": self.prompt_text, "question": self.question, "ref": self.ref, @@ -173,13 +153,11 @@ def setup_llm_request(self): self._setup_chain() def _setup_vector_store(self): - self.opensearch = opensearch_vector_store() + self.opensearch = opensearch_vector_store(index_name=self.index_name) def _setup_chat_client(self): - self.client = openai_chat_client( - deployment_name=self.deployment_name, - openai_api_base=self.azure_endpoint, - openai_api_version=self.openai_api_version, + self.client = bedrock_chat_client( + model_id=self.model_id, callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)], streaming=True, ) diff --git a/chat/src/helpers/response.py b/chat/src/helpers/response.py index a3b946d4..66c424dd 100644 --- a/chat/src/helpers/response.py +++ b/chat/src/helpers/response.py @@ -1,5 +1,4 @@ from helpers.metrics import token_usage -from openai.error import InvalidRequestError def base_response(config, response): return {"answer": response["output_text"], "ref": config.ref} @@ -9,11 +8,9 @@ def debug_response(config, response, original_question): response_base = base_response(config, response) debug_info = { "attributes": config.attributes, - "azure_endpoint": config.azure_endpoint, - "deployment_name": config.deployment_name, + "model_id": config.model_id, "is_superuser": config.api_token.is_superuser(), "k": config.k, - "openai_api_version": config.openai_api_version, "prompt": config.prompt_text, "ref": config.ref, "temperature": config.temperature, @@ -35,7 +32,8 @@ def get_and_send_original_question(config, docs): "question": config.question, "source_documents": doc_response, } - config.socket.send(original_question) + if (config.socket): + config.socket.send(original_question) return original_question def extract_prompt_value(v): @@ -58,7 +56,7 @@ def prepare_response(config): prepared_response = debug_response(config, response, original_question) else: prepared_response = base_response(config, response) - except InvalidRequestError as err: + except Exception as err: prepared_response = { "question": config.question, "error": str(err), diff --git a/chat/src/requirements.txt b/chat/src/requirements.txt index 04100144..4ae766a4 100644 --- a/chat/src/requirements.txt +++ b/chat/src/requirements.txt @@ -1,6 +1,6 @@ # Runtime Dependencies boto3~=1.34.13 -langchain~=0.1.8 +langchain langchain-community openai~=0.27.8 opensearch-py diff --git a/chat/src/setup.py b/chat/src/setup.py index 39a99338..7d7200b3 100644 --- a/chat/src/setup.py +++ b/chat/src/setup.py @@ -1,5 +1,5 @@ from content_handler import ContentHandler -from langchain_community.chat_models import AzureChatOpenAI +from langchain_community.chat_models import BedrockChat from langchain_community.embeddings import SagemakerEndpointEmbeddings from langchain_community.vectorstores import OpenSearchVectorSearch from opensearchpy import OpenSearch, RequestsHttpConnection @@ -12,14 +12,10 @@ def prefix(value): env_prefix = None if env_prefix == "" else env_prefix return '-'.join(filter(None, [env_prefix, value])) -def openai_chat_client(**kwargs): - return AzureChatOpenAI( - openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"), - **kwargs, - ) +def bedrock_chat_client(model_id=os.getenv("AI_MODEL_ID"), region_name=os.getenv("AWS_REGION"), **kwargs): + return BedrockChat(model_id=model_id, region_name=region_name, **kwargs) def opensearch_client(region_name=os.getenv("AWS_REGION")): - print(region_name) session = boto3.Session(region_name=region_name) awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials()) endpoint = os.getenv("ELASTICSEARCH_ENDPOINT") @@ -31,7 +27,7 @@ def opensearch_client(region_name=os.getenv("AWS_REGION")): http_auth=awsauth, ) -def opensearch_vector_store(region_name=os.getenv("AWS_REGION")): +def opensearch_vector_store(region_name=os.getenv("AWS_REGION"), index_name="dc-v2-work"): session = boto3.Session(region_name=region_name) awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials()) @@ -44,7 +40,7 @@ def opensearch_vector_store(region_name=os.getenv("AWS_REGION")): ) docsearch = OpenSearchVectorSearch( - index_name=prefix("dc-v2-work"), + index_name=index_name, embedding_function=embeddings, opensearch_url="https://" + os.getenv("ELASTICSEARCH_ENDPOINT"), connection_class=RequestsHttpConnection, diff --git a/chat/template.yaml b/chat/template.yaml index d7696246..b97b3cde 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -2,27 +2,22 @@ AWSTemplateFormatVersion: "2010-09-09" Transform: AWS::Serverless-2016-10-31 Description: Websocket Chat API for dc-api-v2 Parameters: + AIModelId: + Type: String + Description: Amazon Bedrock Model ApiTokenSecret: Type: String Description: Secret Key for Encrypting JWTs (must match IIIF server) - AzureOpenaiApiKey: - Type: String - Description: Azure OpenAI API Key - AzureOpenaiEmbeddingDeploymentId: - Type: String - Description: Azure OpenAI Embedding Deployment ID - AzureOpenaiLlmDeploymentId: - Type: String - Description: Azure OpenAI LLM Deployment ID - AzureOpenaiResourceName: - Type: String - Description: Azure OpenAI Resource Name ElasticsearchEndpoint: Type: String Description: Elasticsearch URL EmbeddingEndpoint: Type: String Description: Sagemaker Inference Endpoint + IndexName: + Type: String + Description: Index or alias to use for vector search + Default: dc-v2-work Resources: ApiGwAccountConfig: Type: "AWS::ApiGateway::Account" @@ -197,13 +192,11 @@ Resources: Timeout: 300 Environment: Variables: + AI_MODEL_ID: !Ref AIModelId API_TOKEN_SECRET: !Ref ApiTokenSecret - AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey - AZURE_OPENAI_EMBEDDING_DEPLOYMENT_ID: !Ref AzureOpenaiEmbeddingDeploymentId - AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId - AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName ELASTICSEARCH_ENDPOINT: !Ref ElasticsearchEndpoint EMBEDDING_ENDPOINT: !Ref EmbeddingEndpoint + INDEX_NAME: !Ref IndexName Policies: - Statement: - Effect: Allow @@ -217,6 +210,12 @@ Resources: - 'es:ESHttpGet' - 'es:ESHttpPost' Resource: '*' + - Statement: + - Effect: Allow + Action: + - 'bedrock:InvokeModel' + - 'bedrock:InvokeModelWithResponseStream' + Resource: arn:aws:bedrock:*::foundation-model/* - Statement: - Effect: Allow Action: diff --git a/chat/test/test_event_config.py b/chat/test/test_event_config.py index 55f8381d..cf7bbcbf 100644 --- a/chat/test/test_event_config.py +++ b/chat/test/test_event_config.py @@ -8,39 +8,22 @@ from unittest import TestCase, mock -class TestEventConfigWithoutAzureResource(TestCase): - def test_requires_an_azure_resource(self): - with self.assertRaises(EnvironmentError): - EventConfig() - - @mock.patch.dict( os.environ, { - "AZURE_OPENAI_RESOURCE_NAME": "test", + "AI_MODEL_ID": "test", }, ) class TestEventConfig(TestCase): - def test_fetches_attributes_from_vector_database(self): - os.environ.pop("AZURE_OPENAI_RESOURCE_NAME", None) - with self.assertRaises(EnvironmentError): - EventConfig() - - def test_defaults(self): - actual = EventConfig(event={"body": json.dumps({"attributes": ["title"]})}) - expected_defaults = {"azure_endpoint": "https://test.openai.azure.com/"} - self.assertEqual(actual.azure_endpoint, expected_defaults["azure_endpoint"]) - def test_attempt_override_without_superuser_status(self): actual = EventConfig( event={ "body": json.dumps( { - "azure_resource_name": "new_name_for_test", "attributes": ["title", "subject", "date_created"], "index": "testIndex", "k": 100, - "openai_api_version": "2024-01-01", + "model_id": "model_override", "question": "test question", "ref": "test ref", "temperature": 0.9, @@ -51,20 +34,15 @@ def test_attempt_override_without_superuser_status(self): ) expected_output = { "attributes": EventConfig.DEFAULT_ATTRIBUTES, - "azure_endpoint": "https://test.openai.azure.com/", + "model_id": "test", "k": 5, - "openai_api_version": "2023-07-01-preview", "question": "test question", "ref": "test ref", "temperature": 0.2, "text_key": "title", } - self.assertEqual(actual.azure_endpoint, expected_output["azure_endpoint"]) self.assertEqual(actual.attributes, expected_output["attributes"]) self.assertEqual(actual.k, expected_output["k"]) - self.assertEqual( - actual.openai_api_version, expected_output["openai_api_version"] - ) self.assertEqual(actual.question, expected_output["question"]) self.assertEqual(actual.ref, expected_output["ref"]) self.assertEqual(actual.temperature, expected_output["temperature"])