Skip to content

Commit

Permalink
Adapt stability_sdk (zilliztech#277)
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <[email protected]>
  • Loading branch information
jaelgu authored Apr 24, 2023
1 parent 2e6fa14 commit 49d18cc
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 2 deletions.
62 changes: 62 additions & 0 deletions examples/stability_examples/text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import io
import time
from PIL import Image

from gptcache import cache
from gptcache.processor.pre import get_prompt
from gptcache.adapter.stability_sdk import StabilityInference, generation
from gptcache.embedding import Onnx
from gptcache.manager.factory import manager_factory
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation

# init gptcache
onnx = Onnx()
data_manager = manager_factory('sqlite,faiss,local',
data_dir='./',
vector_params={'dimension': onnx.dimension},
object_params={'path': './images'}
)
cache.init(
pre_embedding_func=get_prompt,
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation()
)

# run with gptcache
api_key = os.getenv('STABILITY_KEY', 'key-goes-here')

stability_api = StabilityInference(
key=os.environ['STABILITY_KEY'], # API Key reference.
verbose=False, # Print debug messages.
engine='stable-diffusion-xl-beta-v2-2-2', # Set the engine to use for generation.
)

start = time.time()
answers = stability_api.generate(
prompt='a cat sitting besides a dog',
width=256,
height=256
)

for resp in answers:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary))
assert img.size == (256, 256)
print('Time elapsed 1:', time.time() - start)

start = time.time()
answers = stability_api.generate(
prompt='a dog and a dog sitting together',
width=512,
height=512
)

for resp in answers:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary))
assert img.size == (512, 512)
print('Time elapsed 2:', time.time() - start)
124 changes: 124 additions & 0 deletions gptcache/adapter/stability_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from io import BytesIO
import base64
import warnings
from dataclasses import dataclass
from typing import List

from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils.error import CacheError
from gptcache.utils import (
import_stability, import_pillow
)

import_pillow()
import_stability()

from PIL import Image as PILImage # pylint: disable=C0413
from stability_sdk import client # pylint: disable=C0413
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation # pylint: disable=C0413



class StabilityInference(client.StabilityInference):
"""client.StabilityInference Wrapper
Example:
.. code-block:: python
import os
import io
from PIL import Image
from gptcache import cache
from gptcache.processor.pre import get_prompt
from gptcache.adapter.stability_sdk import StabilityInference, generation
# init gptcache
cache.init(pre_embedding_func=get_prompt)
# run with gptcache
os.environ['STABILITY_KEY'] = 'key-goes-here'
stability_api = StabilityInference(
key=os.environ['STABILITY_KEY'], # API Key reference.
verbose=False, # Print debug messages.
engine="stable-diffusion-xl-beta-v2-2-2", # Set the engine to use for generation.
)
answers = stability_api.generate(
prompt="a cat sitting besides a dog",
width=256,
height=256
)
for resp in answers:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary))
img.save('path/to/save/image.png')
"""

def llm_handler(self, *llm_args, **llm_kwargs):
try:
return super().generate(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("stability error") from e

def generate(self, *args, **kwargs):
width = kwargs.get("width", 512)
height = kwargs.get("height", 512)

def cache_data_convert(cache_data):
return construct_resp_from_cache(cache_data, width=width, height=height)

def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def hook_stream_data(it):
to_save = []
for resp in it:
for artifact in resp.artifacts:
try:
if artifact.finish_reason == generation.FILTER:
warnings.warn(
"Your request activated the API's safety filters and could not be processed."
"Please modify the prompt and try again.")
continue
except AttributeError:
pass
if artifact.type == generation.ARTIFACT_IMAGE:
img_b64 = base64.b64encode(artifact.binary)
to_save.append(img_b64)
yield resp
update_cache_func(Answer(to_save[0], DataType.IMAGE_BASE64))

return hook_stream_data(llm_data)

return adapt(
self.llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)


def construct_resp_from_cache(img_64, height, width):
img_bytes = base64.b64decode((img_64))
img_file = BytesIO(img_bytes)
img = PILImage.open(img_file)
new_size = (width, height)
if new_size != img.size:
img = img.resize(new_size)
buffered = BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
yield MockAnswer(artifacts=[MockArtifact(type=generation.ARTIFACT_IMAGE, binary=img_bytes)])


@dataclass
class MockArtifact:
type: int
binary: bytes


@dataclass
class MockAnswer:
artifacts: List[MockArtifact]


4 changes: 4 additions & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"import_torchvision",
"import_timm",
"import_vit",
"import_stability"
]

