Skip to content

Commit

Permalink
Merge pull request #386 from Capsize-Games/develop
Browse files Browse the repository at this point in the history
Fixes for speech to text
  • Loading branch information
w4ffl35 authored Jan 19, 2024
2 parents 67aad9f + da78110 commit 57ab2be
Show file tree
Hide file tree
Showing 152 changed files with 2,089 additions and 5,780 deletions.
311 changes: 112 additions & 199 deletions src/airunner/aihandler/engine.py

Large diffs are not rendered by default.

23 changes: 0 additions & 23 deletions src/airunner/aihandler/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,6 @@ class FilterType(Enum):
PIXEL_ART = "pixelart"


class EngineResponseCode(Enum):
STATUS = 100
ERROR = 200
WARNING = 300
PROGRESS = 400
IMAGE_GENERATED = 500
CONTROLNET_IMAGE_GENERATED = 501
MASK_IMAGE_GENERATED = 502
EMBEDDING_LOAD_FAILED = 600
TEXT_GENERATED = 700
TEXT_STREAMED = 701
CAPTION_GENERATED = 800
ADD_TO_CONVERSATION = 900
CLEAR_MEMORY = 1000
NSFW_CONTENT_DETECTED = 1100


class EngineRequestCode(Enum):
GENERATE_IMAGE = 100
GENERATE_TEXT = 200
GENERATE_CAPTION = 300


class Scheduler(Enum):
EULER_ANCESTRAL = "Euler a"
EULER = "Euler"
Expand Down
86 changes: 47 additions & 39 deletions src/airunner/aihandler/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,45 @@
from transformers import TextIteratorStreamer

from PyQt6.QtCore import QObject
from PyQt6.QtCore import pyqtSignal, pyqtSlot, QThread

from airunner.aihandler.enums import EngineResponseCode
from airunner.aihandler.logger import Logger
from airunner.workers.worker import Worker
from airunner.mediator_mixin import MediatorMixin

class GenerateWorker(Worker):
def __init__(self, prefix):
class LLMGenerateWorker(Worker):
def __init__(self, prefix="LLMGenerateWorker"):
self.llm = LLM()
super().__init__(prefix=prefix)
self.register("clear_history", self)

def handle_message(self, message):
for response in self.llm.do_generate(message):
self.response_signal.emit(response)
self.emit("llm_text_streamed_signal", response)

def on_clear_history(self):
self.llm.clear_history()


class LLMRequestWorker(Worker):
def __init__(self, prefix="LLMRequestWorker"):
super().__init__(prefix=prefix)

def handle_message(self, message):
super().handle_message(message)

class LLMController(QObject):
logger = Logger(prefix="LLMController")
response_signal = pyqtSignal(dict)

class LLMController(QObject, MediatorMixin):

def __init__(self, *args, **kwargs):
self.engine = kwargs.pop("engine", None)
self.app = self.engine.app
MediatorMixin.__init__(self)
self.engine = kwargs.pop("engine")
super().__init__(*args, **kwargs)
self.logger = Logger(prefix="LLMController")

self.request_worker = Worker(prefix="LLM Request Worker")
self.request_worker_thread = QThread()
self.request_worker.moveToThread(self.request_worker_thread)
self.request_worker.response_signal.connect(self.request_worker_response_signal_slot)
self.request_worker.finished.connect(self.request_worker_thread.quit)
self.request_worker_thread.started.connect(self.request_worker.start)
self.request_worker_thread.start()

self.generate_worker = GenerateWorker(prefix="LLM Generate Worker")
self.generate_worker_thread = QThread()
self.generate_worker.moveToThread(self.generate_worker_thread)
self.generate_worker.response_signal.connect(self.generate_worker_response_signal_slot)
self.generate_worker.finished.connect(self.generate_worker_thread.quit)
self.generate_worker_thread.started.connect(self.generate_worker.start)
self.generate_worker_thread.start()
self.request_worker = self.create_worker(LLMRequestWorker)
self.generate_worker = self.create_worker(LLMGenerateWorker)
self.register("LLMRequestWorker_response_signal", self)
self.register("LLMGenerateWorker_response_signal", self)

def pause(self):
self.request_worker.pause()
Expand All @@ -63,20 +62,21 @@ def resume(self):
def do_request(self, message):
self.request_worker.add_to_queue(message)

@pyqtSlot(dict)
def request_worker_response_signal_slot(self, message):
def clear_history(self):
self.emit("clear_history")

def on_LLMRequestWorker_response_signal(self, message):
self.generate_worker.add_to_queue(message)

