From 5586bd9925e9f6cf8c23e5664dd6105c0382b732 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:08:01 -0600 Subject: [PATCH 1/8] strip emoji and add more tests --- .../handlers/tts/speecht5_tts_handler.py | 45 ++++++++++---- .../tests/test_speecht5_tts_handler.py | 59 +++++++++++++++++++ 2 files changed, 94 insertions(+), 10 deletions(-) diff --git a/src/airunner/handlers/tts/speecht5_tts_handler.py b/src/airunner/handlers/tts/speecht5_tts_handler.py index 6a0777c83..c777e771f 100644 --- a/src/airunner/handlers/tts/speecht5_tts_handler.py +++ b/src/airunner/handlers/tts/speecht5_tts_handler.py @@ -234,10 +234,7 @@ def _unload_speaker_embeddings(self): def _do_generate(self, message): self.logger.debug("Generating text-to-speech with T5") - text = self._replace_unspeakable_characters(message) - text = self._roman_to_int(text) - text = self._replace_numbers_with_words(text) - text = text.strip() + text = self._prepare_text(message) if text == "": return None @@ -300,9 +297,18 @@ def _move_inputs_to_device(self, inputs): self.logger.error(e) return inputs + def _prepare_text(self, text) -> str: + text = self._replace_unspeakable_characters(text) + text = self._strip_emoji_characters(text) + text = self._roman_to_int(text) + text = self._replace_numbers_with_words(text) + text = re.sub(r"\s+", " ", text) # Remove extra spaces + text = text.strip() + return text + @staticmethod - def _replace_unspeakable_characters(text): - # strip things like eplisis, etc + def _replace_unspeakable_characters(text) -> str: + # strip things like ellipsis, etc text = text.replace("...", " ") text = text.replace("…", " ") text = text.replace("’", "") @@ -325,12 +331,31 @@ def _replace_unspeakable_characters(text): # replace tabs text = text.replace("\t", " ") - # replace excessive spaces - text = re.sub(r"\s+", " ", text) return text @staticmethod - def _roman_to_int(text): + def _strip_emoji_characters(text) -> str: + # strip emojis + emoji_pattern = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map symbols + "\U0001F700-\U0001F77F" # alchemical symbols + "\U0001F780-\U0001F7FF" # Geometric Shapes Extended + "\U0001F800-\U0001F8FF" # Supplemental Arrows-C + "\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs + "\U0001FA00-\U0001FA6F" # Chess Symbols + "\U0001FA70-\U0001FAFF" # Symbols and Pictographs Extended-A + "\U00002702-\U000027B0" # Dingbats + "\U000024C2-\U0001F251" + "]+", flags=re.UNICODE + ) + text = emoji_pattern.sub(r'', text) + return text + + @staticmethod + def _roman_to_int(text) -> str: roman_numerals = { 'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000 } @@ -352,7 +377,7 @@ def convert_roman_to_int(roman): return result @staticmethod - def _replace_numbers_with_words(text): + def _replace_numbers_with_words(text) -> str: p = inflect.engine() # Handle time formats separately diff --git a/src/airunner/tests/test_speecht5_tts_handler.py b/src/airunner/tests/test_speecht5_tts_handler.py index 1e710adc2..64e1bf9b7 100644 --- a/src/airunner/tests/test_speecht5_tts_handler.py +++ b/src/airunner/tests/test_speecht5_tts_handler.py @@ -50,5 +50,64 @@ def test_roman_to_int(self): with self.subTest(roman=roman, expected=expected): self.assertEqual(handler._roman_to_int(roman), expected) + def test_replace_unspeakable_characters(self): + handler = SpeechT5TTSHandler() + + # Test cases + test_cases = { + "Hello... world!": "Hello world!", + "This is an ellipsis…": "This is an ellipsis ", + "Smart quotes ‘single’ and “double”": "Smart quotes single and double", + "Em dash — and en dash –": "Em dash and en dash ", + "Tabs\tand\nnewlines\r\n": "Tabs and newlines ", + } + + for input_text, expected_output in test_cases.items(): + with self.subTest(input_text=input_text, expected_output=expected_output): + self.assertEqual(handler._replace_unspeakable_characters(input_text), expected_output) + + def test_strip_emoji_characters(self): + handler = SpeechT5TTSHandler() + + # Test cases + test_cases = { + "😊": "", + "😂": "", + "👍": "", + "🏆": "", + "😊😂👍🏆": "", + "😊😊😊": "", + "😂👍🏆😊": "", + "🏆👍😂😊": "", + "👍🏆😊😂": "", + "No emojis here": "No emojis here" + } + + for input_text, expected_output in test_cases.items(): + with self.subTest(input_text=input_text, expected_output=expected_output): + self.assertEqual(handler._strip_emoji_characters(input_text), expected_output) + + def test_prepare_text(self): + handler = SpeechT5TTSHandler() + + # Test cases + test_cases = { + "Emoji 😊 should be removed": "Emoji should be removed", + "Mixed 😊 text with ‘quotes’ and — dashes": "Mixed text with quotes and dashes", + "Multiple spaces": "Multiple spaces", + "😊": "", + "Hello 😊": "Hello", + "😊😊😊": "", + "Mixed text 😊 with emoji": "Mixed text with emoji", + "Multiple emojis 😊😂👍": "Multiple emojis", + "Text with various emojis 😊😂👍🏆": "Text with various emojis", + "Emojis at the end 😊😂👍🏆": "Emojis at the end", + "No emojis here": "No emojis here" + } + + for input_text, expected_output in test_cases.items(): + with self.subTest(input_text=input_text, expected_output=expected_output): + self.assertEqual(handler._prepare_text(input_text), expected_output) + if __name__ == '__main__': unittest.main() From de1a2da01b924c189cc9c06764ececd7924a1c64 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:08:11 -0600 Subject: [PATCH 2/8] remove unused import --- src/airunner/handlers/stt/whisper_handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/airunner/handlers/stt/whisper_handler.py b/src/airunner/handlers/stt/whisper_handler.py index d81fb6be7..f8685ce79 100644 --- a/src/airunner/handlers/stt/whisper_handler.py +++ b/src/airunner/handlers/stt/whisper_handler.py @@ -10,7 +10,6 @@ from airunner.handlers.base_handler import BaseHandler from airunner.enums import SignalCode, ModelType, ModelStatus, LLMChatRole from airunner.exceptions import NaNException -from airunner.settings import DEFAULT_STT_HF_PATH from airunner.utils.clear_memory import clear_memory From 1acb90ce3e5d8c2404a5f013ed23fc5a2f55541e Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:08:39 -0600 Subject: [PATCH 3/8] use property directly without assigning variable --- src/airunner/handlers/stt/whisper_handler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/airunner/handlers/stt/whisper_handler.py b/src/airunner/handlers/stt/whisper_handler.py index f8685ce79..27f6bdb54 100644 --- a/src/airunner/handlers/stt/whisper_handler.py +++ b/src/airunner/handlers/stt/whisper_handler.py @@ -98,11 +98,10 @@ def unload(self): self.change_model_status(ModelType.STT, ModelStatus.UNLOADED) def _load_model(self): - model_path = self.model_path - self.logger.debug(f"Loading model from {model_path}") + self.logger.debug(f"Loading model from {self.model_path}") try: self._model = WhisperForConditionalGeneration.from_pretrained( - model_path, + self.model_path, local_files_only=True, torch_dtype=torch.bfloat16, device_map=self.device, From 72ee5d6f33021b617bda4c4df921f2e37b104873 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:09:10 -0600 Subject: [PATCH 4/8] remove debugging statement to quiet down logs --- src/airunner/workers/tts_generator_worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/airunner/workers/tts_generator_worker.py b/src/airunner/workers/tts_generator_worker.py index 8f07e248e..c3853c175 100644 --- a/src/airunner/workers/tts_generator_worker.py +++ b/src/airunner/workers/tts_generator_worker.py @@ -92,7 +92,6 @@ def handle_message(self, data): return # Add the incoming tokens to the list - self.logger.debug("Adding tokens to list...") self.tokens.extend(data["message"]) finalize = data.get("finalize", False) From cd0e00087450330eb2d1f9219b827488fb5728be Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:12:44 -0600 Subject: [PATCH 5/8] rearranged functions for readability --- .../handlers/tts/speecht5_tts_handler.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/src/airunner/handlers/tts/speecht5_tts_handler.py b/src/airunner/handlers/tts/speecht5_tts_handler.py index c777e771f..948816e33 100644 --- a/src/airunner/handlers/tts/speecht5_tts_handler.py +++ b/src/airunner/handlers/tts/speecht5_tts_handler.py @@ -297,6 +297,34 @@ def _move_inputs_to_device(self, inputs): self.logger.error(e) return inputs + def _unload_model(self): + self._model = None + self._current_model = None + clear_memory(self.memory_settings.default_gpu_tts) + + def _unload_processor(self): + self._processor = None + clear_memory(self.memory_settings.default_gpu_tts) + + def _unload_vocoder(self): + self._vocoder = None + clear_memory(self.memory_settings.default_gpu_tts) + + def _unload_tokenizer(self): + self.tokenizer = None + clear_memory(self.memory_settings.default_gpu_tts) + + def unblock_tts_generator_signal(self): + self.logger.debug("Unblocking text-to-speech generation...") + self._do_interrupt = False + self._paused = False + + def interrupt_process_signal(self): + self._do_interrupt = True + self._cancel_generated_speech = False + self._paused = True + self._text_queue = Queue() + def _prepare_text(self, text) -> str: text = self._replace_unspeakable_characters(text) text = self._strip_emoji_characters(text) @@ -398,31 +426,3 @@ def _replace_numbers_with_words(text) -> str: result = re.sub(r'\b([AP])M\b', r'\1 M', result) return result - - def _unload_model(self): - self._model = None - self._current_model = None - clear_memory(self.memory_settings.default_gpu_tts) - - def _unload_processor(self): - self._processor = None - clear_memory(self.memory_settings.default_gpu_tts) - - def _unload_vocoder(self): - self._vocoder = None - clear_memory(self.memory_settings.default_gpu_tts) - - def _unload_tokenizer(self): - self.tokenizer = None - clear_memory(self.memory_settings.default_gpu_tts) - - def unblock_tts_generator_signal(self): - self.logger.debug("Unblocking text-to-speech generation...") - self._do_interrupt = False - self._paused = False - - def interrupt_process_signal(self): - self._do_interrupt = True - self._cancel_generated_speech = False - self._paused = True - self._text_queue = Queue() From 5499e3c8dd637255c7398084ee41520bf1710fc2 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 09:00:42 -0600 Subject: [PATCH 6/8] fix switch widgets --- src/airunner/widgets/canvas/input_image.py | 6 ++++++ src/airunner/widgets/switch_widget/switch_widget.py | 6 ++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/airunner/widgets/canvas/input_image.py b/src/airunner/widgets/canvas/input_image.py index 624ba9d5e..cade55ded 100644 --- a/src/airunner/widgets/canvas/input_image.py +++ b/src/airunner/widgets/canvas/input_image.py @@ -76,12 +76,18 @@ def showEvent(self, event): self.ui.link_to_grid_image_button.hide() self.ui.link_to_grid_image_button.hide() self.ui.lock_input_image_button.hide() + self.ui.EnableSwitch.blockSignals(True) + self.ui.EnableSwitch.checked = self.current_settings.enabled + self.ui.EnableSwitch.setChecked(self.current_settings.enabled) + self.ui.EnableSwitch.dPtr.animate(self.current_settings.enabled) + self.ui.EnableSwitch.blockSignals(False) else: self.ui.EnableSwitch.blockSignals(True) self.ui.link_to_grid_image_button.blockSignals(True) self.ui.lock_input_image_button.blockSignals(True) self.ui.EnableSwitch.checked = self.current_settings.enabled self.ui.EnableSwitch.setChecked(self.current_settings.enabled) + self.ui.EnableSwitch.dPtr.animate(self.current_settings.enabled) self.ui.link_to_grid_image_button.setChecked( self.current_settings.use_grid_image_as_input ) diff --git a/src/airunner/widgets/switch_widget/switch_widget.py b/src/airunner/widgets/switch_widget/switch_widget.py index a8998444c..5921c6b6f 100644 --- a/src/airunner/widgets/switch_widget/switch_widget.py +++ b/src/airunner/widgets/switch_widget/switch_widget.py @@ -81,10 +81,8 @@ def __init__(self, parent=None): self.clicked.connect(self.dPtr.animate) self.clicked.connect(self.emitToggled) self._backgroundColor = QColor("blue") # Initialize the internal attribute - - # Initialize the checked state - self.setChecked(True) - self.dPtr.animate(True) + self.setChecked(False) + self.dPtr.animate(False) def emitToggled(self, checked): self.toggled.emit(checked) From 8899b28beb1296471e8e9d70f529756c6829e152 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 09:01:21 -0600 Subject: [PATCH 7/8] re-introduce prompt weight converter allows auto1111 style weights --- .../handlers/stablediffusion/sd_handler.py | 61 ++++--- .../tests/test_prompt_weight_convert.py | 164 ++++++++++++++++++ 2 files changed, 205 insertions(+), 20 deletions(-) create mode 100644 src/airunner/tests/test_prompt_weight_convert.py diff --git a/src/airunner/handlers/stablediffusion/sd_handler.py b/src/airunner/handlers/stablediffusion/sd_handler.py index 410e37709..ce2cd5521 100644 --- a/src/airunner/handlers/stablediffusion/sd_handler.py +++ b/src/airunner/handlers/stablediffusion/sd_handler.py @@ -33,6 +33,7 @@ EngineResponseCode, ModelAction ) from airunner.exceptions import PipeNotLoadedException, InterruptedException +from airunner.handlers.stablediffusion.prompt_weight_bridge import PromptWeightBridge from airunner.settings import MIN_NUM_INFERENCE_STEPS_IMG2IMG from airunner.utils.clear_memory import clear_memory from airunner.utils.convert_base64_to_image import convert_base64_to_image @@ -390,6 +391,26 @@ def _pipeline_class(self): def mask_blur(self) -> int: return self.outpaint_settings_cached.mask_blur + @property + def prompt(self): + prompt = self.generator_settings_cached.prompt + return PromptWeightBridge.convert(prompt) + + @property + def second_prompt(self): + prompt = self.generator_settings_cached.second_prompt + return PromptWeightBridge.convert(prompt) + + @property + def negative_prompt(self): + prompt = self.generator_settings_cached.negative_prompt + return PromptWeightBridge.convert(prompt) + + @property + def second_negative_prompt(self): + prompt = self.generator_settings_cached.second_negative_prompt + return PromptWeightBridge.convert(prompt) + def load_safety_checker(self): """ Public method to load the safety checker model. @@ -1330,43 +1351,43 @@ def _load_prompt_embeds(self): self.logger.debug("Loading prompt embeds") if not self.generator_settings_cached.use_compel: return - prompt = self.generator_settings_cached.prompt - negative_prompt = self.generator_settings_cached.negative_prompt - prompt_2 = self.generator_settings_cached.second_prompt - negative_prompt_2 = self.generator_settings_cached.second_negative_prompt + prompt = self.prompt + negative_prompt = self.negative_prompt + second_prompt = self.second_prompt + second_negative_prompt = self.second_negative_prompt if ( self._current_prompt != prompt or self._current_negative_prompt != negative_prompt - or self._current_prompt_2 != prompt_2 - or self._current_negative_prompt_2 != negative_prompt_2 + or self._current_prompt_2 != second_prompt + or self._current_negative_prompt_2 != second_negative_prompt ): self._unload_latents() self._current_prompt = prompt self._current_negative_prompt = negative_prompt - self._current_prompt_2 = prompt_2 - self._current_negative_prompt_2 = negative_prompt_2 + self._current_prompt_2 = second_prompt + self._current_negative_prompt_2 = second_negative_prompt self._unload_prompt_embeds() pooled_prompt_embeds = None negative_pooled_prompt_embeds = None - if prompt != "" and prompt_2 != "": - compel_prompt = f'("{prompt}", "{prompt_2}").and()' - elif prompt != "" and prompt_2 == "": + if prompt != "" and second_prompt != "": + compel_prompt = f'("{prompt}", "{second_prompt}").and()' + elif prompt != "" and second_prompt == "": compel_prompt = prompt - elif prompt == "" and prompt_2 != "": - compel_prompt = prompt_2 + elif prompt == "" and second_prompt != "": + compel_prompt = second_prompt else: compel_prompt = "" - if negative_prompt != "" and negative_prompt_2 != "": - compel_negative_prompt = f'("{negative_prompt}", "{negative_prompt_2}").and()' - elif negative_prompt != "" and negative_prompt_2 == "": + if negative_prompt != "" and second_negative_prompt != "": + compel_negative_prompt = f'("{negative_prompt}", "{second_negative_prompt}").and()' + elif negative_prompt != "" and second_negative_prompt == "": compel_negative_prompt = negative_prompt - elif negative_prompt == "" and negative_prompt_2 != "": - compel_negative_prompt = negative_prompt_2 + elif negative_prompt == "" and second_negative_prompt != "": + compel_negative_prompt = second_negative_prompt else: compel_negative_prompt = "" @@ -1442,8 +1463,8 @@ def _prepare_data(self, active_rect = None) -> dict: )) else: args.update(dict( - prompt=self.generator_settings_cached.prompt, - negative_prompt=self.generator_settings_cached.negative_prompt + prompt=self.prompt, + negative_prompt=self.negative_prompt )) width = int(self.application_settings_cached.working_width) diff --git a/src/airunner/tests/test_prompt_weight_convert.py b/src/airunner/tests/test_prompt_weight_convert.py new file mode 100644 index 000000000..1428d549c --- /dev/null +++ b/src/airunner/tests/test_prompt_weight_convert.py @@ -0,0 +1,164 @@ +import unittest + +from airunner.handlers.stablediffusion.prompt_weight_bridge import PromptWeightBridge + + +class TestPromptWeightConvert(unittest.TestCase): + def test_simple(self): + prompt = "Example (a GHI:1.4)" + expected_prompt = "Example (a GHI)1.4" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_use_case_a(self): + prompt = "Example (ABC): 1.23 XYZ (DEF) (GHI:1.4)" + expected_prompt = "Example (ABC)1.1: 1.23 XYZ (DEF)1.1 (GHI)1.4" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + # + def test_use_case_b(self): + prompt = "(a dog:0.5) and a cat" + expected_prompt = "(a dog)0.5 and a cat" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_use_case_c(self): + prompt = "A perfect photo of a woman wearing a respirator wandering through the (toxic wasteland:1.3)" + expected_prompt = "A perfect photo of a woman wearing a respirator wandering through the (toxic wasteland)1.3" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_use_case_d(self): + prompt = "(worst quality:0.8), fantasy, cartoon, halftone print, (cinematic:1.2), verybadimagenegative_v1.3, easynegative, (surreal:0.8), (modernism:0.8), (art deco:0.8), (art nouveau:0.8)" + expected_prompt = "(worst quality)0.8, fantasy, cartoon, halftone print, (cinematic)1.2, verybadimagenegative_v1.3, easynegative, (surreal)0.8, (modernism)0.8, (art deco)0.8, (art nouveau)0.8" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_convert_basic_parentheses(self): + prompt = "(a hammer) and a cat in a car" + expected_prompt = "(a hammer)1.1 and a cat in a car" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a foo and (a cat) in a car" + expected_prompt = "a foo and (a cat)1.1 in a car" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "(a bar) and (a cat) in a car" + expected_prompt = "(a bar)1.1 and (a cat)1.1 in a car" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "(a baz) and (a cat) in (a car)" + expected_prompt = "(a baz)1.1 and (a cat)1.1 in (a car)1.1" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "((a car)) and a cat" + expected_prompt = "(a car)1.21 and a cat" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "(((((((((((a car))))))))))) and a cat" + expected_prompt = "(a car)1.4 and a cat" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "(((((((((((a car))))))))))) and (((a cat)))" + expected_prompt = "(a car)1.4 and (a cat)1.33" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_convert_basic_brackets(self): + prompt = "[a hammer] and a cat in a car" + expected_prompt = "(a hammer)0.9 and a cat in a car" + self.assertEqual(PromptWeightBridge.convert_basic_brackets(prompt), expected_prompt) + + prompt = "a foo and [a cat] in a car" + expected_prompt = "a foo and (a cat)0.9 in a car" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "[a bar] and [a cat] in a car" + expected_prompt = "(a bar)0.9 and (a cat)0.9 in a car" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "[a baz] and [a cat] in [a car]" + expected_prompt = "(a baz)0.9 and (a cat)0.9 in (a car)0.9" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "[[a car]] and a cat" + expected_prompt = "(a car)0.79 and a cat" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "[[[[[[[[[[[a car]]]]]]]]]]] and a cat" + expected_prompt = "(a car)0.0 and a cat" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "[[[[[[[[[[[a car]]]]]]]]]]] and [[[a cat]]]" + expected_prompt = "(a car)0.0 and (a cat)0.67" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_convert_prompt_with_weight_value(self): + prompt = "(a bird:0.5) and a plane" + expected_prompt = "(a bird)0.5 and a plane" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a dog and (a cat:0.6)" + expected_prompt = "a dog and (a cat)0.6" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "(a boat:0.5) and (a ship:0.6)" + expected_prompt = "(a boat)0.5 and (a ship)0.6" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "(a man:0.5) and a woman" + expected_prompt = "(a man)0.5 and a woman" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_compel_prompt_weight(self): + prompt = "(a asdf)1.4 and a cat in a car" + expected_prompt = "(a asdf)1.4 and a cat in a car" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a man (eating an apple)++++" + expected_prompt = "a man (eating an apple)++++" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a man (eating fruit)+" + expected_prompt = "a man (eating fruit)+" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a man (eating a bannana)----" + expected_prompt = "a man (eating a bannana)----" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a man (eating bread)-" + expected_prompt = "a man (eating bread)-" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = '("blue sphere", "red cube").blend(0.25,0.75)' + expected_prompt = prompt + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_mixed_prompt(self): + prompt = "a man (drinking (water))0.5" + expected_prompt = "a man (drinking (water)1.1)0.5" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a man (drinking ((milk)))0.5" + expected_prompt = "a man (drinking (milk)1.21)0.5" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a man (drinking+ ((juice)))0.5" + expected_prompt = "a man (drinking+ (juice)1.21)0.5" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a man ((drinking:1.1) ((beer)))0.5" + expected_prompt = "a man ((drinking)1.1 (beer)1.21)0.5" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "a fox (drinking) water in a [bar] from (the country) of [canada]" + expected_prompt = "a fox (drinking)1.1 water in a (bar)0.9 from (the country)1.1 of (canada)0.9" + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + def test_blend_conversion(self): + prompt = "A [frog:turtle:0.1] on a leaf in the forest" + expected_prompt = 'A ("frog", "turtle").blend(0.1, 0.9) on a leaf in the forest' + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "A [frog:turtle:0.5] on a leaf in the forest" + expected_prompt = 'A ("frog", "turtle").blend(0.5, 0.5) on a leaf in the forest' + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) + + prompt = "A (frog:turtle:0.1) on a leaf in the forest" + expected_prompt = 'A ("frog", "turtle").blend(0.9, 0.1) on a leaf in the forest' + self.assertEqual(PromptWeightBridge.convert(prompt), expected_prompt) From 2a867baf660fa0beed30e0b810c4c62f63b5240a Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Sat, 12 Oct 2024 15:08:58 -0600 Subject: [PATCH 8/8] Improvements to - real time voice conversations - text processing for llm - interrupt signal --- src/airunner/enums.py | 1 + src/airunner/handlers/llm/agent/base_agent.py | 27 ++-- .../llm/causal_lm_transformer_base_handler.py | 6 +- src/airunner/handlers/stt/whisper_handler.py | 146 ++++++++---------- .../handlers/tts/speecht5_tts_handler.py | 9 +- .../tests/test_speecht5_tts_handler.py | 2 +- .../tests/test_tts_generator_worker.py | 25 +++ src/airunner/utils/get_torch_device.py | 2 + .../widgets/llm/chat_prompt_widget.py | 11 +- .../widgets/tool_tab/tool_tab_widget.py | 6 - src/airunner/windows/main/main_window.py | 103 ++++++------ src/airunner/worker_manager.py | 42 +---- src/airunner/workers/audio_capture_worker.py | 18 ++- src/airunner/workers/llm_generate_worker.py | 10 +- src/airunner/workers/tts_generator_worker.py | 27 +++- src/airunner/workers/tts_vocalizer_worker.py | 7 +- 16 files changed, 210 insertions(+), 232 deletions(-) create mode 100644 src/airunner/tests/test_tts_generator_worker.py diff --git a/src/airunner/enums.py b/src/airunner/enums.py index 2016054ca..22296306b 100644 --- a/src/airunner/enums.py +++ b/src/airunner/enums.py @@ -352,6 +352,7 @@ class LLMActionType(Enum): TOGGLE_TTS = "TOGGLE TEXT-TO-SPEECH: If the user requests that you turn on or off or toggle text-to-speech, choose this action." PERFORM_RAG_SEARCH = "SEARCH: If the user requests that you search for information, choose this action." SUMMARIZE = "SUMMARIZE" + DO_NOTHING = "DO NOTHING: If the user's request is unclear or you are unable to determine the user's intent, choose this action." diff --git a/src/airunner/handlers/llm/agent/base_agent.py b/src/airunner/handlers/llm/agent/base_agent.py index 2d14c2158..5c255a151 100644 --- a/src/airunner/handlers/llm/agent/base_agent.py +++ b/src/airunner/handlers/llm/agent/base_agent.py @@ -16,6 +16,7 @@ from llama_index.core.chat_engine import ContextChatEngine from llama_index.core import SimpleKeywordTableIndex from llama_index.core.indices.keyword_table import KeywordTableSimpleRetriever +from transformers import TextIteratorStreamer from airunner.handlers.llm.huggingface_llm import HuggingFaceLLM from airunner.handlers.llm.custom_embedding import CustomEmbedding @@ -82,7 +83,7 @@ def __init__(self, *args, **kwargs): self.action = LLMActionType.CHAT self.rendered_template = None self.tokenizer = kwargs.pop("tokenizer", None) - self.streamer = kwargs.pop("streamer", None) + self.streamer = TextIteratorStreamer(self.tokenizer) self.chat_template = kwargs.pop("chat_template", "") self.is_mistral = kwargs.pop("is_mistral", True) self.conversation_id = None @@ -97,12 +98,11 @@ def __init__(self, *args, **kwargs): @property def available_actions(self): return { - 0: LLMActionType.QUIT_APPLICATION, - 1: LLMActionType.TOGGLE_FULLSCREEN, - 2: LLMActionType.TOGGLE_TTS, - 3: LLMActionType.GENERATE_IMAGE, - 4: LLMActionType.PERFORM_RAG_SEARCH, - 5: LLMActionType.CHAT, + 0: LLMActionType.TOGGLE_FULLSCREEN, + 1: LLMActionType.TOGGLE_TTS, + 2: LLMActionType.GENERATE_IMAGE, + 3: LLMActionType.PERFORM_RAG_SEARCH, + 4: LLMActionType.CHAT, } @property @@ -163,7 +163,7 @@ def interrupt_process(self): def do_interrupt_process(self): interrupt = self.do_interrupt - self.do_interrupt = False + self.streamer = TextIteratorStreamer(self.tokenizer) return interrupt @property @@ -303,9 +303,7 @@ def build_system_prompt( self.names_prompt(use_names, botname, username), self.mood(botname, bot_mood, use_mood), system_instructions, - "------\n", - "Chat History:\n", - f"{self.username}: {self.prompt}\n", + self.history_prompt(), ] elif action is LLMActionType.SUMMARIZE: @@ -502,10 +500,9 @@ def run( self.create_conversation() # Add the user's message to history - if action in ( - LLMActionType.CHAT, - LLMActionType.PERFORM_RAG_SEARCH, - LLMActionType.GENERATE_IMAGE, + if action not in ( + LLMActionType.APPLICATION_COMMAND, + LLMActionType.UPDATE_MOOD ): self.add_message_to_history(self.prompt, LLMChatRole.HUMAN) diff --git a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py index dad3e4105..655c28098 100644 --- a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py +++ b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py @@ -226,6 +226,8 @@ def clear_history(self): """ Public method to clear the chat agent history """ + if not self._chat_agent: + return self.logger.debug("Clearing chat history") self._chat_agent.clear_history() @@ -301,7 +303,6 @@ def _load_agent(self): self._chat_agent = BaseAgent( model=self._model, tokenizer=self._tokenizer, - streamer=self._streamer, chat_template=self.chat_template, is_mistral=self.is_mistral, ) @@ -378,8 +379,7 @@ def _load_model_local(self): def _do_generate(self, prompt: str, action: LLMActionType): self.logger.debug("Generating response") - model_path = self.model_path - if self._current_model_path != model_path: + if self._current_model_path != self.model_path: self.unload() self.load() if action is LLMActionType.CHAT and self.chatbot.use_mood: diff --git a/src/airunner/handlers/stt/whisper_handler.py b/src/airunner/handlers/stt/whisper_handler.py index 27f6bdb54..b4f425d9f 100644 --- a/src/airunner/handlers/stt/whisper_handler.py +++ b/src/airunner/handlers/stt/whisper_handler.py @@ -8,7 +8,7 @@ from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor from airunner.handlers.base_handler import BaseHandler -from airunner.enums import SignalCode, ModelType, ModelStatus, LLMChatRole +from airunner.enums import SignalCode, ModelType, ModelStatus from airunner.exceptions import NaNException from airunner.utils.clear_memory import clear_memory @@ -27,6 +27,10 @@ def __init__(self, *args, **kwargs): self._feature_extractor = None self._fs = 16000 + @property + def dtype(self): + return torch.bfloat16 + @property def stt_is_loading(self): return self.model_status is ModelStatus.LOADING @@ -57,15 +61,15 @@ def process_audio(self, audio_data): # Convert the byte string to a float32 array inputs = np.frombuffer(item, dtype=np.int16) inputs = inputs.astype(np.float32) / 32767.0 + transcription = None try: transcription = self._process_inputs(inputs) except Exception as e: self.logger.error(f"Failed to process inputs {e}") self.logger.error(e) - try: - self._process_human_speech(transcription) - except ValueError as e: - self.logger.error(f"Failed to process audio {e}") + + if transcription: + self._send_transcription(transcription) def load(self): if self.stt_is_loading or self.stt_is_loaded: @@ -99,17 +103,18 @@ def unload(self): def _load_model(self): self.logger.debug(f"Loading model from {self.model_path}") + device = self.device try: self._model = WhisperForConditionalGeneration.from_pretrained( self.model_path, local_files_only=True, - torch_dtype=torch.bfloat16, - device_map=self.device, - use_safetensors=True + torch_dtype=self.dtype, + device_map=device, + use_safetensors=True, + force_download=False ) except Exception as e: - self.logger.error(f"Failed to load model") - self.logger.error(e) + self.logger.error(f"Failed to load model: {e}") return None def _load_processor(self): @@ -119,12 +124,11 @@ def _load_processor(self): self._processor = WhisperProcessor.from_pretrained( model_path, local_files_only=True, - torch_dtype=torch.bfloat16, + torch_dtype=self.dtype, device_map=self.device ) except Exception as e: - self.logger.error(f"Failed to load processor") - self.logger.error(e) + self.logger.error(f"Failed to load processor: {e}") return None def _load_feature_extractor(self): @@ -134,7 +138,7 @@ def _load_feature_extractor(self): self._feature_extractor = WhisperFeatureExtractor.from_pretrained( model_path, local_files_only=True, - torch_dtype=torch.bfloat16, + torch_dtype=self.dtype, device_map=self.device ) except Exception as e: @@ -157,59 +161,34 @@ def _unload_feature_extractor(self): self._feature_extractor = None clear_memory(self.device) - def _process_inputs( - self, - inputs: np.ndarray, - role: LLMChatRole = LLMChatRole.HUMAN, - ) -> str: - inputs = torch.from_numpy(inputs) + def _process_inputs(self, inputs: np.ndarray) -> str: + if not self._feature_extractor: + return "" + inputs = torch.from_numpy(inputs).to(torch.float32).to(self.device) if torch.isnan(inputs).any(): raise NaNException + # Move inputs to CPU and ensure they are in float32 before passing to _feature_extractor + inputs = inputs.cpu().to(torch.float32) inputs = self._feature_extractor(inputs, sampling_rate=self._fs, return_tensors="pt") - if torch.isnan(inputs.input_features).any(): - raise NaNException - inputs["input_features"] = inputs["input_features"].to(torch.bfloat16) if torch.isnan(inputs.input_features).any(): raise NaNException - inputs = inputs.to(self._model.device) + inputs["input_features"] = inputs["input_features"].to(self.dtype).to(self.device) if torch.isnan(inputs.input_features).any(): raise NaNException - transcription = self._run(inputs, role) + transcription = self._run(inputs) if transcription is None or 'nan' in transcription: raise NaNException return transcription - def _process_human_speech(self, transcription: str = None): - """ - Process the human speech. - This method is called when the model has processed the human speech - and the transcription is ready to be added to the chat history. - This should only be used for human speech. - :param transcription: - :return: - """ - if transcription == "": - raise ValueError("Transcription is empty") - self.logger.debug("Processing human speech") - data = { - "message": transcription, - "role": LLMChatRole.HUMAN - } - self.emit_signal( - SignalCode.ADD_CHATBOT_MESSAGE_SIGNAL, - data - ) - def _run( self, - inputs, - role: LLMChatRole = LLMChatRole.HUMAN, + inputs ) -> str: """ Run the model on the given inputs. @@ -231,31 +210,39 @@ def _run( if torch.isnan(input_features).any(): raise NaNException - generated_ids = self._model.generate( - input_features=input_features, - # generation_config=None, - # logits_processor=None, - # stopping_criteria=None, - # prefix_allowed_tokens_fn=None, - # synced_gpus=True, - # return_timestamps=None, - # task="transcribe", - # language="en", - # is_multilingual=True, - # prompt_ids=None, - # prompt_condition_type=None, - # condition_on_prev_tokens=None, - temperature=0.8, - # compression_ratio_threshold=None, - # logprob_threshold=None, - # no_speech_threshold=None, - # num_segment_frames=None, - # attention_mask=None, - # time_precision=0.02, - # return_token_timestamps=None, - # return_segments=False, - # return_dict_in_generate=None, - ) + try: + generated_ids = self._model.generate( + input_features=input_features, + # generation_config=None, + # logits_processor=None, + # stopping_criteria=None, + # prefix_allowed_tokens_fn=None, + # synced_gpus=True, + # return_timestamps=None, + # task="transcribe", + # language="en", + is_multilingual=False, + # prompt_ids=None, + # prompt_condition_type=None, + # condition_on_prev_tokens=None, + temperature=0.8, + compression_ratio_threshold=1.35, + logprob_threshold=-1.0, + no_speech_threshold=0.2, + # num_segment_frames=None, + # attention_mask=None, + time_precision=0.02, + # return_token_timestamps=None, + # return_segments=False, + # return_dict_in_generate=None, + ) + except RuntimeError as e: + generated_ids = None + self.logger.error(f"Error in model generation: {e}") + + if generated_ids is None: + return "" + if torch.isnan(generated_ids).any(): raise NaNException @@ -263,16 +250,19 @@ def _run( if len(transcription) == 0 or len(transcription.split(" ")) == 1: return "" - # Emit the transcription so that other handlers can use it + return transcription + + def _send_transcription(self, transcription: str): + """ + Emit the transcription so that other handlers can use it + """ self.emit_signal(SignalCode.AUDIO_PROCESSOR_RESPONSE_SIGNAL, { - "transcription": transcription, - "role": role + "transcription": transcription }) - return transcription - def process_transcription(self, generated_ids) -> str: # Decode the generated ids + generated_ids = generated_ids.to("cpu").to(torch.float32) transcription = self._processor.batch_decode( generated_ids, skip_special_tokens=True diff --git a/src/airunner/handlers/tts/speecht5_tts_handler.py b/src/airunner/handlers/tts/speecht5_tts_handler.py index 948816e33..c89761bc2 100644 --- a/src/airunner/handlers/tts/speecht5_tts_handler.py +++ b/src/airunner/handlers/tts/speecht5_tts_handler.py @@ -328,10 +328,10 @@ def interrupt_process_signal(self): def _prepare_text(self, text) -> str: text = self._replace_unspeakable_characters(text) text = self._strip_emoji_characters(text) - text = self._roman_to_int(text) + # the following function is currently disabled because we must first find a + # reliable way to handle the word "I" and distinguish it from the Roman numeral "I" + # text = self._roman_to_int(text) text = self._replace_numbers_with_words(text) - text = re.sub(r"\s+", " ", text) # Remove extra spaces - text = text.strip() return text @staticmethod @@ -339,13 +339,12 @@ def _replace_unspeakable_characters(text) -> str: # strip things like ellipsis, etc text = text.replace("...", " ") text = text.replace("…", " ") - text = text.replace("’", "") + text = text.replace("’", "'") text = text.replace("“", "") text = text.replace("”", "") text = text.replace("‘", "") text = text.replace("–", "") text = text.replace("—", "") - text = text.replace("'", "") text = text.replace('"', "") text = text.replace("-", "") text = text.replace("-", "") diff --git a/src/airunner/tests/test_speecht5_tts_handler.py b/src/airunner/tests/test_speecht5_tts_handler.py index 64e1bf9b7..2985355dd 100644 --- a/src/airunner/tests/test_speecht5_tts_handler.py +++ b/src/airunner/tests/test_speecht5_tts_handler.py @@ -43,7 +43,7 @@ def test_roman_to_int(self): "M": "1000", "MMXXI": "2021", "This is a IV test": "This is a 4 test", - "A test with no roman numerals": "A test with no roman numerals" + "A test with no roman numerals": "A test with no roman numerals", } for roman, expected in test_cases.items(): diff --git a/src/airunner/tests/test_tts_generator_worker.py b/src/airunner/tests/test_tts_generator_worker.py new file mode 100644 index 000000000..f0cccff5b --- /dev/null +++ b/src/airunner/tests/test_tts_generator_worker.py @@ -0,0 +1,25 @@ +import unittest +from airunner.workers.tts_generator_worker import TTSGeneratorWorker + +class TestTTSGeneratorWorker(unittest.TestCase): + + def test_split_text_at_punctuation(self): + test_cases = [ + ("Hello world.", ["Hello world"]), + ("Hello world. How are you?", ["Hello world", "How are you"]), + ("Hello! How are you? I'm fine.", ["Hello", "How are you", "I'm fine"]), + ("No punctuation here", ["No punctuation here"]), + ("Multiple\nlines\nhere", ["Multiple", "lines", "here"]), + ("Comma, separated, values", ["Comma", "separated", "values"]), + ("Mixed punctuation! Really? Yes.", ["Mixed punctuation", "Really", "Yes"]), + ("The time is 12:45.", ["The time is 1245"]), + ("Meet me at 09:30 AM.", ["Meet me at 0930 AM"]), + ("It happened at 23:59:59.", ["It happened at 235959"]), + ] + + for text, expected_chunks in test_cases: + with self.subTest(text=text, expected_chunks=expected_chunks): + self.assertEqual(TTSGeneratorWorker._split_text_at_punctuation(text), expected_chunks) + +if __name__ == '__main__': + unittest.main() diff --git a/src/airunner/utils/get_torch_device.py b/src/airunner/utils/get_torch_device.py index 92d37623b..51b3235c5 100644 --- a/src/airunner/utils/get_torch_device.py +++ b/src/airunner/utils/get_torch_device.py @@ -3,4 +3,6 @@ def get_torch_device(card_index: int = 0): use_cuda = torch.cuda.is_available() + if not use_cuda: + print("WARNING: CUDA NOT AVAILABLE, USING CPU") return torch.device(f"cuda:{card_index}" if use_cuda else "cpu") diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py index 293692881..2503cd5f2 100644 --- a/src/airunner/widgets/llm/chat_prompt_widget.py +++ b/src/airunner/widgets/llm/chat_prompt_widget.py @@ -98,8 +98,8 @@ def _set_conversation_widgets(self, messages): def on_hear_signal(self, data: dict): transcription = data["transcription"] - self.respond_to_voice(transcription) self.ui.prompt.setPlainText(transcription) + self.ui.send_button.click() def on_add_to_conversation_signal(self, name, text, is_bot): self.add_message_to_conversation(name=name, message=text, is_bot=is_bot) @@ -161,6 +161,9 @@ def action_button_clicked_send(self): def interrupt_button_clicked(self): self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL) + self.stop_progress_bar() + self.generating = False + self.enable_send_button() @property def action(self) -> str: @@ -289,12 +292,6 @@ def display_action_menu(self): def insert_newline(self): self.ui.prompt.insertPlainText("\n") - - def respond_to_voice(self, transcript: str): - transcript = transcript.strip() - if transcript == "." or transcript is None or transcript == "": - return - self.do_generate(prompt_override=transcript) def describe_image(self, image, callback): self.do_generate( diff --git a/src/airunner/widgets/tool_tab/tool_tab_widget.py b/src/airunner/widgets/tool_tab/tool_tab_widget.py index 439698f4b..383e29e25 100644 --- a/src/airunner/widgets/tool_tab/tool_tab_widget.py +++ b/src/airunner/widgets/tool_tab/tool_tab_widget.py @@ -10,9 +10,3 @@ class ToolTabWidget(BaseWidget): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.__application_settings = QSettings(ORGANIZATION, APPLICATION_NAME) - - def showEvent(self, event): - self.ui.tool_tab_widget_container.setCurrentIndex( - int(self.__application_settings.value("tool_tab_widget_index", defaultValue=0)) - ) diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index 2e870f814..3af0687fe 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -49,6 +49,14 @@ ) from airunner.styles_mixin import StylesMixin from airunner.utils.convert_image_to_base64 import convert_image_to_base64 +from airunner.utils.create_worker import create_worker +from airunner.workers.audio_capture_worker import AudioCaptureWorker +from airunner.workers.audio_processor_worker import AudioProcessorWorker +from airunner.workers.llm_generate_worker import LLMGenerateWorker +from airunner.workers.mask_generator_worker import MaskGeneratorWorker +from airunner.workers.sd_worker import SDWorker +from airunner.workers.tts_generator_worker import TTSGeneratorWorker +from airunner.workers.tts_vocalizer_worker import TTSVocalizerWorker from airunner.utils.get_version import get_version from airunner.utils.set_widget_state import set_widget_state @@ -64,7 +72,6 @@ from airunner.windows.prompt_browser.prompt_browser import PromptBrowser from airunner.windows.settings.airunner_settings import SettingsWindow from airunner.windows.update.update_window import UpdateWindow -from airunner.worker_manager import WorkerManager class MainWindow( @@ -104,23 +111,10 @@ class MainWindow( def __init__( self, *args, - disable_sd: bool = False, - disable_llm: bool = False, - disable_tts: bool = False, - disable_stt: bool = False, - use_cuda: bool = True, - tts_enabled: bool = False, - stt_enabled: bool = False, - ai_mode: bool = True, defendatron=None, **kwargs ): self.ui = self.ui_class_() - self.disable_sd = disable_sd - self.disable_llm = disable_llm - self.disable_tts = disable_tts - self.disable_stt = disable_stt - self.defendatron = defendatron self.quitting = False self.update_popup = None @@ -146,7 +140,6 @@ def __init__( self.status_error_color = STATUS_ERROR_COLOR self.status_normal_color_light = STATUS_NORMAL_COLOR_LIGHT self.status_normal_color_dark = STATUS_NORMAL_COLOR_DARK - self.is_started = False self._themes = None self.button_clicked_signal = Signal(dict) self.status_widget = None @@ -158,35 +151,19 @@ def __init__( self.listening = False self.initialized = False self._model_status = {model_type: ModelStatus.UNLOADED for model_type in ModelType} - self.logger = Logger(prefix=self.__class__.__name__) self.logger.debug("Starting AI Runnner") MediatorMixin.__init__(self) SettingsMixin.__init__(self) - super().__init__(*args, **kwargs) - self._updating_settings = True - self.__application_settings = QSettings(ORGANIZATION, APPLICATION_NAME) - PipelineMixin.__init__(self) AIModelMixin.__init__(self) self._updating_settings = False - + self._worker_manager = None self.register_signals() - self.initialize_ui() - self.worker_manager = None - self.is_started = True - self.image_window = None - - for item in ( - (SignalCode.AI_MODELS_SAVE_OR_UPDATE_SIGNAL, self.on_ai_models_save_or_update_signal), - (SignalCode.NAVIGATE_TO_URL, self.on_navigate_to_url), - ): - self.register(item[0], item[1]) - - self.emit_signal(SignalCode.APPLICATION_MAIN_WINDOW_LOADED_SIGNAL, { "main_window": self }) + self._initialize_workers() @property def generator_tab_widget(self): @@ -536,25 +513,30 @@ def show_layers(self): def register_signals(self): self.logger.debug("Connecting signals") - self.register(SignalCode.SD_SAVE_PROMPT_SIGNAL, self.on_save_stablediffusion_prompt_signal) - self.register(SignalCode.QUIT_APPLICATION, self.action_quit_triggered) - self.register(SignalCode.SD_NSFW_CONTENT_DETECTED_SIGNAL, self.on_nsfw_content_detected_signal) - self.register(SignalCode.ENABLE_BRUSH_TOOL_SIGNAL, lambda _message: self.action_toggle_brush(True)) - self.register(SignalCode.ENABLE_ERASER_TOOL_SIGNAL, lambda _message: self.action_toggle_eraser(True)) - self.register(SignalCode.ENABLE_SELECTION_TOOL_SIGNAL, lambda _message: self.action_toggle_select(True)) - self.register(SignalCode.ENABLE_MOVE_TOOL_SIGNAL, lambda _message: self.action_toggle_active_grid_area(True)) - self.register(SignalCode.BASH_EXECUTE_SIGNAL, self.on_bash_execute_signal) - self.register(SignalCode.WRITE_FILE, self.on_write_file_signal) - self.register(SignalCode.TOGGLE_FULLSCREEN_SIGNAL, self.on_toggle_fullscreen_signal) - self.register(SignalCode.TOGGLE_TTS_SIGNAL, self.on_toggle_tts) - self.register(SignalCode.TOGGLE_SD_SIGNAL, self.on_toggle_sd) - self.register(SignalCode.TOGGLE_LLM_SIGNAL, self.on_toggle_llm) - self.register(SignalCode.APPLICATION_RESET_SETTINGS_SIGNAL, self.action_reset_settings) - self.register(SignalCode.APPLICATION_RESET_PATHS_SIGNAL, self.on_reset_paths_signal) - self.register(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal) - self.register(SignalCode.KEYBOARD_SHORTCUTS_UPDATED, self.on_keyboard_shortcuts_updated) - self.register(SignalCode.HISTORY_UPDATED, self.on_history_updated), - self.register(SignalCode.REFRESH_STYLESHEET_SIGNAL, self.on_theme_changed_signal) + for item in ( + (SignalCode.SD_SAVE_PROMPT_SIGNAL, self.on_save_stablediffusion_prompt_signal), + (SignalCode.QUIT_APPLICATION, self.action_quit_triggered), + (SignalCode.SD_NSFW_CONTENT_DETECTED_SIGNAL, self.on_nsfw_content_detected_signal), + (SignalCode.ENABLE_BRUSH_TOOL_SIGNAL, lambda _message: self.action_toggle_brush(True)), + (SignalCode.ENABLE_ERASER_TOOL_SIGNAL, lambda _message: self.action_toggle_eraser(True)), + (SignalCode.ENABLE_SELECTION_TOOL_SIGNAL, lambda _message: self.action_toggle_select(True)), + (SignalCode.ENABLE_MOVE_TOOL_SIGNAL, lambda _message: self.action_toggle_active_grid_area(True)), + (SignalCode.BASH_EXECUTE_SIGNAL, self.on_bash_execute_signal), + (SignalCode.WRITE_FILE, self.on_write_file_signal), + (SignalCode.TOGGLE_FULLSCREEN_SIGNAL, self.on_toggle_fullscreen_signal), + (SignalCode.TOGGLE_TTS_SIGNAL, self.on_toggle_tts), + (SignalCode.TOGGLE_SD_SIGNAL, self.on_toggle_sd), + (SignalCode.TOGGLE_LLM_SIGNAL, self.on_toggle_llm), + (SignalCode.APPLICATION_RESET_SETTINGS_SIGNAL, self.action_reset_settings), + (SignalCode.APPLICATION_RESET_PATHS_SIGNAL, self.on_reset_paths_signal), + (SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal), + (SignalCode.KEYBOARD_SHORTCUTS_UPDATED, self.on_keyboard_shortcuts_updated), + (SignalCode.HISTORY_UPDATED, self.on_history_updated), + (SignalCode.REFRESH_STYLESHEET_SIGNAL, self.on_theme_changed_signal), + (SignalCode.AI_MODELS_SAVE_OR_UPDATE_SIGNAL, self.on_ai_models_save_or_update_signal), + (SignalCode.NAVIGATE_TO_URL, self.on_navigate_to_url), + ): + self.register(item[0], item[1]) def on_reset_paths_signal(self): self.reset_path_settings() @@ -607,6 +589,7 @@ def initialize_ui(self): self.initialize_widget_elements() self.ui.actionUndo.setEnabled(False) self.ui.actionRedo.setEnabled(False) + self.emit_signal(SignalCode.APPLICATION_MAIN_WINDOW_LOADED_SIGNAL, {"main_window": self}) def initialize_widget_elements(self): for item in ( @@ -935,7 +918,6 @@ def showEvent(self, event): icon_data[1], "dark" if self.application_settings.dark_mode_enabled else "light" ) - self._initialize_worker_manager() self.logger.debug("Showing window") self._set_keyboard_shortcuts() @@ -981,14 +963,15 @@ def _set_keyboard_shortcuts(self): session.close() - def _initialize_worker_manager(self): + def _initialize_workers(self): self.logger.debug("Initializing worker manager") - self.worker_manager = WorkerManager( - disable_sd=self.disable_sd, - disable_llm=self.disable_llm, - disable_tts=self.disable_tts, - disable_stt=self.disable_stt - ) + self._mask_generator_worker = create_worker(MaskGeneratorWorker) + self._sd_worker = create_worker(SDWorker) + self._stt_audio_capture_worker = create_worker(AudioCaptureWorker) + self._stt_audio_processor_worker = create_worker(AudioProcessorWorker) + self._tts_generator_worker = create_worker(TTSGeneratorWorker) + self._tts_vocalizer_worker = create_worker(TTSVocalizerWorker) + self._llm_generate_worker = create_worker(LLMGenerateWorker) def _initialize_filter_actions(self): # add more filters: diff --git a/src/airunner/worker_manager.py b/src/airunner/worker_manager.py index 8295de332..fe0d7dbea 100644 --- a/src/airunner/worker_manager.py +++ b/src/airunner/worker_manager.py @@ -22,14 +22,7 @@ class WorkerManager(QObject, MediatorMixin, SettingsMixin): request_signal_status = Signal(str) image_generated_signal = Signal(dict) - def __init__( - self, - disable_sd: bool = False, - disable_llm: bool = False, - disable_tts: bool = False, - disable_stt: bool = False, - agent_options: dict = None - ): + def __init__(self): MediatorMixin.__init__(self) SettingsMixin.__init__(self) super().__init__() @@ -42,32 +35,7 @@ def __init__( self._stt_audio_capture_worker = None self._stt_audio_processor_worker = None - self.agent_options = agent_options - - if not disable_sd: - self.register_sd_workers() - - if not disable_llm: - self.register_llm_workers(self.agent_options) - - if not disable_tts: - self.register_tts_workers() - - if not disable_stt: - self.register_stt_workers() - - self.mask_generator_worker = create_worker(MaskGeneratorWorker) - - def register_sd_workers(self): - self._sd_worker = create_worker(SDWorker) - - def register_llm_workers(self, agent_options): - self._llm_generate_worker = create_worker(LLMGenerateWorker, agent_options=agent_options) - - def register_tts_workers(self): - self._tts_generator_worker = create_worker(TTSGeneratorWorker) - self._tts_vocalizer_worker = create_worker(TTSVocalizerWorker) - - def register_stt_workers(self): - self._stt_audio_capture_worker = create_worker(AudioCaptureWorker) - self._stt_audio_processor_worker = create_worker(AudioProcessorWorker) + self.register_sd_workers() + self.register_llm_workers() + self.register_tts_workers() + self.register_stt_workers() diff --git a/src/airunner/workers/audio_capture_worker.py b/src/airunner/workers/audio_capture_worker.py index 4f05e055e..e4a9ec029 100644 --- a/src/airunner/workers/audio_capture_worker.py +++ b/src/airunner/workers/audio_capture_worker.py @@ -1,5 +1,4 @@ import queue -import threading import time import sounddevice as sd @@ -30,7 +29,9 @@ def __init__(self): self.stream = None self.running = False self._audio_process_queue = queue.Queue() - self._capture_thread = None + #self._capture_thread = None + if self.application_settings.stt_enabled: + self._start_listening() def on_AudioCaptureWorker_response_signal(self, message: dict): item: np.ndarray = message["item"] @@ -38,13 +39,11 @@ def on_AudioCaptureWorker_response_signal(self, message: dict): self.add_to_queue(item) def on_stt_start_capture_signal(self): - if self._capture_thread is not None and self._capture_thread.is_alive(): - return - self._capture_thread = threading.Thread(target=self._start_listening) - self._capture_thread.start() + if not self.listening: + self._start_listening() def on_stt_stop_capture_signal(self): - if self._capture_thread is not None and self._capture_thread.is_alive(): + if self.listening: self._stop_listening() def start(self): @@ -62,9 +61,11 @@ def start(self): try: chunk, overflowed = self.stream.read(int(chunk_duration * fs)) except sd.PortAudioError as e: + self.logger.error(f"PortAudioError: {e}") QThread.msleep(SLEEP_TIME_IN_MS) continue if np.max(np.abs(chunk)) > volume_input_threshold: # check if chunk is not silence + self.logger.debug("Heard voice") is_receiving_input = True self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL) voice_input_start_time = time.time() @@ -73,6 +74,7 @@ def start(self): end_time = voice_input_start_time + silence_buffer_seconds if time.time() >= end_time: if len(recording) > 0: + self.logger.debug("Sending audio to audio_processor_worker") self.emit_signal( SignalCode.AUDIO_CAPTURE_WORKER_RESPONSE_SIGNAL, { @@ -113,4 +115,4 @@ def _stop_listening(self): self.stream.close() except Exception as e: self.logger.error(e) - self._capture_thread.join() + # self._capture_thread.join() diff --git a/src/airunner/workers/llm_generate_worker.py b/src/airunner/workers/llm_generate_worker.py index dcc560bed..90aa2d53f 100644 --- a/src/airunner/workers/llm_generate_worker.py +++ b/src/airunner/workers/llm_generate_worker.py @@ -19,6 +19,7 @@ def __init__(self, agent_options=None): (SignalCode.RAG_RELOAD_INDEX_SIGNAL, self.on_llm_reload_rag_index_signal), (SignalCode.ADD_CHATBOT_MESSAGE_SIGNAL, self.on_llm_add_chatbot_response_to_history), (SignalCode.LOAD_CONVERSATION, self.on_llm_load_conversation), + (SignalCode.INTERRUPT_PROCESS_SIGNAL, self.llm_on_interrupt_process_signal), ): self.register(signal[0], signal[1]) @@ -42,16 +43,19 @@ def _load_llm(self, data): callback(data) def on_llm_clear_history_signal(self): - self.llm.clear_history() + if self.llm: + self.llm.clear_history() def on_llm_request_signal(self, message: dict): self.add_to_queue(message) def llm_on_interrupt_process_signal(self): - self.llm.do_interrupt() + if self.llm: + self.llm.do_interrupt() def on_llm_reload_rag_index_signal(self): - self.llm.reload_rag() + if self.llm: + self.llm.reload_rag() def on_llm_add_chatbot_response_to_history(self, message): self.llm.add_chatbot_response_to_history(message) diff --git a/src/airunner/workers/tts_generator_worker.py b/src/airunner/workers/tts_generator_worker.py index c3853c175..1525fe3de 100644 --- a/src/airunner/workers/tts_generator_worker.py +++ b/src/airunner/workers/tts_generator_worker.py @@ -1,4 +1,5 @@ import queue +import re import threading from airunner.enums import SignalCode, TTSModel, ModelStatus @@ -51,10 +52,11 @@ def on_interrupt_process_signal(self): self.tts.interrupt_process_signal() def on_unblock_tts_generator_signal(self): - self.logger.debug("Unblocking TTS generation...") - self.do_interrupt = False - self.paused = False - self.tts.unblock_tts_generator_signal() + if self.application_settings.tts_enabled: + self.logger.debug("Unblocking TTS generation...") + self.do_interrupt = False + self.paused = False + self.tts.unblock_tts_generator_signal() def on_enable_tts_signal(self): if self.tts: @@ -98,6 +100,15 @@ def handle_message(self, data): # Convert the tokens to a string text = "".join(self.tokens) + # Regular expression to match timestamps in the format HH:MM + timestamp_pattern = re.compile(r'\b(\d{1,2}):(\d{2})\b') + + # Replace the colon in the matched timestamps with a space + text = timestamp_pattern.sub(r'\1 \2', text) + + def word_count(s): + return len(s.split()) + if finalize: self._generate(text) self.play_queue_started = True @@ -112,12 +123,16 @@ def handle_message(self, data): if p in text: split_text = text.split(p, 1) # Split at the first occurrence of punctuation if len(split_text) > 1: - sentence = split_text[0] + before, after = split_text[0], split_text[1] + if p == ",": + if word_count(before) < 3 or word_count(after) < 3: + continue # Skip splitting if there are not enough words around the comma + sentence = before self._generate(sentence) self.play_queue_started = True # Convert the remaining string back to a list of tokens - remaining_text = split_text[1].strip() + remaining_text = after.strip() if not self.do_interrupt: self.tokens = list(remaining_text) break diff --git a/src/airunner/workers/tts_vocalizer_worker.py b/src/airunner/workers/tts_vocalizer_worker.py index bf03c24e2..67d1abd5e 100644 --- a/src/airunner/workers/tts_vocalizer_worker.py +++ b/src/airunner/workers/tts_vocalizer_worker.py @@ -36,9 +36,10 @@ def on_interrupt_process_signal(self): self.queue = Queue() def on_unblock_tts_generator_signal(self): - self.logger.debug("Starting TTS stream...") - self.accept_message = True - self.stream.start() + if self.application_settings.tts_enabled: + self.logger.debug("Starting TTS stream...") + self.accept_message = True + self.stream.start() def start_stream(self): if sd.query_devices(kind='output'):