-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Full test coverage for setup.py and OpenSearchNeuralSearch
- Loading branch information
Showing
2 changed files
with
272 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
import os | ||
from opensearchpy import RequestsHttpConnection | ||
import sys | ||
|
||
sys.path.append('./src') | ||
|
||
from core.setup import chat_model, checkpoint_saver, prefix, opensearch_endpoint, opensearch_client, opensearch_vector_store, websocket_client | ||
|
||
|
||
class TestChatModel(unittest.TestCase): | ||
def test_chat_model_returns_bedrock_instance(self): | ||
kwargs = {"model_id": "test_model"} | ||
with patch("core.setup.ChatBedrock") as mock_bedrock: | ||
result = chat_model(**kwargs) | ||
mock_bedrock.assert_called_once_with(**kwargs) | ||
self.assertEqual(result, mock_bedrock.return_value) | ||
|
||
class TestCheckpointSaver(unittest.TestCase): | ||
@patch.dict(os.environ, {"CHECKPOINT_BUCKET_NAME": "test-bucket"}) | ||
@patch("core.setup.S3Checkpointer") | ||
def test_checkpoint_saver_initialization(self, mock_s3_checkpointer): | ||
kwargs = {"prefix": "test"} | ||
result = checkpoint_saver(**kwargs) | ||
|
||
mock_s3_checkpointer.assert_called_once_with( | ||
bucket_name="test-bucket", | ||
**kwargs | ||
) | ||
self.assertEqual(result, mock_s3_checkpointer.return_value) | ||
|
||
class TestPrefix(unittest.TestCase): | ||
def test_prefix_with_env_prefix(self): | ||
with patch.dict(os.environ, {"ENV_PREFIX": "dev"}): | ||
result = prefix("test") | ||
self.assertEqual(result, "dev-test") | ||
|
||
def test_prefix_without_env_prefix(self): | ||
with patch.dict(os.environ, {"ENV_PREFIX": ""}): | ||
result = prefix("test") | ||
self.assertEqual(result, "test") | ||
|
||
def test_prefix_with_none_env_prefix(self): | ||
with patch.dict(os.environ, clear=True): | ||
result = prefix("test") | ||
self.assertEqual(result, "test") | ||
|
||
class TestOpenSearchEndpoint(unittest.TestCase): | ||
def test_opensearch_endpoint_with_full_url(self): | ||
with patch.dict(os.environ, {"OPENSEARCH_ENDPOINT": "https://test.amazonaws.com"}): | ||
result = opensearch_endpoint() | ||
self.assertEqual(result, "test.amazonaws.com") | ||
|
||
def test_opensearch_endpoint_with_hostname(self): | ||
with patch.dict(os.environ, {"OPENSEARCH_ENDPOINT": "test.amazonaws.com"}): | ||
result = opensearch_endpoint() | ||
self.assertEqual(result, "test.amazonaws.com") | ||
|
||
class TestOpenSearchClient(unittest.TestCase): | ||
@patch("core.setup.boto3.Session") | ||
@patch("core.setup.AWS4Auth") | ||
@patch("core.setup.OpenSearch") | ||
def test_opensearch_client_initialization(self, mock_opensearch, mock_aws4auth, mock_session): | ||
# Setup mock credentials | ||
mock_credentials = MagicMock() | ||
mock_session.return_value.get_credentials.return_value = mock_credentials | ||
|
||
with patch.dict(os.environ, { | ||
"AWS_REGION": "us-west-2", | ||
"OPENSEARCH_ENDPOINT": "test.amazonaws.com" | ||
}): | ||
_result = opensearch_client() | ||
|
||
# Verify AWS4Auth initialization | ||
mock_aws4auth.assert_called_once_with( | ||
region="us-west-2", | ||
service="es", | ||
refreshable_credentials=mock_credentials | ||
) | ||
|
||
# Verify OpenSearch initialization | ||
mock_opensearch.assert_called_once_with( | ||
hosts=[{"host": "test.amazonaws.com", "port": 443}], | ||
use_ssl=True, | ||
connection_class=RequestsHttpConnection, | ||
http_auth=mock_aws4auth.return_value | ||
) | ||
|
||
class TestOpenSearchVectorStore(unittest.TestCase): | ||
@patch("core.setup.boto3.Session") | ||
@patch("core.setup.AWS4Auth") | ||
@patch("core.setup.OpenSearchNeuralSearch") | ||
def test_opensearch_vector_store_initialization(self, mock_neural_search, mock_aws4auth, mock_session): | ||
# Setup mock credentials | ||
mock_credentials = MagicMock() | ||
mock_session.return_value.get_credentials.return_value = mock_credentials | ||
|
||
with patch.dict(os.environ, { | ||
"AWS_REGION": "us-west-2", | ||
"OPENSEARCH_ENDPOINT": "test.amazonaws.com", | ||
"OPENSEARCH_MODEL_ID": "test-model", | ||
"ENV_PREFIX": "dev" | ||
}): | ||
_result = opensearch_vector_store() | ||
|
||
# Verify AWS4Auth initialization | ||
mock_aws4auth.assert_called_once_with( | ||
region="us-west-2", | ||
service="es", | ||
refreshable_credentials=mock_credentials | ||
) | ||
|
||
# Verify OpenSearchNeuralSearch initialization | ||
mock_neural_search.assert_called_once_with( | ||
index="dev-dc-v2-work", | ||
model_id="test-model", | ||
endpoint="test.amazonaws.com", | ||
connection_class=RequestsHttpConnection, | ||
http_auth=mock_aws4auth.return_value, | ||
text_field="id" | ||
) | ||
|
||
class TestWebsocketClient(unittest.TestCase): | ||
@patch("core.setup.boto3.client") | ||
def test_websocket_client_with_provided_endpoint(self, mock_boto3_client): | ||
endpoint_url = "https://test-ws.amazonaws.com" | ||
result = websocket_client(endpoint_url) | ||
|
||
mock_boto3_client.assert_called_once_with( | ||
"apigatewaymanagementapi", | ||
endpoint_url=endpoint_url | ||
) | ||
self.assertEqual(result, mock_boto3_client.return_value) | ||
|
||
@patch("core.setup.boto3.client") | ||
def test_websocket_client_with_env_endpoint(self, mock_boto3_client): | ||
with patch.dict(os.environ, {"APIGATEWAY_URL": "https://test-ws-env.amazonaws.com"}): | ||
result = websocket_client(None) | ||
|
||
mock_boto3_client.assert_called_once_with( | ||
"apigatewaymanagementapi", | ||
endpoint_url="https://test-ws-env.amazonaws.com" | ||
) | ||
self.assertEqual(result, mock_boto3_client.return_value) | ||
|
||
@patch("core.setup.boto3.client") | ||
def test_websocket_client_error_handling(self, mock_boto3_client): | ||
mock_boto3_client.side_effect = Exception("Connection error") | ||
|
||
with self.assertRaises(Exception): | ||
websocket_client("https://test-ws.amazonaws.com") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,135 @@ | ||
# ruff: noqa: E402 | ||
import sys | ||
from unittest import TestCase | ||
from unittest.mock import Mock, patch | ||
from opensearchpy import ConnectionError, AuthenticationException, NotFoundError | ||
sys.path.append('./src') | ||
|
||
from unittest import TestCase | ||
from search.opensearch_neural_search import OpenSearchNeuralSearch | ||
from langchain_core.documents import Document | ||
|
||
class MockClient(): | ||
def search(self, index, body, params): | ||
return { | ||
"hits": { | ||
"hits": [ | ||
{ | ||
"_source": { | ||
"id": "test" | ||
}, | ||
"_score": 0.12345 | ||
} | ||
] | ||
} | ||
"hits": { | ||
"hits": [ | ||
{ | ||
"_source": { | ||
"id": "test" | ||
}, | ||
"_score": 0.12345 | ||
} | ||
] | ||
} | ||
} | ||
|
||
class MockErrorClient(): | ||
def search(self, index, body, params): | ||
raise ConnectionError("Failed to connect to OpenSearch") | ||
|
||
class TestOpenSearchNeuralSearch(TestCase): | ||
def test_similarity_search(self): | ||
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search(query="test", subquery={"_source": {"excludes": ["embedding"]}}, size=10) | ||
self.assertEqual(docs, [Document(page_content='test', metadata={'id': 'test'})]) | ||
def setUp(self): | ||
self.search = OpenSearchNeuralSearch( | ||
client=MockClient(), | ||
endpoint="test", | ||
index="test", | ||
model_id="test" | ||
) | ||
|
||
self.error_search = OpenSearchNeuralSearch( | ||
client=MockErrorClient(), | ||
endpoint="test", | ||
index="test", | ||
model_id="test" | ||
) | ||
|
||
def test_similarity_search(self): | ||
docs = self.search.similarity_search( | ||
query="test", | ||
subquery={"_source": {"excludes": ["embedding"]}}, | ||
size=10 | ||
) | ||
self.assertEqual( | ||
docs, | ||
[Document(page_content='test', metadata={'id': 'test'})] | ||
) | ||
|
||
def test_similarity_search_connection_error(self): | ||
with self.assertRaises(ConnectionError): | ||
self.error_search.similarity_search(query="test") | ||
|
||
@patch('opensearchpy.OpenSearch') | ||
def test_similarity_search_auth_error(self, mock_opensearch): | ||
mock_opensearch.return_value.search.side_effect = AuthenticationException( | ||
"Authentication failed" | ||
) | ||
search = OpenSearchNeuralSearch( | ||
client=mock_opensearch.return_value, | ||
endpoint="test", | ||
index="test", | ||
model_id="test" | ||
) | ||
with self.assertRaises(AuthenticationException): | ||
search.similarity_search(query="test") | ||
|
||
def test_similarity_search_with_score(self): | ||
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search_with_score(query="test") | ||
self.assertEqual(docs, [(Document(page_content='test', metadata={'id': 'test'}), 0.12345)]) | ||
|
||
def test_add_texts(self): | ||
try: | ||
OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").add_texts(texts=["test"], metadatas=[{"id": "test"}]) | ||
except Exception as e: | ||
self.fail(f"from_texts raised an exception: {e}") | ||
|
||
def test_from_texts(self): | ||
try: | ||
OpenSearchNeuralSearch.from_texts(clas="test", texts=["test"], metadatas=[{"id": "test"}]) | ||
except Exception as e: | ||
self.fail(f"from_texts raised an exception: {e}") | ||
docs = self.search.similarity_search_with_score(query="test") | ||
self.assertEqual( | ||
docs, | ||
[(Document(page_content='test', metadata={'id': 'test'}), 0.12345)] | ||
) | ||
|
||
def test_similarity_search_with_score_connection_error(self): | ||
with self.assertRaises(ConnectionError): | ||
self.error_search.similarity_search_with_score(query="test") | ||
|
||
@patch('opensearchpy.OpenSearch') | ||
def test_aggregations_search_index_not_found(self, mock_opensearch): | ||
mock_opensearch.return_value.search.side_effect = NotFoundError( | ||
404, | ||
"index_not_found_exception", | ||
{"error": "index not found"} | ||
) | ||
search = OpenSearchNeuralSearch( | ||
client=mock_opensearch.return_value, | ||
endpoint="test", | ||
index="test", | ||
model_id="test" | ||
) | ||
with self.assertRaises(NotFoundError): | ||
search.aggregations_search(agg_field="test_field") | ||
|
||
def test_aggregations_search_connection_error(self): | ||
with self.assertRaises(ConnectionError): | ||
self.error_search.aggregations_search(agg_field="test_field") | ||
|
||
def test_add_texts_exception(self): | ||
# Test to ensure the exception handler works | ||
with self.assertRaises(AssertionError) as context: | ||
search = self.search | ||
search.add_texts = Mock(side_effect=Exception("Test exception")) | ||
try: | ||
search.add_texts(texts=["test"], metadatas=[{"id": "test"}]) | ||
except Exception as e: | ||
self.fail(f"add_texts raised an exception: {e}") | ||
|
||
self.assertTrue("add_texts raised an exception: Test exception" in str(context.exception)) | ||
|
||
def test_from_texts_exception(self): | ||
with self.assertRaises(AssertionError) as context: | ||
OpenSearchNeuralSearch.from_texts = Mock(side_effect=Exception("Test exception")) | ||
try: | ||
OpenSearchNeuralSearch.from_texts(texts=["test"], metadatas=[{"id": "test"}]) | ||
except Exception as e: | ||
self.fail(f"from_texts raised an exception: {e}") | ||
|
||
self.assertTrue("from_texts raised an exception: Test exception" in str(context.exception)) | ||
|
||
def test_client_initialization_error(self): | ||
with self.assertRaises(ValueError): | ||
OpenSearchNeuralSearch( | ||
endpoint="", # Empty endpoint should raise ValueError | ||
index="test", | ||
model_id="test", | ||
client=None | ||
) |