Skip to content

Commit

Permalink
fleshed out flask api for img2txt; fix to make sure init_stable_diffu…
Browse files Browse the repository at this point in the history
…sion invoked; adds show_gpu_usage during init; added todo notes; misc. cleanup
  • Loading branch information
tomasohara committed Jan 22, 2024
1 parent eba6699 commit 25e6bfe
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions mezcla/examples/hf_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
# - For a stylish Stable Diffusion interface, see
# https://github.com/comfyanonymous/ComfyUI
#
# TODO2:
# - See why HF keeps downloading model (e.g., make sure cache permanent).
# - Add test cases for flask-based API.
# TODO:
# - Set GRADIO_SERVER_NAME to 0.0.0.0?
#
Expand All @@ -36,11 +39,9 @@
# Installed modules
import diskcache
from flask import Flask, request
## OLD: import PIL
## TODO: see why following needed (i.e., plain PIL import yields intermittent errors)
import PIL.Image
import requests
## OLD: import gradio as gr
gr = None

# Local modules
Expand All @@ -58,7 +59,6 @@
"Negative tips for image")
GUIDANCE_HELP = "Degree of fidelity to prompt (1-to-30 w/ 7 suggested)--higher for more"
GUIDANCE_SCALE = system.getenv_int("GUIDANCE_SCALE", 7,
## OLD: "How much the image generation follows the prompt")
description=GUIDANCE_HELP)
SD_URL = system.getenv_value("SD_URL", None,
"URL for SD TCP/restful server--new via flask or remote")
Expand Down Expand Up @@ -100,14 +100,17 @@
PROMPT_ARG = "prompt"
NEGATIVE_ARG = "negative"
GUIDANCE_ARG = "guidance"
## TODO?: TXT2IMG_ARG = "txt2img"
TXT2IMG_ARG = "txt2img"
IMG2IMG_ARG = "img2img"
IMG2TXT_ARG = "img2txt"
IMAGE_ARG = "input-image"
DENOISING_ARG = "denoising-factor"
DUMMY_BASE64_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAAA8AAAAPAgMAAABGuH3ZAAAADFBMVEUAAMzMzP////8AAABGA1scAAAAJUlEQVR4nGNgAAFGQUEowRoa6sCABBZowAgsgBEIGUQCRALAPACMHAOQvR4HGwAAAABJRU5ErkJggg=="
DUMMY_IMAGE_FILE = gh.resolve_path("dummy-image.png")
HTTP_OK = 200
# note: RestFUL API keys (via flask)
IMAGES_KEY = "images"
CAPTION_KEY = "caption"

#--------------------------------------------------------------------------------
# Globals
Expand Down Expand Up @@ -198,6 +201,7 @@ def __init__(self, use_hf_api=None, server_url=None, server_port=None, low_memor
disk_compress_level=0, # no compression
cull_limit=0) # no automatic pruning
debug.assertion(bool(self.use_hf_api) != bool(self.server_url))
show_gpu_usage()
debug.trace_object(5, self, label=f"{class_name} instance")

