diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 7380fda7bb..d482504ef7 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -443,6 +443,7 @@ def __init__( low_cpu_mem_usage=low_cpu_mem_usage, ) elif not isinstance(custom_vae, dict): + precision = "fp16" if "fp16" in custom_vae else None print(f"Loading custom vae, with target {custom_vae}") if os.path.exists(custom_vae): self.vae = AutoencoderKL.from_pretrained( @@ -457,12 +458,19 @@ def __init__( ] ) print("Using hub to get custom vae") - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + try: + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + variant=precision, + ) + except: + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + ) else: - print(f"Loading custom vae, with target {custom_vae}") + print(f"Loading custom vae, with state {custom_vae}") self.vae = AutoencoderKL.from_pretrained( model_id, subfolder="vae", @@ -938,11 +946,19 @@ def __init__( low_cpu_mem_usage=False, ): super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) + try: + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + ) if ( args.attention_slicing is not None and args.attention_slicing != "none" @@ -1084,6 +1100,7 @@ def __init__( model_id, subfolder="text_encoder", low_cpu_mem_usage=low_cpu_mem_usage, + variant="fp16", ) else: self.text_encoder = ( @@ -1091,6 +1108,7 @@ def __init__( model_id, subfolder="text_encoder_2", low_cpu_mem_usage=low_cpu_mem_usage, + variant="fp16", ) ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py index 87777f3feb..a3b52793e9 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py @@ -209,8 +209,6 @@ def generate_images( # Img latents -> PIL images. all_imgs = [] self.load_vae() - # imgs = self.decode_latents_sdxl(None) - # all_imgs.extend(imgs) for i in range(0, latents.shape[0], batch_size): imgs = self.decode_latents_sdxl( latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 95a80ea546..08f932919a 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -20,7 +20,10 @@ HeunDiscreteScheduler, ) from shark.shark_inference import SharkInference -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.schedulers import ( + SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, +) from apps.stable_diffusion.src.models import ( SharkifyStableDiffusionModel, get_vae, @@ -52,6 +55,7 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSinglestepScheduler, diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 5267dd9efc..bf19e1234a 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -15,6 +15,7 @@ cancel_sd, set_model_default_configs, ) +from apps.stable_diffusion.web.ui.common_ui_events import lora_changed from apps.stable_diffusion.web.utils.metadata import import_png_metadata from apps.stable_diffusion.web.utils.common_label_calc import status_label from apps.stable_diffusion.src import ( @@ -271,7 +272,7 @@ def txt2img_sdxl_inf( elem_id="custom_model", value="None", choices=[ - "None", + None, "madebyollin/sdxl-vae-fp16-fix", ] + get_custom_model_files("vae"), @@ -339,6 +340,8 @@ def txt2img_sdxl_inf( "DDIM", "SharkEulerAncestralDiscrete", "SharkEulerDiscrete", + "EulerAncestralDiscrete", + "EulerDiscrete", ], allow_custom_value=False, visible=True, @@ -402,7 +405,7 @@ def txt2img_sdxl_inf( 50, value=args.guidance_scale, step=0.1, - label="CFG Scale", + label="Guidance Scale", ) ondemand = gr.Checkbox( value=args.ondemand, @@ -562,12 +565,14 @@ def txt2img_sdxl_inf( custom_vae, ], ) - txt2img_sdxl_custom_model.select( + txt2img_sdxl_custom_model.change( fn=set_model_default_configs, inputs=[ txt2img_sdxl_custom_model, ], outputs=[ + prompt, + negative_prompt, steps, scheduler, guidance_scale, @@ -576,3 +581,9 @@ def txt2img_sdxl_inf( custom_vae, ], ) + lora_weights.change( + fn=lora_changed, + inputs=[lora_weights], + outputs=[lora_tags], + queue=True, + ) diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index bcc12cbdb4..d6b4abd03a 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -377,6 +377,11 @@ def resource_path(relative_path): lines=2, elem_id="prompt_box", ) + # TODO: coming soon + autogen = gr.Checkbox( + label="Continuous Generation", + visible=False, + ) negative_prompt = gr.Textbox( label="Negative Prompt", value=args.negative_prompts[0], diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index 67ef3d3590..546598ec16 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -4,6 +4,7 @@ import math import json import safetensors +import gradio as gr from pathlib import Path from apps.stable_diffusion.src import args @@ -272,6 +273,8 @@ def set_model_default_configs(model_ckpt_or_id, jsonconfig=None): else: # We don't have default metadata to setup a good config. Do not change configs. return [ + gr.Textbox(label="Prompt", interactive=True, visible=True), + gr.update(), gr.update(), gr.update(), gr.update(), @@ -285,6 +288,8 @@ def get_config_from_json(model_ckpt_or_id, jsonconfig): # TODO: make this work properly. It is currently not user-exposed. cfgdata = json.load(jsonconfig) return [ + cfgdata["prompt_box_behavior"], + cfgdata["neg_prompt_box_behavior"], cfgdata["steps"], cfgdata["scheduler"], cfgdata["guidance_scale"], @@ -305,13 +310,27 @@ def default_config_exists(model_ckpt_or_id): default_configs = { - "stabilityai/sdxl-turbo": [1, "DDIM", 0, 512, 512, ""], + "stabilityai/sdxl-turbo": [ + gr.Textbox(label="", interactive=False, value=None, visible=False), + gr.Textbox( + label="Prompt", + value="A shark lady watching her friend build a snowman, deep orange sky, color block, high resolution, ((8k uhd, excellent artwork))", + ), + gr.Slider(0, 5, value=2), + gr.Dropdown(value="DDIM"), + gr.Slider(0, value=0), + 512, + 512, + "madebyollin/sdxl-vae-fp16-fix", + ], "stabilityai/stable-diffusion-xl-base-1.0": [ - 50, + gr.Textbox(label="Prompt", interactive=True, visible=True), + gr.Textbox(label="Negative Prompt", interactive=True), + 40, "DDIM", 7.5, - 512, - 512, + gr.Slider(value=1024, interactive=False), + gr.Slider(value=1024, interactive=False), "madebyollin/sdxl-vae-fp16-fix", ], }