diff --git a/src/airunner/aihandler/stablediffusion/sd_handler.py b/src/airunner/aihandler/stablediffusion/sd_handler.py index 65d2ae777..062612f37 100644 --- a/src/airunner/aihandler/stablediffusion/sd_handler.py +++ b/src/airunner/aihandler/stablediffusion/sd_handler.py @@ -305,7 +305,10 @@ def unload_stable_diffusion(self): clear_memory() self.change_model_status(ModelType.SD, ModelStatus.UNLOADED) - def handle_generate_signal(self, message: dict): + def handle_generate_signal(self, message: dict=None): + if self._model_status is not ModelStatus.LOADED: + self.load_stable_diffusion() + if self._current_state not in ( HandlerState.GENERATING, HandlerState.PREPARING_TO_GENERATE @@ -319,6 +322,11 @@ def handle_generate_signal(self, message: dict): except Exception as e: self.logger.error(f"Error generating image: {e}") response = None + print("HANDLE GENERATE SIGNAL", message) + if message is not None: + callback = message.get("callback", None) + if callback: + callback(message) self.emit_signal(SignalCode.ENGINE_RESPONSE_WORKER_RESPONSE_SIGNAL, { 'code': EngineResponseCode.IMAGE_GENERATED, 'message': response diff --git a/src/airunner/aihandler/tts/speecht5_tts_handler.py b/src/airunner/aihandler/tts/speecht5_tts_handler.py index a5c319770..0cc48dc60 100644 --- a/src/airunner/aihandler/tts/speecht5_tts_handler.py +++ b/src/airunner/aihandler/tts/speecht5_tts_handler.py @@ -90,9 +90,11 @@ def speaker_embeddings_path(self): ) def generate(self, message): - if self._do_interrupt or self._paused: + if self._model_status is not ModelStatus.LOADED: return None + if self._do_interrupt or self._paused: + return None try: return self._do_generate(message) except torch.cuda.OutOfMemoryError: diff --git a/src/airunner/enums.py b/src/airunner/enums.py index cf99dff71..bebde2ff5 100644 --- a/src/airunner/enums.py +++ b/src/airunner/enums.py @@ -183,6 +183,8 @@ class SignalCode(Enum): QUIT_APPLICATION = "quit" TOGGLE_FULLSCREEN_SIGNAL = "fullscreen_signal" TOGGLE_TTS_SIGNAL = "toggle_tts_signal" + TOGGLE_SD_SIGNAL = "toggle_sd_signal" + TOGGLE_LLM_SIGNAL = "toggle_llm_signal" START_AUTO_IMAGE_GENERATION_SIGNAL = "start_auto_image_generation_signal" STOP_AUTO_IMAGE_GENERATION_SIGNAL = "stop_auto_image_generation_signal" LINES_UPDATED_SIGNAL = "lines_updated_signal" diff --git a/src/airunner/widgets/canvas/custom_scene.py b/src/airunner/widgets/canvas/custom_scene.py index df9f74b81..d8bb6c21b 100644 --- a/src/airunner/widgets/canvas/custom_scene.py +++ b/src/airunner/widgets/canvas/custom_scene.py @@ -412,6 +412,9 @@ def on_image_generated_signal(self, response): else: self.logger.error(f"Unhandled response code: {code}") self.emit_signal(SignalCode.APPLICATION_STOP_SD_PROGRESS_BAR_SIGNAL) + callback = response.get("callback", None) + if callback: + callback(response) def on_canvas_clear_signal(self): self.update_current_settings("image", None) diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index 34074d9b0..23ddf0bab 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -169,10 +169,36 @@ def on_stop_image_generator_progress_bar_signal(self, _data): def on_progress_signal(self, message): self.handle_progress_bar(message) + ########################################################################## + # LLM Generated Image handlers + ########################################################################## def on_llm_image_prompt_generated_signal(self, data): + """ + This slot is called after an LLM has generated the prompts for an image. + It sets the prompts in the generator form UI and continues the image generation process. + """ + + # Send a messagae to the user as chatbot letting them know that the image is generating + self.emit_signal( + SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION, + dict( + message="Your image is generating...", + is_first_message=True, + is_end_of_message=True, + name=self.chatbot.name + ) + ) + + # Unload the LLM + if self.application_settings.llm_enabled: + self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict( + callback=self.unload_llm_callback + )) + + # Set the prompts in the generator form UI data = self.extract_json_from_message(data["message"]) - prompt = data.get("description", None) - secondary_prompt = data.get("composition", None) + prompt = data.get("composition", None) + secondary_prompt = data.get("description", None) prompt_type = data.get("type", ImageCategory.PHOTO.value) if prompt_type == "photo": negative_prompt = PHOTO_REALISTIC_NEGATIVE_PROMPT @@ -182,7 +208,78 @@ def on_llm_image_prompt_generated_signal(self, data): self.ui.negative_prompt.setPlainText(negative_prompt) self.ui.secondary_prompt.setPlainText(secondary_prompt) self.ui.secondary_negative_prompt.setPlainText(negative_prompt) - self.handle_generate_button_clicked() + + def unload_llm_callback(self, data:dict=None): + """ + Callback function to be called after the LLM has been unloaded. + """ + if not self.application_settings.sd_enabled: + # If SD is not enabled, enable it and then emit a signal to generate the image + # The callback function is handled by the signal handler for the SD_LOAD_SIGNAL. + # The finalize function is a callback which is called after the image has been generated. + self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict( + callback=self.handle_generate_button_clicked, + finalize=self.finalize_image_generated_by_llm + )) + else: + # If SD is already enabled, emit a signal to generate the image. + # The finalize function is a callback which is called after the image has been generated. + self.handle_generate_button_clicked(dict( + finalize=self.finalize_image_generated_by_llm + )) + + def finalize_image_generated_by_llm(self, data): + """ + Callback function to be called after the image has been generated. + """ + + # Create a message to be sent to the user as a chatbot message + image_generated_message = dict( + message="Your image has been generated", + is_first_message=True, + is_end_of_message=True, + name=self.chatbot.name + ) + + # If SD is enabled, emit a signal to unload SD. + if self.application_settings.sd_enabled: + # If LLM is disabled, emit a signal to load it. + if not self.application_settings.llm_enabled: + self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict( + callback=lambda d: self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict( + callback=lambda d: self.emit_signal( + SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION, + image_generated_message + ) + )) + )) + else: + self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict( + callback=lambda d: self.emit_signal( + SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION, + image_generated_message + ) + )) + else: + # If SD is disabled and LLM is disabled, emit a signal to load LLM + # with a callback to add the image generated message to the conversation. + if not self.application_settings.llm_enabled: + self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict( + callback=lambda d: self.emit_signal( + SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION, + image_generated_message + ) + )) + else: + # If SD is disabled and LLM is enabled, emit a signal to add + # the image generated message to the conversation. + self.emit_signal( + SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION, + image_generated_message + ) + ########################################################################## + # End LLM Generated Image handlers + ########################################################################## def _set_chatbot_mood(self): self.ui.mood_label.setText(self.chatbot.bot_mood) @@ -232,8 +329,17 @@ def handle_image_presets_changed(self, val): def do_generate_image_from_image_signal_handler(self, _data): self.do_generate() - def do_generate(self): - self.emit_signal(SignalCode.DO_GENERATE_SIGNAL) + def do_generate(self, data=None): + print("DO GENERATE", data) + if data: + finalize = data.get("finalize", None) + if finalize: + data = dict( + callback=finalize + ) + else: + data = None + self.emit_signal(SignalCode.DO_GENERATE_SIGNAL, data) def activate_ai_mode(self): ai_mode = self.application_settings.ai_mode @@ -259,18 +365,19 @@ def handle_second_prompt_changed(self): def handle_second_negative_prompt_changed(self): pass - def handle_generate_button_clicked(self): + def handle_generate_button_clicked(self, data=None): + print("HANDLE GENERATE BUTTON CLICKED", data) self.start_progress_bar() - self.generate() + self.generate(data) @Slot() def handle_interrupt_button_clicked(self): self.emit_signal(SignalCode.INTERRUPT_IMAGE_GENERATION_SIGNAL) - def generate(self): + def generate(self, data=None): if self.generator_settings.random_seed: self.seed = random_seed() - self.do_generate() + self.do_generate(data) self.seed_override = None def do_generate_image(self): diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index 40c0687f9..5b7e597b5 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -19,6 +19,7 @@ QCheckBox, QInputDialog ) from bs4 import BeautifulSoup +from tensorflow.python.ops.gen_dataset_ops import ModelDataset from airunner.aihandler.llm.agent.actions.bash_execute import bash_execute from airunner.aihandler.llm.agent.actions.show_path import show_path @@ -320,6 +321,8 @@ def register_signals(self): self.register(SignalCode.WRITE_FILE, self.on_write_file_signal) self.register(SignalCode.TOGGLE_FULLSCREEN_SIGNAL, self.on_toggle_fullscreen_signal) self.register(SignalCode.TOGGLE_TTS_SIGNAL, self.on_toggle_tts) + self.register(SignalCode.TOGGLE_SD_SIGNAL, self.on_toggle_sd) + self.register(SignalCode.TOGGLE_LLM_SIGNAL, self.on_toggle_llm) self.register(SignalCode.APPLICATION_RESET_SETTINGS_SIGNAL, self.action_reset_settings) self.register(SignalCode.APPLICATION_RESET_PATHS_SIGNAL, self.on_reset_paths_signal) self.register(SignalCode.REFRESH_STYLESHEET_SIGNAL, self.refresh_stylesheet) @@ -554,9 +557,6 @@ def on_toggle_fullscreen_signal(self): else: self.showFullScreen() - def on_toggle_tts(self): - self.tts_button_toggled(not self.application_settings.tts_enabled) - @Slot(bool) def action_outpaint_toggled(self, val: bool): self.update_outpaint_settings("enabled", val) @@ -574,76 +574,33 @@ def action_run_setup_wizard_clicked(self): self.show_setup_wizard() @Slot(bool) - def action_toggle_llm(self, val): - if self._model_status[ModelType.LLM] is ModelStatus.LOADING: - val = not val - self.ui.actionToggle_LLM.blockSignals(True) - self.ui.actionToggle_LLM.setChecked(val) - self.ui.actionToggle_LLM.blockSignals(False) - QApplication.processEvents() - self.update_application_settings("llm_enabled", val) - if self._model_status[ModelType.LLM] is not ModelStatus.LOADING: - if val: - self.emit_signal(SignalCode.LLM_LOAD_SIGNAL) - else: - self.emit_signal(SignalCode.LLM_UNLOAD_SIGNAL) + def action_toggle_llm(self, val: bool): + self.on_toggle_llm(val=val) @Slot(bool) def action_image_generator_toggled(self, val: bool): - if self._model_status[ModelType.SD] is ModelStatus.LOADING: - val = not val - self.ui.actionToggle_Stable_Diffusion.blockSignals(True) - self.ui.actionToggle_Stable_Diffusion.setChecked(val) - self.ui.actionToggle_Stable_Diffusion.blockSignals(False) - QApplication.processEvents() - self.update_application_settings("sd_enabled", val) - if self._model_status[ModelType.SD] is not ModelStatus.LOADING: - if val: - self.emit_signal(SignalCode.SD_LOAD_SIGNAL) - else: - self.emit_signal(SignalCode.SD_UNLOAD_SIGNAL) + self.on_toggle_sd(val=val) @Slot(bool) - def action_controlnet_toggled(self, val: bool): - if self._model_status[ModelType.CONTROLNET] is ModelStatus.LOADING: - val = not val - self.ui.actionToggle_Controlnet.blockSignals(True) - self.ui.actionToggle_Controlnet.setChecked(val) - self.ui.actionToggle_Controlnet.blockSignals(False) - QApplication.processEvents() - self.update_controlnet_settings("enabled", val) - for widget in [self.ui.actionToggle_Controlnet, self.ui.enable_controlnet]: - widget.blockSignals(True) - widget.setChecked(val) - widget.blockSignals(False) - if self._model_status[ModelType.CONTROLNET] is not ModelStatus.LOADING: - if val: - self.emit_signal(SignalCode.CONTROLNET_LOAD_SIGNAL) - else: - self.emit_signal(SignalCode.CONTROLNET_UNLOAD_SIGNAL) + def tts_button_toggled(self, val: bool): + self.on_toggle_tts(val=val) @Slot(bool) - def tts_button_toggled(self, val): - if self._model_status[ModelType.TTS] is ModelStatus.LOADING: - val = not val - self.ui.actionToggle_Text_to_Speech.blockSignals(True) - self.ui.actionToggle_Text_to_Speech.setChecked(val) - self.ui.actionToggle_Text_to_Speech.blockSignals(False) - QApplication.processEvents() - self.update_application_settings("tts_enabled", val) - if self._model_status[ModelType.TTS] is not ModelStatus.LOADING: - if val: - self.emit_signal(SignalCode.TTS_ENABLE_SIGNAL) - else: - self.emit_signal(SignalCode.TTS_DISABLE_SIGNAL) + def action_controlnet_toggled(self, val: bool): + self._update_action_button( + ModelType.CONTROLNET, + self.ui.actionToggle_Controlnet, + val, + SignalCode.CONTROLNET_LOAD_SIGNAL, + SignalCode.CONTROLNET_UNLOAD_SIGNAL, + "enabled" + ) @Slot(bool) def v2t_button_toggled(self, val): if self._model_status[ModelType.STT] is ModelStatus.LOADING: val = not val - self.ui.actionToggle_Speech_to_Text.blockSignals(True) - self.ui.actionToggle_Speech_to_Text.setChecked(val) - self.ui.actionToggle_Speech_to_Text.blockSignals(False) + self._update_action_button(self.ui.actionToggle_Speech_to_Text, val) QApplication.processEvents() self.update_application_settings("stt_enabled", val) if not val: @@ -651,6 +608,73 @@ def v2t_button_toggled(self, val): else: self.emit_signal(SignalCode.STT_LOAD_SIGNAL) + def on_toggle_llm(self, data:dict=None, val=None): + if val is None: + val = not self.application_settings.llm_enabled + self._update_action_button( + ModelType.LLM, + self.ui.actionToggle_LLM, + val, + SignalCode.LLM_LOAD_SIGNAL, + SignalCode.LLM_UNLOAD_SIGNAL, + "llm_enabled", + data + ) + + def on_toggle_sd(self, data:dict=None, val=None): + if val is None: + val = not self.application_settings.sd_enabled + self._update_action_button( + ModelType.SD, + self.ui.actionToggle_Stable_Diffusion, + val, + SignalCode.SD_LOAD_SIGNAL, + SignalCode.SD_UNLOAD_SIGNAL, + "sd_enabled", + data + ) + + def on_toggle_tts(self, data:dict=None, val=None): + if val is None: + val = not self.application_settings.sd_enabled + self._update_action_button( + ModelType.TTS, + self.ui.actionToggle_Text_to_Speech, + val, + SignalCode.TTS_ENABLE_SIGNAL, + SignalCode.TTS_DISABLE_SIGNAL, + "tts_enabled", + data + ) + + def _update_action_button( + self, + model_type, + element, + val:bool, + load_signal: SignalCode, + unload_signal: SignalCode, + application_setting:str=None, + data:dict=None + ): + if self._model_status[model_type] is ModelStatus.LOADING: + val = not val + element.blockSignals(True) + element.setChecked(val) + element.blockSignals(False) + QApplication.processEvents() + if application_setting: + if model_type is ModelType.CONTROLNET: + self.update_controlnet_settings(application_setting, val) + else: + self.update_application_settings(application_setting, val) + print("UDPATE ACTION BUTTON", data) + if self._model_status[model_type] is not ModelStatus.LOADING: + if val: + self.emit_signal(load_signal, data) + else: + self.emit_signal(unload_signal, data) + @Slot() def action_stats_triggered(self): from airunner.widgets.stats.stats_widget import StatsWidget diff --git a/src/airunner/workers/llm_generate_worker.py b/src/airunner/workers/llm_generate_worker.py index c3059a58c..951dc57c5 100644 --- a/src/airunner/workers/llm_generate_worker.py +++ b/src/airunner/workers/llm_generate_worker.py @@ -22,12 +22,18 @@ def __init__(self, prefix=None, agent_options=None): def on_llm_request_worker_response_signal(self, message: dict): self.add_to_queue(message) - def on_llm_on_unload_signal(self): + def on_llm_on_unload_signal(self, data): self.logger.debug("Unloading LLM") self.llm.unload() + callback = data.get("callback", None) + if callback: + callback(data) - def on_llm_load_model_signal(self): + def on_llm_load_model_signal(self, data): self.llm.load() + callback = data.get("callback", None) + if callback: + callback(data) def on_llm_clear_history_signal(self): self.llm.clear_history() diff --git a/src/airunner/workers/sd_worker.py b/src/airunner/workers/sd_worker.py index ade6c6338..56fe557e2 100644 --- a/src/airunner/workers/sd_worker.py +++ b/src/airunner/workers/sd_worker.py @@ -113,21 +113,29 @@ def on_unload_controlnet_signal(self, _data=None): thread = threading.Thread(target=self._unload_controlnet) thread.start() - def on_load_stablediffusion_signal(self): + def on_load_stablediffusion_signal(self, data:dict=None): if self.sd: - thread = threading.Thread(target=self._load_sd) + thread = threading.Thread(target=self._load_sd, args=(data,)) thread.start() - def on_unload_stablediffusion_signal(self, _data=None): + def on_unload_stablediffusion_signal(self, data=None): if self.sd: - thread = threading.Thread(target=self._unload_sd) + thread = threading.Thread(target=self._unload_sd, args=(data,)) thread.start() - def _load_sd(self): + def _load_sd(self, data:dict=None): self.sd.load_stable_diffusion() + if data: + callback = data.get("callback", None) + if callback is not None: + callback(data) - def _unload_sd(self): + def _unload_sd(self, data:dict=None): self.sd.unload_stable_diffusion() + if data: + callback = data.get("callback", None) + if callback is not None: + callback(data) def _load_controlnet(self): self.sd.load_controlnet() diff --git a/src/airunner/workers/tts_generator_worker.py b/src/airunner/workers/tts_generator_worker.py index defaa7e02..0b31d4765 100644 --- a/src/airunner/workers/tts_generator_worker.py +++ b/src/airunner/workers/tts_generator_worker.py @@ -132,8 +132,10 @@ def generate(self, message): if type(message) == dict: message = message.get("message", "") - - response = self.tts.generate(message) + + response = None + if self.tts: + response = self.tts.generate(message) if self.do_interrupt: return