Skip to content

Commit

Permalink
Merge pull request #392 from Capsize-Games/develop
Browse files Browse the repository at this point in the history
Removes chat templates (unused)
  • Loading branch information
w4ffl35 authored Jan 24, 2024
2 parents 0ce6077 + c059a3b commit 9d69ccc
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 124 deletions.
12 changes: 0 additions & 12 deletions src/airunner/aihandler/chat_templates/conversation.j2

This file was deleted.

91 changes: 26 additions & 65 deletions src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import torch
import traceback
import gc

from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from PyQt6.QtCore import QObject, pyqtSignal
from airunner.aihandler.enums import EngineRequestCode, EngineResponseCode
from airunner.aihandler.logger import Logger
from airunner.mediator_mixin import MediatorMixin
from airunner.workers.tts_generator_worker import TTSGeneratorWorker
from airunner.workers.tts_vocalizer_worker import TTSVocalizerWorker
from airunner.workers.worker import Worker
from airunner.aihandler.llm import LLMController
from airunner.aihandler.llm import LLMGenerateWorker, LLMRequestWorker
from airunner.aihandler.logger import Logger
from airunner.aihandler.runner import SDGenerateWorker, SDRequestWorker
from airunner.aihandler.tts import TTS
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.service_locator import ServiceLocator
from airunner.utils import clear_memory


