From 68d6f8593444aae6d6421f52677317a485894197 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Tue, 4 Feb 2025 12:25:48 -0700 Subject: [PATCH] fixes #1006 makes sdxl turbo models work --- .../handlers/stablediffusion/sd_handler.py | 32 +++++--- src/airunner/workers/model_scanner_worker.py | 75 ++++++++++++------- 2 files changed, 71 insertions(+), 36 deletions(-) diff --git a/src/airunner/handlers/stablediffusion/sd_handler.py b/src/airunner/handlers/stablediffusion/sd_handler.py index 839f91901..764a14c51 100644 --- a/src/airunner/handlers/stablediffusion/sd_handler.py +++ b/src/airunner/handlers/stablediffusion/sd_handler.py @@ -143,6 +143,10 @@ def _clear_cached_properties(self): self._outpaint_settings = None self._path_settings = None + @property + def use_compel(self) -> bool: + return self.generator_settings_cached.use_compel + @property def controlnet_path(self): version: str = self.controlnet_model.version @@ -191,6 +195,13 @@ def is_safetensors(self) -> bool: return False return self.model_path.endswith(".safetensors") + @property + def version(self) -> str: + version = self.generator_settings_cached.version + if version == "SDXL Turbo": + version = "SDXL 1.0" + return version + @property def is_sd_xl(self) -> bool: return self.generator_settings_cached.version == StableDiffusionVersion.SDXL1_0.value @@ -258,13 +269,13 @@ def controlnet_image(self) -> Image: def controlnet_model(self) -> ControlnetModel: if ( self._controlnet_model is None or - self._controlnet_model.version != self.generator_settings_cached.version or + self._controlnet_model.version != self.version or self._controlnet_model.display_name != self.controlnet_settings_cached.controlnet ): self._controlnet_model = self.session.query(ControlnetModel).filter_by( display_name=self.controlnet_settings_cached.controlnet, - version=self.generator_settings_cached.version + version=self.version ).first() return self._controlnet_model @@ -320,7 +331,7 @@ def lora_base_path(self) -> str: os.path.join( self.path_settings_cached.base_path, "art/models", - self.generator_settings_cached.version, + self.version, "lora" ) ) @@ -909,7 +920,7 @@ def _load_scheduler(self, scheduler=None): self.change_model_status(ModelType.SCHEDULER, ModelStatus.LOADING) scheduler_name = scheduler or self.generator_settings_cached.scheduler base_path:str = self.path_settings_cached.base_path - scheduler_version:str = self.generator_settings_cached.version + scheduler_version:str = self.version scheduler_path = os.path.expanduser( os.path.join( base_path, @@ -947,6 +958,9 @@ def _prepare_quantization_settings(self, data: dict) -> dict: """ Quantize the model if possible. """ + if self.is_sd_xl_turbo: + return data + path = os.path.expanduser(os.path.join( self.path_settings_cached.base_path, "art", @@ -1091,7 +1105,7 @@ def _move_pipe_to_device(self): def _load_lora(self): enabled_lora = self.session.query(Lora).filter_by( - version=self.generator_settings_cached.version, + version=self.version, enabled=True ).all() for lora in enabled_lora: @@ -1152,7 +1166,7 @@ def _load_embeddings(self): self.logger.error(f"Failed to unload embeddings: {e}") embeddings = self.session.query(Embedding).filter_by( - version=self.generator_settings_cached.version + version=self.version ).all() for embedding in embeddings: @@ -1173,7 +1187,7 @@ def _load_embeddings(self): self.logger.debug("No embeddings enabled") def _load_compel(self): - if self.generator_settings_cached.use_compel: + if self.use_compel: try: self._load_textual_inversion_manager() self._load_compel_proc() @@ -1503,7 +1517,7 @@ def _load_prompt_embeds(self): self.logger.debug("Compel proc is not loading - attempting to load") self._load_compel() self.logger.debug("Loading prompt embeds") - if not self.generator_settings_cached.use_compel: + if not self.use_compel: return prompt = self.prompt @@ -1604,7 +1618,7 @@ def _prepare_data(self, active_rect = None) -> dict: )) self._set_lora_adapters() - if self.generator_settings_cached.use_compel: + if self.use_compel: args.update(dict( prompt_embeds=self._prompt_embeds, negative_prompt_embeds=self._negative_prompt_embeds, diff --git a/src/airunner/workers/model_scanner_worker.py b/src/airunner/workers/model_scanner_worker.py index 806ac5d4c..57684de8e 100644 --- a/src/airunner/workers/model_scanner_worker.py +++ b/src/airunner/workers/model_scanner_worker.py @@ -15,6 +15,21 @@ def __init__(self, *args, **kwargs): PipelineMixin.__init__(self) def handle_message(self): + # ensure turbo path exists + turbo_paths = ( + os.path.expanduser(os.path.join( + self.path_settings.base_path, "art/models", "SDXL 1.0", "txt2img", "turbo_models" + )), + os.path.expanduser(os.path.join( + self.path_settings.base_path, "art/models", "SDXL 1.0", "inpaint", "turbo_models" + )) + ) + for turbo_path in turbo_paths: + if not os.path.exists(turbo_path): + os.makedirs(turbo_path) + with open(os.path.join(turbo_path, "README.txt"), "w") as f: + f.write("Place Stable Diffusion XL Turbo, Lightning and Hyper models here") + self.scan_for_models() self.remove_missing_models() @@ -45,35 +60,41 @@ def scan_for_models(self): action = action_item.name if "controlnet_processors" in action_item.path: continue - with os.scandir(action_item.path) as file_object: - for file_item in file_object: - model = AIModels() - model.name = os.path.basename(file_item.path) - model.path = file_item.path - model.branch = "main" - model.version = version - model.category = "stablediffusion" - model.pipeline_action = action - model.enabled = True - model.model_type = "art" - model.is_default = False - if file_item.is_file(): # ckpt or safetensors file - if file_item.name.endswith(".ckpt") or file_item.name.endswith(".safetensors"): - name = file_item.name.replace(".ckpt", "").replace(".safetensors", "") - model.name = name - else: - model = None - elif file_item.is_dir(): # diffusers folder - is_diffusers_directory = True - for diffuser_folder in diffusers_folders: - if not os.path.exists(os.path.join(file_item.path, diffuser_folder)): - is_diffusers_directory = False + paths = (action_item.path, os.path.join(action_item.path, "turbo_models")) + for path in paths: + if not os.path.exists(path): + continue + with os.scandir(path) as file_object: + for file_item in file_object: + model = AIModels() + model.name = os.path.basename(file_item.path) + model.path = file_item.path + model.branch = "main" + if "turbo_models" in path: + version = "SDXL Turbo" + model.version = version + model.category = "stablediffusion" + model.pipeline_action = action + model.enabled = True + model.model_type = "art" + model.is_default = False + if file_item.is_file(): # ckpt or safetensors file + if file_item.name.endswith(".ckpt") or file_item.name.endswith(".safetensors"): + name = file_item.name.replace(".ckpt", "").replace(".safetensors", "") + model.name = name + else: model = None - if is_diffusers_directory: - model.name = file_item.name + elif file_item.is_dir(): # diffusers folder + is_diffusers_directory = True + for diffuser_folder in diffusers_folders: + if not os.path.exists(os.path.join(file_item.path, diffuser_folder)): + is_diffusers_directory = False + model = None + if is_diffusers_directory: + model.name = file_item.name - if model: - models.append(model) + if model: + models.append(model) self.emit_signal(SignalCode.AI_MODELS_SAVE_OR_UPDATE_SIGNAL, {"models": models}) def remove_missing_models(self):