Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better handling of content generation #380

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 61 additions & 97 deletions src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from airunner.workers.worker import Worker
from airunner.aihandler.enums import EngineRequestCode, EngineResponseCode
from airunner.aihandler.image_processor import ImageProcessor
from airunner.aihandler.llm import LLMEngine
from airunner.aihandler.llm import LLMController
from airunner.aihandler.logger import Logger
from airunner.aihandler.runner import SDRunner
from airunner.aihandler.runner import SDController
from airunner.aihandler.speech_to_text import SpeechToText
from airunner.aihandler.tts import TTS

Expand All @@ -34,17 +34,19 @@ class Engine(QObject):
# Signals
request_signal_status = pyqtSignal(str)
hear_signal = pyqtSignal(str)
text_generated_signal = pyqtSignal(dict)
image_generated_signal = pyqtSignal(dict)

# Loaded flags
llm_loaded: bool = False
sd_loaded: bool = False

# Model controllers
llm = None
sd = None
tts = None
stt = None
ocr = None
llm_controller = None
sd_controller = None
tts_controller = None
stt_controller = None
ocr_controller = None

# Message properties for EngineResponseCode.TEXT_STREAMED
message = ""
Expand Down Expand Up @@ -80,7 +82,7 @@ def do_response(self, response):

def cancel(self):
self.logger.info("Canceling")
self.sd.cancel()
self.sd_controller.cancel()
self.request_worker.cancel()

# END OFFLINE CLIENT
Expand Down Expand Up @@ -135,19 +137,20 @@ def response_worker_response_signal_slot(self, message):
EngineResponseCode.TEXT_STREAMED: self.handle_text_streamed,
EngineResponseCode.IMAGE_GENERATED: self.handle_image_generated,
EngineResponseCode.CAPTION_GENERATED: self.handle_caption_generated,
EngineResponseCode.CLEAR_MEMORY: self.clear_memory
}.get(message["code"], self.handle_default_response)(message["message"], message["code"])

def handle_generate_text(self, message):
self.move_sd_to_cpu()
self.llm.do_request(message["message"])
self.llm_controller.do_request(message["message"])

def handle_generate_image(self, message):
self.unload_llm(
message,
self.app.settings["memory_settings"]["unload_unused_model"],
self.app.settings["memory_settings"]["unload_unused_models"],
self.app.settings["memory_settings"]["move_unused_model_to_cpu"]
)
self.sd.generator_sample(message)
self.sd_controller.do_request(message["message"])

def handle_generate_caption(self, message):
pass
Expand All @@ -162,29 +165,26 @@ def handle_text_streamed(self, message, code):
self.current_message = self.current_message.replace("</s>", "")
# check if sentence enders are in self.current_message
is_end_of_message = "</s>" in message
self.tts.add_text(message.replace("</s>", ""), is_end_of_message=is_end_of_message)
self.app.message_handler_signal.emit(dict(
code=EngineResponseCode.ADD_TO_CONVERSATION,
message=dict(
name=self.app.settings["llm_generator_settings"]["botname"],
text=message.replace("</s>", ""),
is_bot=True,
first_message=self.first_message,
last_message=is_end_of_message
)
self.tts_controller.add_text(message.replace("</s>", ""), is_end_of_message=is_end_of_message)
self.text_generated_signal.emit(dict(
name=self.app.settings["llm_generator_settings"]["botname"],
text=message.replace("</s>", ""),
is_bot=True,
first_message=self.first_message,
last_message=is_end_of_message
))
self.first_message = False
if is_end_of_message:
self.first_message = True
self.message = ""
self.current_message = ""

# self.stt.do_listen()
# self.stt_controller.do_listen()

def handle_image_generated(self, message):
self.send_message(message, code)
def handle_image_generated(self, message, code):
self.image_generated_signal.emit(message)

def handle_caption_generated(self, message):
def handle_caption_generated(self, message, code):
self.send_message(message, code)

def __init__(self, **kwargs):
Expand All @@ -193,13 +193,23 @@ def __init__(self, **kwargs):
self.app = kwargs.get("app", None)
self.message_handler = kwargs.get("message_handler", None)
self.clear_memory()
self.initialize_llm() # Large language model
self.initialize_sd() # Art model
self.initialize_tts() # Text to speech model (voice)
self.initialize_stt() # Speech to text model (ears)
# self.initialize_ocr() # Vision to text model (eyes)

self.llm.response_signal.connect(self.do_response)
# Initialize Controllers
self.llm_controller = LLMController(engine=self)
self.sd_controller = SDController(engine=self)
#self.stt_controller = SpeechToText(engine=self, hear_signal=self.hear_signal, duration=10.0, fs=16000)
#self.hear_signal.connect(self.hear)
# self.listen_thread = threading.Thread(target=self.stt_controller.listen)
# self.listen_thread.start()

self.tts_controller = TTS(engine=self)
#self.tts_thread = threading.Thread(target=self.tts_controller.run)
#self.tts_thread.start()

# self.ocr_controller = ImageProcessor(engine=self)

self.llm_controller.response_signal.connect(self.do_response)
self.sd_controller.response_signal.connect(self.do_response)

# Request worker and thread
self.request_worker = Worker(prefix="RequestWorker")
Expand All @@ -222,8 +232,8 @@ def __init__(self, **kwargs):
def handle_default(self, message):
self.logger.error(f"Unknown code: {message['code']}")

def handle_default_response(self, message):
self.logger.error(f"handle_default_response Unknown code: {message['code']}")
def handle_default_response(self, message, code):
self.app.send_message(code, message)

def request_queue_size(self):
return self.request_worker.queue.qsize()
Expand All @@ -237,51 +247,6 @@ def send_message(self, message, code=None):
code=code,
message=message
))

