diff --git a/tests/integration_tests/test_embedding_openai.py b/tests/integration_tests/test_embedding_openai.py index eade2349c..0fffd8922 100644 --- a/tests/integration_tests/test_embedding_openai.py +++ b/tests/integration_tests/test_embedding_openai.py @@ -1,10 +1,13 @@ import os -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest +from openai.version import VERSION from guardrails.embedding import OpenAIEmbedding +OPENAI_VERSION = VERSION + class MockOpenAIEmbedding: def __init__( @@ -29,12 +32,6 @@ def json(self): return {"data": self.data} -@pytest.fixture -def mock_openai_embedding(monkeypatch): - monkeypatch.setattr("openai.resources.Embeddings.create", MockOpenAIEmbedding()) - return MockOpenAIEmbedding - - @pytest.mark.skipif( os.environ.get("OPENAI_API_KEY") in [None, "mocked"], reason="openai api key not set", @@ -51,18 +48,32 @@ def test_embedding_query(self): result = e.embed_query("foo") assert len(result) == 1536 - def test_embed_query(self, mock_openai_embedding): + def test_embed_query(self, mocker): + mock_create = None + if OPENAI_VERSION.startswith("0"): + mock_create = mocker.patch("openai.Embedding.create") + else: + mock_create = mocker.patch("openai.resources.Embeddings.create") + + mock_create.return_value = MockOpenAIEmbedding() + instance = OpenAIEmbedding() instance._get_embedding = Mock(return_value=[[1.0, 2.0, 3.0]]) result = instance.embed_query("test query") assert result == [1.0, 2.0, 3.0] - @patch("os.environ.get", return_value="test_api_key") - @patch( - "openai.resources.Embeddings.create", - return_value=MockResponse(data=[[1.0, 2.0, 3.0]]), - ) - def test__get_embedding(self, mock_create, mock_get_env): + def test__get_embedding(self, mocker): + mock_environ = mocker.patch("os.environ.get") + mock_environ.return_value = "test_api_key" + + mock_create = None + if OPENAI_VERSION.startswith("0"): + mock_create = mocker.patch("openai.Embedding.create") + else: + mock_create = mocker.patch("openai.resources.Embeddings.create") + + mock_create.return_value = MockResponse(data=[[1.0, 2.0, 3.0]]) + instance = OpenAIEmbedding(api_key="test_api_key") result = instance._get_embedding(["test text"]) assert result == [[1.0, 2.0, 3.0]]