diff --git a/enterprise_catalog/apps/ai_curation/openai_client.py b/enterprise_catalog/apps/ai_curation/openai_client.py index f8f71b320..ecba25de2 100644 --- a/enterprise_catalog/apps/ai_curation/openai_client.py +++ b/enterprise_catalog/apps/ai_curation/openai_client.py @@ -1,17 +1,11 @@ import functools +import json import logging import backoff -import simplejson +import requests from django.conf import settings -from openai import ( - APIConnectionError, - APIError, - APITimeoutError, - InternalServerError, - OpenAI, - RateLimitError, -) +from requests.exceptions import ConnectTimeout from enterprise_catalog.apps.ai_curation.errors import ( AICurationError, @@ -21,8 +15,6 @@ LOGGER = logging.getLogger(__name__) -client = OpenAI(api_key=settings.OPENAI_API_KEY) - def api_error_handler(func): """ @@ -37,7 +29,7 @@ def api_error_handler(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) - except (APIError, AICurationError) as ex: + except (ConnectionError, AICurationError) as ex: LOGGER.exception('[AI_CURATION] API Error: Prompt: [%s]', kwargs.get('messages')) # status_code attribute is not available for all exceptions, such as APIConnectionError and APITimeoutError status_code = getattr(ex, 'status_code', None) @@ -49,27 +41,19 @@ def wrapper(*args, **kwargs): @api_error_handler @backoff.on_exception( backoff.expo, - (APIConnectionError, APITimeoutError, InternalServerError, RateLimitError, InvalidJSONResponseError), + (ConnectTimeout, ConnectionError, InvalidJSONResponseError), max_tries=3, ) def chat_completions( messages, response_format='json', - response_type=list, - model="gpt-4", - temperature=0.3, - max_tokens=500, ): """ - Get a response from the chat.completions endpoint + Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting. Args: messages (list): List of messages to send to the chat.completions endpoint response_format (str): Format of the response. Can be 'json' or 'text' - response_type (any): Expected type of the response. For now we only expect `list` - model (str): Model to use for the completion - temperature (number): Make model output more focused and deterministic - max_tokens (int): Maximum number of tokens that can be generated in the chat completion Returns: : The response from the chat.completions endpoint @@ -81,32 +65,31 @@ def chat_completions( - status_code (int): The actual error code returned by the API """ LOGGER.info('[AI_CURATION] [CHAT_COMPLETIONS] Prompt: [%s]', messages) - response = client.chat.completions.create( - messages=messages, - model=model, - temperature=temperature, - max_tokens=max_tokens, + + headers = {'Content-Type': 'application/json', 'x-api-key': settings.CHAT_COMPLETION_API_KEY} + message_list = [] + for message in messages: + message_list.append({'role': 'assistant', 'content': message['content']}) + body = {'message_list': message_list} + response = requests.post( + settings.CHAT_COMPLETION_API, + headers=headers, + data=json.dumps(body), + timeout=( + settings.CHAT_COMPLETION_API_CONNECT_TIMEOUT, + settings.CHAT_COMPLETION_API_READ_TIMEOUT + ) ) LOGGER.info('[AI_CURATION] [CHAT_COMPLETIONS] Response: [%s]', response) - response_content = response.choices[0].message.content - - if response_format == 'json': - try: - json_response = simplejson.loads(response_content) - if isinstance(json_response, response_type): - return json_response - LOGGER.error( - '[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]', - messages, - response - ) - raise InvalidJSONResponseError('Invalid response type received from chatgpt') - except simplejson.errors.JSONDecodeError as ex: - LOGGER.error( - '[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]', - messages, - response - ) - raise InvalidJSONResponseError('Invalid JSON response received from chatgpt') from ex - - return response_content + try: + response_content = response.json().get('content') + if response_format == 'json': + return json.loads(response_content) + return json.loads(response_content)[0] + except requests.exceptions.JSONDecodeError as ex: + LOGGER.error( + '[AI_CURATION] Invalid JSON response received: Prompt: [%s], Response: [%s]', + messages, + response + ) + raise InvalidJSONResponseError('Invalid response received.') from ex diff --git a/enterprise_catalog/apps/ai_curation/tests/test_utils.py b/enterprise_catalog/apps/ai_curation/tests/test_utils.py index 0ac23cc52..947dede3a 100644 --- a/enterprise_catalog/apps/ai_curation/tests/test_utils.py +++ b/enterprise_catalog/apps/ai_curation/tests/test_utils.py @@ -4,12 +4,10 @@ import json import logging from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import patch -import httpx from django.conf import settings from django.test import TestCase -from openai import APIConnectionError from enterprise_catalog.apps.ai_curation.errors import AICurationError from enterprise_catalog.apps.ai_curation.utils.algolia_utils import ( @@ -75,23 +73,22 @@ def test_fetch_catalog_metadata_from_algolia(self, mock_algolia_client): class TestChatCompletionUtils(TestCase): @patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.LOGGER') - @patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create') - def test_get_filtered_subjects(self, mock_create, mock_logger): + @patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post') + def test_get_filtered_subjects(self, mock_requests, mock_logger): """ Test that get_filtered_subjects returns the correct filtered subjects """ - mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(['subject1', 'subject2'])))] - ) + mock_requests.return_value.json.return_value = { + "role": "assistant", + "content": json.dumps(['subject1', 'subject2']) + } subjects = ['subject1', 'subject2', 'subject3', 'subject4'] query = 'test query' expected_content = settings.AI_CURATION_FILTER_SUBJECTS_PROMPT.format(query=query, subjects=subjects) result = get_filtered_subjects(query, subjects) - mock_create.assert_called_once_with( - messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS - ) + mock_requests.assert_called_once() mock_logger.info.assert_has_calls( [ mock.call( @@ -103,97 +100,22 @@ def test_get_filtered_subjects(self, mock_create, mock_logger): ) assert result == ['subject1', 'subject2'] - @patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER') - @patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create') - def test_invalid_json(self, mock_create, mock_logger): - """ - Test that correct exception is raised if chat.completions.create send an invalid json - """ - mock_create.return_value = MagicMock(choices=[MagicMock(message=MagicMock(content='non json response'))]) - - messages = [ - { - 'role': 'system', - 'content': 'I am a prompt' - } - ] - with self.assertRaises(AICurationError): - chat_completions(messages) - - assert mock_create.call_count == 3 - assert mock_logger.error.called - mock_logger.error.assert_has_calls([ - mock.call( - '[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]', - [{'role': 'system', 'content': 'I am a prompt'}], - mock.ANY - ), - mock.call( - '[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]', - [{'role': 'system', 'content': 'I am a prompt'}], - mock.ANY - ), - mock.call( - '[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]', - [{'role': 'system', 'content': 'I am a prompt'}], - mock.ANY - ) - ]) - - @patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER') - @patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create') - def test_valid_json_with_wrong_type(self, mock_create, mock_logger): - """ - Test that correct exception is raised if chat.completions.create send a valid json but wrong type - """ - mock_create.return_value = MagicMock(choices=[MagicMock(message=MagicMock(content='{"a": 1}'))]) - - messages = [ - { - 'role': 'system', - 'content': 'I am a prompt' - } - ] - with self.assertRaises(AICurationError): - chat_completions(messages) - - assert mock_create.call_count == 3 - assert mock_logger.error.called - mock_logger.error.assert_has_calls([ - mock.call( - '[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]', - [{'role': 'system', 'content': 'I am a prompt'}], - mock.ANY - ), - mock.call( - '[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]', - [{'role': 'system', 'content': 'I am a prompt'}], - mock.ANY - ), - mock.call( - '[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]', - [{'role': 'system', 'content': 'I am a prompt'}], - mock.ANY - ) - ]) - @patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.LOGGER') - @patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create') - def test_get_query_keywords(self, mock_create, mock_logger): + @patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post') + def test_get_query_keywords(self, mock_requests, mock_logger): """ Test that get_query_keywords returns the correct keywords """ - mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(['keyword1', 'keyword2'])))] - ) + mock_requests.return_value.json.return_value = { + "role": "assistant", + "content": json.dumps(['keyword1', 'keyword2']) + } query = 'test query' expected_content = settings.AI_CURATION_QUERY_TO_KEYWORDS_PROMPT.format(query=query) result = get_query_keywords(query) - mock_create.assert_called_once_with( - messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS - ) + mock_requests.assert_called_once() mock_logger.info.assert_has_calls( [ mock.call( @@ -206,25 +128,24 @@ def test_get_query_keywords(self, mock_create, mock_logger): assert result == ['keyword1', 'keyword2'] @patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.LOGGER') - @patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create') + @patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post') @patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.get_query_keywords') - def test_get_keywords_to_prose(self, mock_get_query_keywords, mock_create, mock_logger): + def test_get_keywords_to_prose(self, mock_get_query_keywords, mock_requests, mock_logger): """ Test that get_keywords_to_prose returns the correct prose """ mock_get_query_keywords.return_value = ['keyword1', 'keyword2'] - mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(['I am a prose'])))] - ) + mock_requests.return_value.json.return_value = { + "role": "assistant", + "content": json.dumps(['I am a prose']) + } query = 'test query' keywords = ['keyword1', 'keyword2'] expected_content = settings.AI_CURATION_KEYWORDS_TO_PROSE_PROMPT.format(query=query, keywords=keywords) result = get_keywords_to_prose(query) - mock_create.assert_called_once_with( - messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS - ) + mock_requests.assert_called_once() mock_logger.info.assert_has_calls( [ mock.call( @@ -237,12 +158,12 @@ def test_get_keywords_to_prose(self, mock_get_query_keywords, mock_create, mock_ assert result == 'I am a prose' @patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER') - @patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create') - def test_chat_completions_retries(self, mock_create, mock_logger): + @patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post') + def test_chat_completions_retries(self, mock_requests, mock_logger): """ Test that retries work as expected for chat_completions """ - mock_create.side_effect = APIConnectionError(request=httpx.Request("GET", "https://api.example.com")) + mock_requests.side_effect = ConnectionError() messages = [ { 'role': 'system', @@ -254,34 +175,8 @@ def test_chat_completions_retries(self, mock_create, mock_logger): with mock.patch.multiple(backoff_logger, info=mock.DEFAULT, error=mock.DEFAULT) as mock_backoff_logger: chat_completions(messages=messages) - assert mock_create.call_count == 3 + assert mock_requests.call_count == 3 assert mock_backoff_logger['info'].call_count == 2 - mock_backoff_logger['info'].assert_has_calls( - [ - mock.call( - 'Backing off %s(...) for %.1fs (%s)', - 'chat_completions', - mock.ANY, - 'openai.APIConnectionError: Connection error.' - ), - mock.call( - 'Backing off %s(...) for %.1fs (%s)', - 'chat_completions', - mock.ANY, - 'openai.APIConnectionError: Connection error.' - ) - ] - ) assert mock_backoff_logger['error'].call_count == 1 - mock_backoff_logger['error'].assert_has_calls( - [ - mock.call( - 'Giving up %s(...) after %d tries (%s)', - 'chat_completions', - 3, - 'openai.APIConnectionError: Connection error.' - ) - ] - ) assert mock_logger.exception.called mock_logger.exception.assert_has_calls([mock.call('[AI_CURATION] API Error: Prompt: [%s]', messages)]) diff --git a/enterprise_catalog/apps/ai_curation/tests/utils/test_generate_curation_utils.py b/enterprise_catalog/apps/ai_curation/tests/utils/test_generate_curation_utils.py index b2df8d991..ed6a51a80 100644 --- a/enterprise_catalog/apps/ai_curation/tests/utils/test_generate_curation_utils.py +++ b/enterprise_catalog/apps/ai_curation/tests/utils/test_generate_curation_utils.py @@ -2,7 +2,7 @@ Tests for ai_curation app utils. """ import json -from unittest.mock import MagicMock, patch +from unittest.mock import patch from django.test import TestCase @@ -95,16 +95,17 @@ def test_apply_keywords_filter(self): assert apply_keywords_filter(courses, ['java']) == [] - @patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create') + @patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post') @patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.get_query_keywords') - def test_apply_tfidf_filter(self, mock_get_query_keywords, mock_create): + def test_apply_tfidf_filter(self, mock_get_query_keywords, mock_requests): """ Validate apply_tfidf_filter function. """ mock_get_query_keywords.return_value = ['keyword1', 'keyword2'] - mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(['Learn data science with python'])))] - ) + mock_requests.return_value.json.return_value = { + "role": "assistant", + "content": json.dumps(['Learn data science with python']) + } courses = [ { diff --git a/enterprise_catalog/settings/devstack.py b/enterprise_catalog/settings/devstack.py index bb7d70b3a..6cc7cdeb1 100644 --- a/enterprise_catalog/settings/devstack.py +++ b/enterprise_catalog/settings/devstack.py @@ -87,3 +87,8 @@ 'LOCATION': 'enterprise.catalog.memcached:11211', } } + +CHAT_COMPLETION_API = 'http://test.chat.ai' +CHAT_COMPLETION_API_KEY = 'test chat completion api key' +CHAT_COMPLETION_API_CONNECT_TIMEOUT = 1 +CHAT_COMPLETION_API_READ_TIMEOUT = 15 diff --git a/enterprise_catalog/settings/local.py b/enterprise_catalog/settings/local.py index 7e44697b9..afd260024 100644 --- a/enterprise_catalog/settings/local.py +++ b/enterprise_catalog/settings/local.py @@ -61,6 +61,10 @@ }) ENABLE_AUTO_AUTH = True +CHAT_COMPLETION_API = 'http://test.chat.ai' +CHAT_COMPLETION_API_KEY = 'test chat completion api key' +CHAT_COMPLETION_API_CONNECT_TIMEOUT = 1 +CHAT_COMPLETION_API_READ_TIMEOUT = 15 ##################################################################### # Lastly, see if the developer has any local overrides. diff --git a/enterprise_catalog/settings/test.py b/enterprise_catalog/settings/test.py index f39925675..18099ac1a 100644 --- a/enterprise_catalog/settings/test.py +++ b/enterprise_catalog/settings/test.py @@ -30,3 +30,8 @@ results_dir = tempfile.TemporaryDirectory() CELERY_RESULT_BACKEND = f'file://{results_dir.name}' + +CHAT_COMPLETION_API = 'http://test.chat.ai' +CHAT_COMPLETION_API_KEY = 'test chat completion api key' +CHAT_COMPLETION_API_CONNECT_TIMEOUT = 1 +CHAT_COMPLETION_API_READ_TIMEOUT = 15