def initialize_llm(self):
"""
Initialize the LLM.
"""
self.llm = LLMEngine(app=self.app, engine=self)

def initialize_sd(self):
"""
Initialize Stable Diffusion.
"""
self.sd = SDRunner(
app=self.app,
message_handler=self.message_handler,
engine=self
)

def initialize_stt(self):
"""
Initialize speech to text.
"""
self.stt = SpeechToText(
hear_signal=self.hear_signal,
engine=self,
duration=10.0,
fs=16000
)
self.hear_signal.connect(self.hear)
# self.listen_thread = threading.Thread(target=self.stt.listen)
# self.listen_thread.start()

def initialize_tts(self):
"""
Initialize text to speech.
"""
tts_settings = self.app.settings["tts_settings"]
self.tts = TTS(engine=self)
self.tts_thread = threading.Thread(target=self.tts.run)
self.tts_thread.start()

def initialize_ocr(self):
"""
Initialize vision to text.
"""
self.ocr = ImageProcessor(engine=self)

# def generator_sample(self, data: dict):
# """
Expand All @@ -300,22 +265,22 @@ def initialize_ocr(self):
# if not self.llm_loaded:
# self.logger.info("Preparing LLM")
# # if self.tts:
# # self.tts.move_model(to_cpu=False)
# # self.tts_controller.move_model(to_cpu=False)
# self.llm_loaded = True
# do_unload_model = data["request_data"].get("unload_unused_model", False)
# do_move_to_cpu = not do_unload_model and data["request_data"].get("move_unused_model_to_cpu", False)
# if do_move_to_cpu:
# self.move_sd_to_cpu()
# elif do_unload_model:
# self.sd.unload()
# self.sd_controller.unload()
# self.logger.info("Engine calling llm.do_generate")
# self.llm.do_generate(data)
# self.llm_controller.do_generate(data)

# def tts_generator_sample(self, data: dict):
# if "tts_request" not in data or not self.tts:
# return
# self.logger.info("Preparing TTS model...")
# # self.tts.move_model(to_cpu=False)
# # self.tts_controller.move_model(to_cpu=False)
# signal = data["request_data"].get("signal", None)
# message_object = data["request_data"].get("message_object", None)
# is_bot = data["request_data"].get("is_bot", False)
Expand All @@ -326,7 +291,7 @@ def initialize_ocr(self):
# # check if ends with a proper sentence ender, if not, add a period
# if not text.endswith((".", "?", "!", "...", "-", "—", )):
# text += "."
# generator = self.tts.add_text(text, "a", data["request_data"]["tts_settings"])
# generator = self.tts_controller.add_text(text, "a", data["request_data"]["tts_settings"])
# for success in generator:
# if signal and success:
# signal.emit(message_object, is_bot, first_message, last_message)
Expand All @@ -339,23 +304,23 @@ def initialize_ocr(self):
# self.sd_loaded = True
# self.do_unload_llm()
# self.logger.info("Engine calling sd.generator_sample")
# self.sd.generator_sample(data)
# self.sd_controller.generator_sample(data)

def do_listen(self):
# self.stt.do_listen()
# self.stt_controller.do_listen()
pass

def cancel(self):
"""
Cancel Stable Diffusion request.
"""
self.sd.cancel()
self.sd_controller.cancel()

def unload_stablediffusion(self):
"""
Unload the Stable Diffusion model from memory.
"""
self.sd.unload()
self.sd_controller.unload()

def parse_message(self, message):
if message:
Expand Down Expand Up @@ -385,9 +350,9 @@ def handle_tts(self, message: str):
# tts_settings=self.app.settings["tts_settings"]
# )
# )
self.tts.add_text(message)
self.tts_controller.add_text(message)

def clear_memory(self):
def clear_memory(self, *args, **kwargs):
"""
Clear the GPU ram.
"""
Expand All @@ -398,13 +363,13 @@ def clear_memory(self):

def clear_llm_history(self):
if self.llm:
self.llm.clear_history()
self.llm_controller.clear_history()

def stop(self):
self.logger.info("Stopping")
self.request_worker.stop()
self.response_worker.stop()
self.stt.stop()
#self.stt_controller.stop()

def unload_llm(self, request_data: dict, do_unload_model: bool, move_unused_model_to_cpu: bool):
"""
Expand All @@ -427,19 +392,18 @@ def unload_llm(self, request_data: dict, do_unload_model: bool, move_unused_mode

if do_move_to_cpu:
self.logger.info("Moving LLM to CPU")
self.llm.move_to_cpu()
self.llm_controller.move_to_cpu()
self.clear_memory()
elif do_unload_model:
self.do_unload_llm()

def do_unload_llm(self):
self.logger.info("Unloading LLM")
self.llm.unload_model()
self.llm.unload_tokenizer()
self.llm_controller.do_unload_llm()
self.clear_memory()

def move_sd_to_cpu(self):
if self.sd.is_pipe_on_cpu or not self.sd.has_pipe:
if self.sd_controller.is_pipe_on_cpu or not self.sd_controller.has_pipe:
return
self.sd.move_pipe_to_cpu()
self.sd_controller.move_pipe_to_cpu()
self.clear_memory()
2 changes: 2 additions & 0 deletions src/airunner/aihandler/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class EngineResponseCode(Enum):
TEXT_STREAMED = 701
CAPTION_GENERATED = 800
ADD_TO_CONVERSATION = 900
CLEAR_MEMORY = 1000
NSFW_CONTENT_DETECTED = 1100


class EngineRequestCode(Enum):
Expand Down
Loading