Skip to content

Commit

Permalink
fixes #1006 makes sdxl turbo models work
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Feb 4, 2025
1 parent 21d9647 commit 68d6f85
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 36 deletions.
32 changes: 23 additions & 9 deletions src/airunner/handlers/stablediffusion/sd_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def _clear_cached_properties(self):
self._outpaint_settings = None
self._path_settings = None

@property
def use_compel(self) -> bool:
return self.generator_settings_cached.use_compel

@property
def controlnet_path(self):
version: str = self.controlnet_model.version
Expand Down Expand Up @@ -191,6 +195,13 @@ def is_safetensors(self) -> bool:
return False
return self.model_path.endswith(".safetensors")

@property
def version(self) -> str:
version = self.generator_settings_cached.version
if version == "SDXL Turbo":
version = "SDXL 1.0"
return version

@property
def is_sd_xl(self) -> bool:
return self.generator_settings_cached.version == StableDiffusionVersion.SDXL1_0.value
Expand Down Expand Up @@ -258,13 +269,13 @@ def controlnet_image(self) -> Image:
def controlnet_model(self) -> ControlnetModel:
if (
self._controlnet_model is None or
self._controlnet_model.version != self.generator_settings_cached.version or
self._controlnet_model.version != self.version or
self._controlnet_model.display_name != self.controlnet_settings_cached.controlnet
):

self._controlnet_model = self.session.query(ControlnetModel).filter_by(
display_name=self.controlnet_settings_cached.controlnet,
version=self.generator_settings_cached.version
version=self.version
).first()

return self._controlnet_model
Expand Down Expand Up @@ -320,7 +331,7 @@ def lora_base_path(self) -> str:
os.path.join(
self.path_settings_cached.base_path,
"art/models",
self.generator_settings_cached.version,
self.version,
"lora"
)
)
Expand Down Expand Up @@ -909,7 +920,7 @@ def _load_scheduler(self, scheduler=None):
self.change_model_status(ModelType.SCHEDULER, ModelStatus.LOADING)
scheduler_name = scheduler or self.generator_settings_cached.scheduler
base_path:str = self.path_settings_cached.base_path
scheduler_version:str = self.generator_settings_cached.version
scheduler_version:str = self.version
scheduler_path = os.path.expanduser(
os.path.join(
base_path,
Expand Down Expand Up @@ -947,6 +958,9 @@ def _prepare_quantization_settings(self, data: dict) -> dict:
"""
Quantize the model if possible.
"""
if self.is_sd_xl_turbo:
return data

path = os.path.expanduser(os.path.join(
self.path_settings_cached.base_path,
"art",
Expand Down Expand Up @@ -1091,7 +1105,7 @@ def _move_pipe_to_device(self):
def _load_lora(self):

enabled_lora = self.session.query(Lora).filter_by(
version=self.generator_settings_cached.version,
version=self.version,
enabled=True
).all()
for lora in enabled_lora:
Expand Down Expand Up @@ -1152,7 +1166,7 @@ def _load_embeddings(self):
self.logger.error(f"Failed to unload embeddings: {e}")

embeddings = self.session.query(Embedding).filter_by(
version=self.generator_settings_cached.version
version=self.version
).all()

for embedding in embeddings:
Expand All @@ -1173,7 +1187,7 @@ def _load_embeddings(self):
self.logger.debug("No embeddings enabled")

def _load_compel(self):
if self.generator_settings_cached.use_compel:
if self.use_compel:
try:
self._load_textual_inversion_manager()
self._load_compel_proc()
Expand Down Expand Up @@ -1503,7 +1517,7 @@ def _load_prompt_embeds(self):
self.logger.debug("Compel proc is not loading - attempting to load")
self._load_compel()
self.logger.debug("Loading prompt embeds")
if not self.generator_settings_cached.use_compel:
if not self.use_compel:
return

prompt = self.prompt
Expand Down Expand Up @@ -1604,7 +1618,7 @@ def _prepare_data(self, active_rect = None) -> dict:
))
self._set_lora_adapters()

if self.generator_settings_cached.use_compel:
if self.use_compel:
args.update(dict(
prompt_embeds=self._prompt_embeds,
negative_prompt_embeds=self._negative_prompt_embeds,
Expand Down
75 changes: 48 additions & 27 deletions src/airunner/workers/model_scanner_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ def __init__(self, *args, **kwargs):
PipelineMixin.__init__(self)

def handle_message(self):
# ensure turbo path exists
turbo_paths = (
os.path.expanduser(os.path.join(
self.path_settings.base_path, "art/models", "SDXL 1.0", "txt2img", "turbo_models"
)),
os.path.expanduser(os.path.join(
self.path_settings.base_path, "art/models", "SDXL 1.0", "inpaint", "turbo_models"
))
)
for turbo_path in turbo_paths:
if not os.path.exists(turbo_path):
os.makedirs(turbo_path)
with open(os.path.join(turbo_path, "README.txt"), "w") as f:
f.write("Place Stable Diffusion XL Turbo, Lightning and Hyper models here")

self.scan_for_models()
self.remove_missing_models()

Expand Down Expand Up @@ -45,35 +60,41 @@ def scan_for_models(self):
action = action_item.name
if "controlnet_processors" in action_item.path:
continue
with os.scandir(action_item.path) as file_object:
for file_item in file_object:
model = AIModels()
model.name = os.path.basename(file_item.path)
model.path = file_item.path
model.branch = "main"
model.version = version
model.category = "stablediffusion"
model.pipeline_action = action
model.enabled = True
model.model_type = "art"
model.is_default = False
if file_item.is_file(): # ckpt or safetensors file
if file_item.name.endswith(".ckpt") or file_item.name.endswith(".safetensors"):
name = file_item.name.replace(".ckpt", "").replace(".safetensors", "")
model.name = name
else:
model = None
elif file_item.is_dir(): # diffusers folder
is_diffusers_directory = True
for diffuser_folder in diffusers_folders:
if not os.path.exists(os.path.join(file_item.path, diffuser_folder)):
is_diffusers_directory = False
paths = (action_item.path, os.path.join(action_item.path, "turbo_models"))
for path in paths:
if not os.path.exists(path):
continue
with os.scandir(path) as file_object:
for file_item in file_object:
model = AIModels()
model.name = os.path.basename(file_item.path)
model.path = file_item.path
model.branch = "main"
if "turbo_models" in path:
version = "SDXL Turbo"
model.version = version
model.category = "stablediffusion"
model.pipeline_action = action
model.enabled = True
model.model_type = "art"
model.is_default = False
if file_item.is_file(): # ckpt or safetensors file
if file_item.name.endswith(".ckpt") or file_item.name.endswith(".safetensors"):
name = file_item.name.replace(".ckpt", "").replace(".safetensors", "")
model.name = name
else:
model = None
if is_diffusers_directory:
model.name = file_item.name
elif file_item.is_dir(): # diffusers folder
is_diffusers_directory = True
for diffuser_folder in diffusers_folders:
if not os.path.exists(os.path.join(file_item.path, diffuser_folder)):
is_diffusers_directory = False
model = None
if is_diffusers_directory:
model.name = file_item.name

if model:
models.append(model)
if model:
models.append(model)
self.emit_signal(SignalCode.AI_MODELS_SAVE_OR_UPDATE_SIGNAL, {"models": models})

def remove_missing_models(self):
Expand Down

0 comments on commit 68d6f85

Please sign in to comment.