Skip to content

Commit

Permalink
(Studio) Update gradio and multicontrolnet UI. (#2001)
Browse files Browse the repository at this point in the history
* (Studio) Update gradio and multicontrolnet UI.

* Fixes for outputgallery, exe build

* Fix image return types.

* Update Gradio to 4.7.1

* Fix send buttons and hiresfix

* Various bugfixes and SDXL additions.

* More UI fixes and txt2img_sdxl presets.

*enable SDXL-Turbo and custom models, custom VAE for sdxl

* img2img ui tweaks
  • Loading branch information
monorimet authored Dec 4, 2023
1 parent 9c50edc commit d72da38
Show file tree
Hide file tree
Showing 27 changed files with 1,121 additions and 194 deletions.
3 changes: 3 additions & 0 deletions apps/stable_diffusion/shark_sd.spec
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

Expand Down
4 changes: 3 additions & 1 deletion apps/stable_diffusion/shark_studio_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
Expand Down Expand Up @@ -75,6 +76,7 @@
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
Expand All @@ -85,4 +87,4 @@
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
hiddenimports += ["iree._runtime"]
61 changes: 49 additions & 12 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb")
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None,
Expand Down Expand Up @@ -436,24 +436,48 @@ def __init__(
super().__init__()
self.vae = None
if custom_vae == "":
print(f"Loading default vae, with target {model_id}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
precision = "fp16" if "fp16" in custom_vae else None
print(f"Loading custom vae, with target {custom_vae}")
if os.path.exists(custom_vae):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
custom_vae = "/".join(
[
custom_vae.split("/")[-2].split("\\")[-1],
custom_vae.split("/")[-1],
]
)
print("Using hub to get custom vae")
try:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
variant=precision,
)
except:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
print(f"Loading custom vae, with state {custom_vae}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae

def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
Expand All @@ -465,7 +489,12 @@ def forward(self, latents):
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
is_f16 = False
if not self.custom_vae:
is_f16 = False
elif "16" in self.custom_vae:
is_f16 = True
else:
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
Expand Down Expand Up @@ -917,11 +946,19 @@ def __init__(
low_cpu_mem_usage=False,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
try:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
except:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
Expand Down
5 changes: 4 additions & 1 deletion apps/stable_diffusion/src/models/opt_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)


def get_tokenizer(subfolder="tokenizer"):
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
if hf_model_id is not None:
args.hf_model_id = hf_model_id

tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder=subfolder
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def generate_images(
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
stencils,
images,
resample_type,
control_mode,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
Expand All @@ -38,6 +41,7 @@ def __init__(
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
Expand All @@ -48,8 +52,10 @@ def __init__(
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_fp32_vae: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.is_fp32_vae = is_fp32_vae

def prepare_latents(
self,
Expand Down Expand Up @@ -203,10 +209,10 @@ def generate_images(
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
# imgs = self.decode_latents_sdxl(None)
# all_imgs.extend(imgs)
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
imgs = self.decode_latents_sdxl(
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
Expand Down Expand Up @@ -52,6 +55,7 @@ def __init__(
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
Expand All @@ -62,21 +66,23 @@ def __init__(
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_f32_vae: bool = False,
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.scheduler = scheduler
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
self.is_f32_vae = is_f32_vae
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
Expand Down Expand Up @@ -202,6 +208,9 @@ def encode_prompt_sdxl(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
hf_model_id: Optional[
str
] = "stabilityai/stable-diffusion-xl-base-1.0",
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -211,7 +220,7 @@ def encode_prompt_sdxl(
batch_size = prompt_embeds.shape[0]

# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2")
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
Expand Down Expand Up @@ -332,7 +341,7 @@ def encode_prompt_sdxl(
gc.collect()

# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -523,6 +532,9 @@ def produce_img_latents_sdxl(
cpu_scheduling,
guidance_scale,
dtype,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
# return None
self.status = SD_STATE_IDLE
Expand All @@ -533,11 +545,22 @@ def produce_img_latents_sdxl(
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
if isinstance(latents, np.ndarray):
latents = torch.tensor(latents)
latent_model_input = torch.cat([latents] * 2)

latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
).to(dtype)
)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)

noise_pred = self.unet(
"forward",
Expand All @@ -549,11 +572,17 @@ def produce_img_latents_sdxl(
add_time_ids,
guidance_scale,
),
send_to_host=False,
send_to_host=True,
)
if not isinstance(latents, torch.Tensor):
latents = torch.from_numpy(latents).to("cpu")
noise_pred = torch.from_numpy(noise_pred).to("cpu")

latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.detach().numpy()
noise_pred = noise_pred.detach().numpy()

step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
Expand All @@ -569,11 +598,15 @@ def produce_img_latents_sdxl(

return latents

def decode_latents_sdxl(self, latents):
latents = latents.to(torch.float32)
def decode_latents_sdxl(self, latents, is_fp32_vae):
# latents are in unet dtype here so switch if we want to use fp32
if is_fp32_vae:
print("Casting latents to float32 for VAE")
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()

images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]

Expand Down Expand Up @@ -666,6 +699,17 @@ def from_pretrained(
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
)
if cls.__name__ == "Text2ImageSDXLPipeline":
is_fp32_vae = True if "16" not in custom_vae else False
return cls(
scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
is_fp32_vae,
)

return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)

# #####################################################
Expand Down
3 changes: 3 additions & 0 deletions apps/stable_diffusion/src/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
Loading

0 comments on commit d72da38

Please sign in to comment.