Skip to content

Commit

Permalink
fixes typing and PEP8 warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Feb 5, 2025
1 parent b142db0 commit 14533bc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
1 change: 1 addition & 0 deletions src/airunner/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class SignalCode(Enum):
MISSING_REQUIRED_MODELS = enum.auto()

class EngineResponseCode(Enum):
NONE = 0
STATUS = 100
ERROR = 200
WARNING = 300
Expand Down
39 changes: 22 additions & 17 deletions src/airunner/handlers/stablediffusion/sd_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@


class SDHandler(BaseHandler):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._controlnet_model = None
self._controlnet: Optional[ControlNetModel] = None
Expand Down Expand Up @@ -558,7 +558,7 @@ def unload(self):
clear_memory()
self.change_model_status(ModelType.SD, ModelStatus.UNLOADED)

def handle_generate_signal(self, message: dict=None):
def handle_generate_signal(self, message: Optional[Dict] = None):
self.load()
self._clear_cached_properties()
self._swap_pipeline()
Expand All @@ -568,6 +568,7 @@ def handle_generate_signal(self, message: dict=None):
):
self._current_state = HandlerState.PREPARING_TO_GENERATE
response = None
code = EngineResponseCode.NONE
try:
response = self._generate()
code = EngineResponseCode.IMAGE_GENERATED
Expand Down Expand Up @@ -704,7 +705,7 @@ def _generate(self):
is_outpaint=self.is_outpaint
)

def _initialize_metadata(self, images: List[Any], data:Dict) -> Optional[dict]:
def _initialize_metadata(self, images: List[Any], data: Dict) -> Optional[dict]:
metadata = None
if self.metadata_settings.export_metadata:
metadata_dict = dict()
Expand Down Expand Up @@ -764,7 +765,7 @@ def _initialize_metadata(self, images: List[Any], data:Dict) -> Optional[dict]:
metadata = [metadata_dict for _ in range(len(images))]
return metadata