class EngineRequestWorker(Worker):
Expand Down Expand Up @@ -79,17 +78,14 @@ def do_response(self, response):
"""
self.response_worker.add_to_queue(response)

@pyqtSlot(object)
def on_engine_cancel_signal(self, _ignore):
self.logger.info("Canceling")
self.emit("sd_cancel_signal")
self.request_worker.cancel()

@pyqtSlot(object)
def on_engine_stop_processing_queue_signal(self):
self.do_process_queue = False

@pyqtSlot(object)
def on_engine_start_processing_queue_signal(self):
self.do_process_queue = True

Expand All @@ -103,7 +99,6 @@ def on_hear_signal(self, message):
def handle_generate_caption(self, message):
pass

@pyqtSlot(object)
def on_caption_generated_signal(self, message):
print("TODO: caption generated signal", message)

Expand All @@ -118,7 +113,6 @@ def __init__(self, **kwargs):
self.clear_memory()

# Initialize Controllers
self.llm_controller = LLMController(engine=self)
#self.stt_controller = STTController(engine=self)
# self.ocr_controller = ImageProcessor(engine=self)
self.tts_controller = TTS(engine=self)
Expand Down Expand Up @@ -147,9 +141,15 @@ def __init__(self, **kwargs):

self.generator_worker = self.create_worker(TTSGeneratorWorker)
self.vocalizer_worker = self.create_worker(TTSVocalizerWorker)

self.request_worker = self.create_worker(LLMRequestWorker)
self.generate_worker = self.create_worker(LLMGenerateWorker)

self.register("tts_request", self)

@pyqtSlot(dict)
def on_LLMGenerateWorker_response_signal(self, message:dict):
self.emit("llm_controller_response_signal", message)

def on_tts_request(self, data: dict):
self.generator_worker.add_to_queue(data)

Expand All @@ -175,34 +175,33 @@ def on_EngineResponseWorker_response_signal(self, response:dict):
if code == EngineResponseCode.IMAGE_GENERATED:
self.emit("image_generated_signal", response["message"])

@pyqtSlot()
def on_clear_memory_signal(self):
self.clear_memory()

@pyqtSlot(object)
def on_llm_text_streamed_signal(self, data):
self.do_tts_request(data["message"], data["is_end_of_message"])
self.emit("add_bot_message_to_conversation", data)

@pyqtSlot(object)
def on_sd_image_generated_signal(self, message):
self.emit("image_generated_signal", message)

@pyqtSlot(object)
def on_text_generate_request_signal(self, message):
self.move_sd_to_cpu()
self.llm_controller.do_request(message)
self.emit("llm_request_signal", message)

@pyqtSlot(object)
def on_image_generate_request_signal(self, message):
self.logger.info("on_image_generate_request_signal received")
# self.unload_llm(
# message,
# self.memory_settings["unload_unused_models"],
# self.memory_settings["move_unused_model_to_cpu"]
# )
self.emit("unload_llm_signal", dict(
do_unload_model=self.memory_settings["unload_unused_models"],
move_unused_model_to_cpu=self.memory_settings["move_unused_model_to_cpu"],
dtype=self.llm_generator_settings["dtype"],
callback=lambda _message=message: self.do_image_generate_request(_message)
))

def do_image_generate_request(self, message):
self.clear_memory()
self.emit("engine_do_request_signal", dict(
code=EngineRequestCode.GENERATE_IMAGE,
code=EngineRequestCode.GENERATE_IMAGE,
message=message
))

Expand Down Expand Up @@ -292,57 +291,19 @@ def do_tts_request(self, message: str, is_end_of_message: bool=False):
is_end_of_message=is_end_of_message,
))

def clear_memory(self, *args, **kwargs):
"""
Clear the GPU ram.
"""
self.logger.info("Clearing memory")
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()

def on_clear_llm_history_signal(self):
if self.llm:
self.llm_controller.clear_history()
self.emit("clear_history")

def stop(self):
self.logger.info("Stopping")
self.request_worker.stop()
self.response_worker.stop()

def unload_llm(self, request_data: dict, do_unload_model: bool, move_unused_model_to_cpu: bool):
"""
This function will either
1. Leave the LLM on the GPU
2. Move it to the CPU
3. Unload it from memory
The choice is dependent on the current dtype and other settings.
"""
do_move_to_cpu = not do_unload_model and move_unused_model_to_cpu

if request_data:
# Firist check the dtype
dtype = self.llm_generator_settings["dtype"]
if dtype in ["2bit", "4bit", "8bit"]:
do_unload_model = True
do_move_to_cpu = False

if do_move_to_cpu:
self.logger.info("Moving LLM to CPU")
self.llm_controller.move_to_cpu()
self.clear_memory()
# elif do_unload_model:
# self.do_unload_llm()

def do_unload_llm(self):
self.logger.info("Unloading LLM")
self.llm_controller.do_unload_llm()
#self.clear_memory()

def move_sd_to_cpu(self):
if ServiceLocator.get("is_pipe_on_cpu")() or not ServiceLocator.get("has_pipe")():
return
self.emit("move_pipe_to_cpu_signal")
self.clear_memory()
self.clear_memory()

def clear_memory(self):
clear_memory()
83 changes: 45 additions & 38 deletions src/airunner/aihandler/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,61 +21,62 @@ def __init__(self, prefix="LLMGenerateWorker"):
self.llm = LLM()
super().__init__(prefix=prefix)
self.register("clear_history", self)
self.register("LLMRequestWorker_response_signal", self)
self.register("unload_llm_signal", self)

def on_unload_llm_signal(self, message):
"""
This function will either
1. Leave the LLM on the GPU
2. Move it to the CPU
3. Unload it from memory
The choice is dependent on the current dtype and other settings.
"""
do_unload_model = message.get("do_unload_model", False)
move_unused_model_to_cpu = message.get("move_unused_model_to_cpu", False)
do_move_to_cpu = not do_unload_model and move_unused_model_to_cpu
dtype = message.get("dtype", "")
callback = message.get("callback", None)
if dtype in ["2bit", "4bit", "8bit"]:
do_unload_model = True
do_move_to_cpu = False
if do_move_to_cpu:
self.logger.info("Moving LLM to CPU")
self.llm.move_to_cpu()
elif do_unload_model:
self.llm.unload()
if callback:
callback()

def on_LLMRequestWorker_response_signal(self, message):
self.add_to_queue(message)

def handle_message(self, message):
for response in self.llm.do_generate(message):
self.emit("llm_text_streamed_signal", response)

def on_clear_history(self):
self.llm.clear_history()

def unload_llm(self):
self.llm.unload()


class LLMRequestWorker(Worker):
def __init__(self, prefix="LLMRequestWorker"):
super().__init__(prefix=prefix)
self.register("llm_request_signal", self)

def on_llm_request_signal(self, message):
print("adding llm request to queue", message)
self.add_to_queue(message)

def handle_message(self, message):
super().handle_message(message)


class LLMController(QObject, MediatorMixin):

def __init__(self, *args, **kwargs):
MediatorMixin.__init__(self)
self.engine = kwargs.pop("engine")
super().__init__(*args, **kwargs)
self.logger = Logger(prefix="LLMController")

self.request_worker = self.create_worker(LLMRequestWorker)
self.generate_worker = self.create_worker(LLMGenerateWorker)
self.register("LLMRequestWorker_response_signal", self)
self.register("LLMGenerateWorker_response_signal", self)

def pause(self):
self.request_worker.pause()
self.generate_worker.pause()

def resume(self):
self.request_worker.resume()
self.generate_worker.resume()

def do_request(self, message):
self.request_worker.add_to_queue(message)

def clear_history(self):
self.emit("clear_history")

def on_LLMRequestWorker_response_signal(self, message):
self.generate_worker.add_to_queue(message)

def on_LLMGenerateWorker_response_signal(self, message:dict):
self.emit("llm_controller_response_signal", message)

def do_unload_llm(self):
self.generate_worker.llm.unload_model()
self.generate_worker.llm.unload_tokenizer()


class LLM(QObject, MediatorMixin):
logger = Logger(prefix="LLM")
dtype = ""
Expand Down Expand Up @@ -254,6 +255,12 @@ def load_processor(self, local_files_only = None):
def clear_history(self):
self.history = []

def unload(self):
self.unload_model()
self.unload_tokenizer()
self.unload_processor()
self._processing_request = False

def unload_tokenizer(self):
self.logger.info("Unloading tokenizer")
self.tokenizer = None
Expand Down
16 changes: 8 additions & 8 deletions src/airunner/aihandler/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.service_locator import ServiceLocator
from airunner.utils import clear_memory

torch.backends.cuda.matmul.allow_tf32 = True

Expand Down Expand Up @@ -81,6 +82,7 @@ def handle_message(self, data):
image_base_path = data["image_base_path"]
message = data["message"]
for response in self.sd.generator_sample(message):
print("RESPONSE FROM sd.generate_sample", response)
if not response:
continue

Expand Down Expand Up @@ -1117,6 +1119,7 @@ def call_pipe(self, **kwargs):
return self.call_pipe_txt2vid(**args)

if not self.is_outpaint and not self.is_vid_action and not self.is_upscale:
self.latents = self.latents.to(self.device)
args["latents"] = self.latents

args["clip_skip"] = self.clip_skip
Expand Down Expand Up @@ -1413,7 +1416,7 @@ def on_unload_stablediffusion_signal(self):
def unload(self):
self.unload_model()
self.unload_tokenizer()
self.clear_memory()
clear_memory()

def unload_model(self):
self.pipe = None
Expand All @@ -1434,7 +1437,7 @@ def process_upscale(self, data: dict):
denoise_strength=self.options.get("denoise_strength", 0.5),
face_enhance=self.options.get("face_enhance", True),
).run()
self.clear_memory()
clear_memory()
else:
self.log_error("No image found, unable to upscale")
# check if results is a list
Expand Down Expand Up @@ -1476,7 +1479,7 @@ def generator_sample(self, data: dict):
error = e
if "PYTORCH_CUDA_ALLOC_CONF" in str(e):
error = self.cuda_error_message
self.clear_memory()
clear_memory()
self.reset_applied_memory_settings()
else:
error_message = f"Error during generation"
Expand Down Expand Up @@ -1574,12 +1577,9 @@ def unload_unused_models(self):
val.to("cpu")
setattr(self, action, None)
del val
self.clear_memory()
clear_memory()
self.reset_applied_memory_settings()

def clear_memory(self):
self.emit("clear_memory_signal")

def load_model(self):
self.logger.info("Loading model")
self.torch_compile_applied = False
Expand Down Expand Up @@ -1733,7 +1733,7 @@ def download_from_original_stable_diffusion_ckpt(self, path, local_files_only=No
def clear_controlnet(self):
self.logger.info("Clearing controlnet")
self._controlnet = None
self.clear_memory()
clear_memory()
self.reset_applied_memory_settings()
self.controlnet_loaded = False

Expand Down
Loading

0 comments on commit 9d69ccc

Please sign in to comment.