Skip to content

Commit

Permalink
Merge pull request #395 from Capsize-Games/develop
Browse files Browse the repository at this point in the history
Code maintenence
  • Loading branch information
w4ffl35 authored Jan 24, 2024
2 parents d4c6046 + b346635 commit 49b8653
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 73 deletions.
13 changes: 13 additions & 0 deletions src/airunner/aihandler/base_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from PyQt6.QtCore import QObject

from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.aihandler.logger import Logger


class BaseHandler(QObject, MediatorMixin, SettingsMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
SettingsMixin.__init__(self)
MediatorMixin.__init__(self)
self.logger = Logger(prefix=self.__class__.__name__)
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from transformers import InstructBlipForConditionalGeneration
from transformers import InstructBlipProcessor
from transformers import TextIteratorStreamer

from PyQt6.QtCore import QObject
from airunner.aihandler.base_handler import BaseHandler

from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin


class LLM(QObject, MediatorMixin):
logger = Logger(prefix="LLM")
class LLMHandler(BaseHandler):
logger = Logger(prefix="LLMHandler")
dtype = ""
local_files_only = True
set_attention_mask = False
Expand Down Expand Up @@ -78,10 +77,6 @@ def has_gpu(self):
if self.dtype == "32bit" or not self.use_gpu:
return False
return torch.cuda.is_available()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
MediatorMixin.__init__(self)

def move_to_cpu(self):
if self.model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from diffusers import StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, AsymmetricAutoencoderKL
from diffusers import ConsistencyDecoderVAE
from transformers import AutoFeatureExtractor
from airunner.aihandler.base_handler import BaseHandler

from airunner.aihandler.enums import FilterType
from airunner.aihandler.mixins.compel_mixin import CompelMixin
Expand All @@ -40,16 +41,14 @@
from airunner.windows.main.pipeline_mixin import PipelineMixin
from airunner.windows.main.controlnet_model_mixin import ControlnetModelMixin
from airunner.windows.main.ai_model_mixin import AIModelMixin
from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.service_locator import ServiceLocator
from airunner.utils import clear_memory

torch.backends.cuda.matmul.allow_tf32 = True


class SDRunner(
QObject,
class SDHandler(
BaseHandler,
MergeMixin,
LoraMixin,
MemoryEfficientMixin,
Expand All @@ -59,16 +58,14 @@ class SDRunner(
SchedulerMixin,

# Data Mixins
SettingsMixin,
LayerMixin,
LoraDataMixin,
EmbeddingDataMixin,
PipelineMixin,
ControlnetModelMixin,
AIModelMixin,
MediatorMixin
):
logger = Logger(prefix="SDRunner")
logger = Logger(prefix="SDHandler")
_current_model: str = ""
_previous_model: str = ""
_initialized: bool = False
Expand Down Expand Up @@ -720,15 +717,13 @@ def original_model_data(self):

def __init__(self, **kwargs):
#self.logger.set_level(LOG_LEVEL)
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
super().__init__()
LayerMixin.__init__(self)
LoraDataMixin.__init__(self)
EmbeddingDataMixin.__init__(self)
PipelineMixin.__init__(self)
ControlnetModelMixin.__init__(self)
AIModelMixin.__init__(self)
super().__init__()
self.logger.info("Loading Stable Diffusion model runner...")
self.safety_checker_model = self.models_by_pipeline_action("safety_checker")
self.text_encoder_model = self.models_by_pipeline_action("text_encoder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@

from transformers import AutoProcessor, WhisperForConditionalGeneration, AutoFeatureExtractor

from PyQt6.QtCore import QObject
from airunner.aihandler.base_handler import BaseHandler

from PyQt6.QtCore import pyqtSignal

from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin


class SpeechToText(QObject, MediatorMixin):
logger = Logger(prefix="SpeechToText")
class STTHandler(BaseHandler):
listening = False
move_to_cpu_signal = pyqtSignal()

Expand All @@ -26,9 +23,8 @@ def on_move_to_cpu(self):
self.logger.info("Moving model to CPU")
self.model = self.model.to("cpu")

def __init__(self):
super().__init__()
MediatorMixin.__init__(self)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_model()
self.register("move_to_cpu_signal", self)
self.register("process_audio", self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,21 @@

from queue import Queue

from PyQt6.QtCore import QObject, pyqtSlot
from PyQt6.QtCore import pyqtSlot

from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, BarkModel, BarkProcessor
from datasets import load_dataset
from airunner.aihandler.base_handler import BaseHandler

from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin


class TTS(QObject, MediatorMixin, SettingsMixin):
class TTSHandler(BaseHandler):
"""
Generates speech from given text.
Responsible for managing the model, processor, vocoder, and speaker embeddings.
Generates using either the SpeechT5 or Bark model.
Use from a worker to avoid blocking the main thread.
"""
logger = Logger(prefix="TTS")
character_replacement_map = {
"\n": " ",
"’": "'",
Expand Down Expand Up @@ -130,9 +126,7 @@ def sentence_chunks(self):
return self.tts_settings["sentence_chunks"]

def __init__(self, *args, **kwargs):
super().__init__()
SettingsMixin.__init__(self)
MediatorMixin.__init__(self)
super().__init__(*args, **kwargs)
self.logger.info("Loading")
self.corpus = []
self.processor = None
Expand Down
8 changes: 8 additions & 0 deletions src/airunner/aihandler/vision_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from PyQt6.QtCore import QObject
from airunner.aihandler.base_handler import BaseHandler

from airunner.mediator_mixin import MediatorMixin


class VisionHandler(BaseHandler):
pass
4 changes: 2 additions & 2 deletions src/airunner/windows/main/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airunner.windows.main.controlnet_model_mixin import ControlnetModelMixin
from airunner.windows.main.ai_model_mixin import AIModelMixin
from airunner.windows.main.image_filter_mixin import ImageFilterMixin
from airunner.aihandler.engine import Engine
from airunner.worker_manager import WorkerManager
from airunner.mediator_mixin import MediatorMixin
from airunner.service_locator import ServiceLocator

Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(self, *args, **kwargs):
ServiceLocator.register("get_callback_for_slider", self.get_callback_for_slider)


self.engine = Engine()
self.engine = WorkerManager()

self.ui.setupUi(self)

Expand Down
46 changes: 22 additions & 24 deletions src/airunner/aihandler/engine.py → src/airunner/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from airunner.workers.sd_generate_worker import SDGenerateWorker
from airunner.workers.sd_request_worker import SDRequestWorker
from airunner.aihandler.logger import Logger
from airunner.aihandler.tts import TTS
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.service_locator import ServiceLocator
from airunner.utils import clear_memory
from airunner.workers.vision_capture_worker import VisionCaptureWorker
from airunner.workers.vision_processor_worker import VisionProcessorWorker


class Message:
Expand All @@ -29,7 +30,7 @@ def __init__(self, *args, **kwargs):
self.conversation = kwargs.get("conversation")


class Engine(QObject, MediatorMixin, SettingsMixin):
class WorkerManager(QObject, MediatorMixin, SettingsMixin):
"""
The engine is responsible for processing requests and offloading
them to the appropriate AI model controller.
Expand All @@ -51,12 +52,12 @@ def do_response(self, response):
Handle a response from the application by putting it into
a response worker queue.
"""
self.response_worker.add_to_queue(response)
self.engine_response_worker.add_to_queue(response)

def on_engine_cancel_signal(self, _ignore):
self.logger.info("Canceling")
self.emit("sd_cancel_signal")
self.request_worker.cancel()
self.engine_request_worker.cancel()

def on_engine_stop_processing_queue_signal(self):
self.do_process_queue = False
Expand Down Expand Up @@ -87,11 +88,6 @@ def __init__(self, **kwargs):
self.logger = Logger(prefix="Engine")
self.clear_memory()

# Initialize Controllers
#self.stt_controller = STTController(engine=self)
# self.ocr_controller = ImageProcessor(engine=self)
self.tts_controller = TTS(engine=self)

self.register("hear_signal", self)
self.register("engine_cancel_signal", self)
self.register("engine_stop_processing_queue_signal", self)
Expand All @@ -107,30 +103,32 @@ def __init__(self, **kwargs):
self.register("image_generate_request_signal", self)
self.register("llm_response_signal", self)
self.register("llm_text_streamed_signal", self)
self.register("AudioCaptureWorker_response_signal", self)
self.register("AudioProcessorWorker_processed_audio", self)

self.sd_request_worker = self.create_worker(SDRequestWorker)
self.sd_generate_worker = self.create_worker(SDGenerateWorker)

self.request_worker = self.create_worker(EngineRequestWorker)
self.response_worker = self.create_worker(EngineResponseWorker)
self.engine_request_worker = self.create_worker(EngineRequestWorker)
self.engine_response_worker = self.create_worker(EngineResponseWorker)

self.generator_worker = self.create_worker(TTSGeneratorWorker)
self.vocalizer_worker = self.create_worker(TTSVocalizerWorker)
self.tts_generator_worker = self.create_worker(TTSGeneratorWorker)
self.tts_vocalizer_worker = self.create_worker(TTSVocalizerWorker)

self.request_worker = self.create_worker(LLMRequestWorker)
self.generate_worker = self.create_worker(LLMGenerateWorker)
self.llm_request_worker = self.create_worker(LLMRequestWorker)
self.llm_generate_worker = self.create_worker(LLMGenerateWorker)

self.audio_capture_worker = self.create_worker(AudioCaptureWorker)
self.audio_processor_worker = self.create_worker(AudioProcessorWorker)
self.stt_audio_capture_worker = self.create_worker(AudioCaptureWorker)
self.stt_audio_processor_worker = self.create_worker(AudioProcessorWorker)

self.register("AudioCaptureWorker_response_signal", self)
self.register("AudioProcessorWorker_processed_audio", self)
self.vision_capture_worker = self.create_worker(VisionCaptureWorker)
self.vision_processor_worker = self.create_worker(VisionProcessorWorker)

self.register("tts_request", self)

def on_AudioCaptureWorker_response_signal(self, message: np.ndarray):
self.logger.info("Heard signal")
self.audio_processor_worker.add_to_queue(message)
self.stt_audio_processor_worker.add_to_queue(message)

def on_AudioProcessorWorker_processed_audio(self, message: np.ndarray):
self.logger.info("Processed audio")
Expand All @@ -140,7 +138,7 @@ def on_LLMGenerateWorker_response_signal(self, message:dict):
self.emit("llm_response_signal", message)

def on_tts_request(self, data: dict):
self.generator_worker.add_to_queue(data)
self.tts_generator_worker.add_to_queue(data)

def on_llm_response_signal(self, message):
self.do_response(message)
Expand Down Expand Up @@ -195,7 +193,7 @@ def do_image_generate_request(self, message):
))

def request_queue_size(self):
return self.request_worker.queue.qsize()
return self.engine_request_worker.queue.qsize()

def do_listen(self):
# self.stt_controller.do_listen()
Expand Down Expand Up @@ -227,8 +225,8 @@ def on_clear_llm_history_signal(self):

def stop(self):
self.logger.info("Stopping")
self.request_worker.stop()
self.response_worker.stop()
self.engine_request_worker.stop()
self.engine_response_worker.stop()

def move_sd_to_cpu(self):
if ServiceLocator.get("is_pipe_on_cpu")() or not ServiceLocator.get("has_pipe")():
Expand Down
4 changes: 2 additions & 2 deletions src/airunner/workers/audio_processor_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from airunner.aihandler.speech_to_text import SpeechToText
from airunner.aihandler.stt_handler import STTHandler
from airunner.workers.worker import Worker


Expand All @@ -11,7 +11,7 @@ class AudioProcessorWorker(Worker):

def __init__(self, prefix):
super().__init__(prefix=prefix)
self.stt = SpeechToText()
self.stt = STTHandler()
self.register("stt_audio_processed", self)

def on_stt_audio_processed(self, transcription):
Expand Down
4 changes: 2 additions & 2 deletions src/airunner/workers/llm_generate_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from airunner.workers.worker import Worker
from airunner.aihandler.llm import LLM
from airunner.aihandler.llm_handler import LLMHandler


class LLMGenerateWorker(Worker):
def __init__(self, prefix="LLMGenerateWorker"):
self.llm = LLM()
self.llm = LLMHandler()
super().__init__(prefix=prefix)
self.register("clear_history", self)
self.register("LLMRequestWorker_response_signal", self)
Expand Down
4 changes: 2 additions & 2 deletions src/airunner/workers/sd_generate_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

from airunner.aihandler.enums import EngineResponseCode
from airunner.workers.worker import Worker
from airunner.aihandler.runner import SDRunner
from airunner.aihandler.sd_handler import SDHandler

torch.backends.cuda.matmul.allow_tf32 = True


class SDGenerateWorker(Worker):
def __init__(self, prefix="SDGenerateWorker"):
super().__init__(prefix=prefix)
self.sd = SDRunner()
self.sd = SDHandler()
self.register("add_sd_response_to_queue_signal", self)

def on_add_sd_response_to_queue_signal(self, request):
Expand Down
5 changes: 2 additions & 3 deletions src/airunner/workers/tts_generator_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

from airunner.workers.worker import Worker
from airunner.aihandler.tts import TTS
from airunner.aihandler.tts_handler import TTSHandler


class TTSGeneratorWorker(Worker):
Expand All @@ -10,7 +10,7 @@ class TTSGeneratorWorker(Worker):
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tts = TTS()
self.tts = TTSHandler()
self.tts.run()
self.play_queue = []
self.play_queue_started = False
Expand Down Expand Up @@ -49,7 +49,6 @@ def generate(self, message):
else:
response = self.generate_with_t5(text)

print("adding to stream", response)
self.emit("TTSGeneratorWorker_add_to_stream_signal", response)

def move_inputs_to_device(self, inputs):
Expand Down
Loading

0 comments on commit 49b8653

Please sign in to comment.