forked from zilliztech/GPTCache
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adapt stability_sdk (zilliztech#277)
Signed-off-by: Jael Gu <[email protected]>
- Loading branch information
Showing
5 changed files
with
311 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
import io | ||
import time | ||
from PIL import Image | ||
|
||
from gptcache import cache | ||
from gptcache.processor.pre import get_prompt | ||
from gptcache.adapter.stability_sdk import StabilityInference, generation | ||
from gptcache.embedding import Onnx | ||
from gptcache.manager.factory import manager_factory | ||
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation | ||
|
||
# init gptcache | ||
onnx = Onnx() | ||
data_manager = manager_factory('sqlite,faiss,local', | ||
data_dir='./', | ||
vector_params={'dimension': onnx.dimension}, | ||
object_params={'path': './images'} | ||
) | ||
cache.init( | ||
pre_embedding_func=get_prompt, | ||
embedding_func=onnx.to_embeddings, | ||
data_manager=data_manager, | ||
similarity_evaluation=SearchDistanceEvaluation() | ||
) | ||
|
||
# run with gptcache | ||
api_key = os.getenv('STABILITY_KEY', 'key-goes-here') | ||
|
||
stability_api = StabilityInference( | ||
key=os.environ['STABILITY_KEY'], # API Key reference. | ||
verbose=False, # Print debug messages. | ||
engine='stable-diffusion-xl-beta-v2-2-2', # Set the engine to use for generation. | ||
) | ||
|
||
start = time.time() | ||
answers = stability_api.generate( | ||
prompt='a cat sitting besides a dog', | ||
width=256, | ||
height=256 | ||
) | ||
|
||
for resp in answers: | ||
for artifact in resp.artifacts: | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
img = Image.open(io.BytesIO(artifact.binary)) | ||
assert img.size == (256, 256) | ||
print('Time elapsed 1:', time.time() - start) | ||
|
||
start = time.time() | ||
answers = stability_api.generate( | ||
prompt='a dog and a dog sitting together', | ||
width=512, | ||
height=512 | ||
) | ||
|
||
for resp in answers: | ||
for artifact in resp.artifacts: | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
img = Image.open(io.BytesIO(artifact.binary)) | ||
assert img.size == (512, 512) | ||
print('Time elapsed 2:', time.time() - start) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from io import BytesIO | ||
import base64 | ||
import warnings | ||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
from gptcache.adapter.adapter import adapt | ||
from gptcache.manager.scalar_data.base import Answer, DataType | ||
from gptcache.utils.error import CacheError | ||
from gptcache.utils import ( | ||
import_stability, import_pillow | ||
) | ||
|
||
import_pillow() | ||
import_stability() | ||
|
||
from PIL import Image as PILImage # pylint: disable=C0413 | ||
from stability_sdk import client # pylint: disable=C0413 | ||
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation # pylint: disable=C0413 | ||
|
||
|
||
|
||
class StabilityInference(client.StabilityInference): | ||
"""client.StabilityInference Wrapper | ||
Example: | ||
.. code-block:: python | ||
import os | ||
import io | ||
from PIL import Image | ||
from gptcache import cache | ||
from gptcache.processor.pre import get_prompt | ||
from gptcache.adapter.stability_sdk import StabilityInference, generation | ||
# init gptcache | ||
cache.init(pre_embedding_func=get_prompt) | ||
# run with gptcache | ||
os.environ['STABILITY_KEY'] = 'key-goes-here' | ||
stability_api = StabilityInference( | ||
key=os.environ['STABILITY_KEY'], # API Key reference. | ||
verbose=False, # Print debug messages. | ||
engine="stable-diffusion-xl-beta-v2-2-2", # Set the engine to use for generation. | ||
) | ||
answers = stability_api.generate( | ||
prompt="a cat sitting besides a dog", | ||
width=256, | ||
height=256 | ||
) | ||
for resp in answers: | ||
for artifact in resp.artifacts: | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
img = Image.open(io.BytesIO(artifact.binary)) | ||
img.save('path/to/save/image.png') | ||
""" | ||
|
||
def llm_handler(self, *llm_args, **llm_kwargs): | ||
try: | ||
return super().generate(*llm_args, **llm_kwargs) | ||
except Exception as e: | ||
raise CacheError("stability error") from e | ||
|
||
def generate(self, *args, **kwargs): | ||
width = kwargs.get("width", 512) | ||
height = kwargs.get("height", 512) | ||
|
||
def cache_data_convert(cache_data): | ||
return construct_resp_from_cache(cache_data, width=width, height=height) | ||
|
||
def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument | ||
def hook_stream_data(it): | ||
to_save = [] | ||
for resp in it: | ||
for artifact in resp.artifacts: | ||
try: | ||
if artifact.finish_reason == generation.FILTER: | ||
warnings.warn( | ||
"Your request activated the API's safety filters and could not be processed." | ||
"Please modify the prompt and try again.") | ||
continue | ||
except AttributeError: | ||
pass | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
img_b64 = base64.b64encode(artifact.binary) | ||
to_save.append(img_b64) | ||
yield resp | ||
update_cache_func(Answer(to_save[0], DataType.IMAGE_BASE64)) | ||
|
||
return hook_stream_data(llm_data) | ||
|
||
return adapt( | ||
self.llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs | ||
) | ||
|
||
|
||
def construct_resp_from_cache(img_64, height, width): | ||
img_bytes = base64.b64decode((img_64)) | ||
img_file = BytesIO(img_bytes) | ||
img = PILImage.open(img_file) | ||
new_size = (width, height) | ||
if new_size != img.size: | ||
img = img.resize(new_size) | ||
buffered = BytesIO() | ||
img.save(buffered, format="PNG") | ||
img_bytes = buffered.getvalue() | ||
yield MockAnswer(artifacts=[MockArtifact(type=generation.ARTIFACT_IMAGE, binary=img_bytes)]) | ||
|
||
|
||
@dataclass | ||
class MockArtifact: | ||
type: int | ||
binary: bytes | ||
|
||
|
||
@dataclass | ||
class MockAnswer: | ||
artifacts: List[MockArtifact] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from unittest.mock import patch | ||
import base64 | ||
from io import BytesIO | ||
import os | ||
import numpy as np | ||
|
||
|
||
from gptcache.adapter import stability_sdk as cache_stability | ||
from gptcache.adapter.stability_sdk import generation, construct_resp_from_cache | ||
from gptcache import cache | ||
from gptcache.processor.pre import get_prompt | ||
from gptcache.manager.factory import manager_factory | ||
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation | ||
|
||
from gptcache.utils import ( | ||
import_stability, import_pillow | ||
) | ||
|
||
import_pillow() | ||
import_stability() | ||
|
||
import stability_sdk | ||
from PIL import ImageChops, Image | ||
|
||
|
||
def test_stability_inference_map(): | ||
cache.init(pre_embedding_func=get_prompt) | ||
expected_img = Image.new("RGB", (1, 1)) | ||
|
||
buffered = BytesIO() | ||
expected_img.save(buffered, format="PNG") | ||
test_img_b64 = base64.b64encode(buffered.getvalue()) | ||
expected_response = construct_resp_from_cache(test_img_b64, 1, 1) | ||
|
||
stability_api = cache_stability.StabilityInference(key="ThisIsTest") | ||
with patch.object(stability_sdk.client.StabilityInference, "generate") as mock_call: | ||
mock_call.return_value = expected_response | ||
|
||
answer_response = stability_api.generate(prompt="Test prompt", width=1, height=1) | ||
answers = [] | ||
for resp in answer_response: | ||
for artifact in resp.artifacts: | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
answers.append(Image.open(BytesIO(artifact.binary))) | ||
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}" | ||
diff = ImageChops.difference(answers[0], expected_img) | ||
assert not diff.getbbox() | ||
|
||
answer_response = stability_api.generate(prompt="Test prompt", width=2, height=2) | ||
answers = [] | ||
for resp in answer_response: | ||
for artifact in resp.artifacts: | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
img = Image.open(BytesIO(artifact.binary)) | ||
assert img.size == (2, 2), "Incorrect image size." | ||
answers.append(img) | ||
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}" | ||
diff = ImageChops.difference(answers[0], expected_img) | ||
assert not diff.getbbox() | ||
|
||
|
||
def test_stability_inference_faiss(): | ||
faiss_file = "faiss.index" | ||
if os.path.isfile(faiss_file): | ||
os.remove(faiss_file) | ||
|
||
data_manager = manager_factory('sqlite,faiss,local', | ||
data_dir='./', | ||
vector_params={'dimension': 2}, | ||
object_params={'path': './images'} | ||
) | ||
cache.init( | ||
pre_embedding_func=get_prompt, | ||
embedding_func=lambda x, **_: np.random.random((2,)).astype("float32"), | ||
data_manager=data_manager, | ||
similarity_evaluation=SearchDistanceEvaluation() | ||
) | ||
|
||
expected_img = Image.new("RGB", (1, 1)) | ||
|
||
buffered = BytesIO() | ||
expected_img.save(buffered, format="PNG") | ||
test_img_b64 = base64.b64encode(buffered.getvalue()) | ||
expected_response = construct_resp_from_cache(test_img_b64, 1, 1) | ||
|
||
with patch("stability_sdk.client.StabilityInference.generate") as mock_call: | ||
mock_call.return_value = expected_response | ||
|
||
stability_api = cache_stability.StabilityInference(key="ThisIsTest") | ||
answer_response = stability_api.generate(prompt="Test prompt", width=1, height=1) | ||
answers = [] | ||
for resp in answer_response: | ||
for artifact in resp.artifacts: | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
img = Image.open(BytesIO(artifact.binary)) | ||
assert img.size == (1, 1), "Incorrect image size." | ||
answers.append(img) | ||
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}" | ||
diff = ImageChops.difference(answers[0], expected_img) | ||
assert not diff.getbbox() | ||
|
||
answer_response = stability_api.generate(prompt="Test prompt", width=2, height=2) | ||
answers = [] | ||
for resp in answer_response: | ||
for artifact in resp.artifacts: | ||
if artifact.type == generation.ARTIFACT_IMAGE: | ||
img = Image.open(BytesIO(artifact.binary)) | ||
assert img.size == (2, 2), "Incorrect image size." | ||
answers.append(img) | ||
assert len(answers) == 1, f"Expect to get 1 image but got {len(answers)}" | ||
diff = ImageChops.difference(answers[0], expected_img) | ||
assert not diff.getbbox() | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
test_stability_inference_map() | ||
test_stability_inference_faiss() | ||
|
||
|
||
|