def _export_images(self, images: List[Any], data:Dict):
def _export_images(self, images: List[Any], data: Dict):
extension = self.application_settings_cached.image_export_type
filename = "image"
file_path = os.path.expanduser(
Expand Down Expand Up @@ -923,7 +924,6 @@ def _load_controlnet_processor(self):
if self._controlnet_processor is not None:
return
self.logger.debug(f"Loading controlnet processor {self.controlnet_model.name}")
#self._controlnet_processor = Processor(self.controlnet_model.name)
controlnet_data = controlnet_aux_models[self.controlnet_model.name]
controlnet_class_: Any = controlnet_data["class"]
checkpoint: bool = controlnet_data["checkpoint"]
Expand All @@ -938,8 +938,8 @@ def _load_controlnet_processor(self):
def _load_scheduler(self, scheduler=None):
self.change_model_status(ModelType.SCHEDULER, ModelStatus.LOADING)
scheduler_name = scheduler or self.generator_settings_cached.scheduler
base_path:str = self.path_settings_cached.base_path
scheduler_version:str = self.version
base_path: str = self.path_settings_cached.base_path
scheduler_version: str = self.version
scheduler_path = os.path.expanduser(
os.path.join(
base_path,
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def _prepare_quantization_settings(self, data: dict) -> dict:
)
return data

def _prepare_tiny_autoencoder(self, data: dict) -> dict:
def _prepare_tiny_autoencoder(self, data: Dict) -> Optional[Dict]:
if not self.is_outpaint:
path = os.path.expanduser(os.path.join(
self.path_settings_cached.base_path,
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def _send_pipeline_loaded_signal(self):
pipeline_type = "img2img"
elif pipeline_class in self.outpaint_pipelines:
pipeline_type = "inpaint"
self.emit_signal(SignalCode.SD_PIPELINE_LOADED_SIGNAL, { "pipeline": pipeline_type })
self.emit_signal(SignalCode.SD_PIPELINE_LOADED_SIGNAL, {"pipeline": pipeline_type})

def _move_pipe_to_device(self):
if self._pipe is not None:
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def _load_lora_weights(self, lora: Lora):
def _set_lora_adapters(self):
self.logger.debug("Setting LORA adapters")

loaded_lora_id = [l.id for l in self._loaded_lora.values()]
loaded_lora_id = [lora.id for lora in self._loaded_lora.values()]
enabled_lora = self.session.query(Lora).filter(Lora.id.in_(loaded_lora_id)).all()
adapter_weights = []
adapter_names = []
Expand Down Expand Up @@ -1196,7 +1196,11 @@ def _load_embeddings(self):
else:
try:
self.logger.debug(f"Loading embedding {embedding_path}")
self._pipe.load_textual_inversion(embedding_path, token=embedding.name, weight_name=embedding_path)
self._pipe.load_textual_inversion(
embedding_path,
token=embedding.name,
weight_name=embedding_path
)
self._loaded_embeddings.append(embedding_path)
except Exception as e:
self.logger.error(f"Failed to load embedding {embedding_path}: {e}")
Expand Down Expand Up @@ -1273,7 +1277,6 @@ def _make_memory_efficient(self):

def _finalize_load_stable_diffusion(self):
safety_checker_ready = True
tokenizer_ready = True
if self.use_safety_checker:
safety_checker_ready = (
self._safety_checker is not None and
Expand Down Expand Up @@ -1363,7 +1366,7 @@ def _apply_cpu_offload(self, attr_val):
def _apply_model_offload(self, attr_val):
if attr_val and not self.memory_settings.use_enable_sequential_cpu_offload:
self.logger.debug("Enabling model cpu offload")
#self._move_stable_diffusion_to_cpu()
# self._move_stable_diffusion_to_cpu()
self._pipe.enable_model_cpu_offload()
else:
self.logger.debug("Model cpu offload disabled")
Expand Down Expand Up @@ -1454,7 +1457,7 @@ def _unload_loras(self):
self._loaded_lora = {}
self._disabled_lora = []

def _unload_lora(self, lora:Lora):
def _unload_lora(self, lora: Lora):
if lora.path in self._loaded_lora:
self.logger.debug(f"Unloading LORA {lora.path}")
del self._loaded_lora[lora.path]
Expand Down Expand Up @@ -1590,7 +1593,9 @@ def _load_prompt_embeds(self):

if self.is_sd_xl_or_turbo:
prompt_embeds, pooled_prompt_embeds = self._compel_proc.build_conditioning_tensor(compel_prompt)
negative_prompt_embeds, negative_pooled_prompt_embeds = self._compel_proc.build_conditioning_tensor(compel_negative_prompt)
negative_prompt_embeds, negative_pooled_prompt_embeds = self._compel_proc.build_conditioning_tensor(
compel_negative_prompt
)
else:
prompt_embeds = self._compel_proc.build_conditioning_tensor(compel_prompt)
negative_prompt_embeds = self._compel_proc.build_conditioning_tensor(compel_negative_prompt)
Expand Down Expand Up @@ -1622,7 +1627,7 @@ def _clear_memory_efficient_settings(self):
if key.endswith("_applied"):
self._memory_settings_flags[key] = None

def _prepare_data(self, active_rect = None) -> dict:
def _prepare_data(self, active_rect=None) -> Dict:
"""
Here we are loading the arguments for the Stable Diffusion generator.
:return:
Expand Down Expand Up @@ -1755,7 +1760,7 @@ def _prepare_data(self, active_rect = None) -> dict:
))
return args

def _resize_image(self, image: Image, max_width: int, max_height: int) -> Image:
def _resize_image(self, image: Image, max_width: int, max_height: int) -> Optional[Image]:
"""
Resize the image to ensure it is not larger than max_width and max_height,
while maintaining the aspect ratio.
Expand Down

0 comments on commit 14533bc

Please sign in to comment.