diff --git a/setup.py b/setup.py index 5ecc7d35f..fa055813e 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="airunner", - version="3.0.21", + version="3.1.0", author="Capsize LLC", description="A Stable Diffusion GUI", long_description=open("README.md", "r", encoding="utf-8").read(), diff --git a/src/airunner/app.py b/src/airunner/app.py index add0cdd5b..dd1415127 100644 --- a/src/airunner/app.py +++ b/src/airunner/app.py @@ -32,7 +32,6 @@ from airunner.windows.main.settings_mixin import SettingsMixin from airunner.data.models.settings_models import ApplicationSettings, AIModels from airunner.windows.main.main_window import MainWindow -from airunner.handlers.logger import Logger class App( @@ -55,7 +54,6 @@ def __init__( """ self.main_window_class_ = main_window_class or MainWindow self.app = None - self.logger = Logger(prefix=self.__class__.__name__) self.defendatron = defendatron self.splash = None @@ -95,9 +93,7 @@ def create_paths(self): "images" ) )) - session = self.db_handler.get_db_session() - versions = session.query(distinct(AIModels.version)).filter(AIModels.category == 'stablediffusion').all() - session.close() + versions = self.session.query(distinct(AIModels.version)).filter(AIModels.category == 'stablediffusion').all() for version in versions: os.makedirs( os.path.join(models_path, version[0], "embeddings"), @@ -110,9 +106,7 @@ def create_paths(self): os.makedirs(images_path, exist_ok=True) def run_setup_wizard(self): - session = self.db_handler.get_db_session() - application_settings = session.query(ApplicationSettings).first() - session.close() + application_settings = self.session.query(ApplicationSettings).first() if application_settings.run_setup_wizard: AppInstaller() diff --git a/src/airunner/app_installer.py b/src/airunner/app_installer.py index d03a6a2af..51305ca2b 100644 --- a/src/airunner/app_installer.py +++ b/src/airunner/app_installer.py @@ -19,7 +19,6 @@ from airunner.windows.download_wizard.download_wizard_window import DownloadWizardWindow from airunner.windows.main.settings_mixin import SettingsMixin from airunner.windows.setup_wizard.setup_wizard_window import SetupWizardWindow -from airunner.handlers.logger import Logger class AppInstaller( @@ -43,7 +42,6 @@ def __init__( self.download_wizard = None self.app = None self.close_on_cancel = close_on_cancel - self.logger = Logger(prefix=self.__class__.__name__) """ Mediator and Settings mixins are initialized here, enabling the application diff --git a/src/airunner/data/models/database_handler.py b/src/airunner/data/models/database_handler.py deleted file mode 100644 index 519dd6304..000000000 --- a/src/airunner/data/models/database_handler.py +++ /dev/null @@ -1,24 +0,0 @@ -import os -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from airunner.data.models.settings_models import Base - -class DatabaseHandler: - def __init__(self, db_path=os.path.expanduser( - os.path.join( - "~", - ".local", - "share", - "airunner", - "data", - "airunner.db" - ) - )): - self.db_path = db_path - self.engine = create_engine(f'sqlite:///{self.db_path}') - Base.metadata.create_all(self.engine) - self.Session = sessionmaker(bind=self.engine) - self.conversation_id = None - - def get_db_session(self): - return self.Session() diff --git a/src/airunner/data/models/settings_db_handler.py b/src/airunner/data/models/settings_db_handler.py deleted file mode 100644 index 8104d5a6a..000000000 --- a/src/airunner/data/models/settings_db_handler.py +++ /dev/null @@ -1,601 +0,0 @@ -import datetime -from typing import List - -from sqlalchemy.orm import joinedload - -from airunner.data.models.database_handler import DatabaseHandler -from airunner.data.models.settings_models import Chatbot, AIModels, Schedulers, Lora, PathSettings, SavedPrompt, \ - Embedding, PromptTemplate, ControlnetModel, FontSetting, PipelineModel, ShortcutKeys, \ - GeneratorSettings, WindowSettings, ApplicationSettings, ActiveGridSettings, ControlnetSettings, \ - ImageToImageSettings, OutpaintSettings, DrawingPadSettings, MetadataSettings, \ - LLMGeneratorSettings, TTSSettings, SpeechT5Settings, EspeakSettings, STTSettings, BrushSettings, GridSettings, \ - MemorySettings, Message, Conversation, Summary - - -class SettingsDBHandler(DatabaseHandler): - ####################################### - ### SCHEDULERS ### - ####################################### - def load_schedulers(self) -> List[Schedulers]: - session = self.get_db_session() - try: - return session.query(Schedulers).all() - finally: - session.close() - - ####################################### - ### SETTINGS ### - ####################################### - def load_settings_from_db(self, model_class_): - session = self.get_db_session() - try: - settings = session.query(model_class_).first() - if settings is None: - settings = self.create_new_settings(model_class_) - finally: - session.close() - return settings - - def update_setting(self, model_class_, name, value): - session = self.get_db_session() - try: - setting = session.query(model_class_).order_by(model_class_.id.desc()).first() - if setting: - setattr(setting, name, value) - session.commit() - finally: - session.close() - - def save_generator_settings(self, generator_settings: GeneratorSettings): - session = self.get_db_session() - try: - query = session.query(GeneratorSettings).filter_by( - id=generator_settings.id - ).first() - if query: - for key in generator_settings.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(generator_settings, key)) - else: - session.add(generator_settings) - session.commit() - finally: - session.close() - - def reset_settings(self): - session = self.get_db_session() - try: - # Delete all entries from the model class - session.query(ApplicationSettings).delete() - session.query(ActiveGridSettings).delete() - session.query(ControlnetSettings).delete() - session.query(ImageToImageSettings).delete() - session.query(OutpaintSettings).delete() - session.query(DrawingPadSettings).delete() - session.query(MetadataSettings).delete() - session.query(GeneratorSettings).delete() - session.query(LLMGeneratorSettings).delete() - session.query(TTSSettings).delete() - session.query(SpeechT5Settings).delete() - session.query(EspeakSettings).delete() - session.query(STTSettings).delete() - session.query(BrushSettings).delete() - session.query(GridSettings).delete() - session.query(PathSettings).delete() - session.query(MemorySettings).delete() - # Commit the changes - session.commit() - finally: - session - - def create_new_settings(self, model_class_): - session = self.get_db_session() - try: - new_settings = model_class_() - session.add(new_settings) - session.commit() - session.refresh(new_settings) - finally: - session.close() - return new_settings - - ####################################### - ### SAVED PROMPTS ### - ####################################### - def get_saved_prompt_by_id(self, prompt_id) -> SavedPrompt: - session = self.get_db_session() - try: - return session.query(SavedPrompt).filter_by(id=prompt_id).first() - finally: - session.close() - - def update_saved_prompt(self, saved_prompt: SavedPrompt): - session = self.get_db_session() - try: - query = session.query(SavedPrompt).filter_by( - id=saved_prompt.id - ).first() - if query: - for key in saved_prompt.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(saved_prompt, key)) - else: - session.add(saved_prompt) - session.commit() - finally: - session.close() - - def create_saved_prompt(self, data: dict): - session = self.get_db_session() - try: - new_saved_prompt = SavedPrompt(**data) - session.add(new_saved_prompt) - session.commit() - finally: - session.close() - - def load_saved_prompts(self) -> List[SavedPrompt]: - session = self.get_db_session() - try: - return session.query(SavedPrompt).all() - finally: - session.close() - - def load_font_settings(self) -> List[FontSetting]: - session = self.get_db_session() - try: - return session.query(FontSetting).all() - finally: - session.close() - - def get_font_setting_by_name(self, name) -> FontSetting: - session = self.get_db_session() - try: - return session.query(FontSetting).filter_by(name=name).first() - finally: - session.close() - - def update_font_setting(self, font_setting: FontSetting): - session = self.get_db_session() - try: - query = session.query(FontSetting).filter_by( - name=font_setting.name - ).first() - if query: - for key in font_setting.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(font_setting, key)) - else: - session.add(font_setting) - session.commit() - finally: - session.close() - - ####################################### - ### AI MODELS ### - ####################################### - def load_ai_models(self) -> List[AIModels]: - session = self.get_db_session() - try: - return session.query(AIModels).all() - finally: - session.close() - - def update_ai_models(self, models: List[AIModels]): - for model in models: - self.update_ai_model(model) - - def update_ai_model(self, model: AIModels): - session = self.get_db_session() - try: - query = session.query(AIModels).filter_by( - name=model.name, - path=model.path, - branch=model.branch, - version=model.version, - category=model.category, - pipeline_action=model.pipeline_action, - enabled=model.enabled, - model_type=model.model_type, - is_default=model.is_default - ).first() - if query: - for key in model.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(model, key)) - else: - session.add(model) - session.commit() - finally: - session.close() - - ####################################### - ### CHATBOTS ### - ####################################### - def load_chatbots(self) -> List[Chatbot]: - session = self.get_db_session() - try: - settings = session.query(Chatbot).all() - return settings - finally: - session.close() - - def delete_chatbot_by_name(self, chatbot_name): - session = self.get_db_session() - try: - session.query(Chatbot).filter_by(name=chatbot_name).delete() - session.commit() - finally: - session.close() - - def create_chatbot(self, chatbot_name): - session = self.get_db_session() - try: - new_chatbot = Chatbot(name=chatbot_name) - session.add(new_chatbot) - session.commit() - finally: - session.close() - - def reset_path_settings(self): - session = self.get_db_session() - try: - # Delete all entries from PathSettings - session.query(PathSettings).delete() - - # Create a new PathSettings instance with default values - self.set_default_values(PathSettings) - # Commit the changes - session.commit() - finally: - session.close() - - def set_default_values(self, model_name_): - session = self.get_db_session() - try: - default_values = {} - for column in model_name_.__table__.columns: - if column.default is not None: - default_values[column.name] = column.default.arg - session.execute( - model_name_.__table__.insert(), - [default_values] - ) - session.commit() - finally: - session.close() - - ####################################### - ### LORA ### - ####################################### - def load_lora(self) -> List[Lora]: - session = self.get_db_session() - try: - return session.query(Lora).all() - finally: - session.close() - - def get_lora_by_name(self, name): - session = self.get_db_session() - try: - return session.query(Lora).filter_by(name=name).first() - finally: - session.close() - - - def add_lora(self, lora: Lora): - session = self.get_db_session() - try: - session.add(lora) - session.commit() - finally: - session.close() - - def delete_lora(self, lora: Lora): - session = self.get_db_session() - try: - session.query(Lora).filter_by(name=lora.name).delete() - session.commit() - finally: - session.close() - - def update_lora(self, lora: Lora): - session = self.get_db_session() - try: - query = session.query(Lora).filter_by(name=lora.name).first() - if query: - for key in lora.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(lora, key)) - else: - session.add(lora) - session.commit() - finally: - session.close() - - def update_loras(self, loras: List[Lora]): - session = self.get_db_session() - try: - for lora in loras: - query = session.query(Lora).filter_by(name=lora.name).first() - if query: - for key in lora.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(lora, key)) - else: - session.add(lora) - session.commit() - finally: - session.close() - - def create_lora(self, lora: Lora): - session = self.get_db_session() - try: - session.add(lora) - session.commit() - finally: - session.close() - - def delete_lora_by_name(self, lora_name, version): - session = self.get_db_session() - try: - session.query(Lora).filter_by(name=lora_name, version=version).delete() - session.commit() - finally: - session.close() - - ####################################### - ### EMBEDDINGS ### - ####################################### - def delete_embedding(self, embedding: Embedding): - session = self.get_db_session() - try: - session.query(Embedding).filter_by( - name=embedding.name, - path=embedding.path, - branch=embedding.branch, - version=embedding.version, - category=embedding.category, - pipeline_action=embedding.pipeline_action, - enabled=embedding.enabled, - model_type=embedding.model_type, - is_default=embedding.is_default - ).delete() - session.commit() - finally: - session.close() - - def load_embeddings(self) -> List[Embedding]: - session = self.get_db_session() - try: - return session.query(Embedding).all() - finally: - session.close() - - def update_embeddings(self, embeddings: List[Embedding]): - session = self.get_db_session() - try: - for embedding in embeddings: - query = session.query(Embedding).filter_by( - name=embedding.name, - path=embedding.path, - branch=embedding.branch, - version=embedding.version, - category=embedding.category, - pipeline_action=embedding.pipeline_action, - enabled=embedding.enabled, - model_type=embedding.model_type, - is_default=embedding.is_default - ).first() - if query: - for key in embedding.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(embedding, key)) - else: - session.add(embedding) - session.commit() - finally: - session.close() - - def get_embedding_by_name(self, name): - session = self.get_db_session() - try: - return session.query(Embedding).filter_by(name=name).first() - finally: - session.close() - - def add_embedding(self, embedding: Embedding): - session = self.get_db_session() - try: - session.add(embedding) - session.commit() - finally: - session.close() - - ####################################### - ### PROMPT TEMPLATES ### - ####################################### - def load_prompt_templates(self) -> List[PromptTemplate]: - session = self.get_db_session() - try: - return session.query(PromptTemplate).all() - finally: - session.close() - - def get_prompt_template_by_name(self, name) -> PromptTemplate: - session = self.get_db_session() - try: - return session.query(PromptTemplate).filter_by(template_name=name).first() - finally: - session.close() - - - ####################################### - ### CONTROLNET MODELS ### - ####################################### - def load_controlnet_models(self) -> List[ControlnetModel]: - session = self.get_db_session() - try: - return session.query(ControlnetModel).all() - finally: - session.close() - - def controlnet_model_by_name(self, name) -> ControlnetModel: - session = self.get_db_session() - try: - return session.query(ControlnetModel).filter_by(name=name).first() - finally: - session.close() - - def load_pipelines(self) -> List[PipelineModel]: - session = self.get_db_session() - try: - return session.query(PipelineModel).all() - finally: - session.close() - - - def load_shortcut_keys(self) -> List[ShortcutKeys]: - session = self.get_db_session() - try: - return session.query(ShortcutKeys).all() - finally: - session.close() - - - def load_window_settings(self) -> WindowSettings: - session = self.get_db_session() - try: - return session.query(WindowSettings).first() - finally: - session.close() - - - def save_window_settings(self, window_settings: WindowSettings): - session = self.get_db_session() - try: - query = session.query(WindowSettings).first() - if query: - for key in window_settings.__dict__.keys(): - if key != "_sa_instance_state": - setattr(query, key, getattr(window_settings, key)) - else: - session.add(window_settings) - session.commit() - finally: - session.close() - - def save_object(self, database_object): - session = self.get_db_session() - session.add(database_object) - session.commit() - session.close() - - - def load_history_from_db(self, conversation_id): - with self.get_db_session() as session: - messages = session.query(Message).filter_by( - conversation_id=conversation_id - ).order_by(Message.timestamp).all() - results = [ - { - "role": message.role, - "content": message.content, - "name": message.name, - "is_bot": message.is_bot, - "timestamp": message.timestamp, - "conversation_id": message.conversation_id - } for message in messages - ] - return results - - def add_message_to_history(self, content, role, name, is_bot, conversation_id): - timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object - with self.get_db_session() as session: - llm_generator_settings = session.query(LLMGeneratorSettings).first() - message = Message( - role=role, - content=content, - name=name, - is_bot=is_bot, - timestamp=timestamp, - conversation_id=conversation_id, - chatbot_id=llm_generator_settings.current_chatbot - ) - session.add(message) - session.commit() - - def get_chatbot_by_id(self, chatbot_id) -> Chatbot: - session = self.get_db_session() - try: - chatbot = session.query(Chatbot).filter_by(id=chatbot_id).options(joinedload(Chatbot.target_files)).first() - if chatbot is None: - chatbot = session.query(Chatbot).options(joinedload(Chatbot.target_files)).first() - finally: - session.close() - return chatbot - - def create_conversation(self): - with self.get_db_session() as session: - conversation = Conversation( - timestamp=datetime.datetime.now(datetime.timezone.utc), - title="" - ) - session.add(conversation) - session.commit() - return conversation.id - - def update_conversation_title(self, conversation_id, title): - with self.get_db_session() as session: - conversation = session.query(Conversation).filter_by(id=conversation_id).first() - if conversation: - conversation.title = title - session.commit() - - def add_summary(self, content, conversation_id): - timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object - with self.get_db_session() as session: - summary = Summary( - content=content, - timestamp=timestamp, - conversation_id=conversation_id - ) - session.add(summary) - session.commit() - - def create_conversation_with_messages(self, messages): - conversation_id = self.create_conversation() - for message in messages: - self.add_message_to_history( - content=message["content"], - role=message["role"], - name=message["name"], - is_bot=message["is_bot"], - conversation_id=conversation_id - ) - return conversation_id - - def get_all_conversations(self): - session = self.Session() - conversations = session.query(Conversation).all() - session.close() - return conversations - - def delete_conversation(self, conversation_id): - session = self.Session() - try: - session.query(Message).filter_by(conversation_id=conversation_id).delete() - session.query(Summary).filter_by(conversation_id=conversation_id).delete() - session.query(Conversation).filter_by(id=conversation_id).delete() - session.commit() - except Exception as e: - session.rollback() - print(f"Error deleting conversation: {e}") - finally: - session.close() - - def get_most_recent_conversation_id(self): - session = self.Session() - conversation = session.query(Conversation).order_by(Conversation.timestamp.desc()).first() - session.close() - return conversation.id if conversation else None diff --git a/src/airunner/handlers/base_handler.py b/src/airunner/handlers/base_handler.py index 0fbe3177f..1e718a3ca 100644 --- a/src/airunner/handlers/base_handler.py +++ b/src/airunner/handlers/base_handler.py @@ -2,7 +2,6 @@ from PySide6.QtCore import QObject from airunner.enums import HandlerType, SignalCode, ModelType, ModelStatus, ModelAction from airunner.mediator_mixin import MediatorMixin -from airunner.handlers.logger import Logger from airunner.utils.get_torch_device import get_torch_device from airunner.windows.main.settings_mixin import SettingsMixin @@ -23,7 +22,6 @@ class BaseHandler( def __init__(self, *args, **kwargs): self._model_status = {model_type: ModelStatus.UNLOADED for model_type in ModelType} self.use_gpu = True - self.logger = Logger(prefix=self.__class__.__name__) MediatorMixin.__init__(self) SettingsMixin.__init__(self) super().__init__(*args, **kwargs) diff --git a/src/airunner/handlers/llm/agent/base_agent.py b/src/airunner/handlers/llm/agent/base_agent.py index 8cb664a36..645936a62 100644 --- a/src/airunner/handlers/llm/agent/base_agent.py +++ b/src/airunner/handlers/llm/agent/base_agent.py @@ -9,6 +9,7 @@ from PySide6.QtCore import QObject from llama_index.core import Settings +from llama_index.core.base.llms.types import ChatMessage from llama_index.readers.file import EpubReader, PDFReader, MarkdownReader from llama_index.core import SimpleDirectoryReader from llama_index.core.node_parser import SentenceSplitter @@ -23,7 +24,6 @@ from airunner.handlers.llm.custom_embedding import CustomEmbedding from airunner.handlers.llm.agent.html_file_reader import HtmlFileReader from airunner.handlers.llm.agent.external_condition_stopping_criteria import ExternalConditionStoppingCriteria -from airunner.handlers.logger import Logger from airunner.mediator_mixin import MediatorMixin from airunner.enums import ( SignalCode, @@ -38,13 +38,19 @@ from airunner.workers.agent_worker import AgentWorker +class RefreshContextChatEngine(ContextChatEngine): + def stream_chat(self, *args, system_prompt:str=None, **kwargs): + if system_prompt: + self._prefix_messages[0] = ChatMessage(content=system_prompt, role=self._llm.metadata.system_role) + return super().stream_chat(*args, **kwargs) + + class BaseAgent( QObject, MediatorMixin, SettingsMixin ): def __init__(self, *args, **kwargs): - self.logger = Logger(prefix=self.__class__.__name__) MediatorMixin.__init__(self) SettingsMixin.__init__(self) self.model = kwargs.pop("model", None) @@ -84,19 +90,33 @@ def __init__(self, *args, **kwargs): self.action = LLMActionType.CHAT self.rendered_template = None self.tokenizer = kwargs.pop("tokenizer", None) - self.streamer = TextIteratorStreamer(self.tokenizer) + self._streamer = None self.chat_template = kwargs.pop("chat_template", "") self.is_mistral = kwargs.pop("is_mistral", True) self.conversation_id = None self.conversation_title = None - self.history = self.db_handler.load_history_from_db(self.conversation_id) # Load history by conversation ID + self.history = self.load_history_from_db(self.conversation_id) # Load history by conversation ID super().__init__(*args, **kwargs) self.prompt = "" self.thread = None - self.do_interrupt = False + self._do_interrupt = False self.response_worker = create_worker(AgentWorker) self.load_rag(model=self.model, tokenizer=self.tokenizer) + @property + def do_interrupt(self): + return self._do_interrupt + + @do_interrupt.setter + def do_interrupt(self, value): + self._do_interrupt = value + + @property + def streamer(self): + if self._streamer is None: + self._streamer = TextIteratorStreamer(self.tokenizer) + return self._streamer + @property def available_actions(self): return { @@ -123,13 +143,85 @@ def bot_mood(self) -> str: def bot_mood(self, value: str): chatbot = self.chatbot chatbot.bot_mood = value - self.db_handler.save_object(chatbot) + self.save_object(chatbot) self.emit_signal(SignalCode.BOT_MOOD_UPDATED) @property def bot_personality(self) -> str: return self.chatbot.bot_personality + @property + def override_parameters(self): + generate_kwargs = prepare_llm_generate_kwargs(self.llm_generator_settings) + return generate_kwargs if self.llm_generator_settings.override_parameters else {} + + @property + def system_instructions(self): + return self.chatbot.system_instructions + + @property + def generator_settings(self) -> dict: + return prepare_llm_generate_kwargs(self.chatbot) + + @property + def device(self): + return get_torch_device(self.memory_settings.default_gpu_llm) + + @property + def target_files(self): + return [ + target_file.file_path for target_file in self.chatbot.target_files + ] + + @property + def query_instruction(self): + if self.__state == AgentState.SEARCH: + return self.__query_instruction + elif self.__state == AgentState.CHAT: + return "Search through the chat history for anything relevant to the query." + + @property + def text_instruction(self): + if self.__state == AgentState.SEARCH: + return self.__text_instruction + elif self.__state == AgentState.CHAT: + return "Use the text to respond to the user" + + @property + def index(self): + if self.__state == AgentState.SEARCH: + return self.__index + elif self.__state == AgentState.CHAT: + return self.__chat_history_index + + @property + def llm(self): + if self.__llm is None: + try: + if self.llm_generator_settings.use_api: + self.__llm = self.__model + else: + self.__llm = HuggingFaceLLM(model=self.__model, tokenizer=self.__tokenizer) + except Exception as e: + self.logger.error(f"Error loading LLM: {str(e)}") + return self.__llm + + @property + def chat_engine(self): + return self.__chat_engine + + @property + def is_llama_instruct(self): + return True + + @property + def use_cuda(self): + return torch.cuda.is_available() + + @property + def cuda_index(self): + return 0 + def unload(self): self.unload_rag() del self.model @@ -144,39 +236,29 @@ def clear_history(self): self.conversation_id = None self.conversation_title = None - def update_conversation_title(self, title): + def _update_conversation_title(self, title): self.conversation_title = title - self.db_handler.update_conversation_title(self.conversation_id, title) + self.update_conversation_title(self.conversation_id, title) - def create_conversation(self): + def _create_conversation(self): # Get the most recent conversation ID - recent_conversation_id = self.db_handler.get_most_recent_conversation_id() + recent_conversation_id = self.get_most_recent_conversation_id() # Check if there are messages for the most recent conversation ID if recent_conversation_id is not None: - messages = self.db_handler.load_history_from_db(recent_conversation_id) + messages = self.load_history_from_db(recent_conversation_id) if not messages: self.conversation_id = recent_conversation_id return # If there are messages or no recent conversation ID, create a new conversation - self.conversation_id = self.db_handler.create_conversation() + self.conversation_id = self.create_conversation() def interrupt_process(self): self.do_interrupt = True def do_interrupt_process(self): - interrupt = self.do_interrupt - self.streamer = TextIteratorStreamer(self.tokenizer) - return interrupt - - @property - def use_cuda(self): - return torch.cuda.is_available() - - @property - def cuda_index(self): - return 0 + return self.do_interrupt def mood(self, botname: str, bot_mood: str, use_mood: bool) -> str: return ( @@ -261,12 +343,6 @@ def build_system_prompt( ] elif action is LLMActionType.GENERATE_IMAGE: - prompt_template = self.get_prompt_template_by_name("image") - # system_prompt = [ - # prompt_template.guardrails, - # prompt_template.system, - # self.history_prompt() - # ] system_prompt = [ ( "You are an image generator. " @@ -285,7 +361,7 @@ def build_system_prompt( "You will only return JSON strings.\n" "You will not return any other data types.\n" "You are an artist, so use your imagination and keep things interesting.\n" - "You will not respond in a conversational manner or with additonal notes or information.\n" + "You will not respond in a conversational manner or with additional notes or information.\n" f"Only return one JSON block. Do not generate instructions or additional information.\n" "You must never break the rules.\n" "Here is a description of the attributes: \n" @@ -322,7 +398,6 @@ def build_system_prompt( prompt_template = self.get_prompt_template_by_name("update_mood") system_instructions = prompt_template.system system_prompt = [ - guardrails_prompt, system_instructions, self.names_prompt(use_names, botname, username), self.mood(botname, bot_mood, use_mood), @@ -339,7 +414,7 @@ def build_system_prompt( self.names_prompt(use_names, botname, username), self.mood(botname, bot_mood, use_mood), self.personality_prompt(bot_personality, use_personality), - self.history_prompt(), + # self.history_prompt(), ] elif action is LLMActionType.QUIT_APPLICATION: @@ -420,70 +495,6 @@ def get_rendered_template( rendered_template = rendered_template.replace("{{ " + key + " }}", value) return rendered_template - @property - def override_parameters(self): - generate_kwargs = prepare_llm_generate_kwargs(self.llm_generator_settings) - return generate_kwargs if self.llm_generator_settings.override_parameters else {} - - @property - def system_instructions(self): - return self.chatbot.system_instructions - - @property - def generator_settings(self) -> dict: - return prepare_llm_generate_kwargs(self.chatbot) - - @property - def device(self): - return get_torch_device(self.memory_settings.default_gpu_llm) - - @property - def target_files(self): - return [ - target_file.file_path for target_file in self.chatbot.target_files - ] - - @property - def query_instruction(self): - if self.__state == AgentState.SEARCH: - return self.__query_instruction - elif self.__state == AgentState.CHAT: - return "Search through the chat history for anything relevant to the query." - - @property - def text_instruction(self): - if self.__state == AgentState.SEARCH: - return self.__text_instruction - elif self.__state == AgentState.CHAT: - return "Use the text to respond to the user" - - @property - def index(self): - if self.__state == AgentState.SEARCH: - return self.__index - elif self.__state == AgentState.CHAT: - return self.__chat_history_index - - @property - def llm(self): - if self.__llm is None: - try: - if self.llm_generator_settings.use_api: - self.__llm = self.__model - else: - self.__llm = HuggingFaceLLM(model=self.__model, tokenizer=self.__tokenizer) - except Exception as e: - self.logger.error(f"Error loading LLM: {str(e)}") - return self.__llm - - @property - def chat_engine(self): - return self.__chat_engine - - @property - def is_llama_instruct(self): - return True - def run( self, prompt: str, @@ -501,10 +512,9 @@ def run( self.logger.debug("Running...") self.prompt = prompt - streamer = self.streamer if self.conversation_id is None: - self.create_conversation() + self._create_conversation() self.set_conversation_title() # Add the user's message to history @@ -512,7 +522,10 @@ def run( LLMActionType.APPLICATION_COMMAND, LLMActionType.UPDATE_MOOD ): - self.add_message_to_history(self.prompt, LLMChatRole.HUMAN) + self.add_message_to_history( + self.prompt, + LLMChatRole.HUMAN + ) self.rendered_template = self.get_rendered_template(action) @@ -522,34 +535,14 @@ def run( ).to(self.device) kwargs.update( - streamer=streamer, action=action, do_emit_response=True ) - if streamer: - self.run_with_thread( - model_inputs, - **kwargs - ) - else: - self.emit_signal(SignalCode.UNBLOCK_TTS_GENERATOR_SIGNAL) - stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process) - data = self.prepare_generate_data(model_inputs, stopping_criteria) - res = self.model.generate(**data) - response = self.tokenizer.decode(res[0]) - self.emit_signal( - SignalCode.LLM_TEXT_STREAMED_SIGNAL, - dict( - message=response, - is_first_message=True, - is_end_of_message=True, - name=self.botname, - action=action - ) - ) - - return response + self.run_with_thread( + model_inputs, + **kwargs + ) def prepare_generate_data(self, model_inputs, stopping_criteria): data = dict( @@ -574,13 +567,22 @@ def run_with_thread( stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process) data = self.prepare_generate_data(model_inputs, stopping_criteria) - streamer = kwargs.get("streamer", self.streamer) - data["streamer"] = streamer - # if "attention_mask" in data: - # del data["attention_mask"] + if self.do_interrupt: + self.do_interrupt = False + self.emit_signal( + SignalCode.LLM_TEXT_STREAMED_SIGNAL, + dict( + message="", + is_first_message=False, + is_end_of_message=False, + name=self.botname, + action=LLMActionType.CHAT + ) + ) + return - self.do_interrupt = False + data["streamer"] = kwargs.get("streamer", self.streamer) if action is not LLMActionType.PERFORM_RAG_SEARCH: try: @@ -620,14 +622,14 @@ def run_with_thread( ) is_first_message = False - if streamer and action in ( + if action in ( LLMActionType.CHAT, LLMActionType.GENERATE_IMAGE, LLMActionType.UPDATE_MOOD, LLMActionType.SUMMARIZE, LLMActionType.APPLICATION_COMMAND ): - for new_text in streamer: + for new_text in self.streamer: # strip all newlines from new_text streamed_template += new_text if self.is_mistral: @@ -693,7 +695,10 @@ def run_with_thread( ) data.update(self.override_parameters) self.llm.generate_kwargs = data - response = self.chat_engine.stream_chat(message=self.prompt) + response = self.chat_engine.stream_chat( + message=self.prompt, + system_prompt=self.rendered_template + ) is_first_message = True is_end_of_message = False for new_text in response.response_gen: @@ -710,6 +715,16 @@ def run_with_thread( ) ) is_first_message = False + self.emit_signal( + SignalCode.LLM_TEXT_STREAMED_SIGNAL, + dict( + message="", + is_first_message=False, + is_end_of_message=True, + name=self.botname, + action=action + ) + ) if streamed_template is not None: if action is LLMActionType.CHAT: @@ -727,7 +742,7 @@ def run_with_thread( ) elif action is LLMActionType.SUMMARIZE: - self.update_conversation_title(streamed_template) + self._update_conversation_title(streamed_template) return self.run( prompt=self.prompt, action=LLMActionType.CHAT, @@ -780,22 +795,14 @@ def add_message_to_history( name = self.username is_bot = False + if role is LLMChatRole.ASSISTANT and content: content = content.replace(f"{self.botname}:", "") content = content.replace(f"{self.botname}", "") is_bot = True name = self.botname - self.history.append({ - "role": role.value, - "content": content, - "name": name, - "is_bot": is_bot, - "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "conversation_id": self.conversation_id # Use the stored conversation ID - }) - - self.db_handler.add_message_to_history( + message = self.save_message( content, role.value, name, @@ -803,19 +810,28 @@ def add_message_to_history( self.conversation_id ) + self.history.append({ + "role": message.role, + "content": message.content, + "name": name, + "is_bot": message.is_bot, + "timestamp": message.timestamp, + "conversation_id": message.conversation_id + }) + def on_load_conversation(self, message): self.history = [] self.conversation_id = message["conversation_id"] - self.history = self.db_handler.load_history_from_db(self.conversation_id) + self.history = self.load_history_from_db(self.conversation_id) self.set_conversation_title() self.emit_signal(SignalCode.SET_CONVERSATION, { "messages": self.history }) def set_conversation_title(self): - session = self.db_handler.get_db_session() - self.conversation_title = session.query(Conversation).filter_by(id=self.conversation_id).first().title - session.close() + + self.conversation_title = self.session.query(Conversation).filter_by(id=self.conversation_id).first().title + def load_rag(self, model, tokenizer): self.__model = model @@ -913,7 +929,7 @@ def __load_prompt_helper(self): def __load_context_chat_engine(self): try: - self.__chat_engine = ContextChatEngine.from_defaults( + self.__chat_engine = RefreshContextChatEngine.from_defaults( retriever=self.__retriever, chat_history=self.history, memory=None, diff --git a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py index 655c28098..de54cd33a 100644 --- a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py +++ b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py @@ -388,7 +388,8 @@ def _do_generate(self, prompt: str, action: LLMActionType): prompt, action ) - self._send_final_message() + if action is LLMActionType.CHAT: + self._send_final_message() def _emit_streamed_text_signal(self, **kwargs): self.logger.debug("Emitting streamed text signal") diff --git a/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py b/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py index f4356b6c5..ca439a030 100644 --- a/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py +++ b/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py @@ -6,7 +6,6 @@ from airunner.handlers.logger import Logger from airunner.enums import SignalCode from airunner.mediator_mixin import MediatorMixin -from facehuggershield.huggingface.settings import DEFAULT_HF_ENDPOINT from airunner.windows.main.settings_mixin import SettingsMixin logger = Logger(prefix="DownloadWorker") diff --git a/src/airunner/handlers/stablediffusion/download_civitai.py b/src/airunner/handlers/stablediffusion/download_civitai.py index ad41dd852..56d8bf836 100644 --- a/src/airunner/handlers/stablediffusion/download_civitai.py +++ b/src/airunner/handlers/stablediffusion/download_civitai.py @@ -4,7 +4,6 @@ from PySide6.QtCore import QThread from airunner.handlers.logger import Logger from airunner.handlers.stablediffusion.civit_ai_download_worker import CivitAIDownloadWorker -from airunner.handlers.stablediffusion.download_worker import DownloadWorker from airunner.enums import SignalCode from airunner.mediator_mixin import MediatorMixin from airunner.windows.main.settings_mixin import SettingsMixin diff --git a/src/airunner/handlers/stablediffusion/sd_handler.py b/src/airunner/handlers/stablediffusion/sd_handler.py index 009dafcb9..95c3c6eb4 100644 --- a/src/airunner/handlers/stablediffusion/sd_handler.py +++ b/src/airunner/handlers/stablediffusion/sd_handler.py @@ -50,7 +50,6 @@ class SDHandler(BaseHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._session = self.db_handler.get_db_session() self._controlnet_model = None self._controlnet: ControlNetModel = None self._controlnet_processor: Any = None @@ -206,7 +205,7 @@ def path_settings_cached(self): @property def generator_settings_cached(self): if self._generator_settings is None: - self._generator_settings = self._session.query( + self._generator_settings = self.session.query( GeneratorSettings ).first() return self._generator_settings @@ -235,12 +234,12 @@ def controlnet_model(self) -> ControlnetModel: self._controlnet_model.version != self.generator_settings_cached.version or self._controlnet_model.display_name != self.controlnet_settings_cached.controlnet ): - session = self.db_handler.get_db_session() - self._controlnet_model = session.query(ControlnetModel).filter_by( + + self._controlnet_model = self.session.query(ControlnetModel).filter_by( display_name=self.controlnet_settings_cached.controlnet, version=self.generator_settings_cached.version ).first() - session.close() + return self._controlnet_model @property @@ -866,8 +865,8 @@ def _load_scheduler(self, scheduler=None): "scheduler_config.json" ) ) - session = self.db_handler.get_db_session() - scheduler = session.query(Schedulers).filter_by(display_name=scheduler_name).first() + + scheduler = self.session.query(Schedulers).filter_by(display_name=scheduler_name).first() if not scheduler: self.logger.error(f"Failed to find scheduler {scheduler_name}") return None @@ -941,8 +940,8 @@ def _load_pipe(self): self.logger.error(f"Failed to load model to device: {e}") def _load_lora(self): - session = self.db_handler.get_db_session() - enabled_lora = session.query(Lora).filter_by( + + enabled_lora = self.session.query(Lora).filter_by( version=self.generator_settings_cached.version, enabled=True ).all() @@ -979,9 +978,9 @@ def _load_lora_weights(self, lora: Lora): def _set_lora_adapters(self): self.logger.debug("Setting LORA adapters") - session = self.db_handler.get_db_session() + loaded_lora_id = [l.id for l in self._loaded_lora.values()] - enabled_lora = session.query(Lora).filter(Lora.id.in_(loaded_lora_id)).all() + enabled_lora = self.session.query(Lora).filter(Lora.id.in_(loaded_lora_id)).all() adapter_weights = [] adapter_names = [] for lora in enabled_lora: @@ -1002,11 +1001,11 @@ def _load_embeddings(self): self._pipe.unload_textual_inversion() except RuntimeError as e: self.logger.error(f"Failed to unload embeddings: {e}") - session = self.db_handler.get_db_session() - embeddings = session.query(Embedding).filter_by( + + embeddings = self.session.query(Embedding).filter_by( version=self.generator_settings_cached.version ).all() - session.close() + for embedding in embeddings: embedding_path = embedding.path if embedding.active and embedding_path not in self._loaded_embeddings: diff --git a/src/airunner/handlers/stt/whisper_handler.py b/src/airunner/handlers/stt/whisper_handler.py index 07ccf58de..26f87f42c 100644 --- a/src/airunner/handlers/stt/whisper_handler.py +++ b/src/airunner/handlers/stt/whisper_handler.py @@ -71,13 +71,16 @@ def process_audio(self, audio_data): if transcription: self._send_transcription(transcription) - def load(self): + def load(self, retry:bool = False): if self.stt_is_loading or self.stt_is_loaded: return self.logger.debug("Loading Whisper (text-to-speech)") self.unload() self.change_model_status(ModelType.STT, ModelStatus.LOADING) self._load_model() + # unsure why this is failing to load occasionally - this is a hack + if self._model is None and retry is False: + return self.load(retry=True) self._load_processor() self._load_feature_extractor() if ( diff --git a/src/airunner/handlers/tts/speecht5_tts_handler.py b/src/airunner/handlers/tts/speecht5_tts_handler.py index f0190f5ef..ec19e88cc 100644 --- a/src/airunner/handlers/tts/speecht5_tts_handler.py +++ b/src/airunner/handlers/tts/speecht5_tts_handler.py @@ -241,8 +241,6 @@ def _do_generate(self, message): self.logger.debug("Processing inputs...") - print(text) - inputs = self._processor( text=text, return_tensors="pt", @@ -262,7 +260,7 @@ def _do_generate(self, message): vocoder=self._vocoder, max_length=100 ) - except RuntimeError as e: + except Exception as e: self.logger.error("Failed to generate speech") self.logger.error(e) self._cancel_generated_speech = False diff --git a/src/airunner/settings.py b/src/airunner/settings.py index 364262a63..5b50a3dcb 100644 --- a/src/airunner/settings.py +++ b/src/airunner/settings.py @@ -31,7 +31,7 @@ ) ORGANIZATION = "Capsize Games" APPLICATION_NAME = "AI Runner" -LOG_LEVEL = logging.DEBUG +LOG_LEVEL = logging.ERROR DEFAULT_LLM_HF_PATH = "w4ffl35/Mistral-7B-Instruct-v0.3-4bit" DEFAULT_STT_HF_PATH = "openai/whisper-tiny" DEFAULT_IMAGE_SYSTEM_PROMPT = "\n".join([ diff --git a/src/airunner/utils/models/scan_path_for_items.py b/src/airunner/utils/models/scan_path_for_items.py index 22a82f9fa..5c6913a9d 100644 --- a/src/airunner/utils/models/scan_path_for_items.py +++ b/src/airunner/utils/models/scan_path_for_items.py @@ -1,13 +1,14 @@ import os -from airunner.data.models.settings_db_handler import SettingsDBHandler from airunner.data.models.settings_models import Lora, Embedding +from airunner.windows.main.settings_mixin import SettingsMixin + def scan_path_for_lora(base_path) -> bool: lora_added = False lora_deleted = False - db_handler = SettingsDBHandler() + db_handler = SettingsMixin() for versionpath, versionnames, versionfiles in os.walk(os.path.expanduser(os.path.join(base_path, "art/models"))): version = versionpath.split("/")[-1] lora_path = os.path.expanduser( @@ -21,11 +22,10 @@ def scan_path_for_lora(base_path) -> bool: if not os.path.exists(lora_path): continue - session = db_handler.get_db_session() - existing_lora = session.query(Lora).all() + existing_lora = db_handler.session.query(Lora).all() for lora in existing_lora: if not os.path.exists(lora.path): - session.delete(lora) + db_handler.session.delete(lora) lora_deleted = True for dirpath, dirnames, filenames in os.walk(lora_path): for file in filenames: @@ -43,17 +43,16 @@ def scan_path_for_lora(base_path) -> bool: trigger_word="", version=version ) - session.add(item) + db_handler.session.add(item) lora_added = True if lora_deleted or lora_added: - session.commit() - session.close() + db_handler.session.commit() return lora_deleted or lora_added def scan_path_for_embeddings(base_path) -> bool: embedding_added = False embedding_deleted = False - db_handler = SettingsDBHandler() + db_handler = SettingsMixin() items = [] for versionpath, versionnames, versionfiles in os.walk(os.path.expanduser(os.path.join(base_path, "art/models"))): version = versionpath.split("/")[-1] @@ -67,11 +66,10 @@ def scan_path_for_embeddings(base_path) -> bool: ) if not os.path.exists(embedding_path): continue - session = db_handler.get_db_session() - existing_embeddings = session.query(Embedding).all() + existing_embeddings = db_handler.session.query(Embedding).all() for embedding in existing_embeddings: if not os.path.exists(embedding.path): - session.delete(embedding) + db_handler.session.delete(embedding) embedding_deleted = True for dirpath, dirnames, filenames in os.walk(embedding_path): for file in filenames: @@ -88,9 +86,8 @@ def scan_path_for_embeddings(base_path) -> bool: active=False, trigger_word="" ) - session.add(item) + db_handler.session.add(item) embedding_added = True if embedding_deleted or embedding_added: - session.commit() - session.close() + db_handler.session.commit() return embedding_deleted or embedding_added diff --git a/src/airunner/utils/network/huggingface_downloader.py b/src/airunner/utils/network/huggingface_downloader.py index 3bbfb548b..18a65e7b5 100644 --- a/src/airunner/utils/network/huggingface_downloader.py +++ b/src/airunner/utils/network/huggingface_downloader.py @@ -1,6 +1,5 @@ from typing import Callable from PySide6.QtCore import QObject, QThread, Signal -from airunner.handlers.logger import Logger from airunner.handlers.stablediffusion.download_worker import DownloadWorker from airunner.enums import SignalCode from airunner.mediator_mixin import MediatorMixin @@ -20,7 +19,6 @@ def __init__(self, callback=None): super(HuggingfaceDownloader, self).__init__() self.thread = None self.worker = None - self.logger = Logger(prefix="HuggingfaceDownloader") self.downloading = False self.thread = QThread() diff --git a/src/airunner/widgets/base_widget.py b/src/airunner/widgets/base_widget.py index 99b83a675..32313b996 100644 --- a/src/airunner/widgets/base_widget.py +++ b/src/airunner/widgets/base_widget.py @@ -3,11 +3,9 @@ from PySide6 import QtGui from PySide6.QtWidgets import QWidget -from airunner.handlers.logger import Logger from airunner.enums import CanvasToolName from airunner.windows.main.settings_mixin import SettingsMixin from airunner.mediator_mixin import MediatorMixin -from airunner.settings import DARK_THEME_NAME, LIGHT_THEME_NAME from airunner.utils.create_worker import create_worker @@ -31,7 +29,6 @@ def is_dark(self): return self.application_settings.dark_mode_enabled def __init__(self, *args, **kwargs): - self.logger = Logger(prefix=self.__class__.__name__) MediatorMixin.__init__(self) SettingsMixin.__init__(self) super().__init__(*args, **kwargs) diff --git a/src/airunner/widgets/canvas/brush_scene.py b/src/airunner/widgets/canvas/brush_scene.py index e97109c5b..82b148a0f 100644 --- a/src/airunner/widgets/canvas/brush_scene.py +++ b/src/airunner/widgets/canvas/brush_scene.py @@ -198,8 +198,8 @@ def mousePressEvent(self, event): return super().mousePressEvent(event) def _handle_left_mouse_release(self, event) -> bool: - session = self.db_handler.get_db_session() - drawing_pad_settings = session.query(DrawingPadSettings).first() + + drawing_pad_settings = self.session.query(DrawingPadSettings).first() if self.drawing_pad_settings.mask_layer_enabled: mask_image: Image = ImageQt.fromqimage(self.mask_image) # Ensure mask is fully opaque @@ -215,8 +215,8 @@ def _handle_left_mouse_release(self, event) -> bool: self.current_tool is CanvasToolName.ERASER )): self.emit_signal(SignalCode.GENERATE_MASK) - session.commit() - session.close() + self.session.commit() + self.emit_signal(SignalCode.CANVAS_IMAGE_UPDATED_SIGNAL) if self.drawing_pad_settings.mask_layer_enabled: self.initialize_image() diff --git a/src/airunner/widgets/canvas/custom_scene.py b/src/airunner/widgets/canvas/custom_scene.py index f0cd95474..16b1443ac 100644 --- a/src/airunner/widgets/canvas/custom_scene.py +++ b/src/airunner/widgets/canvas/custom_scene.py @@ -10,7 +10,6 @@ from PySide6.QtGui import QPixmap, QPainter from PySide6.QtWidgets import QGraphicsScene, QGraphicsPixmapItem, QFileDialog, QGraphicsSceneMouseEvent -from airunner.handlers.logger import Logger from airunner.enums import SignalCode, CanvasToolName, GeneratorSection, EngineResponseCode from airunner.mediator_mixin import MediatorMixin from airunner.settings import VALID_IMAGE_FILES @@ -30,7 +29,6 @@ class CustomScene( ): def __init__(self, canvas_type: str): self.canvas_type = canvas_type - self.logger = Logger(prefix=self.__class__.__name__) MediatorMixin.__init__(self) SettingsMixin.__init__(self) self.image_backup = None diff --git a/src/airunner/widgets/canvas/custom_view.py b/src/airunner/widgets/canvas/custom_view.py index 46345708c..95d9521f9 100644 --- a/src/airunner/widgets/canvas/custom_view.py +++ b/src/airunner/widgets/canvas/custom_view.py @@ -9,7 +9,6 @@ from airunner.enums import CanvasToolName, SignalCode, CanvasType from airunner.mediator_mixin import MediatorMixin from airunner.utils.convert_image_to_base64 import convert_image_to_base64 -from airunner.utils.create_worker import create_worker from airunner.utils.snap_to_grid import snap_to_grid from airunner.widgets.canvas.brush_scene import BrushScene from airunner.widgets.canvas.custom_scene import CustomScene diff --git a/src/airunner/widgets/controlnet/controlnet_settings_widget.py b/src/airunner/widgets/controlnet/controlnet_settings_widget.py index 9f7ac2fe2..027835751 100644 --- a/src/airunner/widgets/controlnet/controlnet_settings_widget.py +++ b/src/airunner/widgets/controlnet/controlnet_settings_widget.py @@ -23,11 +23,10 @@ def _load_controlnet_models(self): if self._version is None or self._version != self.generator_settings.version: self._version = self.generator_settings.version current_index = 0 - session = self.db_handler.get_db_session() - controlnet_models = session.query(ControlnetModel).filter_by( + + controlnet_models = self.session.query(ControlnetModel).filter_by( version=self.generator_settings.version ).all() - session.close() self.ui.controlnet.blockSignals(True) self.ui.controlnet.clear() for index, item in enumerate(controlnet_models): diff --git a/src/airunner/widgets/embeddings/embedding_widget.py b/src/airunner/widgets/embeddings/embedding_widget.py index 8ba4a4ec9..ce660a707 100644 --- a/src/airunner/widgets/embeddings/embedding_widget.py +++ b/src/airunner/widgets/embeddings/embedding_widget.py @@ -40,7 +40,7 @@ def action_clicked_button_deleted(self): ) def update_embedding(self, embedding: Embedding): - self.db_handler.save_object(embedding) + self.save_object(embedding) @Slot(bool) def action_toggled_embedding(self, val, emit_signal=True): diff --git a/src/airunner/widgets/embeddings/embeddings_container_widget.py b/src/airunner/widgets/embeddings/embeddings_container_widget.py index eca6789ce..38bce8795 100644 --- a/src/airunner/widgets/embeddings/embeddings_container_widget.py +++ b/src/airunner/widgets/embeddings/embeddings_container_widget.py @@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs): self.register(SignalCode.EMBEDDING_UPDATED_SIGNAL, self.on_embedding_updated_signal) self.register(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal) self.register(SignalCode.EMBEDDING_STATUS_CHANGED, self.on_embedding_modified) - self.register(SignalCode.EMBEDDING_DELETE_SIGNAL, self.delete_embedding) + self.register(SignalCode.EMBEDDING_DELETE_SIGNAL, self._delete_embedding) self.ui.loading_icon.hide() self.ui.loading_icon.set_size(spinner_size=QSize(30, 30), label_size=QSize(24, 24)) self._apply_button_enabled = False @@ -92,7 +92,7 @@ def on_application_settings_changed_signal(self): def on_embedding_updated_signal(self): self._enable_form() - def delete_embedding(self, data): + def _delete_embedding(self, data): self._deleting = True embedding_widget = data["embedding_widget"] @@ -113,10 +113,10 @@ def delete_embedding(self, data): break # Remove lora from database - session = self.db_handler.get_db_session() - session.delete(embedding_widget.embedding) - session.commit() - session.close() + + self.session.delete(embedding_widget.embedding) + self.session.commit() + self._apply_button_enabled = True self.ui.apply_embeddings_button.setEnabled(self._apply_button_enabled) @@ -143,7 +143,7 @@ def load_embeddings(self, force_reload:bool=False): if self.search_filter.lower() in embedding.name.lower() ] for embedding in filtered_embeddings: - self.add_embedding(embedding) + self._add_embedding(embedding) self.add_spacer() def remove_spacer(self): @@ -160,7 +160,7 @@ def add_spacer(self): self.remove_spacer() self.ui.scrollAreaWidgetContents.layout().addWidget(self.spacer) - def add_embedding(self, embedding): + def _add_embedding(self, embedding): if embedding is None: return embedding_widget = EmbeddingWidget(embedding=embedding) diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index 0c78df71f..1357a2d0a 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -483,13 +483,13 @@ def stop_progress_bar(self, do_clear=False): progressbar.setFormat("Complete") def _set_keyboard_shortcuts(self): - session = self.db_handler.get_db_session() - generate_image_key = session.query(ShortcutKeys).filter_by(display_name="Generate Image").first() - interrupt_key = session.query(ShortcutKeys).filter_by(display_name="Interrupt").first() + + generate_image_key = self.session.query(ShortcutKeys).filter_by(display_name="Generate Image").first() + interrupt_key = self.session.query(ShortcutKeys).filter_by(display_name="Interrupt").first() if generate_image_key: self.ui.generate_button.setShortcut(generate_image_key.key) self.ui.generate_button.setToolTip(f"{generate_image_key.display_name} ({generate_image_key.text})") if interrupt_key: self.ui.interrupt_button.setShortcut(interrupt_key.key) self.ui.interrupt_button.setToolTip(f"{interrupt_key.display_name} ({interrupt_key.text})") - session.close() + diff --git a/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py b/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py index e3d25ac03..4c65c945c 100644 --- a/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py +++ b/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py @@ -69,12 +69,12 @@ def get_shortcut(self, shortcut_key: ShortcutKeys, line_edit, event, index): shortcut_key.key = event.key() shortcut_key.modifiers = event.modifiers().value - session = self.db_handler.get_db_session() - session.add(shortcut_key) + + self.session.add(shortcut_key) # clear existing key if it exists - existing_keys = session.query(ShortcutKeys).filter( + existing_keys = self.session.query(ShortcutKeys).filter( ShortcutKeys.text == shortcut_key.text, ShortcutKeys.id != shortcut_key.id ).all() @@ -82,7 +82,7 @@ def get_shortcut(self, shortcut_key: ShortcutKeys, line_edit, event, index): existing_key.text = "" existing_key.key = 0 existing_key.modifiers = 0 - session.add(existing_key) + self.session.add(existing_key) for i, widget in enumerate(self.shortcut_key_widgets): if i == index: @@ -91,8 +91,8 @@ def get_shortcut(self, shortcut_key: ShortcutKeys, line_edit, event, index): widget.line_edit.setText("") line_edit.setText(shortcut_key.text) - session.commit() - session.close() + self.session.commit() + self.pressed_keys.clear() self.emit_signal(SignalCode.KEYBOARD_SHORTCUTS_UPDATED) @@ -102,18 +102,18 @@ def get_key_text(self, event): return key_sequence.toString(QtGui.QKeySequence.SequenceFormat.NativeText) def save_shortcuts(self): - session = self.db_handler.get_db_session() + for k, v in enumerate(self.shortcut_keys): # Ensure v.modifiers is a list if not isinstance(v.modifiers, list): v.modifiers = [] - session.query(ShortcutKeys).filter(ShortcutKeys.id == v.id).update({ + self.session.query(ShortcutKeys).filter(ShortcutKeys.id == v.id).update({ "text": v.text, "key": v.key, "modifiers": ",".join(v.modifiers) # Convert list to comma-separated string }) - session.commit() - session.close() + self.session.commit() + def clear_shortcut_setting(self, key=""): for index, v in enumerate(self.shortcut_keys): diff --git a/src/airunner/widgets/llm/bot_preferences.py b/src/airunner/widgets/llm/bot_preferences.py index 523616332..2742ee960 100644 --- a/src/airunner/widgets/llm/bot_preferences.py +++ b/src/airunner/widgets/llm/bot_preferences.py @@ -94,10 +94,10 @@ def create_new_chatbot_clicked(self): self.load_saved_chatbots() def saved_chatbots_changed(self, val): - session = self.db_handler.get_db_session() - chatbot = session.query(Chatbot).filter(Chatbot.name == val).first() + + chatbot = self.session.query(Chatbot).filter(Chatbot.name == val).first() chatbot_id = chatbot.id - session.close() + self.update_llm_generator_settings("current_chatbot", chatbot_id) self.load_form_elements() self.emit_signal(SignalCode.CHATBOT_CHANGED) @@ -176,10 +176,10 @@ def load_documents(self): layout.addWidget(widget) def delete_document(self, target_file:TargetFiles): - session = self.db_handler.get_db_session() - session.delete(target_file) - session.commit() - session.close() + + self.session.delete(target_file) + self.session.commit() + self.load_documents() self.emit_signal(SignalCode.RAG_RELOAD_INDEX_SIGNAL) @@ -190,4 +190,4 @@ def update_chatbot(self, key, val): except TypeError: self.logger.error(f"Attribute {key} does not exist in Chatbot") return - self.db_handler.save_object(chatbot) + self.save_object(chatbot) diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py index 2503cd5f2..371551d5d 100644 --- a/src/airunner/widgets/llm/chat_prompt_widget.py +++ b/src/airunner/widgets/llm/chat_prompt_widget.py @@ -1,4 +1,4 @@ -from PySide6.QtCore import Slot, QTimer +from PySide6.QtCore import Slot, QTimer, QPropertyAnimation from PySide6.QtWidgets import QSpacerItem, QSizePolicy from PySide6.QtCore import Qt @@ -14,6 +14,8 @@ class ChatPromptWidget(BaseWidget): def __init__(self, *args, **kwargs): super().__init__() + self.scroll_bar = None + self._loading_widget = None self.conversation = None self.is_modal = True self.generating = False @@ -29,11 +31,13 @@ def __init__(self, *args, **kwargs): self.messages_spacer = None self.chat_loaded = False self.conversation_id = None + self._has_loading_widget = False self.ui.action.blockSignals(True) self.ui.action.addItem("Auto") self.ui.action.addItem("Chat") self.ui.action.addItem("Image") + self.ui.action.addItem("RAG") action = LLMActionType[self.action] if action is LLMActionType.APPLICATION_COMMAND: self.ui.action.setCurrentIndex(0) @@ -41,6 +45,8 @@ def __init__(self, *args, **kwargs): self.ui.action.setCurrentIndex(1) elif action is LLMActionType.GENERATE_IMAGE: self.ui.action.setCurrentIndex(2) + elif action is LLMActionType.PERFORM_RAG_SEARCH: + self.ui.action.setCurrentIndex(3) self.ui.action.blockSignals(False) self.originalKeyPressEvent = None self.originalKeyPressEvent = self.ui.prompt.keyPressEvent @@ -51,6 +57,7 @@ def __init__(self, *args, **kwargs): self.register(SignalCode.CONVERSATION_DELETED, self.on_conversation_deleted) self.held_message = None self._disabled = False + self.scroll_animation = None @Slot(str) def handle_token_signal(self, val: str): @@ -95,6 +102,7 @@ def _set_conversation_widgets(self, messages): first_message=True, use_loading_widget=False ) + self.scroll_to_bottom() def on_hear_signal(self, data: dict): transcription = data["transcription"] @@ -122,7 +130,8 @@ def on_add_bot_message_to_conversation(self, data: dict): name=name, message=message, is_bot=True, - first_message=is_first_message + first_message=is_first_message, + action=data["action"] ) if is_end_of_message: @@ -143,9 +152,10 @@ def _clear_conversation(self): self.conversation_history = [] self._clear_conversation_widgets() self._create_conversation() + self.remove_loading_widget() def _create_conversation(self): - conversation_id = self.db_handler.create_conversation() + conversation_id = self.create_conversation() self.emit_signal(SignalCode.LLM_CLEAR_HISTORY_SIGNAL, { "conversation_id": conversation_id }) @@ -163,6 +173,7 @@ def interrupt_button_clicked(self): self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL) self.stop_progress_bar() self.generating = False + self.remove_loading_widget() self.enable_send_button() @property @@ -203,6 +214,7 @@ def do_generate(self, image_override=None, prompt_override=None, callback=None, } } ) + self.scroll_to_bottom() def on_token_signal(self, val): self.handle_token_signal(val) @@ -243,6 +255,8 @@ def llm_action_changed(self, val: str): llm_action_value = LLMActionType.CHAT elif val == "Image": llm_action_value = LLMActionType.GENERATE_IMAGE + elif val == "RAG": + llm_action_value = LLMActionType.PERFORM_RAG_SEARCH else: llm_action_value = LLMActionType.APPLICATION_COMMAND self.update_llm_generator_settings("action", llm_action_value.name) @@ -302,18 +316,21 @@ def describe_image(self, image, callback): ) def add_loading_widget(self): - self.ui.scrollAreaWidgetContents.layout().addWidget( - LoadingWidget() - ) + if not self._has_loading_widget: + self._has_loading_widget = True + if self._loading_widget is None: + self._loading_widget = LoadingWidget() + self.ui.scrollAreaWidgetContents.layout().addWidget( + self._loading_widget + ) def remove_loading_widget(self): - # remove the last LoadingWidget from scrollAreaWidgetContents.layout() - for i in range(self.ui.scrollAreaWidgetContents.layout().count()): - current_widget = self.ui.scrollAreaWidgetContents.layout().itemAt(i).widget() - if isinstance(current_widget, LoadingWidget): - self.ui.scrollAreaWidgetContents.layout().removeWidget(current_widget) - current_widget.deleteLater() - break + if self._has_loading_widget: + try: + self.ui.scrollAreaWidgetContents.layout().removeWidget(self._loading_widget) + except RuntimeError: + pass + self._has_loading_widget = False def add_message_to_conversation( self, @@ -321,7 +338,8 @@ def add_message_to_conversation( message, is_bot, first_message=True, - use_loading_widget=True + use_loading_widget=True, + action:LLMActionType=LLMActionType.CHAT ): if not first_message: # get the last widget from the scrollAreaWidgetContents.layout() @@ -370,6 +388,17 @@ def action_button_clicked_generate_characters(self): pass def scroll_to_bottom(self): - self.ui.chat_container.verticalScrollBar().setValue( - self.ui.chat_container.verticalScrollBar().maximum() - ) + if self.scroll_bar is None: + self.scroll_bar = self.ui.chat_container.verticalScrollBar() + + if self.scroll_animation is None: + self.scroll_animation = QPropertyAnimation(self.scroll_bar, b"value") + self.scroll_animation.setDuration(500) + + # Stop any ongoing animation + if self.scroll_animation and self.scroll_animation.state() == QPropertyAnimation.State.Running: + self.scroll_animation.stop() + + self.scroll_animation.setStartValue(self.scroll_bar.value()) + self.scroll_animation.setEndValue(self.scroll_bar.maximum()) + self.scroll_animation.start() diff --git a/src/airunner/widgets/llm/llm_history_widget.py b/src/airunner/widgets/llm/llm_history_widget.py index 1d4fece2b..e814b72c1 100644 --- a/src/airunner/widgets/llm/llm_history_widget.py +++ b/src/airunner/widgets/llm/llm_history_widget.py @@ -19,7 +19,7 @@ def showEvent(self, event): self.load_conversations() def load_conversations(self): - conversations = self.db_handler.get_all_conversations() + conversations = self.get_all_conversations() layout = self.ui.gridLayout_2 if layout is None: @@ -41,17 +41,16 @@ def load_conversations(self): layout = QVBoxLayout(self.ui.scrollAreaWidgetContents) self.ui.scrollAreaWidgetContents.setLayout(layout) - session = self.db_handler.get_db_session() for conversation in conversations: h_layout = QHBoxLayout() button = QPushButton(conversation.title) button.clicked.connect(lambda _, c=conversation: self.on_conversation_click(c)) # Extract chatbot_id from the first message of the conversation - first_message = session.query(Message).filter_by(conversation_id=conversation.id).first() + first_message = self.session.query(Message).filter_by(conversation_id=conversation.id).first() chatbot_name = "Unknown" if first_message and first_message.chatbot_id: - chatbot = self.db_handler.get_chatbot_by_id(first_message.chatbot_id) + chatbot = self.get_chatbot_by_id(first_message.chatbot_id) if chatbot: chatbot_name = chatbot.name @@ -65,27 +64,25 @@ def load_conversations(self): container_widget = QWidget() container_widget.setLayout(h_layout) layout.addWidget(container_widget) - session.close() # Add a vertical spacer at the end layout.addItem(self.spacer) - self.ui.conversations_scroll_area.setLayout(layout) + self.ui.scrollAreaWidgetContents.setLayout(layout) def on_conversation_click(self, conversation): - session = self.db_handler.get_db_session() - first_message = session.query(Message).filter_by(conversation_id=conversation.id).first() + first_message = self.session.query(Message).filter_by(conversation_id=conversation.id).first() chatbot_id = first_message.chatbot_id - session.query(LLMGeneratorSettings).update({"current_chatbot": chatbot_id}) - session.commit() - session.close() + self.session.query(LLMGeneratorSettings).update({"current_chatbot": chatbot_id}) + self.session.commit() self.emit_signal(SignalCode.LOAD_CONVERSATION, { - "conversation_id": conversation.id + "conversation_id": conversation.id, + "chatbot_id": chatbot_id }) def on_delete_conversation(self, layout, conversation): conversation_id = conversation.id - self.db_handler.delete_conversation(conversation_id) + self.delete_conversation(conversation_id) for i in reversed(range(layout.count())): widget = layout.itemAt(i).widget() if widget: diff --git a/src/airunner/widgets/llm/llm_settings_widget.py b/src/airunner/widgets/llm/llm_settings_widget.py index f437eeb1b..7949ab7a3 100644 --- a/src/airunner/widgets/llm/llm_settings_widget.py +++ b/src/airunner/widgets/llm/llm_settings_widget.py @@ -184,4 +184,4 @@ def update_chatbot(self, key, val): except TypeError: self.logger.error(f"Attribute {key} does not exist in Chatbot") return - self.db_handler.save_object(chatbot) + self.save_object(chatbot) diff --git a/src/airunner/widgets/llm/templates/message.ui b/src/airunner/widgets/llm/templates/message.ui index e84caf44f..57deb697f 100644 --- a/src/airunner/widgets/llm/templates/message.ui +++ b/src/airunner/widgets/llm/templates/message.ui @@ -6,8 +6,8 @@ 0 0 - 418 - 902 + 464 + 433 @@ -26,15 +26,36 @@ Form + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + 10 + + + 10 + TextLabel - Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft + Qt::AlignmentFlag::AlignBottom|Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft @@ -56,16 +77,16 @@ border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff; - QFrame::NoFrame + QFrame::Shape::NoFrame - QFrame::Plain + QFrame::Shadow::Plain - Qt::ScrollBarAlwaysOff + Qt::ScrollBarPolicy::ScrollBarAlwaysOff - Qt::ScrollBarAlwaysOff + Qt::ScrollBarPolicy::ScrollBarAlwaysOff @@ -75,7 +96,7 @@ TextLabel - Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft + Qt::AlignmentFlag::AlignBottom|Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft diff --git a/src/airunner/widgets/llm/templates/message_ui.py b/src/airunner/widgets/llm/templates/message_ui.py index 9012162b2..25657dc5b 100644 --- a/src/airunner/widgets/llm/templates/message_ui.py +++ b/src/airunner/widgets/llm/templates/message_ui.py @@ -22,7 +22,7 @@ class Ui_message(object): def setupUi(self, message): if not message.objectName(): message.setObjectName(u"message") - message.resize(418, 902) + message.resize(464, 433) sizePolicy = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.MinimumExpanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -30,12 +30,16 @@ def setupUi(self, message): message.setSizePolicy(sizePolicy) message.setMinimumSize(QSize(0, 40)) self.gridLayout = QGridLayout(message) + self.gridLayout.setSpacing(0) self.gridLayout.setObjectName(u"gridLayout") + self.gridLayout.setContentsMargins(0, 0, 0, 0) self.horizontalLayout = QHBoxLayout() + self.horizontalLayout.setSpacing(10) self.horizontalLayout.setObjectName(u"horizontalLayout") + self.horizontalLayout.setContentsMargins(-1, -1, -1, 10) self.user_name = QLabel(message) self.user_name.setObjectName(u"user_name") - self.user_name.setAlignment(Qt.AlignBottom|Qt.AlignLeading|Qt.AlignLeft) + self.user_name.setAlignment(Qt.AlignmentFlag.AlignBottom|Qt.AlignmentFlag.AlignLeading|Qt.AlignmentFlag.AlignLeft) self.horizontalLayout.addWidget(self.user_name) @@ -48,16 +52,16 @@ def setupUi(self, message): self.content.setSizePolicy(sizePolicy1) self.content.setMinimumSize(QSize(0, 40)) self.content.setStyleSheet(u"border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff;") - self.content.setFrameShape(QFrame.NoFrame) - self.content.setFrameShadow(QFrame.Plain) - self.content.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.content.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.content.setFrameShape(QFrame.Shape.NoFrame) + self.content.setFrameShadow(QFrame.Shadow.Plain) + self.content.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.content.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) self.horizontalLayout.addWidget(self.content) self.bot_name = QLabel(message) self.bot_name.setObjectName(u"bot_name") - self.bot_name.setAlignment(Qt.AlignBottom|Qt.AlignLeading|Qt.AlignLeft) + self.bot_name.setAlignment(Qt.AlignmentFlag.AlignBottom|Qt.AlignmentFlag.AlignLeading|Qt.AlignmentFlag.AlignLeft) self.horizontalLayout.addWidget(self.bot_name) diff --git a/src/airunner/widgets/lora/lora_container_widget.py b/src/airunner/widgets/lora/lora_container_widget.py index 2e6625880..7cc2fc328 100644 --- a/src/airunner/widgets/lora/lora_container_widget.py +++ b/src/airunner/widgets/lora/lora_container_widget.py @@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs): self.register(SignalCode.LORA_UPDATED_SIGNAL, self.on_lora_updated_signal) self.register(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal) self.register(SignalCode.LORA_STATUS_CHANGED, self.on_lora_modified) - self.register(SignalCode.LORA_DELETE_SIGNAL, self.delete_lora) + self.register(SignalCode.LORA_DELETE_SIGNAL, self._delete_lora) self.ui.loading_icon.hide() self.ui.loading_icon.set_size(spinner_size=QSize(30, 30), label_size=QSize(24, 24)) self._apply_button_enabled = False @@ -44,13 +44,13 @@ def _scan_path_for_lora(self, path) -> bool: @Slot(bool) def on_scan_completed(self, force_reload: bool): - self.load_lora(force_reload=force_reload) + self._load_lora(force_reload=force_reload) @Slot() def scan_for_lora(self): # clear all lora widgets force_reload = scan_path_for_lora(self.path_settings.base_path) - self.load_lora(force_reload=force_reload) + self._load_lora(force_reload=force_reload) @Slot() def apply_lora(self): @@ -94,7 +94,7 @@ def _toggle_lora_widgets(self, enable: bool): lora_widget.disable_lora_widget() def on_application_settings_changed_signal(self): - self.load_lora() + self._load_lora() def on_lora_updated_signal(self): self._enable_form() @@ -114,9 +114,9 @@ def showEvent(self, event): if not self.initialized: self.scan_for_lora() self.initialized = True - self.load_lora() + self._load_lora() - def load_lora(self, force_reload=False): + def _load_lora(self, force_reload=False): version = self.generator_settings.version if self._version is None or self._version != version or force_reload: @@ -130,7 +130,7 @@ def load_lora(self, force_reload=False): if self.search_filter.lower() in lora.name.lower() ] for lora in filtered_loras: - self.add_lora(lora) + self._add_lora(lora) self.add_spacer() def remove_spacer(self): @@ -147,13 +147,13 @@ def add_spacer(self): self.remove_spacer() self.ui.scrollAreaWidgetContents.layout().addWidget(self.spacer) - def add_lora(self, lora): + def _add_lora(self, lora): if lora is None: return lora_widget = LoraWidget(lora=lora) self.ui.scrollAreaWidgetContents.layout().addWidget(lora_widget) - def delete_lora(self, data: dict): + def _delete_lora(self, data: dict): self._deleting = True lora_widget = data["lora_widget"] @@ -174,14 +174,14 @@ def delete_lora(self, data: dict): break # Remove lora from database - session = self.db_handler.get_db_session() - session.delete(lora_widget.current_lora) - session.commit() - session.close() + + self.session.delete(lora_widget.current_lora) + self.session.commit() + self._apply_button_enabled = True self.ui.apply_lora_button.setEnabled(self._apply_button_enabled) - self.load_lora(force_reload=True) + self._load_lora(force_reload=True) self._deleting = False def available_lora(self, action): @@ -235,7 +235,7 @@ def handle_lora_spinbox(self, lora, lora_widget, value, tab_name): def search_text_changed(self, val): self.search_filter = val - self.load_lora(force_reload=True) + self._load_lora(force_reload=True) def clear_lora_widgets(self): if self.spacer: diff --git a/src/airunner/widgets/model_manager/import_widget.py b/src/airunner/widgets/model_manager/import_widget.py index 3fc52dd41..5a1e110c0 100644 --- a/src/airunner/widgets/model_manager/import_widget.py +++ b/src/airunner/widgets/model_manager/import_widget.py @@ -123,7 +123,7 @@ def download_model(self): self.create_lora(new_lora) elif model_type == "TextualInversion": # name = file_path.split("/")[-1].split(".")[0] - # embedding_exists = session.query(Embedding).filter_by( + # embedding_exists = self.session.query(Embedding).filter_by( # name=name, # path=file_path, # ).first() @@ -134,7 +134,7 @@ def download_model(self): # active=True, # tags=trained_words, # ) - # session.add(new_embedding) + # self.session.add(new_embedding) # TODO: handle textual inversion pass elif model_type == "VAE": diff --git a/src/airunner/widgets/slider/slider_widget.py b/src/airunner/widgets/slider/slider_widget.py index c6e207da5..166ab927b 100644 --- a/src/airunner/widgets/slider/slider_widget.py +++ b/src/airunner/widgets/slider/slider_widget.py @@ -139,11 +139,11 @@ def init(self, **kwargs): divide_by = self.property("divide_by") or 1.0 if self.table_id is not None and self.table_name is not None and self.table_column is not None: - session = self.db_handler.get_db_session() + if self.table_name == "lora": - self.table_item = session.query(Lora).filter_by(id=self.table_id).first() + self.table_item = self.session.query(Lora).filter_by(id=self.table_id).first() current_value = getattr(self.table_item, self.table_column) - session.close() + elif current_value is None: if settings_property is not None: current_value = self.get_settings_value(settings_property) @@ -215,11 +215,11 @@ def get_settings_value(self, settings_property): def set_settings_value(self, settings_property: str, val: Any): if self.table_item is not None: - session = self.db_handler.get_db_session() + setattr(self.table_item, self.table_column, val) - session.add(self.table_item) - session.commit() - session.close() + self.session.add(self.table_item) + self.session.commit() + elif settings_property is not None: keys = settings_property.split(".") self.update_settings_by_name(keys[0], keys[1], val) diff --git a/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py b/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py index 300d19f8c..60ea48a50 100644 --- a/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py +++ b/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py @@ -68,15 +68,15 @@ def handle_pipeline_changed(self, val): val = GeneratorSection.TXT2IMG.value elif val == f"{GeneratorSection.INPAINT.value} / {GeneratorSection.OUTPAINT.value}": val = GeneratorSection.INPAINT.value - session = self.db_handler.get_db_session() - generator_settings = session.query(GeneratorSettings).first() + + generator_settings = self.session.query(GeneratorSettings).first() do_reload = False if val == GeneratorSection.TXT2IMG.value: - model = session.query(AIModels).filter( + model = self.session.query(AIModels).filter( AIModels.id == generator_settings.model ).first() if model.pipeline_action == GeneratorSection.INPAINT.value: - model = session.query(AIModels).filter( + model = self.session.query(AIModels).filter( AIModels.version == generator_settings.version, AIModels.pipeline_action == val, AIModels.enabled == True, @@ -88,8 +88,8 @@ def handle_pipeline_changed(self, val): generator_settings.model = None do_reload = True generator_settings.pipeline_action = val - session.commit() - session.close() + self.session.commit() + self.load_versions() self.load_models() if do_reload: @@ -100,9 +100,9 @@ def handle_pipeline_changed(self, val): def handle_version_changed(self, val): self.update_generator_settings("version", val) - session = self.db_handler.get_db_session() - generator_settings = session.query(GeneratorSettings).first() - model = session.query(AIModels).filter( + + generator_settings = self.session.query(GeneratorSettings).first() + model = self.session.query(AIModels).filter( AIModels.version == val, AIModels.pipeline_action == generator_settings.pipeline_action, AIModels.enabled == True, @@ -110,15 +110,15 @@ def handle_version_changed(self, val): ).first() generator_settings.version = val generator_settings.model = model.id - session.commit() - session.close() + self.session.commit() + self.load_models() if self.application_settings.sd_enabled: self.emit_signal(SignalCode.SD_LOAD_SIGNAL, { "do_reload": True }) - def load_pipelines(self): + def _load_pipelines(self): self.ui.pipeline.blockSignals(True) self.ui.pipeline.clear() pipeline_names = [ @@ -148,7 +148,7 @@ def load_versions(self): def on_models_changed_signal(self): try: - self.load_pipelines() + self._load_pipelines() self.load_versions() self.load_models() self.load_schedulers() @@ -162,14 +162,14 @@ def load_models(self): self.ui.model.blockSignals(True) self.clear_models() image_generator = ImageGenerator.STABLEDIFFUSION.value - session = self.db_handler.get_db_session() - generator_settings = session.query(GeneratorSettings).first() + + generator_settings = self.session.query(GeneratorSettings).first() pipeline = generator_settings.pipeline_action version = generator_settings.version pipeline_actions = [GeneratorSection.TXT2IMG.value] if pipeline == GeneratorSection.INPAINT.value: pipeline_actions.append(GeneratorSection.INPAINT.value) - models = session.query(AIModels).filter( + models = self.session.query(AIModels).filter( AIModels.category == image_generator, AIModels.pipeline_action.in_(pipeline_actions), AIModels.version == version, @@ -180,10 +180,10 @@ def load_models(self): if model_id is None and len(models) > 0: current_model = models[0] generator_settings.model = current_model.id - session.commit() + self.session.commit() for model in models: self.ui.model.addItem(model.name, model.id) - session.close() + if model_id: index = self.ui.model.findData(model_id) if index != -1: diff --git a/src/airunner/windows/filter_window.py b/src/airunner/windows/filter_window.py index cca995734..ec0bce6c9 100644 --- a/src/airunner/windows/filter_window.py +++ b/src/airunner/windows/filter_window.py @@ -24,7 +24,6 @@ def __init__(self, image_filter_id): """ super().__init__(exec=False) - self.session = self.db_handler.get_db_session() self.image_filter = self.session.query(ImageFilter).options(joinedload(ImageFilter.image_filter_values)).get( image_filter_id) self.image_filter_model_name = self.image_filter.name diff --git a/src/airunner/windows/installer/installer_window.py b/src/airunner/windows/installer/installer_window.py deleted file mode 100644 index 910f508af..000000000 --- a/src/airunner/windows/installer/installer_window.py +++ /dev/null @@ -1,53 +0,0 @@ -from PySide6.QtWidgets import QWizard -from airunner.mediator_mixin import MediatorMixin -from airunner.windows.installer.completion_page import CompletionPage -from airunner.windows.installer.confirmation_page import ConfirmationPage -from airunner.windows.installer.download_page import DownloadPage -from airunner.windows.main.settings_mixin import SettingsMixin - - -class InstallerWindow( - QWizard, - MediatorMixin, - SettingsMixin -): - def __init__(self, *args): - MediatorMixin.__init__(self) - SettingsMixin.__init__(self) - super(InstallerWindow, self).__init__(*args) - - self.page_ids = {} - - self.do_download_sd_models = True - self.do_download_controlnet_models = True - self.do_download_llm = True - self.do_download_tts_models = True - self.do_download_stt_models = True - - self.pages = { - "confirmation_page": ConfirmationPage(self), - "download_page": DownloadPage(self), - "completion_page": CompletionPage(self), - } - - for page_name, page in self.pages.items(): - self.addPage(page) - - self.setWindowTitle("AI Runner Setup Wizard") - - @property - def download_settings(self): - return { - "compile_with_pyinstaller": False, - "download_ai_runner": True, - "download_sd": self.do_download_sd_models, - "download_controlnet": self.do_download_controlnet_models, - "download_llm": self.do_download_llm, - "download_tts": self.do_download_tts_models, - "download_stt": self.do_download_stt_models, - } - - def addPage(self, page): - page_id = super().addPage(page) - self.page_ids[page] = page_id - return page_id \ No newline at end of file diff --git a/src/airunner/windows/main/ai_model_mixin.py b/src/airunner/windows/main/ai_model_mixin.py index 0e93baaeb..0d4ac1aeb 100644 --- a/src/airunner/windows/main/ai_model_mixin.py +++ b/src/airunner/windows/main/ai_model_mixin.py @@ -4,9 +4,6 @@ class AIModelMixin: - def ai_model_get_by_filter(self, filter_dict): - return [item for item in self.ai_models if all(item.get(k) == v for k, v in filter_dict.items())] - def on_ai_model_delete_signal(self, item: dict): self.ai_models = [existing_item for existing_item in self.ai_models if existing_item.name != item.name] self.update_settings("ai_models", self.ai_models) diff --git a/src/airunner/windows/main/embedding_mixin.py b/src/airunner/windows/main/embedding_mixin.py index 67eec3b11..f10b389eb 100644 --- a/src/airunner/windows/main/embedding_mixin.py +++ b/src/airunner/windows/main/embedding_mixin.py @@ -30,12 +30,12 @@ def delete_missing_embeddings(self): embeddings = self.get_embeddings() for embedding in embeddings: if not os.path.exists(embedding["path"]): - self.delete_embedding(embedding) + self._delete_embedding(embedding) - def delete_embedding(self, embedding): + def _delete_embedding(self, embedding): for index, _embedding in enumerate(self.embeddings): if _embedding.name == embedding.name and _embedding.path == embedding.path: - self.delete_embedding(embedding) + self._delete_embedding(embedding) return def scan_for_embeddings(self): diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index 33a7f3bbc..200b94d47 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -4,14 +4,14 @@ import urllib import webbrowser from functools import partial -from pathlib import Path import requests from PIL import Image from PySide6 import QtGui from PySide6.QtCore import ( Slot, - Signal, QProcess, QSettings + Signal, + QProcess ) from PySide6.QtGui import QGuiApplication, QKeySequence from PySide6.QtWidgets import ( @@ -24,7 +24,6 @@ from airunner.handlers.llm.agent.actions.bash_execute import bash_execute from airunner.handlers.llm.agent.actions.show_path import show_path -from airunner.handlers.logger import Logger from airunner.data.models.settings_models import ShortcutKeys, ImageFilter, DrawingPadSettings from airunner.app_installer import AppInstaller from airunner.settings import ( @@ -32,8 +31,6 @@ STATUS_NORMAL_COLOR_LIGHT, STATUS_NORMAL_COLOR_DARK, NSFW_CONTENT_DETECTED_MESSAGE, - ORGANIZATION, - APPLICATION_NAME, DARK_THEME_NAME, LIGHT_THEME_NAME ) from airunner.enums import ( SignalCode, @@ -151,10 +148,9 @@ def __init__( self.listening = False self.initialized = False self._model_status = {model_type: ModelStatus.UNLOADED for model_type in ModelType} - self.logger = Logger(prefix=self.__class__.__name__) - self.logger.debug("Starting AI Runnner") MediatorMixin.__init__(self) SettingsMixin.__init__(self) + self.logger.debug("Starting AI Runnner") super().__init__(*args, **kwargs) self._updating_settings = True PipelineMixin.__init__(self) @@ -951,12 +947,12 @@ def on_history_updated(self, data): self.ui.actionRedo.setEnabled(data["redo"] != 0) def _set_keyboard_shortcuts(self): - session = self.db_handler.get_db_session() - quit_key = session.query(ShortcutKeys).filter_by(display_name="Quit").first() - brush_key = session.query(ShortcutKeys).filter_by(display_name="Brush").first() - eraser_key = session.query(ShortcutKeys).filter_by(display_name="Eraser").first() - move_tool_key = session.query(ShortcutKeys).filter_by(display_name="Move Tool").first() - select_tool_key = session.query(ShortcutKeys).filter_by(display_name="Select Tool").first() + + quit_key = self.session.query(ShortcutKeys).filter_by(display_name="Quit").first() + brush_key = self.session.query(ShortcutKeys).filter_by(display_name="Brush").first() + eraser_key = self.session.query(ShortcutKeys).filter_by(display_name="Eraser").first() + move_tool_key = self.session.query(ShortcutKeys).filter_by(display_name="Move Tool").first() + select_tool_key = self.session.query(ShortcutKeys).filter_by(display_name="Select Tool").first() if quit_key is not None: key_sequence = QKeySequence(quit_key.key | quit_key.modifiers) @@ -983,7 +979,7 @@ def _set_keyboard_shortcuts(self): self.ui.actionToggle_Selection.setShortcut(key_sequence) self.ui.actionToggle_Selection.setToolTip(f"{select_tool_key.display_name} ({select_tool_key.text})") - session.close() + def _initialize_workers(self): self.logger.debug("Initializing worker manager") @@ -997,12 +993,12 @@ def _initialize_workers(self): def _initialize_filter_actions(self): # add more filters: - session = self.db_handler.get_db_session() - image_filters = session.query(ImageFilter).all() + + image_filters = self.session.query(ImageFilter).all() for image_filter in image_filters: action = self.ui.menuFilters.addAction(image_filter.display_name) action.triggered.connect(partial(self.display_filter_window, image_filter)) - session.close() + def display_filter_window(self, image_filter): FilterWindow(image_filter.id) @@ -1115,8 +1111,8 @@ def _generate_drawingpad_mask(self): height = self.active_grid_settings.height img = Image.new("RGB", (width, height), (0, 0, 0)) base64_image = convert_image_to_base64(img) - session = self.db_handler.get_db_session() - drawing_pad_settings = session.query(DrawingPadSettings).first() + + drawing_pad_settings = self.session.query(DrawingPadSettings).first() drawing_pad_settings.mask = base64_image - session.commit() - session.close() + self.session.commit() + diff --git a/src/airunner/windows/main/settings_mixin.py b/src/airunner/windows/main/settings_mixin.py index 3fd715cc6..46b121119 100644 --- a/src/airunner/windows/main/settings_mixin.py +++ b/src/airunner/windows/main/settings_mixin.py @@ -1,39 +1,69 @@ import logging -from typing import List - -from sqlalchemy.orm import joinedload - -from airunner.data.models.settings_db_handler import SettingsDBHandler -from airunner.data.models.settings_models import ApplicationSettings, LLMGeneratorSettings, GeneratorSettings, \ - ControlnetSettings, BrushSettings, DrawingPadSettings, GridSettings, ActiveGridSettings, \ - ImageToImageSettings, OutpaintSettings, PathSettings, MemorySettings, Chatbot, \ - AIModels, Schedulers, Lora, ShortcutKeys, SavedPrompt, SpeechT5Settings, TTSSettings, EspeakSettings, \ - MetadataSettings, Embedding, STTSettings, PromptTemplate, ControlnetModel, FontSetting, PipelineModel, TargetFiles, \ - ImageFilterValue, WhisperSettings +import datetime +import os +from typing import List, Type + +from sqlalchemy import create_engine +from sqlalchemy.orm import joinedload, sessionmaker + +from airunner.data.models.settings_models import Chatbot, AIModels, Schedulers, Lora, PathSettings, SavedPrompt, \ + Embedding, PromptTemplate, ControlnetModel, FontSetting, PipelineModel, ShortcutKeys, \ + GeneratorSettings, WindowSettings, ApplicationSettings, ActiveGridSettings, ControlnetSettings, \ + ImageToImageSettings, OutpaintSettings, DrawingPadSettings, MetadataSettings, \ + LLMGeneratorSettings, TTSSettings, SpeechT5Settings, EspeakSettings, STTSettings, BrushSettings, GridSettings, \ + MemorySettings, Message, Conversation, Summary, ImageFilterValue, TargetFiles, WhisperSettings, Base from airunner.enums import SignalCode +from airunner.handlers.logger import Logger +from airunner.settings import LOG_LEVEL from airunner.utils.convert_base64_to_image import convert_base64_to_image class SettingsMixin: def __init__(self): logging.debug("Initializing SettingsMixin instance") - self.db_handler = SettingsDBHandler() + self.db_path = os.path.expanduser( + os.path.join( + "~", + ".local", + "share", + "airunner", + "data", + "airunner.db" + ) + ) + self.engine = create_engine(f'sqlite:///{self.db_path}') + Base.metadata.create_all(self.engine) + self._session = None + self.Session = sessionmaker(bind=self.engine) + self.conversation_id = None + self.logger = Logger(prefix=self.__class__.__name__, log_level=LOG_LEVEL) + + @property + def session(self): + if self._session is None: + self._session = self.Session() + return self._session + + def close_session(self): + if self._session is not None: + self._session.close() + self._session = None @property def stt_settings(self) -> STTSettings: - return self.db_handler.load_settings_from_db(STTSettings) + return self.load_settings_from_db(STTSettings) @property def application_settings(self) -> ApplicationSettings: - return self.db_handler.load_settings_from_db(ApplicationSettings) + return self.load_settings_from_db(ApplicationSettings) @property def whisper_settings(self) -> WhisperSettings: - return self.db_handler.load_settings_from_db(WhisperSettings) + return self.load_settings_from_db(WhisperSettings) @property def llm_generator_settings(self) -> LLMGeneratorSettings: - settings = self.db_handler.load_settings_from_db(LLMGeneratorSettings) + settings = self.load_settings_from_db(LLMGeneratorSettings) if settings.current_chatbot == 0: chatbots = self.chatbots if len(chatbots) > 0: @@ -43,103 +73,103 @@ def llm_generator_settings(self) -> LLMGeneratorSettings: @property def generator_settings(self) -> GeneratorSettings: - return self.db_handler.load_settings_from_db(GeneratorSettings) + return self.load_settings_from_db(GeneratorSettings) @property def controlnet_settings(self) -> ControlnetSettings: - return self.db_handler.load_settings_from_db(ControlnetSettings) + return self.load_settings_from_db(ControlnetSettings) @property def image_to_image_settings(self) -> ImageToImageSettings: - return self.db_handler.load_settings_from_db(ImageToImageSettings) + return self.load_settings_from_db(ImageToImageSettings) @property def outpaint_settings(self) -> OutpaintSettings: - return self.db_handler.load_settings_from_db(OutpaintSettings) + return self.load_settings_from_db(OutpaintSettings) @property def drawing_pad_settings(self) -> DrawingPadSettings: - return self.db_handler.load_settings_from_db(DrawingPadSettings) + return self.load_settings_from_db(DrawingPadSettings) @property def brush_settings(self) -> BrushSettings: - return self.db_handler.load_settings_from_db(BrushSettings) + return self.load_settings_from_db(BrushSettings) @property def grid_settings(self) -> GridSettings: - return self.db_handler.load_settings_from_db(GridSettings) + return self.load_settings_from_db(GridSettings) @property def active_grid_settings(self) -> ActiveGridSettings: - return self.db_handler.load_settings_from_db(ActiveGridSettings) + return self.load_settings_from_db(ActiveGridSettings) @property def path_settings(self) -> PathSettings: - return self.db_handler.load_settings_from_db(PathSettings) + return self.load_settings_from_db(PathSettings) @property def memory_settings(self) -> MemorySettings: - return self.db_handler.load_settings_from_db(MemorySettings) + return self.load_settings_from_db(MemorySettings) @property - def chatbots(self) -> List[Chatbot]: - return self.db_handler.load_chatbots() + def chatbots(self) -> List[Type[Chatbot]]: + return self.load_chatbots() @property - def ai_models(self) -> List[AIModels]: - return self.db_handler.load_ai_models() + def ai_models(self) -> List[Type[AIModels]]: + return self.load_ai_models() @property - def schedulers(self) -> List[Schedulers]: - return self.db_handler.load_schedulers() + def schedulers(self) -> List[Type[Schedulers]]: + return self.load_schedulers() @property - def lora(self) -> List[Lora]: - return self.db_handler.load_lora() + def lora(self) -> List[Type[Lora]]: + return self.load_lora() @property - def shortcut_keys(self) -> List[ShortcutKeys]: - return self.db_handler.load_shortcut_keys() + def shortcut_keys(self) -> List[Type[ShortcutKeys]]: + return self.load_shortcut_keys() @property def speech_t5_settings(self) -> SpeechT5Settings: - return self.db_handler.load_settings_from_db(SpeechT5Settings) + return self.load_settings_from_db(SpeechT5Settings) @property def tts_settings(self) -> TTSSettings: - return self.db_handler.load_settings_from_db(TTSSettings) + return self.load_settings_from_db(TTSSettings) @property def espeak_settings(self) -> EspeakSettings: - return self.db_handler.load_settings_from_db(EspeakSettings) + return self.load_settings_from_db(EspeakSettings) @property def metadata_settings(self) -> MetadataSettings: - return self.db_handler.load_settings_from_db(MetadataSettings) + return self.load_settings_from_db(MetadataSettings) @property - def embeddings(self) -> List[Embedding]: - return self.db_handler.load_embeddings() + def embeddings(self) -> List[Type[Embedding]]: + return self.session.query(Embedding).all() @property - def prompt_templates(self) -> List[PromptTemplate]: - return self.db_handler.load_prompt_templates() + def prompt_templates(self) -> List[Type[PromptTemplate]]: + return self.load_prompt_templates() @property def controlnet_models(self): - return self.db_handler.load_controlnet_models() + return self.load_controlnet_models() @property - def saved_prompts(self) -> List[SavedPrompt]: - return self.db_handler.load_saved_prompts() + def saved_prompts(self) -> List[Type[SavedPrompt]]: + return self.load_saved_prompts() @property - def font_settings(self) -> List[FontSetting]: - return self.db_handler.load_font_settings() + def font_settings(self) -> List[Type[FontSetting]]: + return self.load_font_settings() @property - def pipelines(self) -> List[PipelineModel]: - return self.db_handler.load_pipelines() + def pipelines(self) -> List[Type[PipelineModel]]: + return self.load_pipelines() @property def drawing_pad_image(self): @@ -191,103 +221,31 @@ def outpaint_mask(self): @property def image_filter_values(self): - session = self.db_handler.get_db_session() - try: - return session.query(ImageFilterValue).all() - finally: - session.close() - - ####################################### - ### LORA ### - ####################################### + return self.session.query(ImageFilterValue).all() def get_lora_by_version(self, version): - session = self.db_handler.get_db_session() - try: - return session.query(Lora).filter_by(version=version).all() - finally: - session.close() + return self.session.query(Lora).filter_by(version=version).all() - def delete_lora_by_name(self, name, version): - self.db_handler.delete_lora_by_name(name, version) - - def create_lora(self, lora: Lora): - self.db_handler.create_lora(lora) - - def update_lora(self, lora: Lora): - self.db_handler.update_lora(lora) - self.__settings_updated() - - def update_loras(self, loras: List[Lora]): - self.db_handler.update_loras(loras) - self.__settings_updated() - - ####################################### - ### EMBEDDINGS ### - ####################################### def get_embeddings_by_version(self, version): return [embedding for embedding in self.embeddings if embedding.version == version] - def delete_embedding(self, embedding: Embedding): - self.db_handler.delete_embedding(embedding) - - def update_embeddings(self, embeddings: List[Embedding]): - self.db_handler.update_embeddings(embeddings) - self.__settings_updated() - - ####################################### - ### CHATBOT ### - ####################################### @property - def chatbot(self) -> Chatbot: + def chatbot(self) -> Type[Chatbot]: return self.get_chatbot_by_id( self.llm_generator_settings.current_chatbot ) - def get_chatbot_by_id(self, chatbot_id) -> Chatbot: - chatbot = None - session = self.db_handler.get_db_session() - try: - chatbot = session.query(Chatbot).filter_by(id=chatbot_id).options(joinedload(Chatbot.target_files)).first() - if chatbot is None: - chatbot = session.query(Chatbot).options(joinedload(Chatbot.target_files)).first() - finally: - session.close() - return chatbot - - def delete_chatbot_by_name(self, chatbot_name): - self.db_handler.delete_chatbot_by_name(chatbot_name) - - def create_chatbot(self, chatbot_name): - self.db_handler.create_chatbot(chatbot_name) + @property + def window_settings(self): + return self.load_window_settings() def add_chatbot_document_to_chatbot(self, chatbot, file_path): - session = self.db_handler.get_db_session() - try: - document = session.query(TargetFiles).filter_by(chatbot_id=chatbot.id, file_path=file_path).first() - if document is None: - document = TargetFiles(file_path=file_path, chatbot_id=chatbot.id) - session.merge(document) # Use merge instead of add - session.commit() - finally: - session.close() - - ####################################### - ### SAVED PROMPTS ### - ####################################### - def create_saved_prompt(self, data: dict): - self.db_handler.create_saved_prompt(data) + document = self.session.query(TargetFiles).filter_by(chatbot_id=chatbot.id, file_path=file_path).first() + if document is None: + document = TargetFiles(file_path=file_path, chatbot_id=chatbot.id) + self.session.merge(document) # Use merge instead of add + self.session.commit() - def update_saved_prompt(self, saved_prompt: SavedPrompt): - self.db_handler.update_saved_prompt(saved_prompt) - self.__settings_updated() - - def get_saved_prompt_by_id(self, prompt_id) -> SavedPrompt: - self.db_handler.get_saved_prompt_by_id(prompt_id) - - ####################################### - ### SETTINGS ### - ####################################### def update_settings_by_name(self, setting_name, column_name, val): if setting_name == "application_settings": self.update_application_settings(column_name, val) @@ -321,158 +279,450 @@ def update_settings_by_name(self, setting_name, column_name, val): logging.error(f"Invalid setting name: {setting_name}") def update_application_settings(self, column_name, val): - self.db_handler.update_setting(ApplicationSettings, column_name, val) + self.update_setting(ApplicationSettings, column_name, val) self.__settings_updated() - ####################################### - ### TTS Settings ### - ####################################### def update_espeak_settings(self, column_name, val): - self.db_handler.update_setting(EspeakSettings, column_name, val) + self.update_setting(EspeakSettings, column_name, val) self.__settings_updated() def update_tts_settings(self, column_name, val): - self.db_handler.update_setting(TTSSettings, column_name, val) + self.update_setting(TTSSettings, column_name, val) self.__settings_updated() def update_speech_t5_settings(self, column_name, val): - self.db_handler.update_setting(SpeechT5Settings, column_name, val) + self.update_setting(SpeechT5Settings, column_name, val) self.__settings_updated() - - ####################################### - ### CONTROLNET ### - ####################################### def update_controlnet_settings(self, column_name, val): - self.db_handler.update_setting(ControlnetSettings, column_name, val) + self.update_setting(ControlnetSettings, column_name, val) self.__settings_updated() def update_brush_settings(self, column_name, val): - self.db_handler.update_setting(BrushSettings, column_name, val) + self.update_setting(BrushSettings, column_name, val) self.__settings_updated() def update_image_to_image_settings(self, column_name, val): - self.db_handler.update_setting(ImageToImageSettings, column_name, val) + self.update_setting(ImageToImageSettings, column_name, val) self.__settings_updated() def update_outpaint_settings(self, column_name, val): - self.db_handler.update_setting(OutpaintSettings, column_name, val) + self.update_setting(OutpaintSettings, column_name, val) self.__settings_updated() def update_drawing_pad_settings(self, column_name, val): - self.db_handler.update_setting(DrawingPadSettings, column_name, val) + self.update_setting(DrawingPadSettings, column_name, val) self.__settings_updated() def update_grid_settings(self, column_name, val): - self.db_handler.update_setting(GridSettings, column_name, val) + self.update_setting(GridSettings, column_name, val) self.__settings_updated() def update_active_grid_settings(self, column_name, val): - self.db_handler.update_setting(ActiveGridSettings, column_name, val) + self.update_setting(ActiveGridSettings, column_name, val) self.__settings_updated() - ####################################### - ### PATH ### - ####################################### def update_path_settings(self, column_name, val): - self.db_handler.update_setting(PathSettings, column_name, val) + self.update_setting(PathSettings, column_name, val) self.__settings_updated() - def reset_path_settings(self): - self.db_handler.reset_path_settings() - def update_memory_settings(self, column_name, val): - self.db_handler.update_setting(MemorySettings, column_name, val) + self.update_setting(MemorySettings, column_name, val) self.__settings_updated() def update_metadata_settings(self, column_name, val): - self.db_handler.update_setting(MetadataSettings, column_name, val) + self.update_setting(MetadataSettings, column_name, val) self.__settings_updated() def update_llm_generator_settings(self, column_name, val): - self.db_handler.update_setting(LLMGeneratorSettings, column_name, val) + self.update_setting(LLMGeneratorSettings, column_name, val) self.__settings_updated() def update_whisper_settings(self, column_name, val): - self.db_handler.update_setting(WhisperSettings, column_name, val) + self.update_setting(WhisperSettings, column_name, val) self.__settings_updated() - def __settings_updated(self): - self.emit_signal(SignalCode.APPLICATION_SETTINGS_CHANGED_SIGNAL) + def update_ai_models(self, models: List[AIModels]): + for model in models: + self.update_ai_model(model) + self.__settings_updated() + + def update_ai_model(self, model: AIModels): + query = self.session.query(AIModels).filter_by( + name=model.name, + path=model.path, + branch=model.branch, + version=model.version, + category=model.category, + pipeline_action=model.pipeline_action, + enabled=model.enabled, + model_type=model.model_type, + is_default=model.is_default + ).first() + if query: + for key in model.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(model, key)) + else: + self.session.add(model) + self.session.commit() + self.__settings_updated() + + def update_generator_settings(self, column_name, val): + generator_settings = self.generator_settings + setattr(generator_settings, column_name, val) + self.save_generator_settings(generator_settings) + + def update_controlnet_image_settings(self, column_name, val): + controlnet_settings = self.controlnet_settings + setattr(controlnet_settings, column_name, val) + self.update_controlnet_settings(column_name, val) + + def load_schedulers(self) -> list[Type[Schedulers]]: + return self.session.query(Schedulers).all() + + def load_settings_from_db(self, model_class_): + settings = self.session.query(model_class_).first() + if settings is None: + settings = self.create_new_settings(model_class_) + return settings + + def update_setting(self, model_class_, name, value): + setting = self.session.query(model_class_).order_by(model_class_.id.desc()).first() + if setting: + setattr(setting, name, value) + self.session.commit() + + def save_generator_settings(self, generator_settings: GeneratorSettings): + query = self.session.query(GeneratorSettings).filter_by( + id=generator_settings.id + ).first() + if query: + for key in generator_settings.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(generator_settings, key)) + else: + self.session.add(generator_settings) + self.session.commit() + self.__settings_updated() def reset_settings(self): - self.db_handler.reset_settings() + # Delete all entries from the model class + self.session.query(ApplicationSettings).delete() + self.session.query(ActiveGridSettings).delete() + self.session.query(ControlnetSettings).delete() + self.session.query(ImageToImageSettings).delete() + self.session.query(OutpaintSettings).delete() + self.session.query(DrawingPadSettings).delete() + self.session.query(MetadataSettings).delete() + self.session.query(GeneratorSettings).delete() + self.session.query(LLMGeneratorSettings).delete() + self.session.query(TTSSettings).delete() + self.session.query(SpeechT5Settings).delete() + self.session.query(EspeakSettings).delete() + self.session.query(STTSettings).delete() + self.session.query(BrushSettings).delete() + self.session.query(GridSettings).delete() + self.session.query(PathSettings).delete() + self.session.query(MemorySettings).delete() + # Commit the changes + self.session.commit() + + def create_new_settings(self, model_class_): + new_settings = model_class_() + self.session.add(new_settings) + self.session.commit() + self.session.refresh(new_settings) + return new_settings + + def get_saved_prompt_by_id(self, prompt_id) -> Type[SavedPrompt]: + return self.session.query(SavedPrompt).filter_by(id=prompt_id).first() - ####################################### - ### AI MODELS ### - ####################################### - def update_ai_models(self, models: List[AIModels]): - self.db_handler.update_ai_models(models) + def update_saved_prompt(self, saved_prompt: SavedPrompt): + query = self.session.query(SavedPrompt).filter_by( + id=saved_prompt.id + ).first() + if query: + for key in saved_prompt.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(saved_prompt, key)) + else: + self.session.add(saved_prompt) + self.session.commit() self.__settings_updated() - def update_ai_model(self, model: AIModels): - self.db_handler.update_ai_model(model) + def create_saved_prompt(self, data: dict): + new_saved_prompt = SavedPrompt(**data) + self.session.add(new_saved_prompt) + self.session.commit() + + def load_saved_prompts(self) -> List[Type[SavedPrompt]]: + return self.session.query(SavedPrompt).all() + + def load_font_settings(self) -> List[Type[FontSetting]]: + return self.session.query(FontSetting).all() + + def get_font_setting_by_name(self, name) -> Type[FontSetting]: + return self.session.query(FontSetting).filter_by(name=name).first() + + def update_font_setting(self, font_setting: Type[FontSetting]): + query = self.session.query(FontSetting).filter_by( + name=font_setting.name + ).first() + if query: + for key in font_setting.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(font_setting, key)) + else: + self.session.add(font_setting) + self.session.commit() self.__settings_updated() - ####################################### - ### PROMPT TEMPLATES ### - ####################################### - def get_prompt_template_by_name(self, name) -> PromptTemplate: - return self.db_handler.get_prompt_template_by_name(name) + def load_ai_models(self) -> List[Type[AIModels]]: + return self.session.query(AIModels).all() - ####################################### - ### CONTROLNET MODELS ### - ####################################### - def controlnet_model_by_name(self, name) -> ControlnetModel: - return self.db_handler.controlnet_model_by_name(name) + def load_chatbots(self) -> List[Type[Chatbot]]: + settings = self.session.query(Chatbot).all() + return settings - def get_font_setting_by_name(self, name) -> FontSetting: - return self.db_handler.get_font_setting_by_name(name) + def delete_chatbot_by_name(self, chatbot_name): + self.session.query(Chatbot).filter_by(name=chatbot_name).delete() + self.session.commit() - def update_font_setting(self, font_setting: FontSetting): - self.db_handler.update_font_setting(font_setting) + def create_chatbot(self, chatbot_name): + new_chatbot = Chatbot(name=chatbot_name) + self.session.add(new_chatbot) + self.session.commit() + + def reset_path_settings(self): + self.session.query(PathSettings).delete() + self.set_default_values(PathSettings) + self.session.commit() + + def set_default_values(self, model_name_): + default_values = {} + for column in model_name_.__table__.columns: + if column.default is not None: + default_values[column.name] = column.default.arg + self.session.execute( + model_name_.__table__.insert(), + [default_values] + ) + self.session.commit() + + def load_lora(self) -> List[Type[Lora]]: + return self.session.query(Lora).all() + + def get_lora_by_name(self, name): + return self.session.query(Lora).filter_by(name=name).first() + + def add_lora(self, lora: Lora): + self.session.add(lora) + self.session.commit() + + def delete_lora(self, lora: Lora): + self.session.query(Lora).filter_by(name=lora.name).delete() + self.session.commit() + + def update_lora(self, lora: Lora): + query = self.session.query(Lora).filter_by(name=lora.name).first() + if query: + for key in lora.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(lora, key)) + else: + self.session.add(lora) + self.session.commit() self.__settings_updated() - def ai_model_get_by_filter(self, filter_dict: dict) -> List[AIModels]: - results = [] - models = self.db_handler.load_ai_models() - for item in models: - match = True - for k, v in filter_dict.items(): - if isinstance(item, dict): - if item.get(k, "") != v: - match = False - break - else: - if not hasattr(item, k) or getattr(item, k) != v: - match = False - break - if match: - results.append(item) - return results + def update_loras(self, loras: List[Lora]): + for lora in loras: + query = self.session.query(Lora).filter_by(name=lora.name).first() + if query: + for key in lora.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(lora, key)) + else: + self.session.add(lora) + self.session.commit() + self.__settings_updated() - ####################################### - ### GENERATOR SETTINGS ### - ####################################### - def update_generator_settings(self, column_name, val): - generator_settings = self.generator_settings - setattr(generator_settings, column_name, val) - self.save_generator_settings(generator_settings) + def create_lora(self, lora: Lora): + self.session.add(lora) + self.session.commit() + + def delete_lora_by_name(self, lora_name, version): + self.session.query(Lora).filter_by(name=lora_name, version=version).delete() + self.session.commit() - def save_generator_settings(self, generator_settings): - self.db_handler.save_generator_settings(generator_settings) + def delete_embedding(self, embedding: Embedding): + self.session.query(Embedding).filter_by( + name=embedding.name, + path=embedding.path, + branch=embedding.branch, + version=embedding.version, + category=embedding.category, + pipeline_action=embedding.pipeline_action, + enabled=embedding.enabled, + model_type=embedding.model_type, + is_default=embedding.is_default + ).delete() + self.session.commit() + + def update_embeddings(self, embeddings: List[Embedding]): + for embedding in embeddings: + query = self.session.query(Embedding).filter_by( + name=embedding.name, + path=embedding.path, + branch=embedding.branch, + version=embedding.version, + category=embedding.category, + pipeline_action=embedding.pipeline_action, + enabled=embedding.enabled, + model_type=embedding.model_type, + is_default=embedding.is_default + ).first() + if query: + for key in embedding.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(embedding, key)) + else: + self.session.add(embedding) + self.session.commit() self.__settings_updated() + def get_embedding_by_name(self, name): + return self.session.query(Embedding).filter_by(name=name).first() - ####################################### - ### WINDOW SETTINGS ### - ####################################### - @property - def window_settings(self): - return self.db_handler.load_window_settings() + def add_embedding(self, embedding: Embedding): + self.session.add(embedding) + self.session.commit() + + def load_prompt_templates(self) -> List[Type[PromptTemplate]]: + return self.session.query(PromptTemplate).all() + + def get_prompt_template_by_name(self, name) -> Type[PromptTemplate]: + return self.session.query(PromptTemplate).filter_by(template_name=name).first() + + def load_controlnet_models(self) -> List[Type[ControlnetModel]]: + return self.session.query(ControlnetModel).all() + + def controlnet_model_by_name(self, name) -> Type[ControlnetModel]: + return self.session.query(ControlnetModel).filter_by(name=name).first() + def load_pipelines(self) -> List[Type[PipelineModel]]: + return self.session.query(PipelineModel).all() + + def load_shortcut_keys(self) -> List[Type[ShortcutKeys]]: + return self.session.query(ShortcutKeys).all() + + def load_window_settings(self) -> Type[WindowSettings]: + return self.session.query(WindowSettings).first() def save_window_settings(self, column_name, val): window_settings = self.window_settings setattr(window_settings, column_name, val) - self.db_handler.save_window_settings(window_settings) + query = self.session.query(WindowSettings).first() + if query: + for key in window_settings.__dict__.keys(): + if key != "_sa_instance_state": + setattr(query, key, getattr(window_settings, key)) + else: + self.session.add(window_settings) + self.session.commit() + + def save_object(self, database_object): + self.session.add(database_object) + self.session.commit() + + def load_history_from_db(self, conversation_id): + messages = self.session.query(Message).filter_by( + conversation_id=conversation_id + ).order_by(Message.timestamp).all() + results = [ + { + "role": message.role, + "content": message.content, + "name": message.name, + "is_bot": message.is_bot, + "timestamp": message.timestamp, + "conversation_id": message.conversation_id + } for message in messages + ] + return results + + def save_message(self, content, role, name, is_bot, conversation_id) -> Message: + timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object + llm_generator_settings = self.session.query(LLMGeneratorSettings).first() + message = Message( + role=role, + content=content, + name=name, + is_bot=is_bot, + timestamp=timestamp, + conversation_id=conversation_id, + chatbot_id=llm_generator_settings.current_chatbot + ) + self.session.add(message) + self.session.commit() + return message + + def get_chatbot_by_id(self, chatbot_id) -> Type[Chatbot]: + chatbot = self.session.query(Chatbot).filter_by(id=chatbot_id).options(joinedload(Chatbot.target_files)).first() + if chatbot is None: + chatbot = self.session.query(Chatbot).options(joinedload(Chatbot.target_files)).first() + return chatbot + + def create_conversation(self): + conversation = Conversation( + timestamp=datetime.datetime.now(datetime.timezone.utc), + title="" + ) + self.session.add(conversation) + self.session.commit() + return conversation.id + + def update_conversation_title(self, conversation_id, title): + conversation = self.session.query(Conversation).filter_by(id=conversation_id).first() + if conversation: + conversation.title = title + self.session.commit() + + def add_summary(self, content, conversation_id): + timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object + summary = Summary( + content=content, + timestamp=timestamp, + conversation_id=conversation_id + ) + self.session.add(summary) + self.session.commit() + + def create_conversation_with_messages(self, messages): + conversation_id = self.create_conversation() + for message in messages: + self.add_message_to_history( + content=message["content"], + role=message["role"], + name=message["name"], + is_bot=message["is_bot"], + conversation_id=conversation_id + ) + return conversation_id + + def get_all_conversations(self): + conversations = self.session.query(Conversation).all() + return conversations + + def delete_conversation(self, conversation_id): + self.session.query(Message).filter_by(conversation_id=conversation_id).delete() + self.session.query(Summary).filter_by(conversation_id=conversation_id).delete() + self.session.query(Conversation).filter_by(id=conversation_id).delete() + self.session.commit() + + def get_most_recent_conversation_id(self): + conversation = self.session.query(Conversation).order_by(Conversation.timestamp.desc()).first() + return conversation.id if conversation else None + + def __settings_updated(self): + self.emit_signal(SignalCode.APPLICATION_SETTINGS_CHANGED_SIGNAL) diff --git a/src/airunner/windows/prompt_browser/prompt_widget.py b/src/airunner/windows/prompt_browser/prompt_widget.py index 0f4e09005..da0aeb86d 100644 --- a/src/airunner/windows/prompt_browser/prompt_widget.py +++ b/src/airunner/windows/prompt_browser/prompt_widget.py @@ -34,10 +34,10 @@ def action_clicked_button_load(self): }) def action_clicked_button_delete(self): - session = self.db_handler.get_db_session() - session.delete(self.saved_prompt) - session.commit() - session.close() + + self.session.delete(self.saved_prompt) + self.session.commit() + self.deleteLater() def save_prompt(self): diff --git a/src/airunner/windows/setup_wizard/installation_settings/install_page.py b/src/airunner/windows/setup_wizard/installation_settings/install_page.py index 3ce47066f..3097023bd 100644 --- a/src/airunner/windows/setup_wizard/installation_settings/install_page.py +++ b/src/airunner/windows/setup_wizard/installation_settings/install_page.py @@ -78,13 +78,13 @@ def download_stable_diffusion(self): "label": "Downloading Stable Diffusion models..." }) - session = self.db_handler.get_db_session() - models = session.query(AIModels).filter( + + models = self.session.query(AIModels).filter( AIModels.category == "stablediffusion", AIModels.is_default == 1, AIModels.version != "SDXL Turbo" ).all() - session.close() + self.total_models_in_current_step += len(models) for model in models: @@ -198,12 +198,12 @@ def download_controlnet_processors(self): print(f"Error downloading {filename}: {e}") def download_llms(self): - session = self.db_handler.get_db_session() - models = session.query(AIModels).filter( + + models = self.session.query(AIModels).filter( AIModels.category == "llm", AIModels.is_default == 1 ).all() - session.close() + self.total_models_in_current_step += len(models) for model in models: files = LLM_FILE_BOOTSTRAP_DATA[model.path]["files"] @@ -409,12 +409,12 @@ def __init__(self, parent): if self.application_settings.stable_diffusion_agreement_checked: self.total_steps += 1 - session = self.db_handler.get_db_session() + - controlnet_model_count = session.query(func.count(ControlnetModel.id.distinct())).scalar() - controlnet_version_count = session.query(func.count(ControlnetModel.version.distinct())).scalar() + controlnet_model_count = self.session.query(func.count(ControlnetModel.id.distinct())).scalar() + controlnet_version_count = self.session.query(func.count(ControlnetModel.version.distinct())).scalar() - llm_model_count = session.query(func.count(AIModels.id)).filter(AIModels.category == 'llm').scalar() + llm_model_count = self.session.query(func.count(AIModels.id)).filter(AIModels.category == 'llm').scalar() self.total_steps += controlnet_model_count * controlnet_version_count self.total_steps += llm_model_count diff --git a/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py b/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py index 692842fb6..5aa7b501c 100644 --- a/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py +++ b/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py @@ -1,4 +1,4 @@ -from airunner.windows.setup_wizard.setup_wizard_window import BaseWizard +from airunner.windows.setup_wizard.base_wizard import BaseWizard from airunner.windows.setup_wizard.model_setup.stable_diffusion_setup.templates.choose_style_ui import Ui_choose_model_style diff --git a/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py b/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py index eff676599..258dd1e58 100644 --- a/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py +++ b/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py @@ -1,5 +1,5 @@ from airunner.windows.setup_wizard.model_setup.stable_diffusion_setup.templates.choose_version_ui import Ui_choose_model_version -from airunner.windows.setup_wizard.setup_wizard_window import BaseWizard +from airunner.windows.setup_wizard.base_wizard import BaseWizard class ChooseModelVersion(BaseWizard): diff --git a/src/airunner/worker_manager.py b/src/airunner/worker_manager.py deleted file mode 100644 index fe0d7dbea..000000000 --- a/src/airunner/worker_manager.py +++ /dev/null @@ -1,41 +0,0 @@ -from PySide6.QtCore import QObject, Signal - -from airunner.mediator_mixin import MediatorMixin -from airunner.windows.main.settings_mixin import SettingsMixin -from airunner.handlers.logger import Logger -from airunner.utils.create_worker import create_worker -from airunner.workers.audio_capture_worker import AudioCaptureWorker -from airunner.workers.audio_processor_worker import AudioProcessorWorker -from airunner.workers.llm_generate_worker import LLMGenerateWorker -from airunner.workers.mask_generator_worker import MaskGeneratorWorker -from airunner.workers.sd_worker import SDWorker -from airunner.workers.tts_generator_worker import TTSGeneratorWorker -from airunner.workers.tts_vocalizer_worker import TTSVocalizerWorker - - -class WorkerManager(QObject, MediatorMixin, SettingsMixin): - """ - The engine is responsible for processing requests and offloading - them to the appropriate AI model controller. - """ - # Signals - request_signal_status = Signal(str) - image_generated_signal = Signal(dict) - - def __init__(self): - MediatorMixin.__init__(self) - SettingsMixin.__init__(self) - super().__init__() - self.logger = Logger(prefix=self.__class__.__name__) - self._sd_worker = None - self._llm_request_worker = None - self._llm_generate_worker = None - self._tts_generator_worker = None - self._tts_vocalizer_worker = None - self._stt_audio_capture_worker = None - self._stt_audio_processor_worker = None - - self.register_sd_workers() - self.register_llm_workers() - self.register_tts_workers() - self.register_stt_workers() diff --git a/src/airunner/workers/worker.py b/src/airunner/workers/worker.py index a0e31f838..291b152fa 100644 --- a/src/airunner/workers/worker.py +++ b/src/airunner/workers/worker.py @@ -5,7 +5,6 @@ from PySide6.QtCore import Signal, QThread, QObject from airunner.enums import QueueType, SignalCode, WorkerState -from airunner.handlers.logger import Logger from airunner.mediator_mixin import MediatorMixin from airunner.settings import SLEEP_TIME_IN_MS from airunner.windows.main.settings_mixin import SettingsMixin @@ -22,7 +21,6 @@ def __init__(self, signals=None): SettingsMixin.__init__(self) super().__init__() self.state = WorkerState.HALTED - self.logger = Logger(prefix=self.__class__.__name__) self.running = False self.queue = queue.Queue() self.items = {}