Skip to content

Commit

Permalink
Full test coverage for setup.py and OpenSearchNeuralSearch
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Dec 18, 2024
1 parent 883fb67 commit c6a848d
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 28 deletions.
152 changes: 152 additions & 0 deletions chat/test/core/test_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import unittest
from unittest.mock import patch, MagicMock
import os
from opensearchpy import RequestsHttpConnection
import sys

sys.path.append('./src')

from core.setup import chat_model, checkpoint_saver, prefix, opensearch_endpoint, opensearch_client, opensearch_vector_store, websocket_client


class TestChatModel(unittest.TestCase):
def test_chat_model_returns_bedrock_instance(self):
kwargs = {"model_id": "test_model"}
with patch("core.setup.ChatBedrock") as mock_bedrock:
result = chat_model(**kwargs)
mock_bedrock.assert_called_once_with(**kwargs)
self.assertEqual(result, mock_bedrock.return_value)

class TestCheckpointSaver(unittest.TestCase):
@patch.dict(os.environ, {"CHECKPOINT_BUCKET_NAME": "test-bucket"})
@patch("core.setup.S3Checkpointer")
def test_checkpoint_saver_initialization(self, mock_s3_checkpointer):
kwargs = {"prefix": "test"}
result = checkpoint_saver(**kwargs)

mock_s3_checkpointer.assert_called_once_with(
bucket_name="test-bucket",
**kwargs
)
self.assertEqual(result, mock_s3_checkpointer.return_value)

class TestPrefix(unittest.TestCase):
def test_prefix_with_env_prefix(self):
with patch.dict(os.environ, {"ENV_PREFIX": "dev"}):
result = prefix("test")
self.assertEqual(result, "dev-test")

def test_prefix_without_env_prefix(self):
with patch.dict(os.environ, {"ENV_PREFIX": ""}):
result = prefix("test")
self.assertEqual(result, "test")

def test_prefix_with_none_env_prefix(self):
with patch.dict(os.environ, clear=True):
result = prefix("test")
self.assertEqual(result, "test")

class TestOpenSearchEndpoint(unittest.TestCase):
def test_opensearch_endpoint_with_full_url(self):
with patch.dict(os.environ, {"OPENSEARCH_ENDPOINT": "https://test.amazonaws.com"}):
result = opensearch_endpoint()
self.assertEqual(result, "test.amazonaws.com")

def test_opensearch_endpoint_with_hostname(self):
with patch.dict(os.environ, {"OPENSEARCH_ENDPOINT": "test.amazonaws.com"}):
result = opensearch_endpoint()
self.assertEqual(result, "test.amazonaws.com")

class TestOpenSearchClient(unittest.TestCase):
@patch("core.setup.boto3.Session")
@patch("core.setup.AWS4Auth")
@patch("core.setup.OpenSearch")
def test_opensearch_client_initialization(self, mock_opensearch, mock_aws4auth, mock_session):
# Setup mock credentials
mock_credentials = MagicMock()
mock_session.return_value.get_credentials.return_value = mock_credentials

with patch.dict(os.environ, {
"AWS_REGION": "us-west-2",
"OPENSEARCH_ENDPOINT": "test.amazonaws.com"
}):
_result = opensearch_client()

# Verify AWS4Auth initialization
mock_aws4auth.assert_called_once_with(
region="us-west-2",
service="es",
refreshable_credentials=mock_credentials
)

# Verify OpenSearch initialization
mock_opensearch.assert_called_once_with(
hosts=[{"host": "test.amazonaws.com", "port": 443}],
use_ssl=True,
connection_class=RequestsHttpConnection,
http_auth=mock_aws4auth.return_value
)

class TestOpenSearchVectorStore(unittest.TestCase):
@patch("core.setup.boto3.Session")
@patch("core.setup.AWS4Auth")
@patch("core.setup.OpenSearchNeuralSearch")
def test_opensearch_vector_store_initialization(self, mock_neural_search, mock_aws4auth, mock_session):
# Setup mock credentials
mock_credentials = MagicMock()
mock_session.return_value.get_credentials.return_value = mock_credentials

with patch.dict(os.environ, {
"AWS_REGION": "us-west-2",
"OPENSEARCH_ENDPOINT": "test.amazonaws.com",
"OPENSEARCH_MODEL_ID": "test-model",
"ENV_PREFIX": "dev"
}):
_result = opensearch_vector_store()

# Verify AWS4Auth initialization
mock_aws4auth.assert_called_once_with(
region="us-west-2",
service="es",
refreshable_credentials=mock_credentials
)

# Verify OpenSearchNeuralSearch initialization
mock_neural_search.assert_called_once_with(
index="dev-dc-v2-work",
model_id="test-model",
endpoint="test.amazonaws.com",
connection_class=RequestsHttpConnection,
http_auth=mock_aws4auth.return_value,
text_field="id"
)

class TestWebsocketClient(unittest.TestCase):
@patch("core.setup.boto3.client")
def test_websocket_client_with_provided_endpoint(self, mock_boto3_client):
endpoint_url = "https://test-ws.amazonaws.com"
result = websocket_client(endpoint_url)

mock_boto3_client.assert_called_once_with(
"apigatewaymanagementapi",
endpoint_url=endpoint_url
)
self.assertEqual(result, mock_boto3_client.return_value)

