Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code cleanup #393

Merged
merged 3 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ src/airunner/extensions/
alembic.ini
src/airunner/scripts/realesrgan
src/airunner/scripts/weights
gfpgan
gfpgan
test.json
104 changes: 9 additions & 95 deletions src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,19 @@
from airunner.mediator_mixin import MediatorMixin
from airunner.workers.tts_generator_worker import TTSGeneratorWorker
from airunner.workers.tts_vocalizer_worker import TTSVocalizerWorker
from airunner.workers.worker import Worker
from airunner.aihandler.llm import LLMGenerateWorker, LLMRequestWorker
from airunner.workers.llm_request_worker import LLMRequestWorker
from airunner.workers.llm_generate_worker import LLMGenerateWorker
from airunner.workers.engine_request_worker import EngineRequestWorker
from airunner.workers.engine_response_worker import EngineResponseWorker
from airunner.workers.sd_generate_worker import SDGenerateWorker
from airunner.workers.sd_request_worker import SDRequestWorker
from airunner.aihandler.logger import Logger
from airunner.aihandler.runner import SDGenerateWorker, SDRequestWorker
from airunner.aihandler.tts import TTS
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.service_locator import ServiceLocator
from airunner.utils import clear_memory


class EngineRequestWorker(Worker):
def __init__(self, prefix="EngineRequestWorker"):
super().__init__(prefix=prefix)
self.register("engine_do_request_signal", self)

def on_engine_do_request_signal(self, request):
self.logger.info("Adding to queue")
self.add_to_queue(request)

def handle_message(self, request):
if request["code"] == EngineRequestCode.GENERATE_IMAGE:
self.emit("sd_request_signal", request)
else:
self.logger.error(f"Unknown code: {request['code']}")


class EngineResponseWorker(Worker):
def __init__(self, prefix="EngineResponseWorker"):
super().__init__(prefix=prefix)
self.register("engine_do_response_signal", self)

def on_engine_do_response_signal(self, request):
self.logger.info("Adding to queue")
self.add_to_queue(request)


class Message:
def __init__(self, *args, **kwargs):
self.name = kwargs.get("name")
Expand All @@ -63,11 +40,6 @@ class Engine(QObject, MediatorMixin, SettingsMixin):
llm_loaded: bool = False
sd_loaded: bool = False

# Model controllers
llm_controller = None
stt_controller = None
ocr_controller = None

message = ""
current_message = ""

Expand Down Expand Up @@ -130,7 +102,7 @@ def __init__(self, **kwargs):
self.register("EngineResponseWorker_response_signal", self)
self.register("text_generate_request_signal", self)
self.register("image_generate_request_signal", self)
self.register("llm_controller_response_signal", self)
self.register("llm_response_signal", self)
self.register("llm_text_streamed_signal", self)

self.sd_request_worker = self.create_worker(SDRequestWorker)
Expand All @@ -148,12 +120,12 @@ def __init__(self, **kwargs):
self.register("tts_request", self)

def on_LLMGenerateWorker_response_signal(self, message:dict):
self.emit("llm_controller_response_signal", message)
self.emit("llm_response_signal", message)

def on_tts_request(self, data: dict):
self.generator_worker.add_to_queue(data)

def on_llm_controller_response_signal(self, message):
def on_llm_response_signal(self, message):
self.do_response(message)

def EngineRequestWorker_handle_default(self, message):
Expand Down Expand Up @@ -208,64 +180,6 @@ def do_image_generate_request(self, message):
def request_queue_size(self):
return self.request_worker.queue.qsize()

# def generator_sample(self, data: dict):
# """
# This function will determine if the request
# :param data:
# :return:
# """
# self.logger.info("generator_sample called")
# self.llm_generator_sample(data)
# self.tts_generator_sample(data)
# self.sd_generator_sample(data)

