Skip to content

Commit

Permalink
Merge branch 'tom-dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasohara authored Jan 26, 2024
2 parents 9285d9f + 913fc4c commit 17a173b
Show file tree
Hide file tree
Showing 17 changed files with 372 additions and 250 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/github.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ jobs:
matrix:
os: [ubuntu-latest]
## TODO: os: [ubuntu-20.04, ubuntu-latest]
## OLD:
python-version: ["3.8", "3.9", "3.10"]
## OLD: python-version: ["3.8", "3.9", "3.10"]
## NOTE: Need 3.8.17+ for typing support, due to backporting limitations with typing_extensions (see html_utils.py)
## TODO?:
python-version: ["3.8.17", "3.9", "3.10"]
# Note: The Dockerfile uses 3.11 so try different versions for the runner VM
## TODO: python-version: ["3.9", "3.10", "3.11"]

Expand Down
2 changes: 1 addition & 1 deletion mezcla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Tom O'Hara
Feb 2022
"""
version = "1.3.9.7"
version = "1.3.9.8"
__VERSION__ = version
__version__ = __VERSION__

Expand Down
3 changes: 2 additions & 1 deletion mezcla/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ def trace_expr(level, *values, **kwargs):
expressions = []

# Output initial text
trace(level, prefix, no_eol=no_eol)
if prefix:
trace(level, prefix, no_eol=no_eol)

# Output each expression value
for expression, value in zip_longest(expressions, values):
Expand Down
101 changes: 73 additions & 28 deletions mezcla/examples/hf_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
# - For a stylish Stable Diffusion interface, see
# https://github.com/comfyanonymous/ComfyUI
#
# TODO2:
# - See why HF keeps downloading model (e.g., make sure cache permanent).
# - Add test cases for flask-based API.
# TODO:
# - Set GRADIO_SERVER_NAME to 0.0.0.0?
#
Expand All @@ -36,11 +39,9 @@
# Installed modules
import diskcache
from flask import Flask, request
## OLD: import PIL
## TODO: see why following needed (i.e., plain PIL import yields intermittent errors)
import PIL.Image
import requests
## OLD: import gradio as gr
gr = None

# Local modules
Expand All @@ -58,7 +59,6 @@
"Negative tips for image")
GUIDANCE_HELP = "Degree of fidelity to prompt (1-to-30 w/ 7 suggested)--higher for more"
GUIDANCE_SCALE = system.getenv_int("GUIDANCE_SCALE", 7,
## OLD: "How much the image generation follows the prompt")
description=GUIDANCE_HELP)
SD_URL = system.getenv_value("SD_URL", None,
"URL for SD TCP/restful server--new via flask or remote")
Expand Down Expand Up @@ -91,23 +91,38 @@
HF_SD_MODEL = system.getenv_text(
"HF_SD_MODEL", "CompVis/stable-diffusion-v1-4",
description="Hugging Face model for Stable Diffusion")


STREAMLINED_CLIP = system.getenv_bool(
# note: Doesn't default to LOW_MEMORY, which just uses 16-bit floating point:
# several settings are changed (see apply_low_vram_defaults in clip_interrogator).
"STREAMLINED_CLIP", False,
description="Use streamlined CLIP settings to reduce memory usage")
REGULAR_CLIP = (not STREAMLINED_CLIP)
CAPTION_MODEL = system.getenv_text(
"CAPTION_MODEL", ("blip-large" if REGULAR_CLIP else "blip-base"),
description="Caption model to use in CLIP interrogation")
CLIP_MODEL = system.getenv_text(
"CLIP_MODEL", ("ViT-L-14/openai" if REGULAR_CLIP else "ViT-B-16/openai"),
# TODO4: see https://arxiv.org/pdf/2010.11929.pdf for ViT-S-NN explanation
description="Model to use for CLIP interrogation")
#
BATCH_ARG = "batch"
SERVER_ARG = "server"
UI_ARG = "UI"
PORT_ARG = "port"
PROMPT_ARG = "prompt"
NEGATIVE_ARG = "negative"
GUIDANCE_ARG = "guidance"
## TODO?: TXT2IMG_ARG = "txt2img"
TXT2IMG_ARG = "txt2img"
IMG2IMG_ARG = "img2img"
IMG2TXT_ARG = "img2txt"
IMAGE_ARG = "input-image"
## OLD: IMAGE_ARG = "input-image"
DENOISING_ARG = "denoising-factor"
DUMMY_BASE64_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAAA8AAAAPAgMAAABGuH3ZAAAADFBMVEUAAMzMzP////8AAABGA1scAAAAJUlEQVR4nGNgAAFGQUEowRoa6sCABBZowAgsgBEIGUQCRALAPACMHAOQvR4HGwAAAABJRU5ErkJggg=="
DUMMY_IMAGE_FILE = gh.resolve_path("dummy-image.png")
HTTP_OK = 200
# note: RestFUL API keys (via flask)
IMAGES_KEY = "images"
CAPTION_KEY = "caption"

#--------------------------------------------------------------------------------
# Globals
Expand Down Expand Up @@ -198,12 +213,14 @@ def __init__(self, use_hf_api=None, server_url=None, server_port=None, low_memor
disk_compress_level=0, # no compression
cull_limit=0) # no automatic pruning
debug.assertion(bool(self.use_hf_api) != bool(self.server_url))
show_gpu_usage()
debug.trace_object(5, self, label=f"{class_name} instance")

def init_pipeline(self, txt2img=None, img2img=None):
"""Initialize Stable Diffusion"""
debug.trace(4, "init_pipeline()")
debug.assertion(not (txt2img and img2img))
## TODO2: fix lack of support for img2img
debug.assertion(not img2img)
# pylint: disable=import-outside-toplevel
from diffusers import StableDiffusionPipeline
Expand Down Expand Up @@ -249,8 +266,12 @@ def init_clip_interrogation(self):
# pylint: disable=import-outside-toplevel
debug.trace(4, "init_clip_interrogation()")
from clip_interrogator import Config, Interrogator
self.img2txt_engine = Interrogator(Config(clip_model_name="ViT-L-14/openai"))

clip_config = Config(caption_model_name=CAPTION_MODEL,
clip_model_name=CLIP_MODEL)
if STREAMLINED_CLIP:
clip_config.apply_low_vram_defaults()
self.img2txt_engine = Interrogator(clip_config)

def infer(self, prompt=None, negative_prompt=None, scale=None, num_images=None,
skip_img_spec=False, width=None, height=None, skip_cache=False):
"""Generate images using positive PROMPT and NEGATIVE one, along with guidance SCALE,
Expand Down Expand Up @@ -338,7 +359,7 @@ def infer_non_cached(self, prompt=None, negative_prompt=None, scale=None, num_im
images_request = requests.post(url, json=payload, timeout=(5 * 60))
debug.trace_object(6, images_request)
debug.trace_expr(5, payload, images_request, images_request.json(), delim="\n")
for image in images_request.json()["images"]:
for image in images_request.json()[IMAGES_KEY]:
image_b64 = image
if not skip_img_spec:
image_b64 = (f"data:image/png;base64,{image_b64}")
Expand Down Expand Up @@ -439,7 +460,7 @@ def infer_img2img_non_cached(self, image_b64=None, denoise=None, prompt=None, ne
images_request = requests.post(url, json=payload, timeout=(5 * 60))
debug.trace_object(6, images_request)
debug.trace_expr(5, payload, images_request, images_request.json(), delim="\n")
for image in images_request.json()["images"]:
for image in images_request.json()[IMAGES_KEY]:
image_b64 = image
if not skip_img_spec:
image_b64 = (f"data:image/png;base64,{image_b64}")
Expand Down Expand Up @@ -487,7 +508,7 @@ def infer_img2txt_non_cached(self, image_b64=None):
request_result = requests.post(url, json=payload, timeout=(5 * 60))
debug.trace_object(6, request_result)
debug.trace_expr(5, payload, request_result, request_result.json(), delim="\n")
image_caption = request_result.json()["caption"]
image_caption = request_result.json()[CAPTION_KEY]

debug.trace_fmt(5, "infer_img2txt_non_cached() => {r!r}", r=image_caption)
return image_caption
Expand All @@ -497,7 +518,7 @@ def infer_img2txt_non_cached(self, image_b64=None):

def init_stable_diffusion(use_hf_api=None):
"""Initialize stable diffusion usage, locally if USE_HF_API"""
debug.trace(4, "init_stable_diffusion({use_hf_api})")
debug.trace(4, f"init_stable_diffusion({use_hf_api})")
init()
global sd_instance
sd_instance = StableDiffusion(use_hf_api=use_hf_api)
Expand All @@ -515,7 +536,7 @@ def handle_infer():
debug.trace_expr(5, params)
if not sd_instance:
init_stable_diffusion()
images_spec = {"images": sd_instance.infer(**params)}
images_spec = {IMAGES_KEY: sd_instance.infer(**params)}
# note: see https://stackoverflow.com/questions/45412228/sending-json-and-status-code-with-a-flask-response
result = (json.dumps(images_spec), HTTP_OK)
debug.trace_object(7, result)
Expand All @@ -535,6 +556,8 @@ def infer(prompt=None, negative_prompt=None, scale=None, num_images=None, skip_i
Note: intended just for the gradio UI"
"""
debug.trace(5, f"[sd_instance] infer{(prompt, negative_prompt, scale, skip_img_spec)}")
if not sd_instance:
init_stable_diffusion()
return sd_instance.infer(prompt=prompt, negative_prompt=negative_prompt, scale=scale, num_images=num_images, skip_img_spec=skip_img_spec)


Expand All @@ -547,7 +570,9 @@ def handle_infer_img2img():
debug.trace_object(5, request)
params = request.get_json()
debug.trace_expr(5, params)
images_spec = {"images": sd_instance.infer_img2img(**params)}
if not sd_instance:
init_stable_diffusion()
images_spec = {IMAGES_KEY: sd_instance.infer_img2img(**params)}
# note: see https://stackoverflow.com/questions/45412228/sending-json-and-status-code-with-a-flask-response
result = (json.dumps(images_spec), HTTP_OK)
debug.trace_object(7, result)
Expand All @@ -564,7 +589,6 @@ def infer_img2img(image_spec=None, denoise=None, prompt=None, negative_prompt=N
debug.trace(5, "Warning: using first image in image_spec for infer_img2img")
image_spec = image_spec[0]
if ((image_spec is not None) and (not isinstance(image_spec, str))):
## TODO?: image = PIL.Image.fromarray(image_spec, mode="RGB")
debug.trace_expr(7, image_spec)
image = (image_spec)
image_spec = encode_PIL_image(image)
Expand All @@ -576,10 +600,20 @@ def infer_img2img(image_spec=None, denoise=None, prompt=None, negative_prompt=N

@flask_app.route('/img2txt', methods=['GET', 'POST'])
def handle_infer_img2txt():
"""Process request to do inference to generate text description of image"""
# Note: result return via hash with caption key
## TODO: {"caption": sd_instance.infer_img2txt(**params)}
raise NotImplementedError()
"""Process request to do inference to generate text description of image
Note: result returned via hash with caption key"""
## TODO3: add helper for common flask bookkeeping
debug.trace(6, "[flask_app /] handle_infer_img2txt()")
debug.trace_object(5, request)
params = request.get_json()
debug.trace_expr(5, params)
if not sd_instance:
init_stable_diffusion()
caption_spec = {CAPTION_KEY: sd_instance.infer_img2txt(**params)}
result = (json.dumps(caption_spec), HTTP_OK)
debug.trace_object(7, result)
debug.trace_fmt(7, "handle_infer_img2txt() => {r}", r=result)
return result


def infer_img2txt(image_spec):
Expand Down Expand Up @@ -1035,7 +1069,8 @@ def main():

# Parse command line argument, show usage if --help given
# TODO? auto_help=False
main_app = Main(description=__doc__, skip_input=True,
main_app = Main(description=__doc__,
## OLD: skip_input=True,
boolean_options=[(BATCH_ARG, "Use batch mode--no UI"),
(SERVER_ARG, "Run flask server"),
(UI_ARG, "Show user interface"),
Expand All @@ -1044,29 +1079,34 @@ def main():
(IMG2TXT_ARG, "Run image-to-text: clip interrogator")],
text_options=[(PROMPT_ARG, "Positive prompt"),
(NEGATIVE_ARG, "Negative prompt"),
(IMAGE_ARG, "Filename for img2img input image")],
## OLD: (IMAGE_ARG, "Filename for img2img input image")
],
int_options=[(GUIDANCE_ARG, GUIDANCE_HELP)],
float_options=[(DENOISING_ARG, "Denoising factor for img2img")])
debug.trace_object(5, main_app)
debug.assertion(main_app.parsed_args)
#
batch_mode = main_app.get_parsed_option(BATCH_ARG)
input_image_file = main_app.filename
BATCH_MODE_DEFAULT = (input_image_file != "-")
batch_mode = main_app.get_parsed_option(BATCH_ARG, BATCH_MODE_DEFAULT)
server_mode = main_app.get_parsed_option(SERVER_ARG)
ui_mode = main_app.get_parsed_option(UI_ARG, not (batch_mode or server_mode))
## OLD: ui_mode = main_app.get_parsed_option(UI_ARG, not (batch_mode or server_mode))
ui_mode = main_app.get_parsed_option(UI_ARG)
prompt = main_app.get_parsed_option(PROMPT_ARG, PROMPT)
negative_prompt = main_app.get_parsed_option(NEGATIVE_ARG, NEGATIVE_PROMPT)
guidance = main_app.get_parsed_option(GUIDANCE_ARG, GUIDANCE_SCALE)
## TODO?: use_txt2img = main_app.get_parsed_option(TXT2IMG_ARG, USE_TXT2IMG)
use_img2img = main_app.get_parsed_option(IMG2IMG_ARG, USE_IMG2IMG)
use_img2txt = main_app.get_parsed_option(IMG2TXT_ARG, USE_IMG2TXT)
input_image_file = main_app.get_parsed_option(IMAGE_ARG)
## OLD: input_image_file = main_app.get_parsed_option(IMAGE_ARG)
denoising_factor = main_app.get_parsed_option(DENOISING_ARG)
## TODO?: debug.assertion(use_txt2img ^ use_img2img)
# TODO2: BASENAME and NUM_IMAGES (options)
## TODO: x_mode = main_app.get_parsed_option(X_ARG)
debug.assertion(not (batch_mode and server_mode))

# Invoke UI via HTTP unless in batch mode
# Invoke UI via HTTP unless in batch or server mode
show_gpu_usage()
init_stable_diffusion()
if batch_mode:
# Optionally convert input image into base64
Expand All @@ -1092,9 +1132,14 @@ def main():
debug.trace_object(5, flask_app)
flask_app.run(host=SD_URL, port=SD_PORT, debug=SD_DEBUG)
# Start UI
else:
debug.assertion(ui_mode)
elif ui_mode:
debug.assertion(main_app.filename == "-")
run_ui(use_img2img=use_img2img)
# Otheriwse, show command-line options
else:
## TODO3: expose print_usage directly through main_app
main_app.parser.print_usage()
show_gpu_usage()


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 17a173b

Please sign in to comment.