@patch("core.setup.boto3.client")
def test_websocket_client_with_env_endpoint(self, mock_boto3_client):
with patch.dict(os.environ, {"APIGATEWAY_URL": "https://test-ws-env.amazonaws.com"}):
result = websocket_client(None)

mock_boto3_client.assert_called_once_with(
"apigatewaymanagementapi",
endpoint_url="https://test-ws-env.amazonaws.com"
)
self.assertEqual(result, mock_boto3_client.return_value)

@patch("core.setup.boto3.client")
def test_websocket_client_error_handling(self, mock_boto3_client):
mock_boto3_client.side_effect = Exception("Connection error")

with self.assertRaises(Exception):
websocket_client("https://test-ws.amazonaws.com")
148 changes: 120 additions & 28 deletions chat/test/search/test_opensearch_neural_search.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,135 @@
# ruff: noqa: E402
import sys
from unittest import TestCase
from unittest.mock import Mock, patch
from opensearchpy import ConnectionError, AuthenticationException, NotFoundError
sys.path.append('./src')

from unittest import TestCase
from search.opensearch_neural_search import OpenSearchNeuralSearch
from langchain_core.documents import Document

class MockClient():
def search(self, index, body, params):
return {
"hits": {
"hits": [
{
"_source": {
"id": "test"
},
"_score": 0.12345
}
]
}
"hits": {
"hits": [
{
"_source": {
"id": "test"
},
"_score": 0.12345
}
]
}
}

class MockErrorClient():
def search(self, index, body, params):
raise ConnectionError("Failed to connect to OpenSearch")

class TestOpenSearchNeuralSearch(TestCase):
def test_similarity_search(self):
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search(query="test", subquery={"_source": {"excludes": ["embedding"]}}, size=10)
self.assertEqual(docs, [Document(page_content='test', metadata={'id': 'test'})])
def setUp(self):
self.search = OpenSearchNeuralSearch(
client=MockClient(),
endpoint="test",
index="test",
model_id="test"
)

self.error_search = OpenSearchNeuralSearch(
client=MockErrorClient(),
endpoint="test",
index="test",
model_id="test"
)

def test_similarity_search(self):
docs = self.search.similarity_search(
query="test",
subquery={"_source": {"excludes": ["embedding"]}},
size=10
)
self.assertEqual(
docs,
[Document(page_content='test', metadata={'id': 'test'})]
)

def test_similarity_search_connection_error(self):
with self.assertRaises(ConnectionError):
self.error_search.similarity_search(query="test")

@patch('opensearchpy.OpenSearch')
def test_similarity_search_auth_error(self, mock_opensearch):
mock_opensearch.return_value.search.side_effect = AuthenticationException(
"Authentication failed"
)
search = OpenSearchNeuralSearch(
client=mock_opensearch.return_value,
endpoint="test",
index="test",
model_id="test"
)
with self.assertRaises(AuthenticationException):
search.similarity_search(query="test")

def test_similarity_search_with_score(self):
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search_with_score(query="test")
self.assertEqual(docs, [(Document(page_content='test', metadata={'id': 'test'}), 0.12345)])

def test_add_texts(self):
try:
OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").add_texts(texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"from_texts raised an exception: {e}")

def test_from_texts(self):
try:
OpenSearchNeuralSearch.from_texts(clas="test", texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"from_texts raised an exception: {e}")
docs = self.search.similarity_search_with_score(query="test")
self.assertEqual(
docs,
[(Document(page_content='test', metadata={'id': 'test'}), 0.12345)]
)

def test_similarity_search_with_score_connection_error(self):
with self.assertRaises(ConnectionError):
self.error_search.similarity_search_with_score(query="test")

@patch('opensearchpy.OpenSearch')
def test_aggregations_search_index_not_found(self, mock_opensearch):
mock_opensearch.return_value.search.side_effect = NotFoundError(
404,
"index_not_found_exception",
{"error": "index not found"}
)
search = OpenSearchNeuralSearch(
client=mock_opensearch.return_value,
endpoint="test",
index="test",
model_id="test"
)
with self.assertRaises(NotFoundError):
search.aggregations_search(agg_field="test_field")

def test_aggregations_search_connection_error(self):
with self.assertRaises(ConnectionError):
self.error_search.aggregations_search(agg_field="test_field")

def test_add_texts_exception(self):
# Test to ensure the exception handler works
with self.assertRaises(AssertionError) as context:
search = self.search
search.add_texts = Mock(side_effect=Exception("Test exception"))
try:
search.add_texts(texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"add_texts raised an exception: {e}")

self.assertTrue("add_texts raised an exception: Test exception" in str(context.exception))

def test_from_texts_exception(self):
with self.assertRaises(AssertionError) as context:
OpenSearchNeuralSearch.from_texts = Mock(side_effect=Exception("Test exception"))
try:
OpenSearchNeuralSearch.from_texts(texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"from_texts raised an exception: {e}")

self.assertTrue("from_texts raised an exception: Test exception" in str(context.exception))

def test_client_initialization_error(self):
with self.assertRaises(ValueError):
OpenSearchNeuralSearch(
endpoint="", # Empty endpoint should raise ValueError
index="test",
model_id="test",
client=None
)

0 comments on commit c6a848d

Please sign in to comment.