# def llm_generator_sample(self, data: dict):
# if "llm_request" not in data or not self.llm:
# return
# if not self.llm_loaded:
# self.logger.info("Preparing LLM")
# # if self.tts:
# # self.tts_controller.move_model(to_cpu=False)
# self.llm_loaded = True
# do_unload_model = data["request_data"].get("unload_unused_model", False)
# do_move_to_cpu = not do_unload_model and data["request_data"].get("move_unused_model_to_cpu", False)
# if do_move_to_cpu:
# self.move_sd_to_cpu()
# elif do_unload_model:
# self.sd_controller.unload()
# self.logger.info("Engine calling llm.do_generate")
# self.llm_controller.do_generate(data)

# def tts_generator_sample(self, data: dict):
# if "tts_request" not in data or not self.tts:
# return
# self.logger.info("Preparing TTS model...")
# # self.tts_controller.move_model(to_cpu=False)
# signal = data["request_data"].get("signal", None)
# message_object = data["request_data"].get("message_object", None)
# is_bot = data["request_data"].get("is_bot", False)
# first_message = data["request_data"].get("first_message", None)
# last_message = data["request_data"].get("last_message", None)
# if data["request_data"]["tts_settings"]["enable_tts"]:
# text = data["request_data"]["text"]
# # check if ends with a proper sentence ender, if not, add a period
# if not text.endswith((".", "?", "!", "...", "-", "—", )):
# text += "."
# generator = self.tts_controller.add_text(text, "a", data["request_data"]["tts_settings"])
# for success in generator:
# if signal and success:
# signal.emit(message_object, is_bot, first_message, last_message)

# def sd_generator_sample(self, data:dict):
# if "options" not in data or "sd_request" not in data["options"] or not self.sd:
# return
# if not self.sd_loaded:
# self.logger.info("Preparing Stable Diffusion")
# self.sd_loaded = True
# self.do_unload_llm()
# self.logger.info("Engine calling sd.generator_sample")
# self.sd_controller.generator_sample(data)

def do_listen(self):
# self.stt_controller.do_listen()
pass
Expand Down
61 changes: 0 additions & 61 deletions src/airunner/aihandler/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,69 +13,8 @@
from PyQt6.QtCore import QObject

from airunner.aihandler.logger import Logger
from airunner.workers.worker import Worker
from airunner.mediator_mixin import MediatorMixin

class LLMGenerateWorker(Worker):
def __init__(self, prefix="LLMGenerateWorker"):
self.llm = LLM()
super().__init__(prefix=prefix)
self.register("clear_history", self)
self.register("LLMRequestWorker_response_signal", self)
self.register("unload_llm_signal", self)

def on_unload_llm_signal(self, message):
"""
This function will either
1. Leave the LLM on the GPU
2. Move it to the CPU
3. Unload it from memory
The choice is dependent on the current dtype and other settings.
"""
do_unload_model = message.get("do_unload_model", False)
move_unused_model_to_cpu = message.get("move_unused_model_to_cpu", False)
do_move_to_cpu = not do_unload_model and move_unused_model_to_cpu
dtype = message.get("dtype", "")
callback = message.get("callback", None)
if dtype in ["2bit", "4bit", "8bit"]:
do_unload_model = True
do_move_to_cpu = False
if do_move_to_cpu:
self.logger.info("Moving LLM to CPU")
self.llm.move_to_cpu()
elif do_unload_model:
self.llm.unload()
if callback:
callback()

def on_LLMRequestWorker_response_signal(self, message):
self.add_to_queue(message)

def handle_message(self, message):
for response in self.llm.do_generate(message):
self.emit("llm_text_streamed_signal", response)

def on_clear_history(self):
self.llm.clear_history()

def unload_llm(self):
self.llm.unload()


class LLMRequestWorker(Worker):
def __init__(self, prefix="LLMRequestWorker"):
super().__init__(prefix=prefix)
self.register("llm_request_signal", self)

def on_llm_request_signal(self, message):
print("adding llm request to queue", message)
self.add_to_queue(message)

def handle_message(self, message):
super().handle_message(message)


class LLM(QObject, MediatorMixin):
logger = Logger(prefix="LLM")
Expand Down
1 change: 0 additions & 1 deletion src/airunner/aihandler/mixins/merge_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register("sd_merge_models_signal", self)

@pyqtSlot(object)
def on_sd_merge_models_signal(self, options):
print("TODO: on_sd_merge_models_signal")

