From 24686630a99c9c3c80df76555b1dfba50e9cb181 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:49:42 -0600 Subject: [PATCH 1/2] update README.md --- README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/README.md b/README.md index 96d02193c..cae36bbde 100644 --- a/README.md +++ b/README.md @@ -204,7 +204,3 @@ The security measures taken for this library are as follows - All telemetry disabled See [Facehuggershield](https://github.com/capsize-games/facehuggershield) for more information. - ---- - -sudo groupadd docker From f6853dabd34171eaf05ff788ac0970c3b05d6c76 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:45:06 -0600 Subject: [PATCH 2/2] switch to RAKEKeywordTableIndex for better RAG results remove loading widget improve smooth scrolling --- setup.py | 5 +- src/airunner/handlers/llm/agent/base_agent.py | 47 +++++++++---------- src/airunner/handlers/llm/huggingface_llm.py | 2 +- .../widgets/llm/chat_prompt_widget.py | 40 ++-------------- 4 files changed, 29 insertions(+), 65 deletions(-) diff --git a/setup.py b/setup.py index fa055813e..fa187152d 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="airunner", - version="3.1.0", + version="3.1.1", author="Capsize LLC", description="A Stable Diffusion GUI", long_description=open("README.md", "r", encoding="utf-8").read(), @@ -65,7 +65,8 @@ "llama-index-llms-groq==0.2.0", "llama-index-embeddings-mistralai==0.2.0", "EbookLib==0.18", - "html2text==2024.2.26" + "html2text==2024.2.26", + "rake_nltk==1.0.6" ], package_data={ "airunner": [ diff --git a/src/airunner/handlers/llm/agent/base_agent.py b/src/airunner/handlers/llm/agent/base_agent.py index 645936a62..ee1c37d73 100644 --- a/src/airunner/handlers/llm/agent/base_agent.py +++ b/src/airunner/handlers/llm/agent/base_agent.py @@ -8,14 +8,13 @@ import torch from PySide6.QtCore import QObject -from llama_index.core import Settings +from llama_index.core import Settings, RAKEKeywordTableIndex from llama_index.core.base.llms.types import ChatMessage from llama_index.readers.file import EpubReader, PDFReader, MarkdownReader from llama_index.core import SimpleDirectoryReader from llama_index.core.node_parser import SentenceSplitter from llama_index.core import PromptHelper from llama_index.core.chat_engine import ContextChatEngine -from llama_index.core import SimpleKeywordTableIndex from llama_index.core.indices.keyword_table import KeywordTableSimpleRetriever from transformers import TextIteratorStreamer @@ -258,6 +257,17 @@ def interrupt_process(self): self.do_interrupt = True def do_interrupt_process(self): + if self.do_interrupt: + self.emit_signal( + SignalCode.LLM_TEXT_STREAMED_SIGNAL, + dict( + message="", + is_first_message=False, + is_end_of_message=False, + name=self.botname, + action=LLMActionType.CHAT + ) + ) return self.do_interrupt def mood(self, botname: str, bot_mood: str, use_mood: bool) -> str: @@ -564,23 +574,11 @@ def run_with_thread( self.emit_signal(SignalCode.UNBLOCK_TTS_GENERATOR_SIGNAL) - stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process) - - data = self.prepare_generate_data(model_inputs, stopping_criteria) - if self.do_interrupt: self.do_interrupt = False - self.emit_signal( - SignalCode.LLM_TEXT_STREAMED_SIGNAL, - dict( - message="", - is_first_message=False, - is_end_of_message=False, - name=self.botname, - action=LLMActionType.CHAT - ) - ) - return + + stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process) + data = self.prepare_generate_data(model_inputs, stopping_criteria) data["streamer"] = kwargs.get("streamer", self.streamer) @@ -657,6 +655,8 @@ def run_with_thread( if eos_token in new_text: streamed_template = streamed_template.replace(eos_token, "") new_text = new_text.replace(eos_token, "") + streamed_template = streamed_template.replace("<>", "") + new_text = new_text.replace("<>", "") is_end_of_message = True # strip botname from new_text new_text = new_text.replace(f"{self.botname}:", "") @@ -727,7 +727,10 @@ def run_with_thread( ) if streamed_template is not None: - if action is LLMActionType.CHAT: + if action in ( + LLMActionType.CHAT, + LLMActionType.PERFORM_RAG_SEARCH, + ): self.add_message_to_history( streamed_template, LLMChatRole.ASSISTANT @@ -776,12 +779,6 @@ def run_with_thread( ) return streamed_template - def add_chatbot_response_to_history(self, response: dict): - self.add_message_to_history( - response["message"], - response["role"] - ) - def get_db_connection(self): return sqlite3.connect('airunner.db') @@ -951,7 +948,7 @@ def __load_document_index(self): self.logger.debug("Loading index...") documents = self.__documents or [] try: - self.__index = SimpleKeywordTableIndex.from_documents( + self.__index = RAKEKeywordTableIndex.from_documents( documents, llm=self.__llm ) diff --git a/src/airunner/handlers/llm/huggingface_llm.py b/src/airunner/handlers/llm/huggingface_llm.py index 86da6b641..407443f2c 100644 --- a/src/airunner/handlers/llm/huggingface_llm.py +++ b/src/airunner/handlers/llm/huggingface_llm.py @@ -342,7 +342,7 @@ def complete( tokens = self._model.generate( **inputs, - # max_new_tokens=self.max_new_tokens, + max_new_tokens=self.max_new_tokens, stopping_criteria=self._stopping_criteria, **self.generate_kwargs, ) diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py index 371551d5d..7a287e616 100644 --- a/src/airunner/widgets/llm/chat_prompt_widget.py +++ b/src/airunner/widgets/llm/chat_prompt_widget.py @@ -4,7 +4,6 @@ from airunner.enums import SignalCode, LLMActionType, ModelType, ModelStatus from airunner.widgets.base_widget import BaseWidget -from airunner.widgets.llm.loading_widget import LoadingWidget from airunner.widgets.llm.templates.chat_prompt_ui import Ui_chat_prompt from airunner.widgets.llm.message_widget import MessageWidget @@ -15,7 +14,6 @@ class ChatPromptWidget(BaseWidget): def __init__(self, *args, **kwargs): super().__init__() self.scroll_bar = None - self._loading_widget = None self.conversation = None self.is_modal = True self.generating = False @@ -31,7 +29,6 @@ def __init__(self, *args, **kwargs): self.messages_spacer = None self.chat_loaded = False self.conversation_id = None - self._has_loading_widget = False self.ui.action.blockSignals(True) self.ui.action.addItem("Auto") @@ -99,10 +96,8 @@ def _set_conversation_widgets(self, messages): name=message["name"], message=message["content"], is_bot=message["is_bot"], - first_message=True, - use_loading_widget=False + first_message=True ) - self.scroll_to_bottom() def on_hear_signal(self, data: dict): transcription = data["transcription"] @@ -130,8 +125,7 @@ def on_add_bot_message_to_conversation(self, data: dict): name=name, message=message, is_bot=True, - first_message=is_first_message, - action=data["action"] + first_message=is_first_message ) if is_end_of_message: @@ -152,7 +146,6 @@ def _clear_conversation(self): self.conversation_history = [] self._clear_conversation_widgets() self._create_conversation() - self.remove_loading_widget() def _create_conversation(self): conversation_id = self.create_conversation() @@ -173,7 +166,6 @@ def interrupt_button_clicked(self): self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL) self.stop_progress_bar() self.generating = False - self.remove_loading_widget() self.enable_send_button() @property @@ -214,7 +206,6 @@ def do_generate(self, image_override=None, prompt_override=None, callback=None, } } ) - self.scroll_to_bottom() def on_token_signal(self, val): self.handle_token_signal(val) @@ -315,31 +306,12 @@ def describe_image(self, image, callback): generator_name="visualqa" ) - def add_loading_widget(self): - if not self._has_loading_widget: - self._has_loading_widget = True - if self._loading_widget is None: - self._loading_widget = LoadingWidget() - self.ui.scrollAreaWidgetContents.layout().addWidget( - self._loading_widget - ) - - def remove_loading_widget(self): - if self._has_loading_widget: - try: - self.ui.scrollAreaWidgetContents.layout().removeWidget(self._loading_widget) - except RuntimeError: - pass - self._has_loading_widget = False - def add_message_to_conversation( self, name, message, is_bot, - first_message=True, - use_loading_widget=True, - action:LLMActionType=LLMActionType.CHAT + first_message=True ): if not first_message: # get the last widget from the scrollAreaWidgetContents.layout() @@ -357,16 +329,10 @@ def add_message_to_conversation( self.remove_spacer() - if is_bot and use_loading_widget: - self.remove_loading_widget() - if message != "": widget = MessageWidget(name=name, message=message, is_bot=is_bot) self.ui.scrollAreaWidgetContents.layout().addWidget(widget) - if not is_bot and use_loading_widget: - self.add_loading_widget() - self.add_spacer() # automatically scroll to the bottom of the scrollAreaWidgetContents