def init_pipeline(self, txt2img=None, img2img=None):
Expand Down Expand Up @@ -338,7 +342,7 @@ def infer_non_cached(self, prompt=None, negative_prompt=None, scale=None, num_im
images_request = requests.post(url, json=payload, timeout=(5 * 60))
debug.trace_object(6, images_request)
debug.trace_expr(5, payload, images_request, images_request.json(), delim="\n")
for image in images_request.json()["images"]:
for image in images_request.json()[IMAGES_KEY]:
image_b64 = image
if not skip_img_spec:
image_b64 = (f"data:image/png;base64,{image_b64}")
Expand Down Expand Up @@ -439,7 +443,7 @@ def infer_img2img_non_cached(self, image_b64=None, denoise=None, prompt=None, ne
images_request = requests.post(url, json=payload, timeout=(5 * 60))
debug.trace_object(6, images_request)
debug.trace_expr(5, payload, images_request, images_request.json(), delim="\n")
for image in images_request.json()["images"]:
for image in images_request.json()[IMAGES_KEY]:
image_b64 = image
if not skip_img_spec:
image_b64 = (f"data:image/png;base64,{image_b64}")
Expand Down Expand Up @@ -487,7 +491,7 @@ def infer_img2txt_non_cached(self, image_b64=None):
request_result = requests.post(url, json=payload, timeout=(5 * 60))
debug.trace_object(6, request_result)
debug.trace_expr(5, payload, request_result, request_result.json(), delim="\n")
image_caption = request_result.json()["caption"]
image_caption = request_result.json()[CAPTION_KEY]

debug.trace_fmt(5, "infer_img2txt_non_cached() => {r!r}", r=image_caption)
return image_caption
Expand Down Expand Up @@ -515,7 +519,7 @@ def handle_infer():
debug.trace_expr(5, params)
if not sd_instance:
init_stable_diffusion()
images_spec = {"images": sd_instance.infer(**params)}
images_spec = {IMAGES_KEY: sd_instance.infer(**params)}
# note: see https://stackoverflow.com/questions/45412228/sending-json-and-status-code-with-a-flask-response
result = (json.dumps(images_spec), HTTP_OK)
debug.trace_object(7, result)
Expand All @@ -535,6 +539,8 @@ def infer(prompt=None, negative_prompt=None, scale=None, num_images=None, skip_i
Note: intended just for the gradio UI"
"""
debug.trace(5, f"[sd_instance] infer{(prompt, negative_prompt, scale, skip_img_spec)}")
if not sd_instance:
init_stable_diffusion()
return sd_instance.infer(prompt=prompt, negative_prompt=negative_prompt, scale=scale, num_images=num_images, skip_img_spec=skip_img_spec)


Expand All @@ -547,7 +553,9 @@ def handle_infer_img2img():
debug.trace_object(5, request)
params = request.get_json()
debug.trace_expr(5, params)
images_spec = {"images": sd_instance.infer_img2img(**params)}
if not sd_instance:
init_stable_diffusion()
images_spec = {IMAGES_KEY: sd_instance.infer_img2img(**params)}
# note: see https://stackoverflow.com/questions/45412228/sending-json-and-status-code-with-a-flask-response
result = (json.dumps(images_spec), HTTP_OK)
debug.trace_object(7, result)
Expand All @@ -564,7 +572,6 @@ def infer_img2img(image_spec=None, denoise=None, prompt=None, negative_prompt=N
debug.trace(5, "Warning: using first image in image_spec for infer_img2img")
image_spec = image_spec[0]
if ((image_spec is not None) and (not isinstance(image_spec, str))):
## TODO?: image = PIL.Image.fromarray(image_spec, mode="RGB")
debug.trace_expr(7, image_spec)
image = (image_spec)
image_spec = encode_PIL_image(image)
Expand All @@ -576,10 +583,20 @@ def infer_img2img(image_spec=None, denoise=None, prompt=None, negative_prompt=N

@flask_app.route('/img2txt', methods=['GET', 'POST'])
def handle_infer_img2txt():
"""Process request to do inference to generate text description of image"""
# Note: result return via hash with caption key
## TODO: {"caption": sd_instance.infer_img2txt(**params)}
raise NotImplementedError()
"""Process request to do inference to generate text description of image
Note: result returned via hash with caption key"""
## TODO3: add helper for common flask bookkeeping
debug.trace(6, "[flask_app /] handle_infer_img2txt()")
debug.trace_object(5, request)
params = request.get_json()
debug.trace_expr(5, params)
if not sd_instance:
init_stable_diffusion()
caption_spec = {CAPTION_KEY: sd_instance.infer_img2txt(**params)}
result = (json.dumps(caption_spec), HTTP_OK)
debug.trace_object(7, result)
debug.trace_fmt(7, "handle_infer_img2txt() => {r}", r=result)
return result


def infer_img2txt(image_spec):
Expand Down

0 comments on commit 25e6bfe

Please sign in to comment.