diff --git a/src/airunner/aihandler/base_handler.py b/src/airunner/aihandler/base_handler.py new file mode 100644 index 000000000..f412cbeb6 --- /dev/null +++ b/src/airunner/aihandler/base_handler.py @@ -0,0 +1,13 @@ +from PyQt6.QtCore import QObject + +from airunner.mediator_mixin import MediatorMixin +from airunner.windows.main.settings_mixin import SettingsMixin +from airunner.aihandler.logger import Logger + + +class BaseHandler(QObject, MediatorMixin, SettingsMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + SettingsMixin.__init__(self) + MediatorMixin.__init__(self) + self.logger = Logger(prefix=self.__class__.__name__) diff --git a/src/airunner/aihandler/llm_handler.py b/src/airunner/aihandler/llm_handler.py index 83b6172a4..e50beea9c 100644 --- a/src/airunner/aihandler/llm_handler.py +++ b/src/airunner/aihandler/llm_handler.py @@ -9,14 +9,13 @@ from transformers import InstructBlipForConditionalGeneration from transformers import InstructBlipProcessor from transformers import TextIteratorStreamer - -from PyQt6.QtCore import QObject +from airunner.aihandler.base_handler import BaseHandler from airunner.aihandler.logger import Logger from airunner.mediator_mixin import MediatorMixin -class LLMHandler(QObject, MediatorMixin): +class LLMHandler(BaseHandler): logger = Logger(prefix="LLMHandler") dtype = "" local_files_only = True @@ -78,10 +77,6 @@ def has_gpu(self): if self.dtype == "32bit" or not self.use_gpu: return False return torch.cuda.is_available() - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - MediatorMixin.__init__(self) def move_to_cpu(self): if self.model: diff --git a/src/airunner/aihandler/sd_handler.py b/src/airunner/aihandler/sd_handler.py index 4653d6428..e8074bdad 100644 --- a/src/airunner/aihandler/sd_handler.py +++ b/src/airunner/aihandler/sd_handler.py @@ -22,6 +22,7 @@ from diffusers import StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, AsymmetricAutoencoderKL from diffusers import ConsistencyDecoderVAE from transformers import AutoFeatureExtractor +from airunner.aihandler.base_handler import BaseHandler from airunner.aihandler.enums import FilterType from airunner.aihandler.mixins.compel_mixin import CompelMixin @@ -40,8 +41,6 @@ from airunner.windows.main.pipeline_mixin import PipelineMixin from airunner.windows.main.controlnet_model_mixin import ControlnetModelMixin from airunner.windows.main.ai_model_mixin import AIModelMixin -from airunner.mediator_mixin import MediatorMixin -from airunner.windows.main.settings_mixin import SettingsMixin from airunner.service_locator import ServiceLocator from airunner.utils import clear_memory @@ -49,7 +48,7 @@ class SDHandler( - QObject, + BaseHandler, MergeMixin, LoraMixin, MemoryEfficientMixin, @@ -59,14 +58,12 @@ class SDHandler( SchedulerMixin, # Data Mixins - SettingsMixin, LayerMixin, LoraDataMixin, EmbeddingDataMixin, PipelineMixin, ControlnetModelMixin, AIModelMixin, - MediatorMixin ): logger = Logger(prefix="SDHandler") _current_model: str = "" @@ -720,15 +717,13 @@ def original_model_data(self): def __init__(self, **kwargs): #self.logger.set_level(LOG_LEVEL) - MediatorMixin.__init__(self) - SettingsMixin.__init__(self) + super().__init__() LayerMixin.__init__(self) LoraDataMixin.__init__(self) EmbeddingDataMixin.__init__(self) PipelineMixin.__init__(self) ControlnetModelMixin.__init__(self) AIModelMixin.__init__(self) - super().__init__() self.logger.info("Loading Stable Diffusion model runner...") self.safety_checker_model = self.models_by_pipeline_action("safety_checker") self.text_encoder_model = self.models_by_pipeline_action("text_encoder") diff --git a/src/airunner/aihandler/speech_to_text.py b/src/airunner/aihandler/speech_to_text.py deleted file mode 100644 index aa0afebd3..000000000 --- a/src/airunner/aihandler/speech_to_text.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -import numpy as np - -from transformers import AutoProcessor, WhisperForConditionalGeneration, AutoFeatureExtractor - -from PyQt6.QtCore import QObject -from PyQt6.QtCore import pyqtSignal - -from airunner.aihandler.logger import Logger -from airunner.mediator_mixin import MediatorMixin - - -class STTHandler(QObject, MediatorMixin): - logger = Logger(prefix="STTHandler") - listening = False - move_to_cpu_signal = pyqtSignal() - - def on_process_audio(self, audio_data, fs): - inputs = np.squeeze(audio_data) - inputs = self.feature_extractor(inputs, sampling_rate=fs, return_tensors="pt") - inputs = inputs.to(self.model.device) - transcription = self.run(inputs) - self.emit("stt_audio_processed", transcription) - - def on_move_to_cpu(self): - self.logger.info("Moving model to CPU") - self.model = self.model.to("cpu") - - def __init__(self): - super().__init__() - MediatorMixin.__init__(self) - self.load_model() - self.register("move_to_cpu_signal", self) - self.register("process_audio", self) - - @property - def device(self): - return torch.device("cuda" if self.use_cuda else "cpu") - - @property - def use_cuda(self): - return torch.cuda.is_available() - - def load_model(self): - self.logger.info("Loading model") - self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en").to(self.device) - self.processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") - - is_on_gpu = False - def move_to_gpu(self): - if not self.is_on_gpu: - self.logger.info("Moving model to GPU") - self.model = self.model.to(self.device) - self.processor = self.processor - self.feature_extractor = self.feature_extractor - self.is_on_gpu = True - - def move_inputs_to_device(self, inputs): - if self.use_cuda: - self.logger.info("Moving inputs to CUDA") - try: - inputs = {k: v.cuda() for k, v in inputs.items()} - except AttributeError: - pass - return inputs - - def run(self, inputs): - self.logger.info("Running model") - input_features = inputs.input_features - input_features = self.move_inputs_to_device(input_features) - generated_ids = self.model.generate(inputs=input_features) - transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - transcription = transcription.strip() - print("transcription: ", transcription) - if len(transcription) == 0 or len(transcription.split(" ")) == 1: - return None - return transcription diff --git a/src/airunner/aihandler/stt_handler.py b/src/airunner/aihandler/stt_handler.py index aa0afebd3..2eb7f3f95 100644 --- a/src/airunner/aihandler/stt_handler.py +++ b/src/airunner/aihandler/stt_handler.py @@ -3,15 +3,12 @@ from transformers import AutoProcessor, WhisperForConditionalGeneration, AutoFeatureExtractor -from PyQt6.QtCore import QObject +from airunner.aihandler.base_handler import BaseHandler + from PyQt6.QtCore import pyqtSignal -from airunner.aihandler.logger import Logger -from airunner.mediator_mixin import MediatorMixin - -class STTHandler(QObject, MediatorMixin): - logger = Logger(prefix="STTHandler") +class STTHandler(BaseHandler): listening = False move_to_cpu_signal = pyqtSignal() @@ -26,9 +23,8 @@ def on_move_to_cpu(self): self.logger.info("Moving model to CPU") self.model = self.model.to("cpu") - def __init__(self): - super().__init__() - MediatorMixin.__init__(self) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.load_model() self.register("move_to_cpu_signal", self) self.register("process_audio", self) diff --git a/src/airunner/aihandler/tts.py b/src/airunner/aihandler/tts_handler.py similarity index 97% rename from src/airunner/aihandler/tts.py rename to src/airunner/aihandler/tts_handler.py index 790a72089..59b160928 100644 --- a/src/airunner/aihandler/tts.py +++ b/src/airunner/aihandler/tts_handler.py @@ -3,17 +3,14 @@ from queue import Queue -from PyQt6.QtCore import QObject, pyqtSlot +from PyQt6.QtCore import pyqtSlot from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, BarkModel, BarkProcessor from datasets import load_dataset +from airunner.aihandler.base_handler import BaseHandler -from airunner.aihandler.logger import Logger -from airunner.mediator_mixin import MediatorMixin -from airunner.windows.main.settings_mixin import SettingsMixin - -class TTS(QObject, MediatorMixin, SettingsMixin): +class TTSHandler(BaseHandler): """ Generates speech from given text. Responsible for managing the model, processor, vocoder, and speaker embeddings. @@ -21,7 +18,6 @@ class TTS(QObject, MediatorMixin, SettingsMixin): Use from a worker to avoid blocking the main thread. """ - logger = Logger(prefix="TTS") character_replacement_map = { "\n": " ", "’": "'", @@ -130,9 +126,7 @@ def sentence_chunks(self): return self.tts_settings["sentence_chunks"] def __init__(self, *args, **kwargs): - super().__init__() - SettingsMixin.__init__(self) - MediatorMixin.__init__(self) + super().__init__(*args, **kwargs) self.logger.info("Loading") self.corpus = [] self.processor = None diff --git a/src/airunner/aihandler/vision.py b/src/airunner/aihandler/vision_handler.py similarity index 51% rename from src/airunner/aihandler/vision.py rename to src/airunner/aihandler/vision_handler.py index 5e4c6ac98..e4b0ebdd7 100644 --- a/src/airunner/aihandler/vision.py +++ b/src/airunner/aihandler/vision_handler.py @@ -1,7 +1,8 @@ from PyQt6.QtCore import QObject +from airunner.aihandler.base_handler import BaseHandler from airunner.mediator_mixin import MediatorMixin -class Vision(QObject, MediatorMixin): +class VisionHandler(BaseHandler): pass \ No newline at end of file diff --git a/src/airunner/workers/audio_processor_worker.py b/src/airunner/workers/audio_processor_worker.py index 941a35dcf..b0dca87f4 100644 --- a/src/airunner/workers/audio_processor_worker.py +++ b/src/airunner/workers/audio_processor_worker.py @@ -1,4 +1,4 @@ -from airunner.aihandler.speech_to_text import STTHandler +from airunner.aihandler.stt_handler import STTHandler from airunner.workers.worker import Worker diff --git a/src/airunner/workers/tts_generator_worker.py b/src/airunner/workers/tts_generator_worker.py index 57035ad5b..60c2cfd86 100644 --- a/src/airunner/workers/tts_generator_worker.py +++ b/src/airunner/workers/tts_generator_worker.py @@ -1,7 +1,7 @@ import time from airunner.workers.worker import Worker -from airunner.aihandler.tts import TTS +from airunner.aihandler.tts_handler import TTSHandler class TTSGeneratorWorker(Worker): @@ -10,7 +10,7 @@ class TTSGeneratorWorker(Worker): """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.tts = TTS() + self.tts = TTSHandler() self.tts.run() self.play_queue = [] self.play_queue_started = False