import importlib.util
Expand Down Expand Up @@ -155,3 +156,6 @@ def import_vit():

def import_replicate():
_check_library("replicate")

def import_stability():
_check_library("stability_sdk", package="stability-sdk")
2 changes: 0 additions & 2 deletions tests/unit_tests/adapter/test_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import patch

from gptcache.adapter import diffusers as cache_diffusers
from gptcache import cache
from gptcache.processor.pre import get_prompt
Expand Down
121 changes: 121 additions & 0 deletions tests/unit_tests/adapter/test_stability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from unittest.mock import patch
import base64
from io import BytesIO
import os
import numpy as np


from gptcache.adapter import stability_sdk as cache_stability
from gptcache.adapter.stability_sdk import generation, construct_resp_from_cache
from gptcache import cache
from gptcache.processor.pre import get_prompt
from gptcache.manager.factory import manager_factory
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation

from gptcache.utils import (
import_stability, import_pillow
)

import_pillow()
import_stability()

import stability_sdk
from PIL import ImageChops, Image


def test_stability_inference_map():
cache.init(pre_embedding_func=get_prompt)
expected_img = Image.new("RGB", (1, 1))

buffered = BytesIO()
expected_img.save(buffered, format="PNG")
test_img_b64 = base64.b64encode(buffered.getvalue())
expected_response = construct_resp_from_cache(test_img_b64, 1, 1)

stability_api = cache_stability.StabilityInference(key="ThisIsTest")
with patch.object(stability_sdk.client.StabilityInference, "generate") as mock_call:
mock_call.return_value = expected_response

answer_response = stability_api.generate(prompt="Test prompt", width=1, height=1)
answers = []
for resp in answer_response:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
answers.append(Image.open(BytesIO(artifact.binary)))
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}"
diff = ImageChops.difference(answers[0], expected_img)
assert not diff.getbbox()

answer_response = stability_api.generate(prompt="Test prompt", width=2, height=2)
answers = []
for resp in answer_response:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(BytesIO(artifact.binary))
assert img.size == (2, 2), "Incorrect image size."
answers.append(img)
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}"
diff = ImageChops.difference(answers[0], expected_img)
assert not diff.getbbox()


def test_stability_inference_faiss():
faiss_file = "faiss.index"
if os.path.isfile(faiss_file):
os.remove(faiss_file)

data_manager = manager_factory('sqlite,faiss,local',
data_dir='./',
vector_params={'dimension': 2},
object_params={'path': './images'}
)
cache.init(
pre_embedding_func=get_prompt,
embedding_func=lambda x, **_: np.random.random((2,)).astype("float32"),
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation()
)

expected_img = Image.new("RGB", (1, 1))

buffered = BytesIO()
expected_img.save(buffered, format="PNG")
test_img_b64 = base64.b64encode(buffered.getvalue())
expected_response = construct_resp_from_cache(test_img_b64, 1, 1)

with patch("stability_sdk.client.StabilityInference.generate") as mock_call:
mock_call.return_value = expected_response

stability_api = cache_stability.StabilityInference(key="ThisIsTest")
answer_response = stability_api.generate(prompt="Test prompt", width=1, height=1)
answers = []
for resp in answer_response:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(BytesIO(artifact.binary))
assert img.size == (1, 1), "Incorrect image size."
answers.append(img)
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}"
diff = ImageChops.difference(answers[0], expected_img)
assert not diff.getbbox()

answer_response = stability_api.generate(prompt="Test prompt", width=2, height=2)
answers = []
for resp in answer_response:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(BytesIO(artifact.binary))
assert img.size == (2, 2), "Incorrect image size."
answers.append(img)
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}"
diff = ImageChops.difference(answers[0], expected_img)
assert not diff.getbbox()



if __name__ == "__main__":
test_stability_inference_map()
test_stability_inference_faiss()



0 comments on commit 49d18cc

Please sign in to comment.