diff --git a/setup.py b/setup.py index 4c9044e1e..5ecc7d35f 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="airunner", - version="3.0.20", + version="3.0.21", author="Capsize LLC", description="A Stable Diffusion GUI", long_description=open("README.md", "r", encoding="utf-8").read(), diff --git a/src/airunner/enums.py b/src/airunner/enums.py index 22296306b..8d077f3f2 100644 --- a/src/airunner/enums.py +++ b/src/airunner/enums.py @@ -244,6 +244,9 @@ class SignalCode(Enum): HISTORY_UPDATED = enum.auto() CANVAS_IMAGE_UPDATED_SIGNAL = enum.auto() + UNLOAD_NON_SD_MODELS = enum.auto() + LOAD_NON_SD_MODELS = enum.auto() + class EngineResponseCode(Enum): STATUS = 100 ERROR = 200 diff --git a/src/airunner/handlers/llm/agent/base_agent.py b/src/airunner/handlers/llm/agent/base_agent.py index 4c0e1b1e9..8cb664a36 100644 --- a/src/airunner/handlers/llm/agent/base_agent.py +++ b/src/airunner/handlers/llm/agent/base_agent.py @@ -545,6 +545,7 @@ def run( is_first_message=True, is_end_of_message=True, name=self.botname, + action=action ) ) @@ -614,6 +615,7 @@ def run_with_thread( is_first_message=is_first_message, is_end_of_message=False, name=self.botname, + action=LLMActionType.CHAT ) ) is_first_message = False @@ -668,6 +670,7 @@ def run_with_thread( is_first_message=is_first_message, is_end_of_message=is_end_of_message, name=self.botname, + action=action ) ) else: @@ -678,6 +681,7 @@ def run_with_thread( is_first_message=is_first_message, is_end_of_message=is_end_of_message, name=self.botname, + action=action ) ) is_first_message = False @@ -702,6 +706,7 @@ def run_with_thread( is_first_message=is_first_message, is_end_of_message=is_end_of_message, name=self.botname, + action=action ) ) is_first_message = False diff --git a/src/airunner/handlers/stablediffusion/sd_handler.py b/src/airunner/handlers/stablediffusion/sd_handler.py index ce2cd5521..009dafcb9 100644 --- a/src/airunner/handlers/stablediffusion/sd_handler.py +++ b/src/airunner/handlers/stablediffusion/sd_handler.py @@ -998,7 +998,10 @@ def _load_embeddings(self): self.logger.error("Pipe is None, unable to load embeddings") return self.logger.debug("Loading embeddings") - self._pipe.unload_textual_inversion() + try: + self._pipe.unload_textual_inversion() + except RuntimeError as e: + self.logger.error(f"Failed to unload embeddings: {e}") session = self.db_handler.get_db_session() embeddings = session.query(Embedding).filter_by( version=self.generator_settings_cached.version diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index 56e48aec4..0c78df71f 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -7,7 +7,7 @@ from airunner.data.models.settings_models import ShortcutKeys from airunner.enums import SignalCode, GeneratorSection, ImageCategory, ImagePreset, StableDiffusionVersion, \ - ModelStatus, ModelType + ModelStatus, ModelType, LLMActionType from airunner.mediator_mixin import MediatorMixin from airunner.settings import PHOTO_REALISTIC_NEGATIVE_PROMPT, ILLUSTRATION_NEGATIVE_PROMPT from airunner.utils.random_seed import random_seed @@ -185,15 +185,15 @@ def on_llm_image_prompt_generated_signal(self, data): message="Your image is generating...", is_first_message=True, is_end_of_message=True, - name=self.chatbot.name + name=self.chatbot.name, + action=LLMActionType.GENERATE_IMAGE ) ) - # Unload the LLM - if self.application_settings.llm_enabled: - self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict( - callback=self.unload_llm_callback - )) + # Unload non-Stable Diffusion models + self.emit_signal(SignalCode.UNLOAD_NON_SD_MODELS, dict( + callback=self.unload_llm_callback + )) # Set the prompts in the generator form UI data = self.extract_json_from_message(data["message"]) @@ -238,45 +238,18 @@ def finalize_image_generated_by_llm(self, data): message="Your image has been generated", is_first_message=True, is_end_of_message=True, - name=self.chatbot.name + name=self.chatbot.name, + action=LLMActionType.GENERATE_IMAGE ) - # If SD is enabled, emit a signal to unload SD. - if self.application_settings.sd_enabled: - # If LLM is disabled, emit a signal to load it. - if not self.application_settings.llm_enabled: - self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict( - callback=lambda d: self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict( - callback=lambda d: self.emit_signal( - SignalCode.LLM_TEXT_STREAMED_SIGNAL, - image_generated_message - ) - )) - )) - else: - self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict( - callback=lambda d: self.emit_signal( - SignalCode.LLM_TEXT_STREAMED_SIGNAL, - image_generated_message - ) - )) - else: - # If SD is disabled and LLM is disabled, emit a signal to load LLM - # with a callback to add the image generated message to the conversation. - if not self.application_settings.llm_enabled: - self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict( - callback=lambda d: self.emit_signal( - SignalCode.LLM_TEXT_STREAMED_SIGNAL, - image_generated_message - ) - )) - else: - # If SD is disabled and LLM is enabled, emit a signal to add - # the image generated message to the conversation. - self.emit_signal( + self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict( + callback=lambda d: self.emit_signal(SignalCode.LOAD_NON_SD_MODELS, dict( + callback=lambda d: self.emit_signal( SignalCode.LLM_TEXT_STREAMED_SIGNAL, image_generated_message ) + )) + )) ########################################################################## # End LLM Generated Image handlers ########################################################################## @@ -382,10 +355,10 @@ def extract_json_from_message(self, message): json_dict = json.loads(json_block) return json_dict except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") + self.logger.error(f"Error decoding JSON block: {e}") return {} else: - print("No JSON block found in the message.") + self.logger.error("No JSON block found in message") return {} def get_memory_options(self): diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index 5eb4f4170..33a7f3bbc 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -527,6 +527,8 @@ def register_signals(self): (SignalCode.TOGGLE_TTS_SIGNAL, self.on_toggle_tts), (SignalCode.TOGGLE_SD_SIGNAL, self.on_toggle_sd), (SignalCode.TOGGLE_LLM_SIGNAL, self.on_toggle_llm), + (SignalCode.UNLOAD_NON_SD_MODELS, self.on_unload_non_sd_models), + (SignalCode.LOAD_NON_SD_MODELS, self.on_load_non_sd_models), (SignalCode.APPLICATION_RESET_SETTINGS_SIGNAL, self.action_reset_settings), (SignalCode.APPLICATION_RESET_PATHS_SIGNAL, self.on_reset_paths_signal), (SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal), @@ -682,6 +684,22 @@ def on_toggle_fullscreen_signal(self): else: self.showFullScreen() + def on_unload_non_sd_models(self, data:dict=None): + self._llm_generate_worker.on_llm_on_unload_signal() + self._tts_generator_worker.unload() + self._stt_audio_processor_worker.unload() + callback = data.get("callback", None) + if callback: + callback(data) + + def on_load_non_sd_models(self, data:dict=None): + self._llm_generate_worker.load() + self._tts_generator_worker.load() + self._stt_audio_processor_worker.load() + callback = data.get("callback", None) + if callback: + callback(data) + def on_toggle_llm(self, data:dict=None, val=None): if val is None: val = not self.application_settings.llm_enabled diff --git a/src/airunner/workers/agent_worker.py b/src/airunner/workers/agent_worker.py index 1e022c6c2..b1b35bba8 100644 --- a/src/airunner/workers/agent_worker.py +++ b/src/airunner/workers/agent_worker.py @@ -1,7 +1,7 @@ import traceback import torch -from airunner.enums import SignalCode +from airunner.enums import SignalCode, LLMActionType from airunner.workers.worker import Worker @@ -44,6 +44,7 @@ def handle_message(self, message): is_first_message=True, is_end_of_message=True, name=message["botname"], + action=LLMActionType.CHAT ) ) else: @@ -58,6 +59,7 @@ def handle_message(self, message): is_first_message=True, is_end_of_message=True, name=message["botname"], + action=LLMActionType.CHAT ) ) diff --git a/src/airunner/workers/audio_capture_worker.py b/src/airunner/workers/audio_capture_worker.py index e4a9ec029..46e8d13d7 100644 --- a/src/airunner/workers/audio_capture_worker.py +++ b/src/airunner/workers/audio_capture_worker.py @@ -5,7 +5,7 @@ import numpy as np from PySide6.QtCore import QThread -from airunner.enums import SignalCode +from airunner.enums import SignalCode, ModelStatus from airunner.settings import SLEEP_TIME_IN_MS from airunner.workers.worker import Worker @@ -21,6 +21,7 @@ def __init__(self): (SignalCode.AUDIO_CAPTURE_WORKER_RESPONSE_SIGNAL, self.on_AudioCaptureWorker_response_signal), (SignalCode.STT_START_CAPTURE_SIGNAL, self.on_stt_start_capture_signal), (SignalCode.STT_STOP_CAPTURE_SIGNAL, self.on_stt_stop_capture_signal), + (SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal), )) self.listening: bool = False self.voice_input_start_time: time.time = None @@ -29,9 +30,6 @@ def __init__(self): self.stream = None self.running = False self._audio_process_queue = queue.Queue() - #self._capture_thread = None - if self.application_settings.stt_enabled: - self._start_listening() def on_AudioCaptureWorker_response_signal(self, message: dict): item: np.ndarray = message["item"] @@ -46,6 +44,14 @@ def on_stt_stop_capture_signal(self): if self.listening: self._stop_listening() + def on_model_status_changed_signal(self, message: dict): + model = message["model"] + status = message["status"] + if model == "stt" and status is ModelStatus.LOADED: + self._start_listening() + elif model == "stt" and status in (ModelStatus.UNLOADED, ModelStatus.FAILED): + self._stop_listening() + def start(self): self.logger.debug("Starting audio capture worker") self.running = True @@ -64,6 +70,10 @@ def start(self): self.logger.error(f"PortAudioError: {e}") QThread.msleep(SLEEP_TIME_IN_MS) continue + except Exception as e: + self.logger.error(e) + QThread.msleep(SLEEP_TIME_IN_MS) + continue if np.max(np.abs(chunk)) > volume_input_threshold: # check if chunk is not silence self.logger.debug("Heard voice") is_receiving_input = True @@ -92,21 +102,19 @@ def start(self): def _start_listening(self): self.logger.debug("Start listening") + if self.stream is not None: + self._end_stream() + self._initialize_stream() self.listening = True - fs = self.stt_settings.fs - channels = self.stt_settings.channels - if self.stream is None: - self.stream = sd.InputStream(samplerate=fs, channels=channels) - - try: - self.stream.start() - except Exception as e: - self.logger.error(e) def _stop_listening(self): self.logger.debug("Stop listening") self.listening = False self.running = False + self._end_stream() + # self._capture_thread.join() + + def _end_stream(self): try: self.stream.stop() except Exception as e: @@ -115,4 +123,13 @@ def _stop_listening(self): self.stream.close() except Exception as e: self.logger.error(e) - # self._capture_thread.join() + self.stream = None + + def _initialize_stream(self): + fs = self.stt_settings.fs + channels = self.stt_settings.channels + self.stream = sd.InputStream(samplerate=fs, channels=channels) + try: + self.stream.start() + except Exception as e: + self.logger.error(e) diff --git a/src/airunner/workers/audio_processor_worker.py b/src/airunner/workers/audio_processor_worker.py index 6dd764295..a0e1c4334 100644 --- a/src/airunner/workers/audio_processor_worker.py +++ b/src/airunner/workers/audio_processor_worker.py @@ -22,10 +22,14 @@ def __init__(self): )) def start_worker_thread(self): - self._stt = WhisperHandler() + self._initialize_stt_handler() if self.application_settings.stt_enabled: self._stt.load() + def _initialize_stt_handler(self): + if self._stt is None: + self._stt = WhisperHandler() + def on_stt_load_signal(self): if self._stt: threading.Thread(target=self._stt_load).start() @@ -34,13 +38,22 @@ def on_stt_unload_signal(self): if self._stt: threading.Thread(target=self._stt_unload).start() + def unload(self): + self._stt_unload() + + def load(self): + self._initialize_stt_handler() + self._stt_load() + def _stt_load(self): - self._stt.load() - self.emit_signal(SignalCode.STT_START_CAPTURE_SIGNAL) + if self._stt: + self._stt.load() + self.emit_signal(SignalCode.STT_START_CAPTURE_SIGNAL) def _stt_unload(self): self.emit_signal(SignalCode.STT_STOP_CAPTURE_SIGNAL) - self._stt.unload() + if self._stt: + self._stt.unload() def on_stt_process_audio_signal(self, message): self.add_to_queue(message) diff --git a/src/airunner/workers/llm_generate_worker.py b/src/airunner/workers/llm_generate_worker.py index 2f7266d85..1b4e93f38 100644 --- a/src/airunner/workers/llm_generate_worker.py +++ b/src/airunner/workers/llm_generate_worker.py @@ -35,7 +35,8 @@ def on_quit_application_signal(self): def on_llm_request_worker_response_signal(self, message: dict): self.add_to_queue(message) - def on_llm_on_unload_signal(self, data): + def on_llm_on_unload_signal(self, data=None): + data = data or {} self.logger.debug("Unloading LLM") self.llm.unload() callback = data.get("callback", None) @@ -80,7 +81,10 @@ def _load_llm_thread(self, data=None): self._llm_thread = threading.Thread(target=self._load_llm, args=(data,)) self._llm_thread.start() - def _load_llm(self, data): + def load(self): + self._load_llm() + + def _load_llm(self, data=None): data = data or {} if self.llm is None: self.llm = CausalLMTransformerBaseHandler(agent_options=self.agent_options) diff --git a/src/airunner/workers/tts_generator_worker.py b/src/airunner/workers/tts_generator_worker.py index 1525fe3de..e5198d249 100644 --- a/src/airunner/workers/tts_generator_worker.py +++ b/src/airunner/workers/tts_generator_worker.py @@ -2,7 +2,7 @@ import re import threading -from airunner.enums import SignalCode, TTSModel, ModelStatus +from airunner.enums import SignalCode, TTSModel, ModelStatus, LLMActionType from airunner.handlers.tts.espeak_tts_handler import EspeakTTSHandler from airunner.handlers.tts.speecht5_tts_handler import SpeechT5TTSHandler from airunner.workers.worker import Worker @@ -31,6 +31,10 @@ def on_llm_text_streamed_signal(self, data): if not self.application_settings.tts_enabled: return + action = data.get("action", LLMActionType.CHAT) + if action is LLMActionType.GENERATE_IMAGE: + return + if self.tts.model_status is not ModelStatus.LOADED: self.tts.load() @@ -69,15 +73,17 @@ def on_disable_tts_signal(self): thread.start() def start_worker_thread(self): - tts_model = self.tts_settings.model.lower() + self._initialize_tts_handler() + if self.application_settings.tts_enabled: + self.tts.load() + def _initialize_tts_handler(self): + tts_model = self.tts_settings.model.lower() if tts_model == TTSModel.ESPEAK.value: tts_handler_class_ = EspeakTTSHandler else: tts_handler_class_ = SpeechT5TTSHandler self.tts = tts_handler_class_() - if self.application_settings.tts_enabled: - self.tts.load() def add_to_queue(self, message): if self.do_interrupt: @@ -140,10 +146,20 @@ def word_count(s): self.on_interrupt_process_signal() def _load_tts(self): - self.tts.load() + if self.tts: + self.tts.load() + + def load(self): + if not self.tts: + self._initialize_tts_handler() + self._load_tts() + + def unload(self): + self._unload_tts() def _unload_tts(self): - self.tts.unload() + if self.tts: + self.tts.unload() def _generate(self, message): if self.do_interrupt: