From 2a867baf660fa0beed30e0b810c4c62f63b5240a Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 15:08:58 -0600 Subject: [PATCH] Improvements to - real time voice conversations - text processing for llm - interrupt signal --- src/airunner/enums.py | 1 + src/airunner/handlers/llm/agent/base_agent.py | 27 ++-- .../llm/causal_lm_transformer_base_handler.py | 6 +- src/airunner/handlers/stt/whisper_handler.py | 146 ++++++++---------- .../handlers/tts/speecht5_tts_handler.py | 9 +- .../tests/test_speecht5_tts_handler.py | 2 +- .../tests/test_tts_generator_worker.py | 25 +++ src/airunner/utils/get_torch_device.py | 2 + .../widgets/llm/chat_prompt_widget.py | 11 +- .../widgets/tool_tab/tool_tab_widget.py | 6 - src/airunner/windows/main/main_window.py | 103 ++++++------ src/airunner/worker_manager.py | 42 +---- src/airunner/workers/audio_capture_worker.py | 18 ++- src/airunner/workers/llm_generate_worker.py | 10 +- src/airunner/workers/tts_generator_worker.py | 27 +++- src/airunner/workers/tts_vocalizer_worker.py | 7 +- 16 files changed, 210 insertions(+), 232 deletions(-) create mode 100644 src/airunner/tests/test_tts_generator_worker.py diff --git a/src/airunner/enums.py b/src/airunner/enums.py index 2016054ca..22296306b 100644 --- a/src/airunner/enums.py +++ b/src/airunner/enums.py @@ -352,6 +352,7 @@ class LLMActionType(Enum): TOGGLE_TTS = "TOGGLE TEXT-TO-SPEECH: If the user requests that you turn on or off or toggle text-to-speech, choose this action." PERFORM_RAG_SEARCH = "SEARCH: If the user requests that you search for information, choose this action." SUMMARIZE = "SUMMARIZE" + DO_NOTHING = "DO NOTHING: If the user's request is unclear or you are unable to determine the user's intent, choose this action." diff --git a/src/airunner/handlers/llm/agent/base_agent.py b/src/airunner/handlers/llm/agent/base_agent.py index 2d14c2158..5c255a151 100644 --- a/src/airunner/handlers/llm/agent/base_agent.py +++ b/src/airunner/handlers/llm/agent/base_agent.py @@ -16,6 +16,7 @@ from llama_index.core.chat_engine import ContextChatEngine from llama_index.core import SimpleKeywordTableIndex from llama_index.core.indices.keyword_table import KeywordTableSimpleRetriever +from transformers import TextIteratorStreamer from airunner.handlers.llm.huggingface_llm import HuggingFaceLLM from airunner.handlers.llm.custom_embedding import CustomEmbedding @@ -82,7 +83,7 @@ def __init__(self, *args, **kwargs): self.action = LLMActionType.CHAT self.rendered_template = None self.tokenizer = kwargs.pop("tokenizer", None) - self.streamer = kwargs.pop("streamer", None) + self.streamer = TextIteratorStreamer(self.tokenizer) self.chat_template = kwargs.pop("chat_template", "") self.is_mistral = kwargs.pop("is_mistral", True) self.conversation_id = None @@ -97,12 +98,11 @@ def __init__(self, *args, **kwargs): @property def available_actions(self): return { - 0: LLMActionType.QUIT_APPLICATION, - 1: LLMActionType.TOGGLE_FULLSCREEN, - 2: LLMActionType.TOGGLE_TTS, - 3: LLMActionType.GENERATE_IMAGE, - 4: LLMActionType.PERFORM_RAG_SEARCH, - 5: LLMActionType.CHAT, + 0: LLMActionType.TOGGLE_FULLSCREEN, + 1: LLMActionType.TOGGLE_TTS, + 2: LLMActionType.GENERATE_IMAGE, + 3: LLMActionType.PERFORM_RAG_SEARCH, + 4: LLMActionType.CHAT, } @property @@ -163,7 +163,7 @@ def interrupt_process(self): def do_interrupt_process(self): interrupt = self.do_interrupt - self.do_interrupt = False + self.streamer = TextIteratorStreamer(self.tokenizer) return interrupt @property @@ -303,9 +303,7 @@ def build_system_prompt( self.names_prompt(use_names, botname, username), self.mood(botname, bot_mood, use_mood), system_instructions, - "------\n", - "Chat History:\n", - f"{self.username}: {self.prompt}\n", + self.history_prompt(), ] elif action is LLMActionType.SUMMARIZE: @@ -502,10 +500,9 @@ def run( self.create_conversation() # Add the user's message to history - if action in ( - LLMActionType.CHAT, - LLMActionType.PERFORM_RAG_SEARCH, - LLMActionType.GENERATE_IMAGE, + if action not in ( + LLMActionType.APPLICATION_COMMAND, + LLMActionType.UPDATE_MOOD ): self.add_message_to_history(self.prompt, LLMChatRole.HUMAN) diff --git a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py index dad3e4105..655c28098 100644 --- a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py +++ b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py @@ -226,6 +226,8 @@ def clear_history(self): """ Public method to clear the chat agent history """ + if not self._chat_agent: + return self.logger.debug("Clearing chat history") self._chat_agent.clear_history() @@ -301,7 +303,6 @@ def _load_agent(self): self._chat_agent = BaseAgent( model=self._model, tokenizer=self._tokenizer, - streamer=self._streamer, chat_template=self.chat_template, is_mistral=self.is_mistral, ) @@ -378,8 +379,7 @@ def _load_model_local(self): def _do_generate(self, prompt: str, action: LLMActionType): self.logger.debug("Generating response") - model_path = self.model_path - if self._current_model_path != model_path: + if self._current_model_path != self.model_path: self.unload() self.load() if action is LLMActionType.CHAT and self.chatbot.use_mood: diff --git a/src/airunner/handlers/stt/whisper_handler.py b/src/airunner/handlers/stt/whisper_handler.py index 27f6bdb54..b4f425d9f 100644 --- a/src/airunner/handlers/stt/whisper_handler.py +++ b/src/airunner/handlers/stt/whisper_handler.py @@ -8,7 +8,7 @@ from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor from airunner.handlers.base_handler import BaseHandler -from airunner.enums import SignalCode, ModelType, ModelStatus, LLMChatRole +from airunner.enums import SignalCode, ModelType, ModelStatus from airunner.exceptions import NaNException from airunner.utils.clear_memory import clear_memory @@ -27,6 +27,10 @@ def __init__(self, *args, **kwargs): self._feature_extractor = None self._fs = 16000 + @property + def dtype(self): + return torch.bfloat16 + @property def stt_is_loading(self): return self.model_status is ModelStatus.LOADING @@ -57,15 +61,15 @@ def process_audio(self, audio_data): # Convert the byte string to a float32 array inputs = np.frombuffer(item, dtype=np.int16) inputs = inputs.astype(np.float32) / 32767.0 + transcription = None try: transcription = self._process_inputs(inputs) except Exception as e: self.logger.error(f"Failed to process inputs {e}") self.logger.error(e) - try: - self._process_human_speech(transcription) - except ValueError as e: - self.logger.error(f"Failed to process audio {e}") + + if transcription: + self._send_transcription(transcription) def load(self): if self.stt_is_loading or self.stt_is_loaded: @@ -99,17 +103,18 @@ def unload(self): def _load_model(self): self.logger.debug(f"Loading model from {self.model_path}") + device = self.device try: self._model = WhisperForConditionalGeneration.from_pretrained( self.model_path, local_files_only=True, - torch_dtype=torch.bfloat16, - device_map=self.device, - use_safetensors=True + torch_dtype=self.dtype, + device_map=device, + use_safetensors=True, + force_download=False ) except Exception as e: - self.logger.error(f"Failed to load model") - self.logger.error(e) + self.logger.error(f"Failed to load model: {e}") return None def _load_processor(self): @@ -119,12 +124,11 @@ def _load_processor(self): self._processor = WhisperProcessor.from_pretrained( model_path, local_files_only=True, - torch_dtype=torch.bfloat16, + torch_dtype=self.dtype, device_map=self.device ) except Exception as e: - self.logger.error(f"Failed to load processor") - self.logger.error(e) + self.logger.error(f"Failed to load processor: {e}") return None def _load_feature_extractor(self): @@ -134,7 +138,7 @@ def _load_feature_extractor(self): self._feature_extractor = WhisperFeatureExtractor.from_pretrained( model_path, local_files_only=True, - torch_dtype=torch.bfloat16, + torch_dtype=self.dtype, device_map=self.device ) except Exception as e: @@ -157,59 +161,34 @@ def _unload_feature_extractor(self): self._feature_extractor = None clear_memory(self.device) - def _process_inputs( - self, - inputs: np.ndarray, - role: LLMChatRole = LLMChatRole.HUMAN, - ) -> str: - inputs = torch.from_numpy(inputs) + def _process_inputs(self, inputs: np.ndarray) -> str: + if not self._feature_extractor: + return "" + inputs = torch.from_numpy(inputs).to(torch.float32).to(self.device) if torch.isnan(inputs).any(): raise NaNException + # Move inputs to CPU and ensure they are in float32 before passing to _feature_extractor + inputs = inputs.cpu().to(torch.float32) inputs = self._feature_extractor(inputs, sampling_rate=self._fs, return_tensors="pt") - if torch.isnan(inputs.input_features).any(): - raise NaNException - inputs["input_features"] = inputs["input_features"].to(torch.bfloat16) if torch.isnan(inputs.input_features).any(): raise NaNException - inputs = inputs.to(self._model.device) + inputs["input_features"] = inputs["input_features"].to(self.dtype).to(self.device) if torch.isnan(inputs.input_features).any(): raise NaNException - transcription = self._run(inputs, role) + transcription = self._run(inputs) if transcription is None or 'nan' in transcription: raise NaNException return transcription - def _process_human_speech(self, transcription: str = None): - """ - Process the human speech. - This method is called when the model has processed the human speech - and the transcription is ready to be added to the chat history. - This should only be used for human speech. - :param transcription: - :return: - """ - if transcription == "": - raise ValueError("Transcription is empty") - self.logger.debug("Processing human speech") - data = { - "message": transcription, - "role": LLMChatRole.HUMAN - } - self.emit_signal( - SignalCode.ADD_CHATBOT_MESSAGE_SIGNAL, - data - ) - def _run( self, - inputs, - role: LLMChatRole = LLMChatRole.HUMAN, + inputs ) -> str: """ Run the model on the given inputs. @@ -231,31 +210,39 @@ def _run( if torch.isnan(input_features).any(): raise NaNException - generated_ids = self._model.generate( - input_features=input_features, - # generation_config=None, - # logits_processor=None, - # stopping_criteria=None, - # prefix_allowed_tokens_fn=None, - # synced_gpus=True, - # return_timestamps=None, - # task="transcribe", - # language="en", - # is_multilingual=True, - # prompt_ids=None, - # prompt_condition_type=None, - # condition_on_prev_tokens=None, - temperature=0.8, - # compression_ratio_threshold=None, - # logprob_threshold=None, - # no_speech_threshold=None, - # num_segment_frames=None, - # attention_mask=None, - # time_precision=0.02, - # return_token_timestamps=None, - # return_segments=False, - # return_dict_in_generate=None, - ) + try: + generated_ids = self._model.generate( + input_features=input_features, + # generation_config=None, + # logits_processor=None, + # stopping_criteria=None, + # prefix_allowed_tokens_fn=None, + # synced_gpus=True, + # return_timestamps=None, + # task="transcribe", + # language="en", + is_multilingual=False, + # prompt_ids=None, + # prompt_condition_type=None, + # condition_on_prev_tokens=None, + temperature=0.8, + compression_ratio_threshold=1.35, + logprob_threshold=-1.0, + no_speech_threshold=0.2, + # num_segment_frames=None, + # attention_mask=None, + time_precision=0.02, + # return_token_timestamps=None, + # return_segments=False, + # return_dict_in_generate=None, + ) + except RuntimeError as e: + generated_ids = None + self.logger.error(f"Error in model generation: {e}") + + if generated_ids is None: + return "" + if torch.isnan(generated_ids).any(): raise NaNException @@ -263,16 +250,19 @@ def _run( if len(transcription) == 0 or len(transcription.split(" ")) == 1: return "" - # Emit the transcription so that other handlers can use it + return transcription + + def _send_transcription(self, transcription: str): + """ + Emit the transcription so that other handlers can use it + """ self.emit_signal(SignalCode.AUDIO_PROCESSOR_RESPONSE_SIGNAL, { - "transcription": transcription, - "role": role + "transcription": transcription }) - return transcription - def process_transcription(self, generated_ids) -> str: # Decode the generated ids + generated_ids = generated_ids.to("cpu").to(torch.float32) transcription = self._processor.batch_decode( generated_ids, skip_special_tokens=True diff --git a/src/airunner/handlers/tts/speecht5_tts_handler.py b/src/airunner/handlers/tts/speecht5_tts_handler.py index 948816e33..c89761bc2 100644 --- a/src/airunner/handlers/tts/speecht5_tts_handler.py +++ b/src/airunner/handlers/tts/speecht5_tts_handler.py @@ -328,10 +328,10 @@ def interrupt_process_signal(self): def _prepare_text(self, text) -> str: text = self._replace_unspeakable_characters(text) text = self._strip_emoji_characters(text) - text = self._roman_to_int(text) + # the following function is currently disabled because we must first find a + # reliable way to handle the word "I" and distinguish it from the Roman numeral "I" + # text = self._roman_to_int(text) text = self._replace_numbers_with_words(text) - text = re.sub(r"\s+", " ", text) # Remove extra spaces - text = text.strip() return text @staticmethod @@ -339,13 +339,12 @@ def _replace_unspeakable_characters(text) -> str: # strip things like ellipsis, etc text = text.replace("...", " ") text = text.replace("…", " ") - text = text.replace("’", "") + text = text.replace("’", "'") text = text.replace("“", "") text = text.replace("”", "") text = text.replace("‘", "") text = text.replace("–", "") text = text.replace("—", "") - text = text.replace("'", "") text = text.replace('"', "") text = text.replace("-", "") text = text.replace("-", "") diff --git a/src/airunner/tests/test_speecht5_tts_handler.py b/src/airunner/tests/test_speecht5_tts_handler.py index 64e1bf9b7..2985355dd 100644 --- a/src/airunner/tests/test_speecht5_tts_handler.py +++ b/src/airunner/tests/test_speecht5_tts_handler.py @@ -43,7 +43,7 @@ def test_roman_to_int(self): "M": "1000", "MMXXI": "2021", "This is a IV test": "This is a 4 test", - "A test with no roman numerals": "A test with no roman numerals" + "A test with no roman numerals": "A test with no roman numerals", } for roman, expected in test_cases.items(): diff --git a/src/airunner/tests/test_tts_generator_worker.py b/src/airunner/tests/test_tts_generator_worker.py new file mode 100644 index 000000000..f0cccff5b --- /dev/null +++ b/src/airunner/tests/test_tts_generator_worker.py @@ -0,0 +1,25 @@ +import unittest +from airunner.workers.tts_generator_worker import TTSGeneratorWorker + +class TestTTSGeneratorWorker(unittest.TestCase): + + def test_split_text_at_punctuation(self): + test_cases = [ + ("Hello world.", ["Hello world"]), + ("Hello world. How are you?", ["Hello world", "How are you"]), + ("Hello! How are you? I'm fine.", ["Hello", "How are you", "I'm fine"]), + ("No punctuation here", ["No punctuation here"]), + ("Multiple\nlines\nhere", ["Multiple", "lines", "here"]), + ("Comma, separated, values", ["Comma", "separated", "values"]), + ("Mixed punctuation! Really? Yes.", ["Mixed punctuation", "Really", "Yes"]), + ("The time is 12:45.", ["The time is 1245"]), + ("Meet me at 09:30 AM.", ["Meet me at 0930 AM"]), + ("It happened at 23:59:59.", ["It happened at 235959"]), + ] + + for text, expected_chunks in test_cases: + with self.subTest(text=text, expected_chunks=expected_chunks): + self.assertEqual(TTSGeneratorWorker._split_text_at_punctuation(text), expected_chunks) + +if __name__ == '__main__': + unittest.main() diff --git a/src/airunner/utils/get_torch_device.py b/src/airunner/utils/get_torch_device.py index 92d37623b..51b3235c5 100644 --- a/src/airunner/utils/get_torch_device.py +++ b/src/airunner/utils/get_torch_device.py @@ -3,4 +3,6 @@ def get_torch_device(card_index: int = 0): use_cuda = torch.cuda.is_available() + if not use_cuda: + print("WARNING: CUDA NOT AVAILABLE, USING CPU") return torch.device(f"cuda:{card_index}" if use_cuda else "cpu") diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py index 293692881..2503cd5f2 100644 --- a/src/airunner/widgets/llm/chat_prompt_widget.py +++ b/src/airunner/widgets/llm/chat_prompt_widget.py @@ -98,8 +98,8 @@ def _set_conversation_widgets(self, messages): def on_hear_signal(self, data: dict): transcription = data["transcription"] - self.respond_to_voice(transcription) self.ui.prompt.setPlainText(transcription) + self.ui.send_button.click() def on_add_to_conversation_signal(self, name, text, is_bot): self.add_message_to_conversation(name=name, message=text, is_bot=is_bot) @@ -161,6 +161,9 @@ def action_button_clicked_send(self): def interrupt_button_clicked(self): self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL) + self.stop_progress_bar() + self.generating = False + self.enable_send_button() @property def action(self) -> str: @@ -289,12 +292,6 @@ def display_action_menu(self): def insert_newline(self): self.ui.prompt.insertPlainText("\n") - - def respond_to_voice(self, transcript: str): - transcript = transcript.strip() - if transcript == "." or transcript is None or transcript == "": - return - self.do_generate(prompt_override=transcript) def describe_image(self, image, callback): self.do_generate( diff --git a/src/airunner/widgets/tool_tab/tool_tab_widget.py b/src/airunner/widgets/tool_tab/tool_tab_widget.py index 439698f4b..383e29e25 100644 --- a/src/airunner/widgets/tool_tab/tool_tab_widget.py +++ b/src/airunner/widgets/tool_tab/tool_tab_widget.py @@ -10,9 +10,3 @@ class ToolTabWidget(BaseWidget): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.__application_settings = QSettings(ORGANIZATION, APPLICATION_NAME) - - def showEvent(self, event): - self.ui.tool_tab_widget_container.setCurrentIndex( - int(self.__application_settings.value("tool_tab_widget_index", defaultValue=0)) - ) diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index 2e870f814..3af0687fe 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -49,6 +49,14 @@ ) from airunner.styles_mixin import StylesMixin from airunner.utils.convert_image_to_base64 import convert_image_to_base64 +from airunner.utils.create_worker import create_worker +from airunner.workers.audio_capture_worker import AudioCaptureWorker +from airunner.workers.audio_processor_worker import AudioProcessorWorker +from airunner.workers.llm_generate_worker import LLMGenerateWorker +from airunner.workers.mask_generator_worker import MaskGeneratorWorker +from airunner.workers.sd_worker import SDWorker +from airunner.workers.tts_generator_worker import TTSGeneratorWorker +from airunner.workers.tts_vocalizer_worker import TTSVocalizerWorker from airunner.utils.get_version import get_version from airunner.utils.set_widget_state import set_widget_state @@ -64,7 +72,6 @@ from airunner.windows.prompt_browser.prompt_browser import PromptBrowser from airunner.windows.settings.airunner_settings import SettingsWindow from airunner.windows.update.update_window import UpdateWindow -from airunner.worker_manager import WorkerManager class MainWindow( @@ -104,23 +111,10 @@ class MainWindow( def __init__( self, *args, - disable_sd: bool = False, - disable_llm: bool = False, - disable_tts: bool = False, - disable_stt: bool = False, - use_cuda: bool = True, - tts_enabled: bool = False, - stt_enabled: bool = False, - ai_mode: bool = True, defendatron=None, **kwargs ): self.ui = self.ui_class_() - self.disable_sd = disable_sd - self.disable_llm = disable_llm - self.disable_tts = disable_tts - self.disable_stt = disable_stt - self.defendatron = defendatron self.quitting = False self.update_popup = None @@ -146,7 +140,6 @@ def __init__( self.status_error_color = STATUS_ERROR_COLOR self.status_normal_color_light = STATUS_NORMAL_COLOR_LIGHT self.status_normal_color_dark = STATUS_NORMAL_COLOR_DARK - self.is_started = False self._themes = None self.button_clicked_signal = Signal(dict) self.status_widget = None @@ -158,35 +151,19 @@ def __init__( self.listening = False self.initialized = False self._model_status = {model_type: ModelStatus.UNLOADED for model_type in ModelType} - self.logger = Logger(prefix=self.__class__.__name__) self.logger.debug("Starting AI Runnner") MediatorMixin.__init__(self) SettingsMixin.__init__(self) - super().__init__(*args, **kwargs) - self._updating_settings = True - self.__application_settings = QSettings(ORGANIZATION, APPLICATION_NAME) - PipelineMixin.__init__(self) AIModelMixin.__init__(self) self._updating_settings = False - + self._worker_manager = None self.register_signals() - self.initialize_ui() - self.worker_manager = None - self.is_started = True - self.image_window = None - - for item in ( - (SignalCode.AI_MODELS_SAVE_OR_UPDATE_SIGNAL, self.on_ai_models_save_or_update_signal), - (SignalCode.NAVIGATE_TO_URL, self.on_navigate_to_url), - ): - self.register(item[0], item[1]) - - self.emit_signal(SignalCode.APPLICATION_MAIN_WINDOW_LOADED_SIGNAL, { "main_window": self }) + self._initialize_workers() @property def generator_tab_widget(self): @@ -536,25 +513,30 @@ def show_layers(self): def register_signals(self): self.logger.debug("Connecting signals") - self.register(SignalCode.SD_SAVE_PROMPT_SIGNAL, self.on_save_stablediffusion_prompt_signal) - self.register(SignalCode.QUIT_APPLICATION, self.action_quit_triggered) - self.register(SignalCode.SD_NSFW_CONTENT_DETECTED_SIGNAL, self.on_nsfw_content_detected_signal) - self.register(SignalCode.ENABLE_BRUSH_TOOL_SIGNAL, lambda _message: self.action_toggle_brush(True)) - self.register(SignalCode.ENABLE_ERASER_TOOL_SIGNAL, lambda _message: self.action_toggle_eraser(True)) - self.register(SignalCode.ENABLE_SELECTION_TOOL_SIGNAL, lambda _message: self.action_toggle_select(True)) - self.register(SignalCode.ENABLE_MOVE_TOOL_SIGNAL, lambda _message: self.action_toggle_active_grid_area(True)) - self.register(SignalCode.BASH_EXECUTE_SIGNAL, self.on_bash_execute_signal) - self.register(SignalCode.WRITE_FILE, self.on_write_file_signal) - self.register(SignalCode.TOGGLE_FULLSCREEN_SIGNAL, self.on_toggle_fullscreen_signal) - self.register(SignalCode.TOGGLE_TTS_SIGNAL, self.on_toggle_tts) - self.register(SignalCode.TOGGLE_SD_SIGNAL, self.on_toggle_sd) - self.register(SignalCode.TOGGLE_LLM_SIGNAL, self.on_toggle_llm) - self.register(SignalCode.APPLICATION_RESET_SETTINGS_SIGNAL, self.action_reset_settings) - self.register(SignalCode.APPLICATION_RESET_PATHS_SIGNAL, self.on_reset_paths_signal) - self.register(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal) - self.register(SignalCode.KEYBOARD_SHORTCUTS_UPDATED, self.on_keyboard_shortcuts_updated) - self.register(SignalCode.HISTORY_UPDATED, self.on_history_updated), - self.register(SignalCode.REFRESH_STYLESHEET_SIGNAL, self.on_theme_changed_signal) + for item in ( + (SignalCode.SD_SAVE_PROMPT_SIGNAL, self.on_save_stablediffusion_prompt_signal), + (SignalCode.QUIT_APPLICATION, self.action_quit_triggered), + (SignalCode.SD_NSFW_CONTENT_DETECTED_SIGNAL, self.on_nsfw_content_detected_signal), + (SignalCode.ENABLE_BRUSH_TOOL_SIGNAL, lambda _message: self.action_toggle_brush(True)), + (SignalCode.ENABLE_ERASER_TOOL_SIGNAL, lambda _message: self.action_toggle_eraser(True)), + (SignalCode.ENABLE_SELECTION_TOOL_SIGNAL, lambda _message: self.action_toggle_select(True)), + (SignalCode.ENABLE_MOVE_TOOL_SIGNAL, lambda _message: self.action_toggle_active_grid_area(True)), + (SignalCode.BASH_EXECUTE_SIGNAL, self.on_bash_execute_signal), + (SignalCode.WRITE_FILE, self.on_write_file_signal), + (SignalCode.TOGGLE_FULLSCREEN_SIGNAL, self.on_toggle_fullscreen_signal), + (SignalCode.TOGGLE_TTS_SIGNAL, self.on_toggle_tts), + (SignalCode.TOGGLE_SD_SIGNAL, self.on_toggle_sd), + (SignalCode.TOGGLE_LLM_SIGNAL, self.on_toggle_llm), + (SignalCode.APPLICATION_RESET_SETTINGS_SIGNAL, self.action_reset_settings), + (SignalCode.APPLICATION_RESET_PATHS_SIGNAL, self.on_reset_paths_signal), + (SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal), + (SignalCode.KEYBOARD_SHORTCUTS_UPDATED, self.on_keyboard_shortcuts_updated), + (SignalCode.HISTORY_UPDATED, self.on_history_updated), + (SignalCode.REFRESH_STYLESHEET_SIGNAL, self.on_theme_changed_signal), + (SignalCode.AI_MODELS_SAVE_OR_UPDATE_SIGNAL, self.on_ai_models_save_or_update_signal), + (SignalCode.NAVIGATE_TO_URL, self.on_navigate_to_url), + ): + self.register(item[0], item[1]) def on_reset_paths_signal(self): self.reset_path_settings() @@ -607,6 +589,7 @@ def initialize_ui(self): self.initialize_widget_elements() self.ui.actionUndo.setEnabled(False) self.ui.actionRedo.setEnabled(False) + self.emit_signal(SignalCode.APPLICATION_MAIN_WINDOW_LOADED_SIGNAL, {"main_window": self}) def initialize_widget_elements(self): for item in ( @@ -935,7 +918,6 @@ def showEvent(self, event): icon_data[1], "dark" if self.application_settings.dark_mode_enabled else "light" ) - self._initialize_worker_manager() self.logger.debug("Showing window") self._set_keyboard_shortcuts() @@ -981,14 +963,15 @@ def _set_keyboard_shortcuts(self): session.close() - def _initialize_worker_manager(self): + def _initialize_workers(self): self.logger.debug("Initializing worker manager") - self.worker_manager = WorkerManager( - disable_sd=self.disable_sd, - disable_llm=self.disable_llm, - disable_tts=self.disable_tts, - disable_stt=self.disable_stt - ) + self._mask_generator_worker = create_worker(MaskGeneratorWorker) + self._sd_worker = create_worker(SDWorker) + self._stt_audio_capture_worker = create_worker(AudioCaptureWorker) + self._stt_audio_processor_worker = create_worker(AudioProcessorWorker) + self._tts_generator_worker = create_worker(TTSGeneratorWorker) + self._tts_vocalizer_worker = create_worker(TTSVocalizerWorker) + self._llm_generate_worker = create_worker(LLMGenerateWorker) def _initialize_filter_actions(self): # add more filters: diff --git a/src/airunner/worker_manager.py b/src/airunner/worker_manager.py index 8295de332..fe0d7dbea 100644 --- a/src/airunner/worker_manager.py +++ b/src/airunner/worker_manager.py @@ -22,14 +22,7 @@ class WorkerManager(QObject, MediatorMixin, SettingsMixin): request_signal_status = Signal(str) image_generated_signal = Signal(dict) - def __init__( - self, - disable_sd: bool = False, - disable_llm: bool = False, - disable_tts: bool = False, - disable_stt: bool = False, - agent_options: dict = None - ): + def __init__(self): MediatorMixin.__init__(self) SettingsMixin.__init__(self) super().__init__() @@ -42,32 +35,7 @@ def __init__( self._stt_audio_capture_worker = None self._stt_audio_processor_worker = None - self.agent_options = agent_options - - if not disable_sd: - self.register_sd_workers() - - if not disable_llm: - self.register_llm_workers(self.agent_options) - - if not disable_tts: - self.register_tts_workers() - - if not disable_stt: - self.register_stt_workers() - - self.mask_generator_worker = create_worker(MaskGeneratorWorker) - - def register_sd_workers(self): - self._sd_worker = create_worker(SDWorker) - - def register_llm_workers(self, agent_options): - self._llm_generate_worker = create_worker(LLMGenerateWorker, agent_options=agent_options) - - def register_tts_workers(self): - self._tts_generator_worker = create_worker(TTSGeneratorWorker) - self._tts_vocalizer_worker = create_worker(TTSVocalizerWorker) - - def register_stt_workers(self): - self._stt_audio_capture_worker = create_worker(AudioCaptureWorker) - self._stt_audio_processor_worker = create_worker(AudioProcessorWorker) + self.register_sd_workers() + self.register_llm_workers() + self.register_tts_workers() + self.register_stt_workers() diff --git a/src/airunner/workers/audio_capture_worker.py b/src/airunner/workers/audio_capture_worker.py index 4f05e055e..e4a9ec029 100644 --- a/src/airunner/workers/audio_capture_worker.py +++ b/src/airunner/workers/audio_capture_worker.py @@ -1,5 +1,4 @@ import queue -import threading import time import sounddevice as sd @@ -30,7 +29,9 @@ def __init__(self): self.stream = None self.running = False self._audio_process_queue = queue.Queue() - self._capture_thread = None + #self._capture_thread = None + if self.application_settings.stt_enabled: + self._start_listening() def on_AudioCaptureWorker_response_signal(self, message: dict): item: np.ndarray = message["item"] @@ -38,13 +39,11 @@ def on_AudioCaptureWorker_response_signal(self, message: dict): self.add_to_queue(item) def on_stt_start_capture_signal(self): - if self._capture_thread is not None and self._capture_thread.is_alive(): - return - self._capture_thread = threading.Thread(target=self._start_listening) - self._capture_thread.start() + if not self.listening: + self._start_listening() def on_stt_stop_capture_signal(self): - if self._capture_thread is not None and self._capture_thread.is_alive(): + if self.listening: self._stop_listening() def start(self): @@ -62,9 +61,11 @@ def start(self): try: chunk, overflowed = self.stream.read(int(chunk_duration * fs)) except sd.PortAudioError as e: + self.logger.error(f"PortAudioError: {e}") QThread.msleep(SLEEP_TIME_IN_MS) continue if np.max(np.abs(chunk)) > volume_input_threshold: # check if chunk is not silence + self.logger.debug("Heard voice") is_receiving_input = True self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL) voice_input_start_time = time.time() @@ -73,6 +74,7 @@ def start(self): end_time = voice_input_start_time + silence_buffer_seconds if time.time() >= end_time: if len(recording) > 0: + self.logger.debug("Sending audio to audio_processor_worker") self.emit_signal( SignalCode.AUDIO_CAPTURE_WORKER_RESPONSE_SIGNAL, { @@ -113,4 +115,4 @@ def _stop_listening(self): self.stream.close() except Exception as e: self.logger.error(e) - self._capture_thread.join() + # self._capture_thread.join() diff --git a/src/airunner/workers/llm_generate_worker.py b/src/airunner/workers/llm_generate_worker.py index dcc560bed..90aa2d53f 100644 --- a/src/airunner/workers/llm_generate_worker.py +++ b/src/airunner/workers/llm_generate_worker.py @@ -19,6 +19,7 @@ def __init__(self, agent_options=None): (SignalCode.RAG_RELOAD_INDEX_SIGNAL, self.on_llm_reload_rag_index_signal), (SignalCode.ADD_CHATBOT_MESSAGE_SIGNAL, self.on_llm_add_chatbot_response_to_history), (SignalCode.LOAD_CONVERSATION, self.on_llm_load_conversation), + (SignalCode.INTERRUPT_PROCESS_SIGNAL, self.llm_on_interrupt_process_signal), ): self.register(signal[0], signal[1]) @@ -42,16 +43,19 @@ def _load_llm(self, data): callback(data) def on_llm_clear_history_signal(self): - self.llm.clear_history() + if self.llm: + self.llm.clear_history() def on_llm_request_signal(self, message: dict): self.add_to_queue(message) def llm_on_interrupt_process_signal(self): - self.llm.do_interrupt() + if self.llm: + self.llm.do_interrupt() def on_llm_reload_rag_index_signal(self): - self.llm.reload_rag() + if self.llm: + self.llm.reload_rag() def on_llm_add_chatbot_response_to_history(self, message): self.llm.add_chatbot_response_to_history(message) diff --git a/src/airunner/workers/tts_generator_worker.py b/src/airunner/workers/tts_generator_worker.py index c3853c175..1525fe3de 100644 --- a/src/airunner/workers/tts_generator_worker.py +++ b/src/airunner/workers/tts_generator_worker.py @@ -1,4 +1,5 @@ import queue +import re import threading from airunner.enums import SignalCode, TTSModel, ModelStatus @@ -51,10 +52,11 @@ def on_interrupt_process_signal(self): self.tts.interrupt_process_signal() def on_unblock_tts_generator_signal(self): - self.logger.debug("Unblocking TTS generation...") - self.do_interrupt = False - self.paused = False - self.tts.unblock_tts_generator_signal() + if self.application_settings.tts_enabled: + self.logger.debug("Unblocking TTS generation...") + self.do_interrupt = False + self.paused = False + self.tts.unblock_tts_generator_signal() def on_enable_tts_signal(self): if self.tts: @@ -98,6 +100,15 @@ def handle_message(self, data): # Convert the tokens to a string text = "".join(self.tokens) + # Regular expression to match timestamps in the format HH:MM + timestamp_pattern = re.compile(r'\b(\d{1,2}):(\d{2})\b') + + # Replace the colon in the matched timestamps with a space + text = timestamp_pattern.sub(r'\1 \2', text) + + def word_count(s): + return len(s.split()) + if finalize: self._generate(text) self.play_queue_started = True @@ -112,12 +123,16 @@ def handle_message(self, data): if p in text: split_text = text.split(p, 1) # Split at the first occurrence of punctuation if len(split_text) > 1: - sentence = split_text[0] + before, after = split_text[0], split_text[1] + if p == ",": + if word_count(before) < 3 or word_count(after) < 3: + continue # Skip splitting if there are not enough words around the comma + sentence = before self._generate(sentence) self.play_queue_started = True # Convert the remaining string back to a list of tokens - remaining_text = split_text[1].strip() + remaining_text = after.strip() if not self.do_interrupt: self.tokens = list(remaining_text) break diff --git a/src/airunner/workers/tts_vocalizer_worker.py b/src/airunner/workers/tts_vocalizer_worker.py index bf03c24e2..67d1abd5e 100644 --- a/src/airunner/workers/tts_vocalizer_worker.py +++ b/src/airunner/workers/tts_vocalizer_worker.py @@ -36,9 +36,10 @@ def on_interrupt_process_signal(self): self.queue = Queue() def on_unblock_tts_generator_signal(self): - self.logger.debug("Starting TTS stream...") - self.accept_message = True - self.stream.start() + if self.application_settings.tts_enabled: + self.logger.debug("Starting TTS stream...") + self.accept_message = True + self.stream.start() def start_stream(self): if sd.query_devices(kind='output'):