Skip to content

Commit

Permalink
Merge pull request #951 from Capsize-Games/devastator
Browse files Browse the repository at this point in the history
Devastator
  • Loading branch information
w4ffl35 authored Oct 23, 2024
2 parents 1e96bb4 + f6853da commit 41a4bbb
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 69 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,3 @@ The security measures taken for this library are as follows
- All telemetry disabled

See [Facehuggershield](https://github.com/capsize-games/facehuggershield) for more information.

---

sudo groupadd docker
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="airunner",
version="3.1.0",
version="3.1.1",
author="Capsize LLC",
description="A Stable Diffusion GUI",
long_description=open("README.md", "r", encoding="utf-8").read(),
Expand Down Expand Up @@ -65,7 +65,8 @@
"llama-index-llms-groq==0.2.0",
"llama-index-embeddings-mistralai==0.2.0",
"EbookLib==0.18",
"html2text==2024.2.26"
"html2text==2024.2.26",
"rake_nltk==1.0.6"
],
package_data={
"airunner": [
Expand Down
47 changes: 22 additions & 25 deletions src/airunner/handlers/llm/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import torch
from PySide6.QtCore import QObject

from llama_index.core import Settings
from llama_index.core import Settings, RAKEKeywordTableIndex
from llama_index.core.base.llms.types import ChatMessage
from llama_index.readers.file import EpubReader, PDFReader, MarkdownReader
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import PromptHelper
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core import SimpleKeywordTableIndex
from llama_index.core.indices.keyword_table import KeywordTableSimpleRetriever
from transformers import TextIteratorStreamer

Expand Down Expand Up @@ -258,6 +257,17 @@ def interrupt_process(self):
self.do_interrupt = True

def do_interrupt_process(self):
if self.do_interrupt:
self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
dict(
message="",
is_first_message=False,
is_end_of_message=False,
name=self.botname,
action=LLMActionType.CHAT
)
)
return self.do_interrupt

def mood(self, botname: str, bot_mood: str, use_mood: bool) -> str:
Expand Down Expand Up @@ -564,23 +574,11 @@ def run_with_thread(

self.emit_signal(SignalCode.UNBLOCK_TTS_GENERATOR_SIGNAL)

stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process)

data = self.prepare_generate_data(model_inputs, stopping_criteria)

if self.do_interrupt:
self.do_interrupt = False
self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
dict(
message="",
is_first_message=False,
is_end_of_message=False,
name=self.botname,
action=LLMActionType.CHAT
)
)
return

stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process)
data = self.prepare_generate_data(model_inputs, stopping_criteria)

data["streamer"] = kwargs.get("streamer", self.streamer)

