Skip to content

Commit

Permalink
Merge pull request #870 from openedx/iahmad/ENT-8980
Browse files Browse the repository at this point in the history
feat: Replaced ai chat client for curation
  • Loading branch information
irfanuddinahmad authored Jul 19, 2024
2 parents d16d72d + c92c9ca commit 0e5595d
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 186 deletions.
81 changes: 32 additions & 49 deletions enterprise_catalog/apps/ai_curation/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import functools
import json
import logging

import backoff
import simplejson
import requests
from django.conf import settings
from openai import (
APIConnectionError,
APIError,
APITimeoutError,
InternalServerError,
OpenAI,
RateLimitError,
)
from requests.exceptions import ConnectTimeout

from enterprise_catalog.apps.ai_curation.errors import (
AICurationError,
Expand All @@ -21,8 +15,6 @@

LOGGER = logging.getLogger(__name__)

client = OpenAI(api_key=settings.OPENAI_API_KEY)


def api_error_handler(func):
"""
Expand All @@ -37,7 +29,7 @@ def api_error_handler(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (APIError, AICurationError) as ex:
except (ConnectionError, AICurationError) as ex:
LOGGER.exception('[AI_CURATION] API Error: Prompt: [%s]', kwargs.get('messages'))
# status_code attribute is not available for all exceptions, such as APIConnectionError and APITimeoutError
status_code = getattr(ex, 'status_code', None)
Expand All @@ -49,27 +41,19 @@ def wrapper(*args, **kwargs):
@api_error_handler
@backoff.on_exception(
backoff.expo,
(APIConnectionError, APITimeoutError, InternalServerError, RateLimitError, InvalidJSONResponseError),
(ConnectTimeout, ConnectionError, InvalidJSONResponseError),
max_tries=3,
)
def chat_completions(
messages,
response_format='json',
response_type=list,
model="gpt-4",
temperature=0.3,
max_tokens=500,
):
"""
Get a response from the chat.completions endpoint
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
Args:
messages (list): List of messages to send to the chat.completions endpoint
response_format (str): Format of the response. Can be 'json' or 'text'
response_type (any): Expected type of the response. For now we only expect `list`
model (str): Model to use for the completion
temperature (number): Make model output more focused and deterministic
max_tokens (int): Maximum number of tokens that can be generated in the chat completion
Returns:
<list, text>: The response from the chat.completions endpoint
Expand All @@ -81,32 +65,31 @@ def chat_completions(
- status_code (int): The actual error code returned by the API
"""
LOGGER.info('[AI_CURATION] [CHAT_COMPLETIONS] Prompt: [%s]', messages)
response = client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,

headers = {'Content-Type': 'application/json', 'x-api-key': settings.CHAT_COMPLETION_API_KEY}
message_list = []
for message in messages:
message_list.append({'role': 'assistant', 'content': message['content']})
body = {'message_list': message_list}
response = requests.post(
settings.CHAT_COMPLETION_API,
headers=headers,
data=json.dumps(body),
timeout=(
settings.CHAT_COMPLETION_API_CONNECT_TIMEOUT,
settings.CHAT_COMPLETION_API_READ_TIMEOUT
)
)
LOGGER.info('[AI_CURATION] [CHAT_COMPLETIONS] Response: [%s]', response)
response_content = response.choices[0].message.content

if response_format == 'json':
try:
json_response = simplejson.loads(response_content)
if isinstance(json_response, response_type):
return json_response
LOGGER.error(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
messages,
response
)
raise InvalidJSONResponseError('Invalid response type received from chatgpt')
except simplejson.errors.JSONDecodeError as ex:
LOGGER.error(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
messages,
response
)
raise InvalidJSONResponseError('Invalid JSON response received from chatgpt') from ex

return response_content
try:
response_content = response.json().get('content')
if response_format == 'json':
return json.loads(response_content)
return json.loads(response_content)[0]
except requests.exceptions.JSONDecodeError as ex:
LOGGER.error(
'[AI_CURATION] Invalid JSON response received: Prompt: [%s], Response: [%s]',
messages,
response
)
raise InvalidJSONResponseError('Invalid response received.') from ex
157 changes: 26 additions & 131 deletions enterprise_catalog/apps/ai_curation/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
import json
import logging
from unittest import mock
from unittest.mock import MagicMock, patch
from unittest.mock import patch

import httpx
from django.conf import settings
from django.test import TestCase
from openai import APIConnectionError

from enterprise_catalog.apps.ai_curation.errors import AICurationError
from enterprise_catalog.apps.ai_curation.utils.algolia_utils import (
Expand Down Expand Up @@ -75,23 +73,22 @@ def test_fetch_catalog_metadata_from_algolia(self, mock_algolia_client):

class TestChatCompletionUtils(TestCase):
@patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_get_filtered_subjects(self, mock_create, mock_logger):
@patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post')
def test_get_filtered_subjects(self, mock_requests, mock_logger):
"""
Test that get_filtered_subjects returns the correct filtered subjects
"""
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=json.dumps(['subject1', 'subject2'])))]
)
mock_requests.return_value.json.return_value = {
"role": "assistant",
"content": json.dumps(['subject1', 'subject2'])
}
subjects = ['subject1', 'subject2', 'subject3', 'subject4']
query = 'test query'
expected_content = settings.AI_CURATION_FILTER_SUBJECTS_PROMPT.format(query=query, subjects=subjects)

result = get_filtered_subjects(query, subjects)

mock_create.assert_called_once_with(
messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS
)
mock_requests.assert_called_once()
mock_logger.info.assert_has_calls(
[
mock.call(
Expand All @@ -103,97 +100,22 @@ def test_get_filtered_subjects(self, mock_create, mock_logger):
)
assert result == ['subject1', 'subject2']

@patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_invalid_json(self, mock_create, mock_logger):
"""
Test that correct exception is raised if chat.completions.create send an invalid json
"""
mock_create.return_value = MagicMock(choices=[MagicMock(message=MagicMock(content='non json response'))])

messages = [
{
'role': 'system',
'content': 'I am a prompt'
}
]
with self.assertRaises(AICurationError):
chat_completions(messages)

assert mock_create.call_count == 3
assert mock_logger.error.called
mock_logger.error.assert_has_calls([
mock.call(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
)
])

@patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_valid_json_with_wrong_type(self, mock_create, mock_logger):
"""
Test that correct exception is raised if chat.completions.create send a valid json but wrong type
"""
mock_create.return_value = MagicMock(choices=[MagicMock(message=MagicMock(content='{"a": 1}'))])

messages = [
{
'role': 'system',
'content': 'I am a prompt'
}
]
with self.assertRaises(AICurationError):
chat_completions(messages)

assert mock_create.call_count == 3
assert mock_logger.error.called
mock_logger.error.assert_has_calls([
mock.call(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
)
])

@patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_get_query_keywords(self, mock_create, mock_logger):
@patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post')
def test_get_query_keywords(self, mock_requests, mock_logger):
"""
Test that get_query_keywords returns the correct keywords
"""
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=json.dumps(['keyword1', 'keyword2'])))]
)
mock_requests.return_value.json.return_value = {
"role": "assistant",
"content": json.dumps(['keyword1', 'keyword2'])
}
query = 'test query'
expected_content = settings.AI_CURATION_QUERY_TO_KEYWORDS_PROMPT.format(query=query)

result = get_query_keywords(query)

mock_create.assert_called_once_with(
messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS
)
mock_requests.assert_called_once()
mock_logger.info.assert_has_calls(
[
mock.call(
Expand All @@ -206,25 +128,24 @@ def test_get_query_keywords(self, mock_create, mock_logger):
assert result == ['keyword1', 'keyword2']

@patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
@patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post')
@patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.get_query_keywords')
def test_get_keywords_to_prose(self, mock_get_query_keywords, mock_create, mock_logger):
def test_get_keywords_to_prose(self, mock_get_query_keywords, mock_requests, mock_logger):
"""
Test that get_keywords_to_prose returns the correct prose
"""
mock_get_query_keywords.return_value = ['keyword1', 'keyword2']
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=json.dumps(['I am a prose'])))]
)
mock_requests.return_value.json.return_value = {
"role": "assistant",
"content": json.dumps(['I am a prose'])
}
query = 'test query'
keywords = ['keyword1', 'keyword2']
expected_content = settings.AI_CURATION_KEYWORDS_TO_PROSE_PROMPT.format(query=query, keywords=keywords)

result = get_keywords_to_prose(query)

mock_create.assert_called_once_with(
messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS
)
mock_requests.assert_called_once()
mock_logger.info.assert_has_calls(
[
mock.call(
Expand All @@ -237,12 +158,12 @@ def test_get_keywords_to_prose(self, mock_get_query_keywords, mock_create, mock_
assert result == 'I am a prose'

@patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_chat_completions_retries(self, mock_create, mock_logger):
@patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post')
def test_chat_completions_retries(self, mock_requests, mock_logger):
"""
Test that retries work as expected for chat_completions
"""
mock_create.side_effect = APIConnectionError(request=httpx.Request("GET", "https://api.example.com"))
mock_requests.side_effect = ConnectionError()
messages = [
{
'role': 'system',
Expand All @@ -254,34 +175,8 @@ def test_chat_completions_retries(self, mock_create, mock_logger):
with mock.patch.multiple(backoff_logger, info=mock.DEFAULT, error=mock.DEFAULT) as mock_backoff_logger:
chat_completions(messages=messages)

assert mock_create.call_count == 3
assert mock_requests.call_count == 3
assert mock_backoff_logger['info'].call_count == 2
mock_backoff_logger['info'].assert_has_calls(
[
mock.call(
'Backing off %s(...) for %.1fs (%s)',
'chat_completions',
mock.ANY,
'openai.APIConnectionError: Connection error.'
),
mock.call(
'Backing off %s(...) for %.1fs (%s)',
'chat_completions',
mock.ANY,
'openai.APIConnectionError: Connection error.'
)
]
)
assert mock_backoff_logger['error'].call_count == 1
mock_backoff_logger['error'].assert_has_calls(
[
mock.call(
'Giving up %s(...) after %d tries (%s)',
'chat_completions',
3,
'openai.APIConnectionError: Connection error.'
)
]
)
assert mock_logger.exception.called
mock_logger.exception.assert_has_calls([mock.call('[AI_CURATION] API Error: Prompt: [%s]', messages)])
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Tests for ai_curation app utils.
"""
import json
from unittest.mock import MagicMock, patch
from unittest.mock import patch

from django.test import TestCase

Expand Down Expand Up @@ -95,16 +95,17 @@ def test_apply_keywords_filter(self):

assert apply_keywords_filter(courses, ['java']) == []

@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
@patch('enterprise_catalog.apps.ai_curation.openai_client.requests.post')
@patch('enterprise_catalog.apps.ai_curation.utils.open_ai_utils.get_query_keywords')
def test_apply_tfidf_filter(self, mock_get_query_keywords, mock_create):
def test_apply_tfidf_filter(self, mock_get_query_keywords, mock_requests):
"""
Validate apply_tfidf_filter function.
"""
mock_get_query_keywords.return_value = ['keyword1', 'keyword2']
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=json.dumps(['Learn data science with python'])))]
)
mock_requests.return_value.json.return_value = {
"role": "assistant",
"content": json.dumps(['Learn data science with python'])
}

courses = [
{
Expand Down
Loading

0 comments on commit 0e5595d

Please sign in to comment.