Skip to content

Commit

Permalink
Fixes #891 - allows LLM to generate images
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Oct 3, 2024
1 parent 5f72f6b commit 6709395
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 82 deletions.
10 changes: 9 additions & 1 deletion src/airunner/aihandler/stablediffusion/sd_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ def unload_stable_diffusion(self):
clear_memory()
self.change_model_status(ModelType.SD, ModelStatus.UNLOADED)

def handle_generate_signal(self, message: dict):
def handle_generate_signal(self, message: dict=None):
if self._model_status is not ModelStatus.LOADED:
self.load_stable_diffusion()

if self._current_state not in (
HandlerState.GENERATING,
HandlerState.PREPARING_TO_GENERATE
Expand All @@ -319,6 +322,11 @@ def handle_generate_signal(self, message: dict):
except Exception as e:
self.logger.error(f"Error generating image: {e}")
response = None
print("HANDLE GENERATE SIGNAL", message)
if message is not None:
callback = message.get("callback", None)
if callback:
callback(message)
self.emit_signal(SignalCode.ENGINE_RESPONSE_WORKER_RESPONSE_SIGNAL, {
'code': EngineResponseCode.IMAGE_GENERATED,
'message': response
Expand Down
4 changes: 3 additions & 1 deletion src/airunner/aihandler/tts/speecht5_tts_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ def speaker_embeddings_path(self):
)

def generate(self, message):
if self._do_interrupt or self._paused:
if self._model_status is not ModelStatus.LOADED:
return None

if self._do_interrupt or self._paused:
return None
try:
return self._do_generate(message)
except torch.cuda.OutOfMemoryError:
Expand Down
2 changes: 2 additions & 0 deletions src/airunner/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class SignalCode(Enum):
QUIT_APPLICATION = "quit"
TOGGLE_FULLSCREEN_SIGNAL = "fullscreen_signal"
TOGGLE_TTS_SIGNAL = "toggle_tts_signal"
TOGGLE_SD_SIGNAL = "toggle_sd_signal"
TOGGLE_LLM_SIGNAL = "toggle_llm_signal"
START_AUTO_IMAGE_GENERATION_SIGNAL = "start_auto_image_generation_signal"
STOP_AUTO_IMAGE_GENERATION_SIGNAL = "stop_auto_image_generation_signal"
LINES_UPDATED_SIGNAL = "lines_updated_signal"
Expand Down
3 changes: 3 additions & 0 deletions src/airunner/widgets/canvas/custom_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ def on_image_generated_signal(self, response):
else:
self.logger.error(f"Unhandled response code: {code}")
self.emit_signal(SignalCode.APPLICATION_STOP_SD_PROGRESS_BAR_SIGNAL)
callback = response.get("callback", None)
if callback:
callback(response)

def on_canvas_clear_signal(self):
self.update_current_settings("image", None)
Expand Down
125 changes: 116 additions & 9 deletions src/airunner/widgets/generator_form/generator_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,36 @@ def on_stop_image_generator_progress_bar_signal(self, _data):
def on_progress_signal(self, message):
self.handle_progress_bar(message)

##########################################################################
# LLM Generated Image handlers
##########################################################################
def on_llm_image_prompt_generated_signal(self, data):
"""
This slot is called after an LLM has generated the prompts for an image.
It sets the prompts in the generator form UI and continues the image generation process.
"""

# Send a messagae to the user as chatbot letting them know that the image is generating
self.emit_signal(
SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION,
dict(
message="Your image is generating...",
is_first_message=True,
is_end_of_message=True,
name=self.chatbot.name
)
)

# Unload the LLM
if self.application_settings.llm_enabled:
self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict(
callback=self.unload_llm_callback
))

# Set the prompts in the generator form UI
data = self.extract_json_from_message(data["message"])
prompt = data.get("description", None)
secondary_prompt = data.get("composition", None)
prompt = data.get("composition", None)
secondary_prompt = data.get("description", None)
prompt_type = data.get("type", ImageCategory.PHOTO.value)
if prompt_type == "photo":
negative_prompt = PHOTO_REALISTIC_NEGATIVE_PROMPT
Expand All @@ -182,7 +208,78 @@ def on_llm_image_prompt_generated_signal(self, data):
self.ui.negative_prompt.setPlainText(negative_prompt)
self.ui.secondary_prompt.setPlainText(secondary_prompt)
self.ui.secondary_negative_prompt.setPlainText(negative_prompt)
self.handle_generate_button_clicked()

def unload_llm_callback(self, data:dict=None):
"""
Callback function to be called after the LLM has been unloaded.
"""
if not self.application_settings.sd_enabled:
# If SD is not enabled, enable it and then emit a signal to generate the image
# The callback function is handled by the signal handler for the SD_LOAD_SIGNAL.
# The finalize function is a callback which is called after the image has been generated.
self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict(
callback=self.handle_generate_button_clicked,
finalize=self.finalize_image_generated_by_llm
))
else:
# If SD is already enabled, emit a signal to generate the image.
# The finalize function is a callback which is called after the image has been generated.
self.handle_generate_button_clicked(dict(
finalize=self.finalize_image_generated_by_llm
))

