From 019ba7051d7e69d82ae7214b322fa6b5f5fe2858 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 2 Feb 2024 11:25:41 -0600 Subject: [PATCH] Small cleanup --- apps/shark_studio/api/sd.py | 40 ------------------- apps/shark_studio/modules/img_processing.py | 44 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 40 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 2822d83829..c26c25bf00 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -374,46 +374,6 @@ def decode_latents(self, latents, cpu_scheduling=True): pil_images = self.image_processor.numpy_to_pil(images) return pil_images - # def process_sd_init_image(self, sd_init_image, resample_type): - # if isinstance(sd_init_image, list): - # images = [] - # for img in sd_init_image: - # img, _ = self.process_sd_init_image(img, resample_type) - # images.append(img) - # is_img2img = True - # return images, is_img2img - # if isinstance(sd_init_image, str): - # if os.path.isfile(sd_init_image): - # sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB") - # image, is_img2img = self.process_sd_init_image( - # sd_init_image, resample_type - # ) - # else: - # image = None - # is_img2img = False - # elif isinstance(sd_init_image, Image.Image): - # image = sd_init_image.convert("RGB") - # elif sd_init_image: - # image = sd_init_image["image"].convert("RGB") - # else: - # image = None - # is_img2img = False - # if image: - # resample_type = ( - # resamplers[resample_type] - # if resample_type in resampler_list - # # Fallback to Lanczos - # else Image.Resampling.LANCZOS - # ) - # image = image.resize((self.width, self.height), resample=resample_type) - # image_arr = np.stack([np.array(i) for i in (image,)], axis=0) - # image_arr = image_arr / 255.0 - # image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype) - # image_arr = 2 * (image_arr - 0.5) - # is_img2img = True - # image = image_arr - # return image, is_img2img - def generate_images( self, prompt, diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py index 80062814cf..821f7b86eb 100644 --- a/apps/shark_studio/modules/img_processing.py +++ b/apps/shark_studio/modules/img_processing.py @@ -1,6 +1,8 @@ import os import re import json +import torch +import numpy as np from csv import DictWriter from PIL import Image, PngImagePlugin @@ -8,6 +10,7 @@ from datetime import datetime as dt from base64 import decode + resamplers = { "Lanczos": Image.Resampling.LANCZOS, "Nearest Neighbor": Image.Resampling.NEAREST, @@ -158,3 +161,44 @@ def resize_stencil(image: Image.Image, width, height, resampler_type=None): resampler = resamplers["Nearest Neighbor"] new_image = image.resize((n_width, n_height), resampler=resampler) return new_image, n_width, n_height + + +def process_sd_init_image(self, sd_init_image, resample_type): + if isinstance(sd_init_image, list): + images = [] + for img in sd_init_image: + img, _ = self.process_sd_init_image(img, resample_type) + images.append(img) + is_img2img = True + return images, is_img2img + if isinstance(sd_init_image, str): + if os.path.isfile(sd_init_image): + sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB") + image, is_img2img = self.process_sd_init_image( + sd_init_image, resample_type + ) + else: + image = None + is_img2img = False + elif isinstance(sd_init_image, Image.Image): + image = sd_init_image.convert("RGB") + elif sd_init_image: + image = sd_init_image["image"].convert("RGB") + else: + image = None + is_img2img = False + if image: + resample_type = ( + resamplers[resample_type] + if resample_type in resampler_list + # Fallback to Lanczos + else Image.Resampling.LANCZOS + ) + image = image.resize((self.width, self.height), resample=resample_type) + image_arr = np.stack([np.array(i) for i in (image,)], axis=0) + image_arr = image_arr / 255.0 + image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype) + image_arr = 2 * (image_arr - 0.5) + is_img2img = True + image = image_arr + return image, is_img2img \ No newline at end of file