Expand Down
74 changes: 2 additions & 72 deletions src/airunner/aihandler/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import base64
import re
import traceback
Expand All @@ -10,7 +9,7 @@
import torch

from PIL import Image, ImageDraw, ImageFont
from PyQt6.QtCore import QObject, pyqtSlot
from PyQt6.QtCore import QObject

from controlnet_aux.processor import Processor
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import \
Expand All @@ -24,7 +23,7 @@
from diffusers import ConsistencyDecoderVAE
from transformers import AutoFeatureExtractor

from airunner.aihandler.enums import EngineRequestCode, EngineResponseCode, FilterType
from airunner.aihandler.enums import FilterType
from airunner.aihandler.mixins.compel_mixin import CompelMixin
from airunner.aihandler.mixins.embedding_mixin import EmbeddingMixin
from airunner.aihandler.mixins.lora_mixin import LoraMixin
Expand All @@ -41,7 +40,6 @@
from airunner.windows.main.pipeline_mixin import PipelineMixin
from airunner.windows.main.controlnet_model_mixin import ControlnetModelMixin
from airunner.windows.main.ai_model_mixin import AIModelMixin
from airunner.workers.worker import Worker
from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.service_locator import ServiceLocator
Expand All @@ -50,72 +48,6 @@
torch.backends.cuda.matmul.allow_tf32 = True


class SDRequestWorker(Worker):
def __init__(self, prefix="SDRequestWorker"):
super().__init__(prefix=prefix)
self.register("sd_request_signal", self)

def on_sd_request_signal(self, request):
self.logger.info("Request recieved")
self.add_to_queue(request["message"])

def handle_message(self, message):
self.logger.info("Handling message")
self.emit("add_sd_response_to_queue_signal", dict(
message=message,
image_base_path=self.path_settings["image_path"]
))


class SDGenerateWorker(Worker):
def __init__(self, prefix="SDGenerateWorker"):
super().__init__(prefix=prefix)
self.sd = SDRunner()
self.register("add_sd_response_to_queue_signal", self)

def on_add_sd_response_to_queue_signal(self, request):
self.logger.info("Request recieved")
self.add_to_queue(request)

def handle_message(self, data):
self.logger.info("Generating")
image_base_path = data["image_base_path"]
message = data["message"]
for response in self.sd.generator_sample(message):
print("RESPONSE FROM sd.generate_sample", response)
if not response:
continue

images = response['images']
data = response["data"]
nsfw_content_detected = response["nsfw_content_detected"]
if nsfw_content_detected:
self.emit("nsfw_content_detected_signal", response)
continue

seed = data["options"]["seed"]
updated_images = []
for index, image in enumerate(images):
# hash the prompt and negative prompt along with the action
action = data["action"]
prompt = data["options"]["prompt"][0]
negative_prompt = data["options"]["negative_prompt"][0]
prompt_hash = hash(f"{action}{prompt}{negative_prompt}{index}")
image_name = f"{prompt_hash}_{seed}.png"
image_path = os.path.join(image_base_path, image_name)
# save the image
image.save(image_path)
updated_images.append(dict(
path=image_path,
image=image
))
response["images"] = updated_images
self.emit("engine_do_response_signal", dict(
code=EngineResponseCode.IMAGE_GENERATED,
message=response
))


class SDRunner(
QObject,
MergeMixin,
Expand Down Expand Up @@ -1409,7 +1341,6 @@ def callback(self, step: int, _time_step, latents):
}
self.emit("progress_signal", res)

@pyqtSlot(object)
def on_unload_stablediffusion_signal(self):
self.unload()

Expand Down Expand Up @@ -1496,7 +1427,6 @@ def generator_sample(self, data: dict):
self._current_model = ""
self.local_files_only = True

@pyqtSlot(object)
def on_sd_cancel_signal(self):
self.do_cancel = True

Expand Down
4 changes: 1 addition & 3 deletions src/airunner/aihandler/tts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import time
import torch
import sounddevice as sd
import numpy as np

from queue import Queue

from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot, QThread
from PyQt6.QtCore import QObject, pyqtSlot

from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, BarkModel, BarkProcessor
from datasets import load_dataset
Expand Down
Loading