Skip to content

Commit

Permalink
Fix OpenAI Embedding Mocks (#446)
Browse files Browse the repository at this point in the history
* mock openai embeddings based on version

* lint fix

* remove logs
  • Loading branch information
CalebCourier authored Nov 21, 2023
1 parent 7f34e2d commit ce930f0
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions tests/integration_tests/test_embedding_openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest
from openai.version import VERSION

from guardrails.embedding import OpenAIEmbedding

OPENAI_VERSION = VERSION


class MockOpenAIEmbedding:
def __init__(
Expand All @@ -29,12 +32,6 @@ def json(self):
return {"data": self.data}


@pytest.fixture
def mock_openai_embedding(monkeypatch):
monkeypatch.setattr("openai.resources.Embeddings.create", MockOpenAIEmbedding())
return MockOpenAIEmbedding


@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") in [None, "mocked"],
reason="openai api key not set",
Expand All @@ -51,18 +48,32 @@ def test_embedding_query(self):
result = e.embed_query("foo")
assert len(result) == 1536

def test_embed_query(self, mock_openai_embedding):
def test_embed_query(self, mocker):
mock_create = None
if OPENAI_VERSION.startswith("0"):
mock_create = mocker.patch("openai.Embedding.create")
else:
mock_create = mocker.patch("openai.resources.Embeddings.create")

mock_create.return_value = MockOpenAIEmbedding()

instance = OpenAIEmbedding()
instance._get_embedding = Mock(return_value=[[1.0, 2.0, 3.0]])
result = instance.embed_query("test query")
assert result == [1.0, 2.0, 3.0]

@patch("os.environ.get", return_value="test_api_key")
@patch(
"openai.resources.Embeddings.create",
return_value=MockResponse(data=[[1.0, 2.0, 3.0]]),
)
def test__get_embedding(self, mock_create, mock_get_env):
def test__get_embedding(self, mocker):
mock_environ = mocker.patch("os.environ.get")
mock_environ.return_value = "test_api_key"

mock_create = None
if OPENAI_VERSION.startswith("0"):
mock_create = mocker.patch("openai.Embedding.create")
else:
mock_create = mocker.patch("openai.resources.Embeddings.create")

mock_create.return_value = MockResponse(data=[[1.0, 2.0, 3.0]])

instance = OpenAIEmbedding(api_key="test_api_key")
result = instance._get_embedding(["test text"])
assert result == [[1.0, 2.0, 3.0]]
Expand Down

0 comments on commit ce930f0

Please sign in to comment.