From abf95c699f8eb1a9adefb368d4a814deed4a241f Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Mon, 22 Jan 2024 13:05:05 -0700 Subject: [PATCH] Fixes LLM, adding image to canvas from thumbnail panel, saving layers and grabbing --- src/airunner/aihandler/engine.py | 25 ++- src/airunner/aihandler/runner.py | 172 +++++++++--------- src/airunner/aihandler/transformer_runner.py | 2 - src/airunner/signal_mediator.py | 5 +- .../widgets/canvas_plus/canvas_plus_widget.py | 69 ++++--- .../widgets/canvas_plus/draggables.py | 8 +- .../generator_form/generator_form_widget.py | 10 +- .../generator_form/generator_tab_widget.py | 1 - src/airunner/windows/main/ai_model_mixin.py | 56 +++--- src/airunner/windows/main/main_window.py | 3 + src/airunner/windows/main/settings_mixin.py | 29 ++- src/airunner/workers/worker.py | 8 +- 12 files changed, 209 insertions(+), 179 deletions(-) diff --git a/src/airunner/aihandler/engine.py b/src/airunner/aihandler/engine.py index 68307874f..d01997cc9 100644 --- a/src/airunner/aihandler/engine.py +++ b/src/airunner/aihandler/engine.py @@ -8,9 +8,10 @@ from airunner.workers.worker import Worker from airunner.aihandler.llm import LLMController from airunner.aihandler.logger import Logger -from airunner.aihandler.runner import SDController +from airunner.aihandler.runner import SDGenerateWorker, SDRequestWorker from airunner.aihandler.tts import TTS from airunner.windows.main.settings_mixin import SettingsMixin +from airunner.service_locator import ServiceLocator class EngineRequestWorker(Worker): @@ -51,7 +52,6 @@ class Engine(QObject, MediatorMixin, SettingsMixin): # Model controllers llm_controller = None - sd_controller = None tts_controller = None stt_controller = None ocr_controller = None @@ -69,7 +69,7 @@ def do_response(self, response): @pyqtSlot(object) def on_engine_cancel_signal(self, _ignore): self.logger.info("Canceling") - self.sd_controller.cancel() + self.emit("sd_cancel_signal") self.request_worker.cancel() @pyqtSlot(object) @@ -105,7 +105,6 @@ def __init__(self, **kwargs): # Initialize Controllers self.llm_controller = LLMController(engine=self) - self.sd_controller = SDController(engine=self) #self.stt_controller = STTController(engine=self) # self.ocr_controller = ImageProcessor(engine=self) self.tts_controller = TTS(engine=self) @@ -123,16 +122,14 @@ def __init__(self, **kwargs): self.register("EngineResponseWorker_response_signal", self) self.register("text_generate_request_signal", self) self.register("image_generate_request_signal", self) - self.register("sd_controller_response_signal", self) self.register("llm_controller_response_signal", self) self.register("llm_text_streamed_signal", self) + + self.sd_request_worker = self.create_worker(SDRequestWorker) + self.sd_generate_worker = self.create_worker(SDGenerateWorker) self.request_worker = self.create_worker(EngineRequestWorker) self.response_worker = self.create_worker(EngineResponseWorker) - - - def on_sd_controller_response_signal(self, message): - self.do_response(message) def on_llm_controller_response_signal(self, message): self.do_response(message) @@ -178,7 +175,7 @@ def on_image_generate_request_signal(self, message): self.memory_settings["unload_unused_models"], self.memory_settings["move_unused_model_to_cpu"] ) - self.sd_controller.do_request(message["message"]) + self.sd_request_worker.add_to_queue(message) def request_queue_size(self): return self.request_worker.queue.qsize() @@ -249,7 +246,7 @@ def unload_stablediffusion(self): """ Unload the Stable Diffusion model from memory. """ - self.sd_controller.unload() + self.emit("unload_stablediffusion_signal") def parse_message(self, message): if message: @@ -313,10 +310,10 @@ def unload_llm(self, request_data: dict, do_unload_model: bool, move_unused_mode def do_unload_llm(self): self.logger.info("Unloading LLM") self.llm_controller.do_unload_llm() - self.clear_memory() + #self.clear_memory() def move_sd_to_cpu(self): - if self.sd_controller.is_pipe_on_cpu or not self.sd_controller.has_pipe: + if ServiceLocator.get("is_pipe_on_cpu")() or not ServiceLocator.get("has_pipe")(): return - self.sd_controller.move_pipe_to_cpu() + self.emit("move_pipe_to_cpu_signal") self.clear_memory() \ No newline at end of file diff --git a/src/airunner/aihandler/runner.py b/src/airunner/aihandler/runner.py index 29b1902f9..c10f293ae 100644 --- a/src/airunner/aihandler/runner.py +++ b/src/airunner/aihandler/runner.py @@ -10,7 +10,7 @@ import torch from PIL import Image, ImageDraw, ImageFont -from PyQt6.QtCore import QObject +from PyQt6.QtCore import QObject, pyqtSlot from controlnet_aux.processor import Processor from diffusers.pipelines.stable_diffusion.convert_from_ckpt import \ @@ -44,20 +44,31 @@ from airunner.workers.worker import Worker from airunner.mediator_mixin import MediatorMixin from airunner.windows.main.settings_mixin import SettingsMixin - -logger = Logger(prefix="SDRunner") +from airunner.service_locator import ServiceLocator torch.backends.cuda.matmul.allow_tf32 = True + class SDRequestWorker(Worker): def __init__(self, prefix="SDRequestWorker"): super().__init__(prefix=prefix) + + def handle_message(self, message): + self.emit("add_sd_response_to_queue_signal", dict( + message=message, + image_base_path=self.path_settings["image_path"] + )) class SDGenerateWorker(Worker): def __init__(self, prefix): self.sd = SDRunner() super().__init__(prefix) + self.register("add_sd_response_to_queue_signal", self) + + @pyqtSlot(object) + def on_add_sd_response_to_queue_signal(self, data): + self.add_to_queue(data) def handle_message(self, data): image_base_path = data["image_base_path"] @@ -90,44 +101,6 @@ def handle_message(self, data): image=image )) response["images"] = updated_images - self.emit("sd_image_generated_signal", response) - - -class SDController(QObject, MediatorMixin): - - logger = Logger(prefix="SDController") - - def __init__(self, *args, **kwargs): - MediatorMixin.__init__(self) - self.engine = kwargs.pop("engine", None) - super().__init__() - self.request_worker = self.create_worker(SDRequestWorker) - self.generate_worker = self.create_worker(SDGenerateWorker) - self.register("SDGenerateWorker_response_signal", self) - self.register("SDRequestWorker_response_signal", self) - - def do_request(self, message): - self.request_worker.add_to_queue(message) - - def on_SDRequestWorker_response_signal(self, message): - self.generate_worker.add_to_queue(dict( - message=message, - image_base_path=self.path_settings["image_path"] - )) - - def on_SDGenerateWorker_response_signal(self, message): - self.emit("sd_controller_response_signal", message) - - @property - def is_pipe_on_cpu(self): - return self.generate_worker.sd.is_pipe_on_cpu - - @property - def has_pipe(self): - return self.generate_worker.sd.has_pipe - - def move_pipe_to_cpu(self): - self.generate_worker.sd.move_pipe_to_cpu() class SDRunner( @@ -150,6 +123,7 @@ class SDRunner( AIModelMixin, MediatorMixin ): + logger = Logger(prefix="SDRunner") _current_model: str = "" _previous_model: str = "" _initialized: bool = False @@ -244,7 +218,7 @@ def local_files_only(self): @local_files_only.setter def local_files_only(self, value): - logger.info("Setting local_files_only to %s" % value) + self.logger.info("Setting local_files_only to %s" % value) self._local_files_only = value @property @@ -635,7 +609,7 @@ def pipe(self): elif self.is_vid_action: return self.txt2vid else: - logger.warning(f"Invalid action {self.action} unable to get pipe") + self.logger.warning(f"Invalid action {self.action} unable to get pipe") @pipe.setter def pipe(self, value): @@ -800,14 +774,28 @@ def original_model_data(self): return self.options.get("original_model_data", {}) def __init__(self, **kwargs): - #logger.set_level(LOG_LEVEL) - logger.info("Loading Stable Diffusion model runner...") + #self.logger.set_level(LOG_LEVEL) MediatorMixin.__init__(self) SettingsMixin.__init__(self) + LayerMixin.__init__(self) + LoraDataMixin.__init__(self) + EmbeddingDataMixin.__init__(self) + PipelineMixin.__init__(self) + ControlnetModelMixin.__init__(self) + AIModelMixin.__init__(self) super().__init__() + self.logger.info("Loading Stable Diffusion model runner...") self.safety_checker_model = self.models_by_pipeline_action("safety_checker") self.text_encoder_model = self.models_by_pipeline_action("text_encoder") self.inpaint_vae_model = self.models_by_pipeline_action("inpaint_vae") + self.register("sd_cancel_signal", self) + services = [ + "is_pipe_on_cpu", + "has_pipe", + ] + + for service in services: + ServiceLocator.register(service, lambda: getattr(self, service)) self._safety_checker = None self._controlnet = None @@ -818,6 +806,8 @@ def __init__(self, **kwargs): self.depth2img = None self.txt2vid = None + self.register("unload_stablediffusion_signal", self) + @staticmethod def latents_to_image(latents: torch.Tensor): image = latents.permute(0, 2, 3, 1) @@ -875,7 +865,7 @@ def is_safetensor_file(model): def initialize(self): if not self.initialized or self.reload_model or self.pipe is None: - logger.info("Initializing") + self.logger.info("Initializing") self.compel_proc = None self.prompt_embeds = None self.negative_prompt_embeds = None @@ -893,7 +883,8 @@ def generator(self, device=None, seed=None): return torch.Generator(device=device).manual_seed(seed) def prepare_options(self, data): - logger.info(f"Preparing options") + print("DATA", data) + self.logger.info(f"Preparing options") action = data["action"] options = data["options"] requested_model = options.get(f"model", None) @@ -906,9 +897,9 @@ def prepare_options(self, data): if (self.is_pipe_loaded and (sequential_cpu_offload_changed)) or model_changed: # model change if model_changed: - logger.info(f"Model changed, reloading model" + f" (from {self.model} to {requested_model})") + self.logger.info(f"Model changed, reloading model" + f" (from {self.model} to {requested_model})") if sequential_cpu_offload_changed: - logger.info(f"Sequential cpu offload changed, reloading model") + self.logger.info(f"Sequential cpu offload changed, reloading model") self.reload_model = True self.clear_scheduler() self.clear_controlnet() @@ -935,7 +926,7 @@ def error_handler(self, error): if "got an unexpected keyword argument 'image'" in message and self.action in ["outpaint", "pix2pix", "depth2img"]: message = f"This model does not support {self.action}" traceback.print_exc() - logger.error(error) + self.logger.error(error) self.emit("error_signal", message) def initialize_safety_checker(self, local_files_only=None): @@ -964,16 +955,16 @@ def load_safety_checker(self): if not self.pipe: return if not self.do_nsfw_filter: - logger.info("Disabling safety checker") + self.logger.info("Disabling safety checker") self.pipe.safety_checker = None elif self.pipe.safety_checker is None: - logger.info("Loading safety checker") + self.logger.info("Loading safety checker") self.pipe.safety_checker = self.safety_checker if self.pipe.safety_checker: self.pipe.safety_checker.to(self.device) def do_sample(self, **kwargs): - logger.info(f"Sampling {self.action}") + self.logger.info(f"Sampling {self.action}") if self.is_vid_action: message = "Generating video" @@ -1003,12 +994,12 @@ def do_sample(self, **kwargs): try: images = output.images except AttributeError: - logger.error("Unable to get images from output") + self.logger.error("Unable to get images from output") if self.action_has_safety_checker: try: nsfw_content_detected = output.nsfw_content_detected except AttributeError: - logger.error("Unable to get nsfw_content_detected from output") + self.logger.error("Unable to get nsfw_content_detected from output") return images, nsfw_content_detected def generate_latents(self): @@ -1068,7 +1059,7 @@ def call_pipe(self, **kwargs): "prompt_embeds": self.prompt_embeds, }) except Exception as _e: - logger.warning("Compel failed: " + str(_e)) + self.logger.warning("Compel failed: " + str(_e)) args.update({ "prompt": self.prompt, }) @@ -1082,7 +1073,7 @@ def call_pipe(self, **kwargs): "negative_prompt_embeds": self.negative_prompt_embeds, }) except Exception as _e: - logger.warning("Compel failed: " + str(_e)) + self.logger.warning("Compel failed: " + str(_e)) args.update({ "negative_prompt": self.negative_prompt, }) @@ -1107,7 +1098,7 @@ def call_pipe(self, **kwargs): args["generator"] = generator if self.enable_controlnet: - logger.info(f"Setting up controlnet") + self.logger.info(f"Setting up controlnet") args = self.load_controlnet_arguments(**args) self.load_safety_checker() @@ -1147,7 +1138,7 @@ def call_pipe_txt2vid(self, **kwargs): ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1] frame_ids = list(range(ch_start, ch_end)) try: - logger.info(f"Generating video with {len(frame_ids)} frames") + self.logger.info(f"Generating video with {len(frame_ids)} frames") self.emit("status_signal", f"Generating video, frames {cur_frame} to {cur_frame + len(frame_ids)-1} of {self.n_samples}") cur_frame += len(frame_ids) kwargs = { @@ -1225,7 +1216,7 @@ def prepare_extra_args(self, _data, image, mask): return extra_args def sample_diffusers_model(self, data: dict): - logger.info("sample_diffusers_model") + self.logger.info("sample_diffusers_model") from pytorch_lightning import seed_everything image = self.image mask = self.mask @@ -1258,7 +1249,7 @@ def process_prompts(self, data, seed): data["options"][f"negative_prompt"] = [negative_prompt for _ in range(self.batch_size)] return data prompt_data = self.prompt_data - logger.info(f"Process prompt") + self.logger.info(f"Process prompt") if self.deterministic_seed: prompt = data["options"][f"prompt"] if ".blend(" in prompt: @@ -1302,7 +1293,7 @@ def process_prompts(self, data, seed): def process_data(self, data: dict): import traceback - logger.info("Runner: process_data called") + self.logger.info("Runner: process_data called") self.requested_data = data self.prepare_options(data) #self.prepare_scheduler() @@ -1314,7 +1305,7 @@ def process_data(self, data: dict): def generate(self, data: dict): if not self.pipe: return - logger.info("generate called") + self.logger.info("generate called") self.do_cancel = False self.process_data(data) @@ -1405,6 +1396,10 @@ def callback(self, step: int, _time_step, latents): } self.emit("progress_signal", res) + @pyqtSlot(object) + def on_unload_stablediffusion_signal(self): + self.unload() + def unload(self): self.unload_model() self.unload_tokenizer() @@ -1417,7 +1412,7 @@ def unload_tokenizer(self): self.tokenizer = None def process_upscale(self, data: dict): - logger.info("Processing upscale") + self.logger.info("Processing upscale") image = self.input_image results = [] if image: @@ -1447,7 +1442,7 @@ def generator_sample(self, data: dict): return self.image_handler(images, self.requested_data, None) if not self.pipe: - logger.info("pipe is None") + self.logger.info("pipe is None") return self.emit("status_signal", f"Generating {'video' if self.is_vid_action else 'image'}") @@ -1457,7 +1452,7 @@ def generator_sample(self, data: dict): try: self.initialized = self.__dict__[action] is not None except KeyError: - logger.info(f"{action} model has not been initialized yet") + self.logger.info(f"{action} model has not been initialized yet") self.initialized = False error = None @@ -1488,7 +1483,8 @@ def generator_sample(self, data: dict): self._current_model = "" self.local_files_only = True - def cancel(self): + @pyqtSlot(object) + def on_sd_cancel_signal(self): self.do_cancel = True def log_error(self, error, message=None): @@ -1497,7 +1493,7 @@ def log_error(self, error, message=None): self.error_handler(message) def load_controlnet_from_ckpt(self, pipeline): - logger.info("Loading controlnet from ckpt") + self.logger.info("Loading controlnet from ckpt") pipeline = self.controlnet_action_diffuser( vae=pipeline.vae, text_encoder=pipeline.text_encoder, @@ -1512,7 +1508,7 @@ def load_controlnet_from_ckpt(self, pipeline): return pipeline def load_controlnet(self): - logger.info(f"Loading controlnet {self.controlnet_type} self.controlnet_model {self.controlnet_model}") + self.logger.info(f"Loading controlnet {self.controlnet_type} self.controlnet_model {self.controlnet_model}") self._controlnet = None self.current_controlnet_type = self.controlnet_type controlnet = StableDiffusionControlNetPipeline.from_pretrained( @@ -1524,17 +1520,17 @@ def load_controlnet(self): def preprocess_for_controlnet(self, image): if self.current_controlnet_type != self.controlnet_type or not self.processor: - logger.info("Loading controlnet processor " + self.controlnet_type) + self.logger.info("Loading controlnet processor " + self.controlnet_type) self.current_controlnet_type = self.controlnet_type - logger.info("Controlnet: Processing image") + self.logger.info("Controlnet: Processing image") self.processor = Processor(self.controlnet_type) if self.processor: - logger.info("Controlnet: Processing image") + self.logger.info("Controlnet: Processing image") image = self.processor(image) # resize image to width and height image = image.resize((self.width, self.height)) return image - logger.error("No controlnet processor found") + self.logger.error("No controlnet processor found") def load_controlnet_arguments(self, **kwargs): if not self.is_vid_action: @@ -1552,7 +1548,7 @@ def load_controlnet_arguments(self, **kwargs): return kwargs def unload_unused_models(self): - logger.info("Unloading unused models") + self.logger.info("Unloading unused models") for action in [ "txt2img", "img2img", @@ -1575,7 +1571,7 @@ def clear_memory(self): self.emit("clear_memory_signal") def load_model(self): - logger.info("Loading model") + self.logger.info("Loading model") self.torch_compile_applied = False self.lora_loaded = False self.embeds_loaded = False @@ -1600,7 +1596,7 @@ def load_model(self): if self.pipe is None or self.reload_model: kwargs["from_safetensors"] = self.is_safetensors - logger.info(f"Loading model from scratch {self.reload_model}") + self.logger.info(f"Loading model from scratch {self.reload_model}") self.reset_applied_memory_settings() self.send_model_loading_message(self.model_path) @@ -1628,7 +1624,7 @@ def load_model(self): except OSError as e: self.handle_missing_files(self.action) else: - logger.info(f"Loading model {self.model_path} from PRETRAINED") + self.logger.info(f"Loading model {self.model_path} from PRETRAINED") scheduler = self.load_scheduler() if scheduler: kwargs["scheduler"] = scheduler @@ -1647,12 +1643,12 @@ def load_model(self): Initialize pipe for video to video zero """ if self.pipe and self.is_vid2vid: - logger.info("Initializing pipe for vid2vid") + self.logger.info("Initializing pipe for vid2vid") self.pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) self.pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) if self.is_outpaint: - logger.info("Initializing vae for inpaint / outpaint") + self.logger.info("Initializing vae for inpaint / outpaint") self.pipe.vae = AsymmetricAutoencoderKL.from_pretrained( self.inpaint_vae_model, torch_dtype=self.data_type @@ -1677,7 +1673,7 @@ def load_model(self): #self.load_learned_embed_in_clip() def load_ckpt_model(self): - logger.info(f"Loading ckpt file {self.model_path}") + self.logger.info(f"Loading ckpt file {self.model_path}") pipeline = self.download_from_original_stable_diffusion_ckpt(path=self.model_path) return pipeline @@ -1721,11 +1717,11 @@ def download_from_original_stable_diffusion_ckpt(self, path, local_files_only=No local_files_only=False ) except Exception as e: - logger.error(f"Failed to load model from ckpt: {e}") + self.logger.error(f"Failed to load model from ckpt: {e}") return pipe def clear_controlnet(self): - logger.info("Clearing controlnet") + self.logger.info("Clearing controlnet") self._controlnet = None self.clear_memory() self.reset_applied_memory_settings() @@ -1738,14 +1734,14 @@ def load_vae(self): ) def reuse_pipeline(self, do_load_controlnet): - logger.info("Reusing pipeline") + self.logger.info("Reusing pipeline") pipe = None if self.is_txt2img: pipe = self.img2img if self.txt2img is None else self.txt2img elif self.is_img2img: pipe = self.txt2img if self.img2img is None else self.img2img if pipe is None: - logger.warning("Failed to reuse pipeline") + self.logger.warning("Failed to reuse pipeline") self.clear_controlnet() return kwargs = pipe.components @@ -1804,7 +1800,7 @@ def send_model_loading_message(self, model_name): self.emit("status_signal", message) def prepare_model(self): - logger.info("Prepare model") + self.logger.info("Prepare model") # get model and switch to it # get models from database @@ -1824,18 +1820,18 @@ def prepare_model(self): def unload_controlnet(self): if self.pipe: - logger.info("Unloading controlnet") + self.logger.info("Unloading controlnet") self.pipe.controlnet = None self.controlnet_loaded = False def handle_missing_files(self, action): if not self.attempt_download: if self.is_ckpt_model or self.is_safetensors: - logger.info("Required files not found, attempting download") + self.logger.info("Required files not found, attempting download") else: import traceback traceback.print_exc() - logger.info("Model not found, attempting download") + self.logger.info("Model not found, attempting download") # check if we have an internet connection if self.allow_online_when_missing_files: self.emit("status_signal", "Downloading model files") diff --git a/src/airunner/aihandler/transformer_runner.py b/src/airunner/aihandler/transformer_runner.py index 567d4efce..5db67a6f6 100644 --- a/src/airunner/aihandler/transformer_runner.py +++ b/src/airunner/aihandler/transformer_runner.py @@ -99,12 +99,10 @@ def move_to_device(self, device=None): def unload_tokenizer(self): self.logger.info("Unloading tokenizer") self.tokenizer = None - self.engine.clear_memory() def unload_model(self): self.model = None self.processor = None - self.engine.clear_memory() def quantization_config(self): config = None diff --git a/src/airunner/signal_mediator.py b/src/airunner/signal_mediator.py index e152bec5c..76f0d69da 100644 --- a/src/airunner/signal_mediator.py +++ b/src/airunner/signal_mediator.py @@ -29,7 +29,10 @@ def register(self, signal_name, slot_parent): # Create a new Signal instance for this signal name self.signals[signal_name] = Signal() # Connect the Signal's pyqtSignal to the receive method of the slot parent - self.signals[signal_name].signal.connect(getattr(slot_parent, f"on_{signal_name}")) + try: + self.signals[signal_name].signal.connect(getattr(slot_parent, f"on_{signal_name}")) + except Exception as e: + print(f"Error connecting signal {signal_name}", e) def emit(self, signal_name, data=None): if signal_name in self.signals: diff --git a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py index 291eeb9bf..72d85fcd0 100644 --- a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py +++ b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py @@ -10,7 +10,7 @@ from PyQt6.QtWidgets import QGraphicsPixmapItem from PyQt6 import QtWidgets, QtCore from PyQt6.QtCore import pyqtSlot -from PyQt6.QtWidgets import QGraphicsItemGroup +from PyQt6.QtWidgets import QGraphicsItemGroup, QGraphicsItem from airunner.workers.image_data_worker import ImageDataWorker from airunner.aihandler.logger import Logger @@ -172,23 +172,6 @@ def current_active_image(self): def current_active_image(self, value): self.add_image_to_current_layer(value) - def add_image_to_current_layer(self,value): - self.logger.info("Adding image to current layer") - layer_index = self.settings["current_layer_index"] - base_64_image = "" - - try: - if value: - buffered = io.BytesIO() - value.save(buffered, format="PNG") - base_64_image = base64.b64encode(buffered.getvalue()) - except Exception as e: - self.logger.error(e) - - settings = self.settings - settings["layers"][layer_index]["base_64_image"] = base_64_image - self.settings = settings - @property def layer_container_widget(self): # TODO @@ -204,6 +187,7 @@ def __init__(self, *args, **kwargs): self.register("main_window_loaded_signal", self) self._zoom_level = 1 self.canvas_container.resizeEvent = self.window_resized + self.pixmaps = {} self.image_data_worker = self.create_worker(ImageDataWorker) self.canvas_resize_worker = self.create_worker(CanvasResizeWorker) @@ -494,6 +478,28 @@ def set_canvas_color(self): return self.scene.setBackgroundBrush(QBrush(QColor(self.canvas_color))) + def add_image_to_current_layer(self,value): + self.logger.info("Adding image to current layer") + layer_index = self.settings["current_layer_index"] + base_64_image = "" + + try: + if value: + buffered = io.BytesIO() + value.save(buffered, format="PNG") + base_64_image = base64.b64encode(buffered.getvalue()) + except Exception as e: + self.logger.error(e) + + settings = self.settings + # If there's an existing image in the layer, remove it from the scene + if layer_index in self.pixmaps and isinstance(self.pixmaps[layer_index], QGraphicsItem): + if self.pixmaps[layer_index].scene() == self.scene: + self.scene.removeItem(self.pixmaps[layer_index]) + del self.pixmaps[layer_index] + settings["layers"][layer_index]["base_64_image"] = base_64_image + self.settings = settings + def draw_layers(self): layers = self.settings["layers"] for index, layer in enumerate(layers): @@ -507,18 +513,23 @@ def draw_layers(self): ) if not layer["visible"]: - if layer["pixmap"] in self.scene.items(): - self.scene.removeItem(layer["pixmap"]) + if index in self.pixmaps and isinstance(self.pixmaps[index], QGraphicsItem) and self.pixmaps[index].scene() == self.scene: + self.scene.removeItem(self.pixmaps[index]) elif layer["visible"]: - if type(layer["pixmap"]) is not DraggablePixmap or layer["pixmap"] not in self.scene.items(): - print("adding to scene") - layer["pixmap"].convertFromImage(ImageQt(image)) - layer["pixmap"] = DraggablePixmap(self, layer["pixmap"]) - self.emit("update_layer_signal", dict( - layer=layer, - index=index - )) - self.scene.addItem(layer["pixmap"]) + # If there's an existing pixmap in the layer, remove it from the scene + if index in self.pixmaps and isinstance(self.pixmaps[index], QGraphicsItem): + if self.pixmaps[index].scene() == self.scene: + self.scene.removeItem(self.pixmaps[index]) + del self.pixmaps[index] + pixmap = QPixmap() + pixmap.convertFromImage(ImageQt(image)) + self.pixmaps[index] = DraggablePixmap(self, pixmap) + self.emit("update_layer_signal", dict( + layer=layer, + index=index + )) + if self.pixmaps[index].scene() != self.scene: + self.scene.addItem(self.pixmaps[index]) continue def set_scene_rect(self): diff --git a/src/airunner/widgets/canvas_plus/draggables.py b/src/airunner/widgets/canvas_plus/draggables.py index 457aea4ad..f482869be 100644 --- a/src/airunner/widgets/canvas_plus/draggables.py +++ b/src/airunner/widgets/canvas_plus/draggables.py @@ -3,17 +3,21 @@ from PyQt6.QtCore import QRect from PyQt6.QtGui import QBrush, QColor, QPen, QPixmap, QPainter from PyQt6.QtWidgets import QGraphicsItem, QGraphicsPixmapItem +from airunner.windows.main.settings_mixin import SettingsMixin +from airunner.mediator_mixin import MediatorMixin -class DraggablePixmap(QGraphicsPixmapItem): +class DraggablePixmap(QGraphicsPixmapItem, MediatorMixin, SettingsMixin): def __init__(self, parent, pixmap): self.parent = parent super().__init__(pixmap) + MediatorMixin.__init__(self) + SettingsMixin.__init__(self) self.pixmap = pixmap self.setFlag(QGraphicsItem.GraphicsItemFlag.ItemIsMovable, True) def snap_to_grid(self): - cell_size = self.parent.app.settings["grid_settings"]["cell_size"] + cell_size = self.grid_settings["cell_size"] x = round(self.x() / cell_size) * cell_size y = round(self.y() / cell_size) * cell_size x += self.parent.last_pos.x() diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index 9f4cdf874..c92617b9c 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -20,7 +20,6 @@ class GeneratorForm(BaseWidget): deterministic_seed = None initialized = False parent = None - generate_signal = pyqtSignal(dict) current_prompt_value = None current_negative_prompt_value = None @@ -109,8 +108,8 @@ def on_application_settings_changed_signal(self): self.activate_ai_mode() @pyqtSlot(object) - def on_progress_signal(self, response): - self.handle_progress_bar(response["message"]) + def on_progress_signal(self, message): + self.handle_progress_bar(message) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -306,7 +305,7 @@ def do_generate(self, extra_options=None, seed=None, do_deterministic=False, ove n_iter = int(override_data.get("n_iter", 1)) n_samples = int(override_data.get("n_samples", self.generator_settings["n_samples"])) # iterate over all keys in model_data - model_data = {} + model_data=self.generator_settings for k,v in override_data.items(): if k.startswith("model_data_"): model_data[k.replace("model_data_", "")] = v @@ -324,6 +323,7 @@ def do_generate(self, extra_options=None, seed=None, do_deterministic=False, ove print(model_data, self.generator_settings["model"]) name = model_data["name"] if "name" in model_data else self.generator_settings["model"] model = self.get_service("ai_model_by_name")(name) + print("MODEL:", model, name) # set the model data, first using model_data pulled from the override_data model_data = dict( name=model_data.get("name", model["name"]), @@ -427,8 +427,6 @@ def do_generate(self, extra_options=None, seed=None, do_deterministic=False, ove Emitting generate_signal with options allows us to pass more options to the dict from modal windows such as the image interpolation window. """ - self.emit("generate_image_signal", options) - memory_options = self.get_memory_options() self.emit("image_generate_request_signal", dict( diff --git a/src/airunner/widgets/generator_form/generator_tab_widget.py b/src/airunner/widgets/generator_form/generator_tab_widget.py index 5d212cc27..1684ba1ce 100644 --- a/src/airunner/widgets/generator_form/generator_tab_widget.py +++ b/src/airunner/widgets/generator_form/generator_tab_widget.py @@ -7,7 +7,6 @@ class GeneratorTabWidget(BaseWidget): widget_class_ = Ui_generator_tab - generate_signal = pyqtSignal(dict) data = {} clip_skip_disabled_tabs = [] clip_skip_disabled_sections = ["upscale", "superresolution", "txt2vid"] diff --git a/src/airunner/windows/main/ai_model_mixin.py b/src/airunner/windows/main/ai_model_mixin.py index abb202ef5..25a7755d7 100644 --- a/src/airunner/windows/main/ai_model_mixin.py +++ b/src/airunner/windows/main/ai_model_mixin.py @@ -5,24 +5,30 @@ class AIModelMixin: def __init__(self): - ServiceLocator.register("ai_model_paths", self.ai_model_paths) - ServiceLocator.register("ai_models_find", self.ai_models_find) - ServiceLocator.register("ai_model_categories", self.ai_model_categories) - ServiceLocator.register("ai_model_pipeline_actions", self.ai_model_pipeline_actions) - ServiceLocator.register("ai_model_versions", self.ai_model_versions) - ServiceLocator.register("ai_model_get_disabled_default", self.ai_model_get_disabled_default) - ServiceLocator.register("ai_model_get_all", self.ai_model_get_all) - ServiceLocator.register("ai_model_update", self.ai_model_update) - ServiceLocator.register("ai_model_get_by_filter", self.ai_model_get_by_filter) - ServiceLocator.register("ai_model_names_by_section", self.ai_model_names_by_section) - ServiceLocator.register("ai_models_by_category", self.ai_models_by_category) + services = [ + "ai_model_paths", + "ai_models_find", + "ai_model_categories", + "ai_model_pipeline_actions", + "ai_model_versions", + "ai_model_get_disabled_default", + "ai_model_get_all", + "ai_model_update", + "ai_model_get_by_filter", + "ai_model_names_by_section", + "ai_models_by_category", + "ai_model_by_name" + ] + + for service in services: + ServiceLocator.register(service, getattr(self, service)) self.register("ai_model_save_or_update_signal", self) self.register("ai_model_delete_signal", self) self.register("ai_model_create_signal", self) def ai_model_get_by_filter(self, filter_dict): - return [item for item in self.settings["ai_models"] if all(item.get(k) == v for k, v in filter_dict.items())] + return [item for item in self.ai_models if all(item.get(k) == v for k, v in filter_dict.items())] @pyqtSlot(object) def on_ai_model_create_signal(self, item): @@ -32,7 +38,7 @@ def on_ai_model_create_signal(self, item): def ai_model_update(self, item): settings = self.settings - for i, existing_item in enumerate(self.settings["ai_models"]): + for i, existing_item in enumerate(self.ai_models): if existing_item['name'] == item['name']: settings["ai_models"][i] = item self.settings = settings @@ -41,20 +47,20 @@ def ai_model_update(self, item): @pyqtSlot(object) def on_ai_model_delete_signal(self, item): settings = self.settings - settings["ai_models"] = [existing_item for existing_item in self.settings["ai_models"] if existing_item['name'] != item['name']] + settings["ai_models"] = [existing_item for existing_item in self.ai_models if existing_item['name'] != item['name']] self.settings = settings def ai_model_names_by_section(self, section): - return [model["name"] for model in self.settings["ai_models"] if model["section"] == section] + return [model["name"] for model in self.ai_models if model["section"] == section] def models_by_pipeline_action(self, pipeline_action): - return [model for model in self.settings["ai_models"] if model["pipeline_action"] == pipeline_action] + return [model for model in self.ai_models if model["pipeline_action"] == pipeline_action] def ai_models_find(self, search="", default=False): - return [model for model in self.settings["ai_models"] if model["is_default"] == default and search.lower() in model["name"].lower()] + return [model for model in self.ai_models if model["is_default"] == default and search.lower() in model["name"].lower()] def ai_model_get_disabled_default(self): - return [model for model in self.settings["ai_models"] if model["is_default"] == True and model["enabled"] == False] + return [model for model in self.ai_models if model["is_default"] == True and model["enabled"] == False] @pyqtSlot(object) def on_ai_model_save_or_update_signal(self, model_data): @@ -66,7 +72,7 @@ def on_ai_model_save_or_update_signal(self, model_data): self.emit("ai_model_create_signal", model_data) def ai_model_paths(self, model_type=None, pipeline_action=None): - models = self.settings["ai_models"] + models = self.ai_models if model_type: models = [model for model in models if "model_type" in model and model["model_type"] == model_type] if pipeline_action: @@ -75,22 +81,22 @@ def ai_model_paths(self, model_type=None, pipeline_action=None): return [model["path"] for model in models] def ai_model_categories(self): - return [model["category"] for model in self.settings["ai_models"]] + return [model["category"] for model in self.ai_models] def ai_model_pipeline_actions(self): - return [model["pipeline_action"] for model in self.settings["ai_models"]] + return [model["pipeline_action"] for model in self.ai_models] def ai_model_versions(self): - return [model["version"] for model in self.settings["ai_models"]] + return [model["version"] for model in self.ai_models] def ai_models_by_category(self, category): - return [model for model in self.settings["ai_models"] if model["category"] == category] + return [model for model in self.ai_models if model["category"] == category] def ai_model_by_name(self, name): try: - return [model for model in self.settings["ai_models"] if model["name"] == name][0] + return [model for model in self.ai_models if model["name"] == name][0] except Exception as e: self.logger.error(f"Error finding model by name: {name}") def ai_model_get_all(self): - return self.settings["ai_models"] \ No newline at end of file + return self.ai_models \ No newline at end of file diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index c5dadf0e5..405f6c210 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -286,6 +286,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) MediatorMixin.__init__(self) SettingsMixin.__init__(self) + + self.update_settings() + LoraMixin.__init__(self) LayerMixin.__init__(self) EmbeddingMixin.__init__(self) diff --git a/src/airunner/windows/main/settings_mixin.py b/src/airunner/windows/main/settings_mixin.py index 86a177b87..26e8a6905 100644 --- a/src/airunner/windows/main/settings_mixin.py +++ b/src/airunner/windows/main/settings_mixin.py @@ -18,9 +18,7 @@ def __init__(self): ServiceLocator.register("get_settings", self.get_settings) ServiceLocator.register("set_settings", self.set_settings) self.register("reset_settings_signal", self) - - def get_settings(self): - return self.application_settings.value("settings", dict( + self.default_settings = dict( current_layer_index=0, ocr_enabled=False, tts_enabled=False, @@ -374,9 +372,18 @@ def get_settings(self): controlnet=controlnet_bootstrap_data, ai_models=model_bootstrap_data, image_filters=imagefilter_bootstrap_data, - ), - type=dict - ) + ) + + def update_settings(self): + default_settings = self.default_settings + current_settings = self.settings + for k,v in default_settings.items(): + if k not in current_settings: + current_settings[k] = v + self.settings = current_settings + + def get_settings(self): + return self.application_settings.value("settings", self.default_settings, type=dict) def set_settings(self, val): self.application_settings.setValue("settings", val) @@ -388,6 +395,16 @@ def on_reset_settings_signal(self): self.application_settings.sync() self.set_settings(self.get_settings()) + @property + def ai_models(self): + return self.get_settings()["ai_models"] + + @ai_models.setter + def ai_models(self, val): + settings = self.get_settings() + settings["ai_models"] = val + self.set_settings(settings) + @property def generator_settings(self): return self.get_settings()["generator_settings"] diff --git a/src/airunner/workers/worker.py b/src/airunner/workers/worker.py index 04dc5df4d..27b3aff2b 100644 --- a/src/airunner/workers/worker.py +++ b/src/airunner/workers/worker.py @@ -4,20 +4,18 @@ from airunner.aihandler.logger import Logger from airunner.mediator_mixin import MediatorMixin +from airunner.windows.main.settings_mixin import SettingsMixin -class Worker(QObject, MediatorMixin): +class Worker(QObject, MediatorMixin, SettingsMixin): queue_type = "get_next_item" finished = pyqtSignal() - @property - def settings(self): - return self.application_settings.value("settings") - def __init__(self, prefix="Worker"): self.prefix = prefix super().__init__() MediatorMixin.__init__(self) + SettingsMixin.__init__(self) self.logger = Logger(prefix=prefix) self.running = False self.queue = queue.Queue()