def finalize_image_generated_by_llm(self, data):
"""
Callback function to be called after the image has been generated.
"""

# Create a message to be sent to the user as a chatbot message
image_generated_message = dict(
message="Your image has been generated",
is_first_message=True,
is_end_of_message=True,
name=self.chatbot.name
)

# If SD is enabled, emit a signal to unload SD.
if self.application_settings.sd_enabled:
# If LLM is disabled, emit a signal to load it.
if not self.application_settings.llm_enabled:
self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict(
callback=lambda d: self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict(
callback=lambda d: self.emit_signal(
SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION,
image_generated_message
)
))
))
else:
self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict(
callback=lambda d: self.emit_signal(
SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION,
image_generated_message
)
))
else:
# If SD is disabled and LLM is disabled, emit a signal to load LLM
# with a callback to add the image generated message to the conversation.
if not self.application_settings.llm_enabled:
self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict(
callback=lambda d: self.emit_signal(
SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION,
image_generated_message
)
))
else:
# If SD is disabled and LLM is enabled, emit a signal to add
# the image generated message to the conversation.
self.emit_signal(
SignalCode.APPLICATION_ADD_BOT_MESSAGE_TO_CONVERSATION,
image_generated_message
)
##########################################################################
# End LLM Generated Image handlers
##########################################################################

def _set_chatbot_mood(self):
self.ui.mood_label.setText(self.chatbot.bot_mood)
Expand Down Expand Up @@ -232,8 +329,17 @@ def handle_image_presets_changed(self, val):
def do_generate_image_from_image_signal_handler(self, _data):
self.do_generate()

def do_generate(self):
self.emit_signal(SignalCode.DO_GENERATE_SIGNAL)
def do_generate(self, data=None):
print("DO GENERATE", data)
if data:
finalize = data.get("finalize", None)
if finalize:
data = dict(
callback=finalize
)
else:
data = None
self.emit_signal(SignalCode.DO_GENERATE_SIGNAL, data)

def activate_ai_mode(self):
ai_mode = self.application_settings.ai_mode
Expand All @@ -259,18 +365,19 @@ def handle_second_prompt_changed(self):
def handle_second_negative_prompt_changed(self):
pass

def handle_generate_button_clicked(self):
def handle_generate_button_clicked(self, data=None):
print("HANDLE GENERATE BUTTON CLICKED", data)
self.start_progress_bar()
self.generate()
self.generate(data)

@Slot()
def handle_interrupt_button_clicked(self):
self.emit_signal(SignalCode.INTERRUPT_IMAGE_GENERATION_SIGNAL)

def generate(self):
def generate(self, data=None):
if self.generator_settings.random_seed:
self.seed = random_seed()
self.do_generate()
self.do_generate(data)
self.seed_override = None

def do_generate_image(self):
Expand Down
Loading

0 comments on commit 6709395

Please sign in to comment.