Skip to content

Commit

Permalink
Merge pull request #935 from Capsize-Games/devastator
Browse files Browse the repository at this point in the history
Devastator
  • Loading branch information
w4ffl35 authored Oct 12, 2024
2 parents f8c2e09 + 2a867ba commit 65e23a6
Show file tree
Hide file tree
Showing 20 changed files with 544 additions and 296 deletions.
1 change: 1 addition & 0 deletions src/airunner/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ class LLMActionType(Enum):
TOGGLE_TTS = "TOGGLE TEXT-TO-SPEECH: If the user requests that you turn on or off or toggle text-to-speech, choose this action."
PERFORM_RAG_SEARCH = "SEARCH: If the user requests that you search for information, choose this action."
SUMMARIZE = "SUMMARIZE"
DO_NOTHING = "DO NOTHING: If the user's request is unclear or you are unable to determine the user's intent, choose this action."



Expand Down
27 changes: 12 additions & 15 deletions src/airunner/handlers/llm/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core import SimpleKeywordTableIndex
from llama_index.core.indices.keyword_table import KeywordTableSimpleRetriever
from transformers import TextIteratorStreamer

from airunner.handlers.llm.huggingface_llm import HuggingFaceLLM
from airunner.handlers.llm.custom_embedding import CustomEmbedding
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self, *args, **kwargs):
self.action = LLMActionType.CHAT
self.rendered_template = None
self.tokenizer = kwargs.pop("tokenizer", None)
self.streamer = kwargs.pop("streamer", None)
self.streamer = TextIteratorStreamer(self.tokenizer)
self.chat_template = kwargs.pop("chat_template", "")
self.is_mistral = kwargs.pop("is_mistral", True)
self.conversation_id = None
Expand All @@ -97,12 +98,11 @@ def __init__(self, *args, **kwargs):
@property
def available_actions(self):
return {
0: LLMActionType.QUIT_APPLICATION,
1: LLMActionType.TOGGLE_FULLSCREEN,
2: LLMActionType.TOGGLE_TTS,
3: LLMActionType.GENERATE_IMAGE,
4: LLMActionType.PERFORM_RAG_SEARCH,
5: LLMActionType.CHAT,
0: LLMActionType.TOGGLE_FULLSCREEN,
1: LLMActionType.TOGGLE_TTS,
2: LLMActionType.GENERATE_IMAGE,
3: LLMActionType.PERFORM_RAG_SEARCH,
4: LLMActionType.CHAT,
}

@property
Expand Down Expand Up @@ -163,7 +163,7 @@ def interrupt_process(self):

def do_interrupt_process(self):
interrupt = self.do_interrupt
self.do_interrupt = False
self.streamer = TextIteratorStreamer(self.tokenizer)
return interrupt

@property
Expand Down Expand Up @@ -303,9 +303,7 @@ def build_system_prompt(
self.names_prompt(use_names, botname, username),
self.mood(botname, bot_mood, use_mood),
system_instructions,
"------\n",
"Chat History:\n",
f"{self.username}: {self.prompt}\n",
self.history_prompt(),
]

