diff --git a/mezcla/examples/hf_stable_diffusion.py b/mezcla/examples/hf_stable_diffusion.py index dcc3013a..66a04a10 100755 --- a/mezcla/examples/hf_stable_diffusion.py +++ b/mezcla/examples/hf_stable_diffusion.py @@ -21,6 +21,9 @@ # - For a stylish Stable Diffusion interface, see # https://github.com/comfyanonymous/ComfyUI # +# TODO2: +# - See why HF keeps downloading model (e.g., make sure cache permanent). +# - Add test cases for flask-based API. # TODO: # - Set GRADIO_SERVER_NAME to 0.0.0.0? # @@ -36,11 +39,9 @@ # Installed modules import diskcache from flask import Flask, request -## OLD: import PIL ## TODO: see why following needed (i.e., plain PIL import yields intermittent errors) import PIL.Image import requests -## OLD: import gradio as gr gr = None # Local modules @@ -58,7 +59,6 @@ "Negative tips for image") GUIDANCE_HELP = "Degree of fidelity to prompt (1-to-30 w/ 7 suggested)--higher for more" GUIDANCE_SCALE = system.getenv_int("GUIDANCE_SCALE", 7, - ## OLD: "How much the image generation follows the prompt") description=GUIDANCE_HELP) SD_URL = system.getenv_value("SD_URL", None, "URL for SD TCP/restful server--new via flask or remote") @@ -100,7 +100,7 @@ PROMPT_ARG = "prompt" NEGATIVE_ARG = "negative" GUIDANCE_ARG = "guidance" -## TODO?: TXT2IMG_ARG = "txt2img" +TXT2IMG_ARG = "txt2img" IMG2IMG_ARG = "img2img" IMG2TXT_ARG = "img2txt" IMAGE_ARG = "input-image" @@ -108,6 +108,9 @@ DUMMY_BASE64_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAAA8AAAAPAgMAAABGuH3ZAAAADFBMVEUAAMzMzP////8AAABGA1scAAAAJUlEQVR4nGNgAAFGQUEowRoa6sCABBZowAgsgBEIGUQCRALAPACMHAOQvR4HGwAAAABJRU5ErkJggg==" DUMMY_IMAGE_FILE = gh.resolve_path("dummy-image.png") HTTP_OK = 200 +# note: RestFUL API keys (via flask) +IMAGES_KEY = "images" +CAPTION_KEY = "caption" #-------------------------------------------------------------------------------- # Globals @@ -198,6 +201,7 @@ def __init__(self, use_hf_api=None, server_url=None, server_port=None, low_memor disk_compress_level=0, # no compression cull_limit=0) # no automatic pruning debug.assertion(bool(self.use_hf_api) != bool(self.server_url)) + show_gpu_usage() debug.trace_object(5, self, label=f"{class_name} instance") def init_pipeline(self, txt2img=None, img2img=None): @@ -338,7 +342,7 @@ def infer_non_cached(self, prompt=None, negative_prompt=None, scale=None, num_im images_request = requests.post(url, json=payload, timeout=(5 * 60)) debug.trace_object(6, images_request) debug.trace_expr(5, payload, images_request, images_request.json(), delim="\n") - for image in images_request.json()["images"]: + for image in images_request.json()[IMAGES_KEY]: image_b64 = image if not skip_img_spec: image_b64 = (f"data:image/png;base64,{image_b64}") @@ -439,7 +443,7 @@ def infer_img2img_non_cached(self, image_b64=None, denoise=None, prompt=None, ne images_request = requests.post(url, json=payload, timeout=(5 * 60)) debug.trace_object(6, images_request) debug.trace_expr(5, payload, images_request, images_request.json(), delim="\n") - for image in images_request.json()["images"]: + for image in images_request.json()[IMAGES_KEY]: image_b64 = image if not skip_img_spec: image_b64 = (f"data:image/png;base64,{image_b64}") @@ -487,7 +491,7 @@ def infer_img2txt_non_cached(self, image_b64=None): request_result = requests.post(url, json=payload, timeout=(5 * 60)) debug.trace_object(6, request_result) debug.trace_expr(5, payload, request_result, request_result.json(), delim="\n") - image_caption = request_result.json()["caption"] + image_caption = request_result.json()[CAPTION_KEY] debug.trace_fmt(5, "infer_img2txt_non_cached() => {r!r}", r=image_caption) return image_caption @@ -515,7 +519,7 @@ def handle_infer(): debug.trace_expr(5, params) if not sd_instance: init_stable_diffusion() - images_spec = {"images": sd_instance.infer(**params)} + images_spec = {IMAGES_KEY: sd_instance.infer(**params)} # note: see https://stackoverflow.com/questions/45412228/sending-json-and-status-code-with-a-flask-response result = (json.dumps(images_spec), HTTP_OK) debug.trace_object(7, result) @@ -535,6 +539,8 @@ def infer(prompt=None, negative_prompt=None, scale=None, num_images=None, skip_i Note: intended just for the gradio UI" """ debug.trace(5, f"[sd_instance] infer{(prompt, negative_prompt, scale, skip_img_spec)}") + if not sd_instance: + init_stable_diffusion() return sd_instance.infer(prompt=prompt, negative_prompt=negative_prompt, scale=scale, num_images=num_images, skip_img_spec=skip_img_spec) @@ -547,7 +553,9 @@ def handle_infer_img2img(): debug.trace_object(5, request) params = request.get_json() debug.trace_expr(5, params) - images_spec = {"images": sd_instance.infer_img2img(**params)} + if not sd_instance: + init_stable_diffusion() + images_spec = {IMAGES_KEY: sd_instance.infer_img2img(**params)} # note: see https://stackoverflow.com/questions/45412228/sending-json-and-status-code-with-a-flask-response result = (json.dumps(images_spec), HTTP_OK) debug.trace_object(7, result) @@ -564,7 +572,6 @@ def infer_img2img(image_spec=None, denoise=None, prompt=None, negative_prompt=N debug.trace(5, "Warning: using first image in image_spec for infer_img2img") image_spec = image_spec[0] if ((image_spec is not None) and (not isinstance(image_spec, str))): - ## TODO?: image = PIL.Image.fromarray(image_spec, mode="RGB") debug.trace_expr(7, image_spec) image = (image_spec) image_spec = encode_PIL_image(image) @@ -576,10 +583,20 @@ def infer_img2img(image_spec=None, denoise=None, prompt=None, negative_prompt=N @flask_app.route('/img2txt', methods=['GET', 'POST']) def handle_infer_img2txt(): - """Process request to do inference to generate text description of image""" - # Note: result return via hash with caption key - ## TODO: {"caption": sd_instance.infer_img2txt(**params)} - raise NotImplementedError() + """Process request to do inference to generate text description of image + Note: result returned via hash with caption key""" + ## TODO3: add helper for common flask bookkeeping + debug.trace(6, "[flask_app /] handle_infer_img2txt()") + debug.trace_object(5, request) + params = request.get_json() + debug.trace_expr(5, params) + if not sd_instance: + init_stable_diffusion() + caption_spec = {CAPTION_KEY: sd_instance.infer_img2txt(**params)} + result = (json.dumps(caption_spec), HTTP_OK) + debug.trace_object(7, result) + debug.trace_fmt(7, "handle_infer_img2txt() => {r}", r=result) + return result def infer_img2txt(image_spec):