Skip to content

Commit

Permalink
Rename speech to text and make use of new BaseHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Jan 24, 2024
1 parent 2ee7011 commit b346635
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 116 deletions.
13 changes: 13 additions & 0 deletions src/airunner/aihandler/base_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from PyQt6.QtCore import QObject

from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.aihandler.logger import Logger


class BaseHandler(QObject, MediatorMixin, SettingsMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
SettingsMixin.__init__(self)
MediatorMixin.__init__(self)
self.logger = Logger(prefix=self.__class__.__name__)
9 changes: 2 additions & 7 deletions src/airunner/aihandler/llm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from transformers import InstructBlipForConditionalGeneration
from transformers import InstructBlipProcessor
from transformers import TextIteratorStreamer

from PyQt6.QtCore import QObject
from airunner.aihandler.base_handler import BaseHandler

from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin


class LLMHandler(QObject, MediatorMixin):
class LLMHandler(BaseHandler):
logger = Logger(prefix="LLMHandler")
dtype = ""
local_files_only = True
Expand Down Expand Up @@ -78,10 +77,6 @@ def has_gpu(self):
if self.dtype == "32bit" or not self.use_gpu:
return False
return torch.cuda.is_available()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
MediatorMixin.__init__(self)

def move_to_cpu(self):
if self.model:
Expand Down
11 changes: 3 additions & 8 deletions src/airunner/aihandler/sd_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from diffusers import StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, AsymmetricAutoencoderKL
from diffusers import ConsistencyDecoderVAE
from transformers import AutoFeatureExtractor
from airunner.aihandler.base_handler import BaseHandler

from airunner.aihandler.enums import FilterType
from airunner.aihandler.mixins.compel_mixin import CompelMixin
Expand All @@ -40,16 +41,14 @@
from airunner.windows.main.pipeline_mixin import PipelineMixin
from airunner.windows.main.controlnet_model_mixin import ControlnetModelMixin
from airunner.windows.main.ai_model_mixin import AIModelMixin
from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.service_locator import ServiceLocator
from airunner.utils import clear_memory

torch.backends.cuda.matmul.allow_tf32 = True


class SDHandler(
QObject,
BaseHandler,
MergeMixin,
LoraMixin,
MemoryEfficientMixin,
Expand All @@ -59,14 +58,12 @@ class SDHandler(
SchedulerMixin,

# Data Mixins
SettingsMixin,
LayerMixin,
LoraDataMixin,
EmbeddingDataMixin,
PipelineMixin,
ControlnetModelMixin,
AIModelMixin,
MediatorMixin
):
logger = Logger(prefix="SDHandler")
_current_model: str = ""
Expand Down Expand Up @@ -720,15 +717,13 @@ def original_model_data(self):

def __init__(self, **kwargs):
#self.logger.set_level(LOG_LEVEL)
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
super().__init__()
LayerMixin.__init__(self)
LoraDataMixin.__init__(self)
EmbeddingDataMixin.__init__(self)
PipelineMixin.__init__(self)
ControlnetModelMixin.__init__(self)
AIModelMixin.__init__(self)
super().__init__()
self.logger.info("Loading Stable Diffusion model runner...")
self.safety_checker_model = self.models_by_pipeline_action("safety_checker")
self.text_encoder_model = self.models_by_pipeline_action("text_encoder")
Expand Down
78 changes: 0 additions & 78 deletions src/airunner/aihandler/speech_to_text.py

This file was deleted.

14 changes: 5 additions & 9 deletions src/airunner/aihandler/stt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@

from transformers import AutoProcessor, WhisperForConditionalGeneration, AutoFeatureExtractor

from PyQt6.QtCore import QObject
from airunner.aihandler.base_handler import BaseHandler

from PyQt6.QtCore import pyqtSignal

from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin


class STTHandler(QObject, MediatorMixin):
logger = Logger(prefix="STTHandler")
class STTHandler(BaseHandler):
listening = False
move_to_cpu_signal = pyqtSignal()

Expand All @@ -26,9 +23,8 @@ def on_move_to_cpu(self):
self.logger.info("Moving model to CPU")
self.model = self.model.to("cpu")

def __init__(self):
super().__init__()
MediatorMixin.__init__(self)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_model()
self.register("move_to_cpu_signal", self)
self.register("process_audio", self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,21 @@

from queue import Queue

from PyQt6.QtCore import QObject, pyqtSlot
from PyQt6.QtCore import pyqtSlot

from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, BarkModel, BarkProcessor
from datasets import load_dataset
from airunner.aihandler.base_handler import BaseHandler

from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin


class TTS(QObject, MediatorMixin, SettingsMixin):
class TTSHandler(BaseHandler):
"""
Generates speech from given text.
Responsible for managing the model, processor, vocoder, and speaker embeddings.
Generates using either the SpeechT5 or Bark model.
Use from a worker to avoid blocking the main thread.
"""
logger = Logger(prefix="TTS")
character_replacement_map = {
"\n": " ",
"’": "'",
Expand Down Expand Up @@ -130,9 +126,7 @@ def sentence_chunks(self):
return self.tts_settings["sentence_chunks"]

def __init__(self, *args, **kwargs):
super().__init__()
SettingsMixin.__init__(self)
MediatorMixin.__init__(self)
super().__init__(*args, **kwargs)
self.logger.info("Loading")
self.corpus = []
self.processor = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from PyQt6.QtCore import QObject
from airunner.aihandler.base_handler import BaseHandler

from airunner.mediator_mixin import MediatorMixin


class Vision(QObject, MediatorMixin):
class VisionHandler(BaseHandler):
pass
2 changes: 1 addition & 1 deletion src/airunner/workers/audio_processor_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from airunner.aihandler.speech_to_text import STTHandler
from airunner.aihandler.stt_handler import STTHandler
from airunner.workers.worker import Worker


Expand Down
4 changes: 2 additions & 2 deletions src/airunner/workers/tts_generator_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

from airunner.workers.worker import Worker
from airunner.aihandler.tts import TTS
from airunner.aihandler.tts_handler import TTSHandler


class TTSGeneratorWorker(Worker):
Expand All @@ -10,7 +10,7 @@ class TTSGeneratorWorker(Worker):
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tts = TTS()
self.tts = TTSHandler()
self.tts.run()
self.play_queue = []
self.play_queue_started = False
Expand Down

0 comments on commit b346635

Please sign in to comment.