Expand Down Expand Up @@ -657,6 +655,8 @@ def run_with_thread(
if eos_token in new_text:
streamed_template = streamed_template.replace(eos_token, "")
new_text = new_text.replace(eos_token, "")
streamed_template = streamed_template.replace("<</SYS>>", "")
new_text = new_text.replace("<</SYS>>", "")
is_end_of_message = True
# strip botname from new_text
new_text = new_text.replace(f"{self.botname}:", "")
Expand Down Expand Up @@ -727,7 +727,10 @@ def run_with_thread(
)

if streamed_template is not None:
if action is LLMActionType.CHAT:
if action in (
LLMActionType.CHAT,
LLMActionType.PERFORM_RAG_SEARCH,
):
self.add_message_to_history(
streamed_template,
LLMChatRole.ASSISTANT
Expand Down Expand Up @@ -776,12 +779,6 @@ def run_with_thread(
)
return streamed_template

def add_chatbot_response_to_history(self, response: dict):
self.add_message_to_history(
response["message"],
response["role"]
)

def get_db_connection(self):
return sqlite3.connect('airunner.db')

Expand Down Expand Up @@ -951,7 +948,7 @@ def __load_document_index(self):
self.logger.debug("Loading index...")
documents = self.__documents or []
try:
self.__index = SimpleKeywordTableIndex.from_documents(
self.__index = RAKEKeywordTableIndex.from_documents(
documents,
llm=self.__llm
)
Expand Down
2 changes: 1 addition & 1 deletion src/airunner/handlers/llm/huggingface_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def complete(

tokens = self._model.generate(
**inputs,
# max_new_tokens=self.max_new_tokens,
max_new_tokens=self.max_new_tokens,
stopping_criteria=self._stopping_criteria,
**self.generate_kwargs,
)
Expand Down
40 changes: 3 additions & 37 deletions src/airunner/widgets/llm/chat_prompt_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from airunner.enums import SignalCode, LLMActionType, ModelType, ModelStatus
from airunner.widgets.base_widget import BaseWidget
from airunner.widgets.llm.loading_widget import LoadingWidget
from airunner.widgets.llm.templates.chat_prompt_ui import Ui_chat_prompt
from airunner.widgets.llm.message_widget import MessageWidget

Expand All @@ -15,7 +14,6 @@ class ChatPromptWidget(BaseWidget):
def __init__(self, *args, **kwargs):
super().__init__()
self.scroll_bar = None
self._loading_widget = None
self.conversation = None
self.is_modal = True
self.generating = False
Expand All @@ -31,7 +29,6 @@ def __init__(self, *args, **kwargs):
self.messages_spacer = None
self.chat_loaded = False
self.conversation_id = None
self._has_loading_widget = False

self.ui.action.blockSignals(True)
self.ui.action.addItem("Auto")
Expand Down Expand Up @@ -99,10 +96,8 @@ def _set_conversation_widgets(self, messages):
name=message["name"],
message=message["content"],
is_bot=message["is_bot"],
first_message=True,
use_loading_widget=False
first_message=True
)
self.scroll_to_bottom()

def on_hear_signal(self, data: dict):
transcription = data["transcription"]
Expand Down Expand Up @@ -130,8 +125,7 @@ def on_add_bot_message_to_conversation(self, data: dict):
name=name,
message=message,
is_bot=True,
first_message=is_first_message,
action=data["action"]
first_message=is_first_message
)

if is_end_of_message:
Expand All @@ -152,7 +146,6 @@ def _clear_conversation(self):
self.conversation_history = []
self._clear_conversation_widgets()
self._create_conversation()
self.remove_loading_widget()

def _create_conversation(self):
conversation_id = self.create_conversation()
Expand All @@ -173,7 +166,6 @@ def interrupt_button_clicked(self):
self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL)
self.stop_progress_bar()
self.generating = False
self.remove_loading_widget()
self.enable_send_button()

@property
Expand Down Expand Up @@ -214,7 +206,6 @@ def do_generate(self, image_override=None, prompt_override=None, callback=None,
}
}
)
self.scroll_to_bottom()

def on_token_signal(self, val):
self.handle_token_signal(val)
Expand Down Expand Up @@ -315,31 +306,12 @@ def describe_image(self, image, callback):
generator_name="visualqa"
)

def add_loading_widget(self):
if not self._has_loading_widget:
self._has_loading_widget = True
if self._loading_widget is None:
self._loading_widget = LoadingWidget()
self.ui.scrollAreaWidgetContents.layout().addWidget(
self._loading_widget
)

def remove_loading_widget(self):
if self._has_loading_widget:
try:
self.ui.scrollAreaWidgetContents.layout().removeWidget(self._loading_widget)
except RuntimeError:
pass
self._has_loading_widget = False

def add_message_to_conversation(
self,
name,
message,
is_bot,
first_message=True,
use_loading_widget=True,
action:LLMActionType=LLMActionType.CHAT
first_message=True
):
if not first_message:
# get the last widget from the scrollAreaWidgetContents.layout()
Expand All @@ -357,16 +329,10 @@ def add_message_to_conversation(

self.remove_spacer()

if is_bot and use_loading_widget:
self.remove_loading_widget()

if message != "":
widget = MessageWidget(name=name, message=message, is_bot=is_bot)
self.ui.scrollAreaWidgetContents.layout().addWidget(widget)

if not is_bot and use_loading_widget:
self.add_loading_widget()

self.add_spacer()

# automatically scroll to the bottom of the scrollAreaWidgetContents
Expand Down

0 comments on commit 41a4bbb

Please sign in to comment.