elif action is LLMActionType.SUMMARIZE:
Expand Down Expand Up @@ -502,10 +500,9 @@ def run(
self.create_conversation()

# Add the user's message to history
if action in (
LLMActionType.CHAT,
LLMActionType.PERFORM_RAG_SEARCH,
LLMActionType.GENERATE_IMAGE,
if action not in (
LLMActionType.APPLICATION_COMMAND,
LLMActionType.UPDATE_MOOD
):
self.add_message_to_history(self.prompt, LLMChatRole.HUMAN)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def clear_history(self):
"""
Public method to clear the chat agent history
"""
if not self._chat_agent:
return
self.logger.debug("Clearing chat history")
self._chat_agent.clear_history()

Expand Down Expand Up @@ -301,7 +303,6 @@ def _load_agent(self):
self._chat_agent = BaseAgent(
model=self._model,
tokenizer=self._tokenizer,
streamer=self._streamer,
chat_template=self.chat_template,
is_mistral=self.is_mistral,
)
Expand Down Expand Up @@ -378,8 +379,7 @@ def _load_model_local(self):

def _do_generate(self, prompt: str, action: LLMActionType):
self.logger.debug("Generating response")
model_path = self.model_path
if self._current_model_path != model_path:
if self._current_model_path != self.model_path:
self.unload()
self.load()
if action is LLMActionType.CHAT and self.chatbot.use_mood:
Expand Down
61 changes: 41 additions & 20 deletions src/airunner/handlers/stablediffusion/sd_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
EngineResponseCode, ModelAction
)
from airunner.exceptions import PipeNotLoadedException, InterruptedException
from airunner.handlers.stablediffusion.prompt_weight_bridge import PromptWeightBridge
from airunner.settings import MIN_NUM_INFERENCE_STEPS_IMG2IMG
from airunner.utils.clear_memory import clear_memory
from airunner.utils.convert_base64_to_image import convert_base64_to_image
Expand Down Expand Up @@ -390,6 +391,26 @@ def _pipeline_class(self):
def mask_blur(self) -> int:
return self.outpaint_settings_cached.mask_blur

@property
def prompt(self):
prompt = self.generator_settings_cached.prompt
return PromptWeightBridge.convert(prompt)

@property
def second_prompt(self):
prompt = self.generator_settings_cached.second_prompt
return PromptWeightBridge.convert(prompt)

@property
def negative_prompt(self):
prompt = self.generator_settings_cached.negative_prompt
return PromptWeightBridge.convert(prompt)

@property
def second_negative_prompt(self):
prompt = self.generator_settings_cached.second_negative_prompt
return PromptWeightBridge.convert(prompt)

def load_safety_checker(self):
"""
Public method to load the safety checker model.
Expand Down Expand Up @@ -1330,43 +1351,43 @@ def _load_prompt_embeds(self):
self.logger.debug("Loading prompt embeds")
if not self.generator_settings_cached.use_compel:
return
prompt = self.generator_settings_cached.prompt
negative_prompt = self.generator_settings_cached.negative_prompt
prompt_2 = self.generator_settings_cached.second_prompt
negative_prompt_2 = self.generator_settings_cached.second_negative_prompt

prompt = self.prompt
negative_prompt = self.negative_prompt
second_prompt = self.second_prompt
second_negative_prompt = self.second_negative_prompt

if (
self._current_prompt != prompt
or self._current_negative_prompt != negative_prompt
or self._current_prompt_2 != prompt_2
or self._current_negative_prompt_2 != negative_prompt_2
or self._current_prompt_2 != second_prompt
or self._current_negative_prompt_2 != second_negative_prompt
):
self._unload_latents()
self._current_prompt = prompt
self._current_negative_prompt = negative_prompt
self._current_prompt_2 = prompt_2
self._current_negative_prompt_2 = negative_prompt_2
self._current_prompt_2 = second_prompt
self._current_negative_prompt_2 = second_negative_prompt
self._unload_prompt_embeds()

pooled_prompt_embeds = None
negative_pooled_prompt_embeds = None

if prompt != "" and prompt_2 != "":
compel_prompt = f'("{prompt}", "{prompt_2}").and()'
elif prompt != "" and prompt_2 == "":
if prompt != "" and second_prompt != "":
compel_prompt = f'("{prompt}", "{second_prompt}").and()'
elif prompt != "" and second_prompt == "":
compel_prompt = prompt
elif prompt == "" and prompt_2 != "":
compel_prompt = prompt_2
elif prompt == "" and second_prompt != "":
compel_prompt = second_prompt
else:
compel_prompt = ""

if negative_prompt != "" and negative_prompt_2 != "":
compel_negative_prompt = f'("{negative_prompt}", "{negative_prompt_2}").and()'
elif negative_prompt != "" and negative_prompt_2 == "":
if negative_prompt != "" and second_negative_prompt != "":
compel_negative_prompt = f'("{negative_prompt}", "{second_negative_prompt}").and()'
elif negative_prompt != "" and second_negative_prompt == "":
compel_negative_prompt = negative_prompt
elif negative_prompt == "" and negative_prompt_2 != "":
compel_negative_prompt = negative_prompt_2
elif negative_prompt == "" and second_negative_prompt != "":
compel_negative_prompt = second_negative_prompt
else:
compel_negative_prompt = ""

Expand Down Expand Up @@ -1442,8 +1463,8 @@ def _prepare_data(self, active_rect = None) -> dict:
))
else:
args.update(dict(
prompt=self.generator_settings_cached.prompt,
negative_prompt=self.generator_settings_cached.negative_prompt
prompt=self.prompt,
negative_prompt=self.negative_prompt
))

width = int(self.application_settings_cached.working_width)
Expand Down
Loading

0 comments on commit 65e23a6

Please sign in to comment.