@pyqtSlot(dict)
def generate_worker_response_signal_slot(self, message):
self.response_signal.emit(message)
def on_LLMGenerateWorker_response_signal(self, message:dict):
self.emit("llm_controller_response_signal", message)

def do_unload_llm(self):
self.generate_worker.llm.unload_model()
self.generate_worker.llm.unload_tokenizer()


class LLM(QObject):
class LLM(QObject, MediatorMixin):
logger = Logger(prefix="LLM")
dtype = ""
local_files_only = True
Expand Down Expand Up @@ -138,11 +138,11 @@ def has_gpu(self):
if self.dtype == "32bit" or not self.use_gpu:
return False
return torch.cuda.is_available()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self.llm_api = LLMAPI(app=app)
MediatorMixin.__init__(self)

def move_to_cpu(self):
if self.model:
self.logger.info("Moving model to CPU")
Expand Down Expand Up @@ -211,7 +211,6 @@ def load_model(self, local_files_only = None):
params["quantization_config"] = config

path = self.current_model_path
# self.engine.send_message(f"Loading {self.requested_generator_name} model from {path}")

auto_class_ = None
if self.requested_generator_name == "seq2seq":
Expand Down Expand Up @@ -508,6 +507,8 @@ def generate(self):
n = 0
streamed_template = ""
replaced = False
is_end_of_message = False
is_first_message = True
for new_text in self.streamer:
# strip all newlines from new_text
parsed_new_text = new_text.replace("\n", " ")
Expand All @@ -532,15 +533,22 @@ def generate(self):
replaced = True
streamed_template = streamed_template.replace(rendered_template, "")
else:
if "</s>" in new_text:
streamed_template = streamed_template.replace("</s>", "")
new_text = new_text.replace("</s>", "")
is_end_of_message = True
yield dict(
code=EngineResponseCode.TEXT_STREAMED,
message=new_text
message=new_text,
is_first_message=is_first_message,
is_end_of_message=is_end_of_message,
name=self.botname,
)
is_first_message = False

if "</s>" in new_text:
if is_end_of_message:
self.history.append({
"role": "bot",
"content": streamed_template.replace("</s>", "").strip()
"content": streamed_template.strip()
})
streamed_template = ""
replaced = False
Expand Down
103 changes: 0 additions & 103 deletions src/airunner/aihandler/llm_api.py

This file was deleted.

9 changes: 4 additions & 5 deletions src/airunner/aihandler/mixins/embedding_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from airunner.aihandler.logger import Logger as logger
from airunner.aihandler.enums import EngineResponseCode


class EmbeddingMixin:
Expand All @@ -25,10 +24,10 @@ def load_learned_embed_in_clip(self):
self.pipe.load_textual_inversion(path, token=token, weight_name=f)
except Exception as e:
if "already in tokenizer vocabulary" not in str(e):
self.send_message({
"embedding_name": token,
"model_name": self.model,
}, EngineResponseCode.EMBEDDING_LOAD_FAILED)
self.emit("embedding_load_failed_signal", dict(
embedding_name=token,
model_name=self.model,
))
logger.warning(e)
except AttributeError as e:
if "load_textual_inversion" in str(e):
Expand Down
15 changes: 14 additions & 1 deletion src/airunner/aihandler/mixins/merge_mixin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import os

from PyQt6.QtCore import pyqtSlot

from airunner.aihandler.logger import Logger

logger = Logger(prefix="MergeMixin")


class MergeMixin:
def merge_models(self, base_model_path, models_to_merge_path, weights, output_path, name, action):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register("sd_merge_models_signal", self)

@pyqtSlot(object)
def on_sd_merge_models_signal(self, options):
print("TODO: on_sd_merge_models_signal")

@pyqtSlot(object)
def merge_models(self, options):
base_model_path, models_to_merge_path, weights, output_path, name, action = options
from diffusers import (
StableDiffusionPipeline,
StableDiffusionInstructPix2PixPipeline,
Expand Down
3 changes: 1 addition & 2 deletions src/airunner/aihandler/mixins/scheduler_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def change_scheduler(self):
def prepare_scheduler(self):
scheduler_name = self.options.get(f"scheduler", "euler_a")
if self.scheduler_name != scheduler_name:
logger.info(f"Prepare scheduler {scheduler_name}")
self.send_message("Preparing scheduler...")
self.emit("status_signal", f"Preparing scheduler {scheduler_name}")
self.scheduler_name = scheduler_name
self.do_change_scheduler = True
else:
Expand Down
Loading

0 comments on commit 57ab2be

Please sign in to comment.