diff --git a/setup.py b/setup.py
index 5ecc7d35f..fa055813e 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,7 @@
setup(
name="airunner",
- version="3.0.21",
+ version="3.1.0",
author="Capsize LLC",
description="A Stable Diffusion GUI",
long_description=open("README.md", "r", encoding="utf-8").read(),
diff --git a/src/airunner/app.py b/src/airunner/app.py
index add0cdd5b..dd1415127 100644
--- a/src/airunner/app.py
+++ b/src/airunner/app.py
@@ -32,7 +32,6 @@
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.data.models.settings_models import ApplicationSettings, AIModels
from airunner.windows.main.main_window import MainWindow
-from airunner.handlers.logger import Logger
class App(
@@ -55,7 +54,6 @@ def __init__(
"""
self.main_window_class_ = main_window_class or MainWindow
self.app = None
- self.logger = Logger(prefix=self.__class__.__name__)
self.defendatron = defendatron
self.splash = None
@@ -95,9 +93,7 @@ def create_paths(self):
"images"
)
))
- session = self.db_handler.get_db_session()
- versions = session.query(distinct(AIModels.version)).filter(AIModels.category == 'stablediffusion').all()
- session.close()
+ versions = self.session.query(distinct(AIModels.version)).filter(AIModels.category == 'stablediffusion').all()
for version in versions:
os.makedirs(
os.path.join(models_path, version[0], "embeddings"),
@@ -110,9 +106,7 @@ def create_paths(self):
os.makedirs(images_path, exist_ok=True)
def run_setup_wizard(self):
- session = self.db_handler.get_db_session()
- application_settings = session.query(ApplicationSettings).first()
- session.close()
+ application_settings = self.session.query(ApplicationSettings).first()
if application_settings.run_setup_wizard:
AppInstaller()
diff --git a/src/airunner/app_installer.py b/src/airunner/app_installer.py
index d03a6a2af..51305ca2b 100644
--- a/src/airunner/app_installer.py
+++ b/src/airunner/app_installer.py
@@ -19,7 +19,6 @@
from airunner.windows.download_wizard.download_wizard_window import DownloadWizardWindow
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.windows.setup_wizard.setup_wizard_window import SetupWizardWindow
-from airunner.handlers.logger import Logger
class AppInstaller(
@@ -43,7 +42,6 @@ def __init__(
self.download_wizard = None
self.app = None
self.close_on_cancel = close_on_cancel
- self.logger = Logger(prefix=self.__class__.__name__)
"""
Mediator and Settings mixins are initialized here, enabling the application
diff --git a/src/airunner/data/models/database_handler.py b/src/airunner/data/models/database_handler.py
deleted file mode 100644
index 519dd6304..000000000
--- a/src/airunner/data/models/database_handler.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import os
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
-from airunner.data.models.settings_models import Base
-
-class DatabaseHandler:
- def __init__(self, db_path=os.path.expanduser(
- os.path.join(
- "~",
- ".local",
- "share",
- "airunner",
- "data",
- "airunner.db"
- )
- )):
- self.db_path = db_path
- self.engine = create_engine(f'sqlite:///{self.db_path}')
- Base.metadata.create_all(self.engine)
- self.Session = sessionmaker(bind=self.engine)
- self.conversation_id = None
-
- def get_db_session(self):
- return self.Session()
diff --git a/src/airunner/data/models/settings_db_handler.py b/src/airunner/data/models/settings_db_handler.py
deleted file mode 100644
index 8104d5a6a..000000000
--- a/src/airunner/data/models/settings_db_handler.py
+++ /dev/null
@@ -1,601 +0,0 @@
-import datetime
-from typing import List
-
-from sqlalchemy.orm import joinedload
-
-from airunner.data.models.database_handler import DatabaseHandler
-from airunner.data.models.settings_models import Chatbot, AIModels, Schedulers, Lora, PathSettings, SavedPrompt, \
- Embedding, PromptTemplate, ControlnetModel, FontSetting, PipelineModel, ShortcutKeys, \
- GeneratorSettings, WindowSettings, ApplicationSettings, ActiveGridSettings, ControlnetSettings, \
- ImageToImageSettings, OutpaintSettings, DrawingPadSettings, MetadataSettings, \
- LLMGeneratorSettings, TTSSettings, SpeechT5Settings, EspeakSettings, STTSettings, BrushSettings, GridSettings, \
- MemorySettings, Message, Conversation, Summary
-
-
-class SettingsDBHandler(DatabaseHandler):
- #######################################
- ### SCHEDULERS ###
- #######################################
- def load_schedulers(self) -> List[Schedulers]:
- session = self.get_db_session()
- try:
- return session.query(Schedulers).all()
- finally:
- session.close()
-
- #######################################
- ### SETTINGS ###
- #######################################
- def load_settings_from_db(self, model_class_):
- session = self.get_db_session()
- try:
- settings = session.query(model_class_).first()
- if settings is None:
- settings = self.create_new_settings(model_class_)
- finally:
- session.close()
- return settings
-
- def update_setting(self, model_class_, name, value):
- session = self.get_db_session()
- try:
- setting = session.query(model_class_).order_by(model_class_.id.desc()).first()
- if setting:
- setattr(setting, name, value)
- session.commit()
- finally:
- session.close()
-
- def save_generator_settings(self, generator_settings: GeneratorSettings):
- session = self.get_db_session()
- try:
- query = session.query(GeneratorSettings).filter_by(
- id=generator_settings.id
- ).first()
- if query:
- for key in generator_settings.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(generator_settings, key))
- else:
- session.add(generator_settings)
- session.commit()
- finally:
- session.close()
-
- def reset_settings(self):
- session = self.get_db_session()
- try:
- # Delete all entries from the model class
- session.query(ApplicationSettings).delete()
- session.query(ActiveGridSettings).delete()
- session.query(ControlnetSettings).delete()
- session.query(ImageToImageSettings).delete()
- session.query(OutpaintSettings).delete()
- session.query(DrawingPadSettings).delete()
- session.query(MetadataSettings).delete()
- session.query(GeneratorSettings).delete()
- session.query(LLMGeneratorSettings).delete()
- session.query(TTSSettings).delete()
- session.query(SpeechT5Settings).delete()
- session.query(EspeakSettings).delete()
- session.query(STTSettings).delete()
- session.query(BrushSettings).delete()
- session.query(GridSettings).delete()
- session.query(PathSettings).delete()
- session.query(MemorySettings).delete()
- # Commit the changes
- session.commit()
- finally:
- session
-
- def create_new_settings(self, model_class_):
- session = self.get_db_session()
- try:
- new_settings = model_class_()
- session.add(new_settings)
- session.commit()
- session.refresh(new_settings)
- finally:
- session.close()
- return new_settings
-
- #######################################
- ### SAVED PROMPTS ###
- #######################################
- def get_saved_prompt_by_id(self, prompt_id) -> SavedPrompt:
- session = self.get_db_session()
- try:
- return session.query(SavedPrompt).filter_by(id=prompt_id).first()
- finally:
- session.close()
-
- def update_saved_prompt(self, saved_prompt: SavedPrompt):
- session = self.get_db_session()
- try:
- query = session.query(SavedPrompt).filter_by(
- id=saved_prompt.id
- ).first()
- if query:
- for key in saved_prompt.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(saved_prompt, key))
- else:
- session.add(saved_prompt)
- session.commit()
- finally:
- session.close()
-
- def create_saved_prompt(self, data: dict):
- session = self.get_db_session()
- try:
- new_saved_prompt = SavedPrompt(**data)
- session.add(new_saved_prompt)
- session.commit()
- finally:
- session.close()
-
- def load_saved_prompts(self) -> List[SavedPrompt]:
- session = self.get_db_session()
- try:
- return session.query(SavedPrompt).all()
- finally:
- session.close()
-
- def load_font_settings(self) -> List[FontSetting]:
- session = self.get_db_session()
- try:
- return session.query(FontSetting).all()
- finally:
- session.close()
-
- def get_font_setting_by_name(self, name) -> FontSetting:
- session = self.get_db_session()
- try:
- return session.query(FontSetting).filter_by(name=name).first()
- finally:
- session.close()
-
- def update_font_setting(self, font_setting: FontSetting):
- session = self.get_db_session()
- try:
- query = session.query(FontSetting).filter_by(
- name=font_setting.name
- ).first()
- if query:
- for key in font_setting.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(font_setting, key))
- else:
- session.add(font_setting)
- session.commit()
- finally:
- session.close()
-
- #######################################
- ### AI MODELS ###
- #######################################
- def load_ai_models(self) -> List[AIModels]:
- session = self.get_db_session()
- try:
- return session.query(AIModels).all()
- finally:
- session.close()
-
- def update_ai_models(self, models: List[AIModels]):
- for model in models:
- self.update_ai_model(model)
-
- def update_ai_model(self, model: AIModels):
- session = self.get_db_session()
- try:
- query = session.query(AIModels).filter_by(
- name=model.name,
- path=model.path,
- branch=model.branch,
- version=model.version,
- category=model.category,
- pipeline_action=model.pipeline_action,
- enabled=model.enabled,
- model_type=model.model_type,
- is_default=model.is_default
- ).first()
- if query:
- for key in model.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(model, key))
- else:
- session.add(model)
- session.commit()
- finally:
- session.close()
-
- #######################################
- ### CHATBOTS ###
- #######################################
- def load_chatbots(self) -> List[Chatbot]:
- session = self.get_db_session()
- try:
- settings = session.query(Chatbot).all()
- return settings
- finally:
- session.close()
-
- def delete_chatbot_by_name(self, chatbot_name):
- session = self.get_db_session()
- try:
- session.query(Chatbot).filter_by(name=chatbot_name).delete()
- session.commit()
- finally:
- session.close()
-
- def create_chatbot(self, chatbot_name):
- session = self.get_db_session()
- try:
- new_chatbot = Chatbot(name=chatbot_name)
- session.add(new_chatbot)
- session.commit()
- finally:
- session.close()
-
- def reset_path_settings(self):
- session = self.get_db_session()
- try:
- # Delete all entries from PathSettings
- session.query(PathSettings).delete()
-
- # Create a new PathSettings instance with default values
- self.set_default_values(PathSettings)
- # Commit the changes
- session.commit()
- finally:
- session.close()
-
- def set_default_values(self, model_name_):
- session = self.get_db_session()
- try:
- default_values = {}
- for column in model_name_.__table__.columns:
- if column.default is not None:
- default_values[column.name] = column.default.arg
- session.execute(
- model_name_.__table__.insert(),
- [default_values]
- )
- session.commit()
- finally:
- session.close()
-
- #######################################
- ### LORA ###
- #######################################
- def load_lora(self) -> List[Lora]:
- session = self.get_db_session()
- try:
- return session.query(Lora).all()
- finally:
- session.close()
-
- def get_lora_by_name(self, name):
- session = self.get_db_session()
- try:
- return session.query(Lora).filter_by(name=name).first()
- finally:
- session.close()
-
-
- def add_lora(self, lora: Lora):
- session = self.get_db_session()
- try:
- session.add(lora)
- session.commit()
- finally:
- session.close()
-
- def delete_lora(self, lora: Lora):
- session = self.get_db_session()
- try:
- session.query(Lora).filter_by(name=lora.name).delete()
- session.commit()
- finally:
- session.close()
-
- def update_lora(self, lora: Lora):
- session = self.get_db_session()
- try:
- query = session.query(Lora).filter_by(name=lora.name).first()
- if query:
- for key in lora.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(lora, key))
- else:
- session.add(lora)
- session.commit()
- finally:
- session.close()
-
- def update_loras(self, loras: List[Lora]):
- session = self.get_db_session()
- try:
- for lora in loras:
- query = session.query(Lora).filter_by(name=lora.name).first()
- if query:
- for key in lora.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(lora, key))
- else:
- session.add(lora)
- session.commit()
- finally:
- session.close()
-
- def create_lora(self, lora: Lora):
- session = self.get_db_session()
- try:
- session.add(lora)
- session.commit()
- finally:
- session.close()
-
- def delete_lora_by_name(self, lora_name, version):
- session = self.get_db_session()
- try:
- session.query(Lora).filter_by(name=lora_name, version=version).delete()
- session.commit()
- finally:
- session.close()
-
- #######################################
- ### EMBEDDINGS ###
- #######################################
- def delete_embedding(self, embedding: Embedding):
- session = self.get_db_session()
- try:
- session.query(Embedding).filter_by(
- name=embedding.name,
- path=embedding.path,
- branch=embedding.branch,
- version=embedding.version,
- category=embedding.category,
- pipeline_action=embedding.pipeline_action,
- enabled=embedding.enabled,
- model_type=embedding.model_type,
- is_default=embedding.is_default
- ).delete()
- session.commit()
- finally:
- session.close()
-
- def load_embeddings(self) -> List[Embedding]:
- session = self.get_db_session()
- try:
- return session.query(Embedding).all()
- finally:
- session.close()
-
- def update_embeddings(self, embeddings: List[Embedding]):
- session = self.get_db_session()
- try:
- for embedding in embeddings:
- query = session.query(Embedding).filter_by(
- name=embedding.name,
- path=embedding.path,
- branch=embedding.branch,
- version=embedding.version,
- category=embedding.category,
- pipeline_action=embedding.pipeline_action,
- enabled=embedding.enabled,
- model_type=embedding.model_type,
- is_default=embedding.is_default
- ).first()
- if query:
- for key in embedding.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(embedding, key))
- else:
- session.add(embedding)
- session.commit()
- finally:
- session.close()
-
- def get_embedding_by_name(self, name):
- session = self.get_db_session()
- try:
- return session.query(Embedding).filter_by(name=name).first()
- finally:
- session.close()
-
- def add_embedding(self, embedding: Embedding):
- session = self.get_db_session()
- try:
- session.add(embedding)
- session.commit()
- finally:
- session.close()
-
- #######################################
- ### PROMPT TEMPLATES ###
- #######################################
- def load_prompt_templates(self) -> List[PromptTemplate]:
- session = self.get_db_session()
- try:
- return session.query(PromptTemplate).all()
- finally:
- session.close()
-
- def get_prompt_template_by_name(self, name) -> PromptTemplate:
- session = self.get_db_session()
- try:
- return session.query(PromptTemplate).filter_by(template_name=name).first()
- finally:
- session.close()
-
-
- #######################################
- ### CONTROLNET MODELS ###
- #######################################
- def load_controlnet_models(self) -> List[ControlnetModel]:
- session = self.get_db_session()
- try:
- return session.query(ControlnetModel).all()
- finally:
- session.close()
-
- def controlnet_model_by_name(self, name) -> ControlnetModel:
- session = self.get_db_session()
- try:
- return session.query(ControlnetModel).filter_by(name=name).first()
- finally:
- session.close()
-
- def load_pipelines(self) -> List[PipelineModel]:
- session = self.get_db_session()
- try:
- return session.query(PipelineModel).all()
- finally:
- session.close()
-
-
- def load_shortcut_keys(self) -> List[ShortcutKeys]:
- session = self.get_db_session()
- try:
- return session.query(ShortcutKeys).all()
- finally:
- session.close()
-
-
- def load_window_settings(self) -> WindowSettings:
- session = self.get_db_session()
- try:
- return session.query(WindowSettings).first()
- finally:
- session.close()
-
-
- def save_window_settings(self, window_settings: WindowSettings):
- session = self.get_db_session()
- try:
- query = session.query(WindowSettings).first()
- if query:
- for key in window_settings.__dict__.keys():
- if key != "_sa_instance_state":
- setattr(query, key, getattr(window_settings, key))
- else:
- session.add(window_settings)
- session.commit()
- finally:
- session.close()
-
- def save_object(self, database_object):
- session = self.get_db_session()
- session.add(database_object)
- session.commit()
- session.close()
-
-
- def load_history_from_db(self, conversation_id):
- with self.get_db_session() as session:
- messages = session.query(Message).filter_by(
- conversation_id=conversation_id
- ).order_by(Message.timestamp).all()
- results = [
- {
- "role": message.role,
- "content": message.content,
- "name": message.name,
- "is_bot": message.is_bot,
- "timestamp": message.timestamp,
- "conversation_id": message.conversation_id
- } for message in messages
- ]
- return results
-
- def add_message_to_history(self, content, role, name, is_bot, conversation_id):
- timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object
- with self.get_db_session() as session:
- llm_generator_settings = session.query(LLMGeneratorSettings).first()
- message = Message(
- role=role,
- content=content,
- name=name,
- is_bot=is_bot,
- timestamp=timestamp,
- conversation_id=conversation_id,
- chatbot_id=llm_generator_settings.current_chatbot
- )
- session.add(message)
- session.commit()
-
- def get_chatbot_by_id(self, chatbot_id) -> Chatbot:
- session = self.get_db_session()
- try:
- chatbot = session.query(Chatbot).filter_by(id=chatbot_id).options(joinedload(Chatbot.target_files)).first()
- if chatbot is None:
- chatbot = session.query(Chatbot).options(joinedload(Chatbot.target_files)).first()
- finally:
- session.close()
- return chatbot
-
- def create_conversation(self):
- with self.get_db_session() as session:
- conversation = Conversation(
- timestamp=datetime.datetime.now(datetime.timezone.utc),
- title=""
- )
- session.add(conversation)
- session.commit()
- return conversation.id
-
- def update_conversation_title(self, conversation_id, title):
- with self.get_db_session() as session:
- conversation = session.query(Conversation).filter_by(id=conversation_id).first()
- if conversation:
- conversation.title = title
- session.commit()
-
- def add_summary(self, content, conversation_id):
- timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object
- with self.get_db_session() as session:
- summary = Summary(
- content=content,
- timestamp=timestamp,
- conversation_id=conversation_id
- )
- session.add(summary)
- session.commit()
-
- def create_conversation_with_messages(self, messages):
- conversation_id = self.create_conversation()
- for message in messages:
- self.add_message_to_history(
- content=message["content"],
- role=message["role"],
- name=message["name"],
- is_bot=message["is_bot"],
- conversation_id=conversation_id
- )
- return conversation_id
-
- def get_all_conversations(self):
- session = self.Session()
- conversations = session.query(Conversation).all()
- session.close()
- return conversations
-
- def delete_conversation(self, conversation_id):
- session = self.Session()
- try:
- session.query(Message).filter_by(conversation_id=conversation_id).delete()
- session.query(Summary).filter_by(conversation_id=conversation_id).delete()
- session.query(Conversation).filter_by(id=conversation_id).delete()
- session.commit()
- except Exception as e:
- session.rollback()
- print(f"Error deleting conversation: {e}")
- finally:
- session.close()
-
- def get_most_recent_conversation_id(self):
- session = self.Session()
- conversation = session.query(Conversation).order_by(Conversation.timestamp.desc()).first()
- session.close()
- return conversation.id if conversation else None
diff --git a/src/airunner/handlers/base_handler.py b/src/airunner/handlers/base_handler.py
index 0fbe3177f..1e718a3ca 100644
--- a/src/airunner/handlers/base_handler.py
+++ b/src/airunner/handlers/base_handler.py
@@ -2,7 +2,6 @@
from PySide6.QtCore import QObject
from airunner.enums import HandlerType, SignalCode, ModelType, ModelStatus, ModelAction
from airunner.mediator_mixin import MediatorMixin
-from airunner.handlers.logger import Logger
from airunner.utils.get_torch_device import get_torch_device
from airunner.windows.main.settings_mixin import SettingsMixin
@@ -23,7 +22,6 @@ class BaseHandler(
def __init__(self, *args, **kwargs):
self._model_status = {model_type: ModelStatus.UNLOADED for model_type in ModelType}
self.use_gpu = True
- self.logger = Logger(prefix=self.__class__.__name__)
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
super().__init__(*args, **kwargs)
diff --git a/src/airunner/handlers/llm/agent/base_agent.py b/src/airunner/handlers/llm/agent/base_agent.py
index 8cb664a36..645936a62 100644
--- a/src/airunner/handlers/llm/agent/base_agent.py
+++ b/src/airunner/handlers/llm/agent/base_agent.py
@@ -9,6 +9,7 @@
from PySide6.QtCore import QObject
from llama_index.core import Settings
+from llama_index.core.base.llms.types import ChatMessage
from llama_index.readers.file import EpubReader, PDFReader, MarkdownReader
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
@@ -23,7 +24,6 @@
from airunner.handlers.llm.custom_embedding import CustomEmbedding
from airunner.handlers.llm.agent.html_file_reader import HtmlFileReader
from airunner.handlers.llm.agent.external_condition_stopping_criteria import ExternalConditionStoppingCriteria
-from airunner.handlers.logger import Logger
from airunner.mediator_mixin import MediatorMixin
from airunner.enums import (
SignalCode,
@@ -38,13 +38,19 @@
from airunner.workers.agent_worker import AgentWorker
+class RefreshContextChatEngine(ContextChatEngine):
+ def stream_chat(self, *args, system_prompt:str=None, **kwargs):
+ if system_prompt:
+ self._prefix_messages[0] = ChatMessage(content=system_prompt, role=self._llm.metadata.system_role)
+ return super().stream_chat(*args, **kwargs)
+
+
class BaseAgent(
QObject,
MediatorMixin,
SettingsMixin
):
def __init__(self, *args, **kwargs):
- self.logger = Logger(prefix=self.__class__.__name__)
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
self.model = kwargs.pop("model", None)
@@ -84,19 +90,33 @@ def __init__(self, *args, **kwargs):
self.action = LLMActionType.CHAT
self.rendered_template = None
self.tokenizer = kwargs.pop("tokenizer", None)
- self.streamer = TextIteratorStreamer(self.tokenizer)
+ self._streamer = None
self.chat_template = kwargs.pop("chat_template", "")
self.is_mistral = kwargs.pop("is_mistral", True)
self.conversation_id = None
self.conversation_title = None
- self.history = self.db_handler.load_history_from_db(self.conversation_id) # Load history by conversation ID
+ self.history = self.load_history_from_db(self.conversation_id) # Load history by conversation ID
super().__init__(*args, **kwargs)
self.prompt = ""
self.thread = None
- self.do_interrupt = False
+ self._do_interrupt = False
self.response_worker = create_worker(AgentWorker)
self.load_rag(model=self.model, tokenizer=self.tokenizer)
+ @property
+ def do_interrupt(self):
+ return self._do_interrupt
+
+ @do_interrupt.setter
+ def do_interrupt(self, value):
+ self._do_interrupt = value
+
+ @property
+ def streamer(self):
+ if self._streamer is None:
+ self._streamer = TextIteratorStreamer(self.tokenizer)
+ return self._streamer
+
@property
def available_actions(self):
return {
@@ -123,13 +143,85 @@ def bot_mood(self) -> str:
def bot_mood(self, value: str):
chatbot = self.chatbot
chatbot.bot_mood = value
- self.db_handler.save_object(chatbot)
+ self.save_object(chatbot)
self.emit_signal(SignalCode.BOT_MOOD_UPDATED)
@property
def bot_personality(self) -> str:
return self.chatbot.bot_personality
+ @property
+ def override_parameters(self):
+ generate_kwargs = prepare_llm_generate_kwargs(self.llm_generator_settings)
+ return generate_kwargs if self.llm_generator_settings.override_parameters else {}
+
+ @property
+ def system_instructions(self):
+ return self.chatbot.system_instructions
+
+ @property
+ def generator_settings(self) -> dict:
+ return prepare_llm_generate_kwargs(self.chatbot)
+
+ @property
+ def device(self):
+ return get_torch_device(self.memory_settings.default_gpu_llm)
+
+ @property
+ def target_files(self):
+ return [
+ target_file.file_path for target_file in self.chatbot.target_files
+ ]
+
+ @property
+ def query_instruction(self):
+ if self.__state == AgentState.SEARCH:
+ return self.__query_instruction
+ elif self.__state == AgentState.CHAT:
+ return "Search through the chat history for anything relevant to the query."
+
+ @property
+ def text_instruction(self):
+ if self.__state == AgentState.SEARCH:
+ return self.__text_instruction
+ elif self.__state == AgentState.CHAT:
+ return "Use the text to respond to the user"
+
+ @property
+ def index(self):
+ if self.__state == AgentState.SEARCH:
+ return self.__index
+ elif self.__state == AgentState.CHAT:
+ return self.__chat_history_index
+
+ @property
+ def llm(self):
+ if self.__llm is None:
+ try:
+ if self.llm_generator_settings.use_api:
+ self.__llm = self.__model
+ else:
+ self.__llm = HuggingFaceLLM(model=self.__model, tokenizer=self.__tokenizer)
+ except Exception as e:
+ self.logger.error(f"Error loading LLM: {str(e)}")
+ return self.__llm
+
+ @property
+ def chat_engine(self):
+ return self.__chat_engine
+
+ @property
+ def is_llama_instruct(self):
+ return True
+
+ @property
+ def use_cuda(self):
+ return torch.cuda.is_available()
+
+ @property
+ def cuda_index(self):
+ return 0
+
def unload(self):
self.unload_rag()
del self.model
@@ -144,39 +236,29 @@ def clear_history(self):
self.conversation_id = None
self.conversation_title = None
- def update_conversation_title(self, title):
+ def _update_conversation_title(self, title):
self.conversation_title = title
- self.db_handler.update_conversation_title(self.conversation_id, title)
+ self.update_conversation_title(self.conversation_id, title)
- def create_conversation(self):
+ def _create_conversation(self):
# Get the most recent conversation ID
- recent_conversation_id = self.db_handler.get_most_recent_conversation_id()
+ recent_conversation_id = self.get_most_recent_conversation_id()
# Check if there are messages for the most recent conversation ID
if recent_conversation_id is not None:
- messages = self.db_handler.load_history_from_db(recent_conversation_id)
+ messages = self.load_history_from_db(recent_conversation_id)
if not messages:
self.conversation_id = recent_conversation_id
return
# If there are messages or no recent conversation ID, create a new conversation
- self.conversation_id = self.db_handler.create_conversation()
+ self.conversation_id = self.create_conversation()
def interrupt_process(self):
self.do_interrupt = True
def do_interrupt_process(self):
- interrupt = self.do_interrupt
- self.streamer = TextIteratorStreamer(self.tokenizer)
- return interrupt
-
- @property
- def use_cuda(self):
- return torch.cuda.is_available()
-
- @property
- def cuda_index(self):
- return 0
+ return self.do_interrupt
def mood(self, botname: str, bot_mood: str, use_mood: bool) -> str:
return (
@@ -261,12 +343,6 @@ def build_system_prompt(
]
elif action is LLMActionType.GENERATE_IMAGE:
- prompt_template = self.get_prompt_template_by_name("image")
- # system_prompt = [
- # prompt_template.guardrails,
- # prompt_template.system,
- # self.history_prompt()
- # ]
system_prompt = [
(
"You are an image generator. "
@@ -285,7 +361,7 @@ def build_system_prompt(
"You will only return JSON strings.\n"
"You will not return any other data types.\n"
"You are an artist, so use your imagination and keep things interesting.\n"
- "You will not respond in a conversational manner or with additonal notes or information.\n"
+ "You will not respond in a conversational manner or with additional notes or information.\n"
f"Only return one JSON block. Do not generate instructions or additional information.\n"
"You must never break the rules.\n"
"Here is a description of the attributes: \n"
@@ -322,7 +398,6 @@ def build_system_prompt(
prompt_template = self.get_prompt_template_by_name("update_mood")
system_instructions = prompt_template.system
system_prompt = [
- guardrails_prompt,
system_instructions,
self.names_prompt(use_names, botname, username),
self.mood(botname, bot_mood, use_mood),
@@ -339,7 +414,7 @@ def build_system_prompt(
self.names_prompt(use_names, botname, username),
self.mood(botname, bot_mood, use_mood),
self.personality_prompt(bot_personality, use_personality),
- self.history_prompt(),
+ # self.history_prompt(),
]
elif action is LLMActionType.QUIT_APPLICATION:
@@ -420,70 +495,6 @@ def get_rendered_template(
rendered_template = rendered_template.replace("{{ " + key + " }}", value)
return rendered_template
- @property
- def override_parameters(self):
- generate_kwargs = prepare_llm_generate_kwargs(self.llm_generator_settings)
- return generate_kwargs if self.llm_generator_settings.override_parameters else {}
-
- @property
- def system_instructions(self):
- return self.chatbot.system_instructions
-
- @property
- def generator_settings(self) -> dict:
- return prepare_llm_generate_kwargs(self.chatbot)
-
- @property
- def device(self):
- return get_torch_device(self.memory_settings.default_gpu_llm)
-
- @property
- def target_files(self):
- return [
- target_file.file_path for target_file in self.chatbot.target_files
- ]
-
- @property
- def query_instruction(self):
- if self.__state == AgentState.SEARCH:
- return self.__query_instruction
- elif self.__state == AgentState.CHAT:
- return "Search through the chat history for anything relevant to the query."
-
- @property
- def text_instruction(self):
- if self.__state == AgentState.SEARCH:
- return self.__text_instruction
- elif self.__state == AgentState.CHAT:
- return "Use the text to respond to the user"
-
- @property
- def index(self):
- if self.__state == AgentState.SEARCH:
- return self.__index
- elif self.__state == AgentState.CHAT:
- return self.__chat_history_index
-
- @property
- def llm(self):
- if self.__llm is None:
- try:
- if self.llm_generator_settings.use_api:
- self.__llm = self.__model
- else:
- self.__llm = HuggingFaceLLM(model=self.__model, tokenizer=self.__tokenizer)
- except Exception as e:
- self.logger.error(f"Error loading LLM: {str(e)}")
- return self.__llm
-
- @property
- def chat_engine(self):
- return self.__chat_engine
-
- @property
- def is_llama_instruct(self):
- return True
-
def run(
self,
prompt: str,
@@ -501,10 +512,9 @@ def run(
self.logger.debug("Running...")
self.prompt = prompt
- streamer = self.streamer
if self.conversation_id is None:
- self.create_conversation()
+ self._create_conversation()
self.set_conversation_title()
# Add the user's message to history
@@ -512,7 +522,10 @@ def run(
LLMActionType.APPLICATION_COMMAND,
LLMActionType.UPDATE_MOOD
):
- self.add_message_to_history(self.prompt, LLMChatRole.HUMAN)
+ self.add_message_to_history(
+ self.prompt,
+ LLMChatRole.HUMAN
+ )
self.rendered_template = self.get_rendered_template(action)
@@ -522,34 +535,14 @@ def run(
).to(self.device)
kwargs.update(
- streamer=streamer,
action=action,
do_emit_response=True
)
- if streamer:
- self.run_with_thread(
- model_inputs,
- **kwargs
- )
- else:
- self.emit_signal(SignalCode.UNBLOCK_TTS_GENERATOR_SIGNAL)
- stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process)
- data = self.prepare_generate_data(model_inputs, stopping_criteria)
- res = self.model.generate(**data)
- response = self.tokenizer.decode(res[0])
- self.emit_signal(
- SignalCode.LLM_TEXT_STREAMED_SIGNAL,
- dict(
- message=response,
- is_first_message=True,
- is_end_of_message=True,
- name=self.botname,
- action=action
- )
- )
-
- return response
+ self.run_with_thread(
+ model_inputs,
+ **kwargs
+ )
def prepare_generate_data(self, model_inputs, stopping_criteria):
data = dict(
@@ -574,13 +567,22 @@ def run_with_thread(
stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process)
data = self.prepare_generate_data(model_inputs, stopping_criteria)
- streamer = kwargs.get("streamer", self.streamer)
- data["streamer"] = streamer
- # if "attention_mask" in data:
- # del data["attention_mask"]
+ if self.do_interrupt:
+ self.do_interrupt = False
+ self.emit_signal(
+ SignalCode.LLM_TEXT_STREAMED_SIGNAL,
+ dict(
+ message="",
+ is_first_message=False,
+ is_end_of_message=False,
+ name=self.botname,
+ action=LLMActionType.CHAT
+ )
+ )
+ return
- self.do_interrupt = False
+ data["streamer"] = kwargs.get("streamer", self.streamer)
if action is not LLMActionType.PERFORM_RAG_SEARCH:
try:
@@ -620,14 +622,14 @@ def run_with_thread(
)
is_first_message = False
- if streamer and action in (
+ if action in (
LLMActionType.CHAT,
LLMActionType.GENERATE_IMAGE,
LLMActionType.UPDATE_MOOD,
LLMActionType.SUMMARIZE,
LLMActionType.APPLICATION_COMMAND
):
- for new_text in streamer:
+ for new_text in self.streamer:
# strip all newlines from new_text
streamed_template += new_text
if self.is_mistral:
@@ -693,7 +695,10 @@ def run_with_thread(
)
data.update(self.override_parameters)
self.llm.generate_kwargs = data
- response = self.chat_engine.stream_chat(message=self.prompt)
+ response = self.chat_engine.stream_chat(
+ message=self.prompt,
+ system_prompt=self.rendered_template
+ )
is_first_message = True
is_end_of_message = False
for new_text in response.response_gen:
@@ -710,6 +715,16 @@ def run_with_thread(
)
)
is_first_message = False
+ self.emit_signal(
+ SignalCode.LLM_TEXT_STREAMED_SIGNAL,
+ dict(
+ message="",
+ is_first_message=False,
+ is_end_of_message=True,
+ name=self.botname,
+ action=action
+ )
+ )
if streamed_template is not None:
if action is LLMActionType.CHAT:
@@ -727,7 +742,7 @@ def run_with_thread(
)
elif action is LLMActionType.SUMMARIZE:
- self.update_conversation_title(streamed_template)
+ self._update_conversation_title(streamed_template)
return self.run(
prompt=self.prompt,
action=LLMActionType.CHAT,
@@ -780,22 +795,14 @@ def add_message_to_history(
name = self.username
is_bot = False
+
if role is LLMChatRole.ASSISTANT and content:
content = content.replace(f"{self.botname}:", "")
content = content.replace(f"{self.botname}", "")
is_bot = True
name = self.botname
- self.history.append({
- "role": role.value,
- "content": content,
- "name": name,
- "is_bot": is_bot,
- "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- "conversation_id": self.conversation_id # Use the stored conversation ID
- })
-
- self.db_handler.add_message_to_history(
+ message = self.save_message(
content,
role.value,
name,
@@ -803,19 +810,28 @@ def add_message_to_history(
self.conversation_id
)
+ self.history.append({
+ "role": message.role,
+ "content": message.content,
+ "name": name,
+ "is_bot": message.is_bot,
+ "timestamp": message.timestamp,
+ "conversation_id": message.conversation_id
+ })
+
def on_load_conversation(self, message):
self.history = []
self.conversation_id = message["conversation_id"]
- self.history = self.db_handler.load_history_from_db(self.conversation_id)
+ self.history = self.load_history_from_db(self.conversation_id)
self.set_conversation_title()
self.emit_signal(SignalCode.SET_CONVERSATION, {
"messages": self.history
})
def set_conversation_title(self):
- session = self.db_handler.get_db_session()
- self.conversation_title = session.query(Conversation).filter_by(id=self.conversation_id).first().title
- session.close()
+
+ self.conversation_title = self.session.query(Conversation).filter_by(id=self.conversation_id).first().title
+
def load_rag(self, model, tokenizer):
self.__model = model
@@ -913,7 +929,7 @@ def __load_prompt_helper(self):
def __load_context_chat_engine(self):
try:
- self.__chat_engine = ContextChatEngine.from_defaults(
+ self.__chat_engine = RefreshContextChatEngine.from_defaults(
retriever=self.__retriever,
chat_history=self.history,
memory=None,
diff --git a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py
index 655c28098..de54cd33a 100644
--- a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py
+++ b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py
@@ -388,7 +388,8 @@ def _do_generate(self, prompt: str, action: LLMActionType):
prompt,
action
)
- self._send_final_message()
+ if action is LLMActionType.CHAT:
+ self._send_final_message()
def _emit_streamed_text_signal(self, **kwargs):
self.logger.debug("Emitting streamed text signal")
diff --git a/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py b/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py
index f4356b6c5..ca439a030 100644
--- a/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py
+++ b/src/airunner/handlers/stablediffusion/civit_ai_download_worker.py
@@ -6,7 +6,6 @@
from airunner.handlers.logger import Logger
from airunner.enums import SignalCode
from airunner.mediator_mixin import MediatorMixin
-from facehuggershield.huggingface.settings import DEFAULT_HF_ENDPOINT
from airunner.windows.main.settings_mixin import SettingsMixin
logger = Logger(prefix="DownloadWorker")
diff --git a/src/airunner/handlers/stablediffusion/download_civitai.py b/src/airunner/handlers/stablediffusion/download_civitai.py
index ad41dd852..56d8bf836 100644
--- a/src/airunner/handlers/stablediffusion/download_civitai.py
+++ b/src/airunner/handlers/stablediffusion/download_civitai.py
@@ -4,7 +4,6 @@
from PySide6.QtCore import QThread
from airunner.handlers.logger import Logger
from airunner.handlers.stablediffusion.civit_ai_download_worker import CivitAIDownloadWorker
-from airunner.handlers.stablediffusion.download_worker import DownloadWorker
from airunner.enums import SignalCode
from airunner.mediator_mixin import MediatorMixin
from airunner.windows.main.settings_mixin import SettingsMixin
diff --git a/src/airunner/handlers/stablediffusion/sd_handler.py b/src/airunner/handlers/stablediffusion/sd_handler.py
index 009dafcb9..95c3c6eb4 100644
--- a/src/airunner/handlers/stablediffusion/sd_handler.py
+++ b/src/airunner/handlers/stablediffusion/sd_handler.py
@@ -50,7 +50,6 @@
class SDHandler(BaseHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._session = self.db_handler.get_db_session()
self._controlnet_model = None
self._controlnet: ControlNetModel = None
self._controlnet_processor: Any = None
@@ -206,7 +205,7 @@ def path_settings_cached(self):
@property
def generator_settings_cached(self):
if self._generator_settings is None:
- self._generator_settings = self._session.query(
+ self._generator_settings = self.session.query(
GeneratorSettings
).first()
return self._generator_settings
@@ -235,12 +234,12 @@ def controlnet_model(self) -> ControlnetModel:
self._controlnet_model.version != self.generator_settings_cached.version or
self._controlnet_model.display_name != self.controlnet_settings_cached.controlnet
):
- session = self.db_handler.get_db_session()
- self._controlnet_model = session.query(ControlnetModel).filter_by(
+
+ self._controlnet_model = self.session.query(ControlnetModel).filter_by(
display_name=self.controlnet_settings_cached.controlnet,
version=self.generator_settings_cached.version
).first()
- session.close()
+
return self._controlnet_model
@property
@@ -866,8 +865,8 @@ def _load_scheduler(self, scheduler=None):
"scheduler_config.json"
)
)
- session = self.db_handler.get_db_session()
- scheduler = session.query(Schedulers).filter_by(display_name=scheduler_name).first()
+
+ scheduler = self.session.query(Schedulers).filter_by(display_name=scheduler_name).first()
if not scheduler:
self.logger.error(f"Failed to find scheduler {scheduler_name}")
return None
@@ -941,8 +940,8 @@ def _load_pipe(self):
self.logger.error(f"Failed to load model to device: {e}")
def _load_lora(self):
- session = self.db_handler.get_db_session()
- enabled_lora = session.query(Lora).filter_by(
+
+ enabled_lora = self.session.query(Lora).filter_by(
version=self.generator_settings_cached.version,
enabled=True
).all()
@@ -979,9 +978,9 @@ def _load_lora_weights(self, lora: Lora):
def _set_lora_adapters(self):
self.logger.debug("Setting LORA adapters")
- session = self.db_handler.get_db_session()
+
loaded_lora_id = [l.id for l in self._loaded_lora.values()]
- enabled_lora = session.query(Lora).filter(Lora.id.in_(loaded_lora_id)).all()
+ enabled_lora = self.session.query(Lora).filter(Lora.id.in_(loaded_lora_id)).all()
adapter_weights = []
adapter_names = []
for lora in enabled_lora:
@@ -1002,11 +1001,11 @@ def _load_embeddings(self):
self._pipe.unload_textual_inversion()
except RuntimeError as e:
self.logger.error(f"Failed to unload embeddings: {e}")
- session = self.db_handler.get_db_session()
- embeddings = session.query(Embedding).filter_by(
+
+ embeddings = self.session.query(Embedding).filter_by(
version=self.generator_settings_cached.version
).all()
- session.close()
+
for embedding in embeddings:
embedding_path = embedding.path
if embedding.active and embedding_path not in self._loaded_embeddings:
diff --git a/src/airunner/handlers/stt/whisper_handler.py b/src/airunner/handlers/stt/whisper_handler.py
index 07ccf58de..26f87f42c 100644
--- a/src/airunner/handlers/stt/whisper_handler.py
+++ b/src/airunner/handlers/stt/whisper_handler.py
@@ -71,13 +71,16 @@ def process_audio(self, audio_data):
if transcription:
self._send_transcription(transcription)
- def load(self):
+ def load(self, retry:bool = False):
if self.stt_is_loading or self.stt_is_loaded:
return
self.logger.debug("Loading Whisper (text-to-speech)")
self.unload()
self.change_model_status(ModelType.STT, ModelStatus.LOADING)
self._load_model()
+ # unsure why this is failing to load occasionally - this is a hack
+ if self._model is None and retry is False:
+ return self.load(retry=True)
self._load_processor()
self._load_feature_extractor()
if (
diff --git a/src/airunner/handlers/tts/speecht5_tts_handler.py b/src/airunner/handlers/tts/speecht5_tts_handler.py
index f0190f5ef..ec19e88cc 100644
--- a/src/airunner/handlers/tts/speecht5_tts_handler.py
+++ b/src/airunner/handlers/tts/speecht5_tts_handler.py
@@ -241,8 +241,6 @@ def _do_generate(self, message):
self.logger.debug("Processing inputs...")
- print(text)
-
inputs = self._processor(
text=text,
return_tensors="pt",
@@ -262,7 +260,7 @@ def _do_generate(self, message):
vocoder=self._vocoder,
max_length=100
)
- except RuntimeError as e:
+ except Exception as e:
self.logger.error("Failed to generate speech")
self.logger.error(e)
self._cancel_generated_speech = False
diff --git a/src/airunner/settings.py b/src/airunner/settings.py
index 364262a63..5b50a3dcb 100644
--- a/src/airunner/settings.py
+++ b/src/airunner/settings.py
@@ -31,7 +31,7 @@
)
ORGANIZATION = "Capsize Games"
APPLICATION_NAME = "AI Runner"
-LOG_LEVEL = logging.DEBUG
+LOG_LEVEL = logging.ERROR
DEFAULT_LLM_HF_PATH = "w4ffl35/Mistral-7B-Instruct-v0.3-4bit"
DEFAULT_STT_HF_PATH = "openai/whisper-tiny"
DEFAULT_IMAGE_SYSTEM_PROMPT = "\n".join([
diff --git a/src/airunner/utils/models/scan_path_for_items.py b/src/airunner/utils/models/scan_path_for_items.py
index 22a82f9fa..5c6913a9d 100644
--- a/src/airunner/utils/models/scan_path_for_items.py
+++ b/src/airunner/utils/models/scan_path_for_items.py
@@ -1,13 +1,14 @@
import os
-from airunner.data.models.settings_db_handler import SettingsDBHandler
from airunner.data.models.settings_models import Lora, Embedding
+from airunner.windows.main.settings_mixin import SettingsMixin
+
def scan_path_for_lora(base_path) -> bool:
lora_added = False
lora_deleted = False
- db_handler = SettingsDBHandler()
+ db_handler = SettingsMixin()
for versionpath, versionnames, versionfiles in os.walk(os.path.expanduser(os.path.join(base_path, "art/models"))):
version = versionpath.split("/")[-1]
lora_path = os.path.expanduser(
@@ -21,11 +22,10 @@ def scan_path_for_lora(base_path) -> bool:
if not os.path.exists(lora_path):
continue
- session = db_handler.get_db_session()
- existing_lora = session.query(Lora).all()
+ existing_lora = db_handler.session.query(Lora).all()
for lora in existing_lora:
if not os.path.exists(lora.path):
- session.delete(lora)
+ db_handler.session.delete(lora)
lora_deleted = True
for dirpath, dirnames, filenames in os.walk(lora_path):
for file in filenames:
@@ -43,17 +43,16 @@ def scan_path_for_lora(base_path) -> bool:
trigger_word="",
version=version
)
- session.add(item)
+ db_handler.session.add(item)
lora_added = True
if lora_deleted or lora_added:
- session.commit()
- session.close()
+ db_handler.session.commit()
return lora_deleted or lora_added
def scan_path_for_embeddings(base_path) -> bool:
embedding_added = False
embedding_deleted = False
- db_handler = SettingsDBHandler()
+ db_handler = SettingsMixin()
items = []
for versionpath, versionnames, versionfiles in os.walk(os.path.expanduser(os.path.join(base_path, "art/models"))):
version = versionpath.split("/")[-1]
@@ -67,11 +66,10 @@ def scan_path_for_embeddings(base_path) -> bool:
)
if not os.path.exists(embedding_path):
continue
- session = db_handler.get_db_session()
- existing_embeddings = session.query(Embedding).all()
+ existing_embeddings = db_handler.session.query(Embedding).all()
for embedding in existing_embeddings:
if not os.path.exists(embedding.path):
- session.delete(embedding)
+ db_handler.session.delete(embedding)
embedding_deleted = True
for dirpath, dirnames, filenames in os.walk(embedding_path):
for file in filenames:
@@ -88,9 +86,8 @@ def scan_path_for_embeddings(base_path) -> bool:
active=False,
trigger_word=""
)
- session.add(item)
+ db_handler.session.add(item)
embedding_added = True
if embedding_deleted or embedding_added:
- session.commit()
- session.close()
+ db_handler.session.commit()
return embedding_deleted or embedding_added
diff --git a/src/airunner/utils/network/huggingface_downloader.py b/src/airunner/utils/network/huggingface_downloader.py
index 3bbfb548b..18a65e7b5 100644
--- a/src/airunner/utils/network/huggingface_downloader.py
+++ b/src/airunner/utils/network/huggingface_downloader.py
@@ -1,6 +1,5 @@
from typing import Callable
from PySide6.QtCore import QObject, QThread, Signal
-from airunner.handlers.logger import Logger
from airunner.handlers.stablediffusion.download_worker import DownloadWorker
from airunner.enums import SignalCode
from airunner.mediator_mixin import MediatorMixin
@@ -20,7 +19,6 @@ def __init__(self, callback=None):
super(HuggingfaceDownloader, self).__init__()
self.thread = None
self.worker = None
- self.logger = Logger(prefix="HuggingfaceDownloader")
self.downloading = False
self.thread = QThread()
diff --git a/src/airunner/widgets/base_widget.py b/src/airunner/widgets/base_widget.py
index 99b83a675..32313b996 100644
--- a/src/airunner/widgets/base_widget.py
+++ b/src/airunner/widgets/base_widget.py
@@ -3,11 +3,9 @@
from PySide6 import QtGui
from PySide6.QtWidgets import QWidget
-from airunner.handlers.logger import Logger
from airunner.enums import CanvasToolName
from airunner.windows.main.settings_mixin import SettingsMixin
from airunner.mediator_mixin import MediatorMixin
-from airunner.settings import DARK_THEME_NAME, LIGHT_THEME_NAME
from airunner.utils.create_worker import create_worker
@@ -31,7 +29,6 @@ def is_dark(self):
return self.application_settings.dark_mode_enabled
def __init__(self, *args, **kwargs):
- self.logger = Logger(prefix=self.__class__.__name__)
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
super().__init__(*args, **kwargs)
diff --git a/src/airunner/widgets/canvas/brush_scene.py b/src/airunner/widgets/canvas/brush_scene.py
index e97109c5b..82b148a0f 100644
--- a/src/airunner/widgets/canvas/brush_scene.py
+++ b/src/airunner/widgets/canvas/brush_scene.py
@@ -198,8 +198,8 @@ def mousePressEvent(self, event):
return super().mousePressEvent(event)
def _handle_left_mouse_release(self, event) -> bool:
- session = self.db_handler.get_db_session()
- drawing_pad_settings = session.query(DrawingPadSettings).first()
+
+ drawing_pad_settings = self.session.query(DrawingPadSettings).first()
if self.drawing_pad_settings.mask_layer_enabled:
mask_image: Image = ImageQt.fromqimage(self.mask_image)
# Ensure mask is fully opaque
@@ -215,8 +215,8 @@ def _handle_left_mouse_release(self, event) -> bool:
self.current_tool is CanvasToolName.ERASER
)):
self.emit_signal(SignalCode.GENERATE_MASK)
- session.commit()
- session.close()
+ self.session.commit()
+
self.emit_signal(SignalCode.CANVAS_IMAGE_UPDATED_SIGNAL)
if self.drawing_pad_settings.mask_layer_enabled:
self.initialize_image()
diff --git a/src/airunner/widgets/canvas/custom_scene.py b/src/airunner/widgets/canvas/custom_scene.py
index f0cd95474..16b1443ac 100644
--- a/src/airunner/widgets/canvas/custom_scene.py
+++ b/src/airunner/widgets/canvas/custom_scene.py
@@ -10,7 +10,6 @@
from PySide6.QtGui import QPixmap, QPainter
from PySide6.QtWidgets import QGraphicsScene, QGraphicsPixmapItem, QFileDialog, QGraphicsSceneMouseEvent
-from airunner.handlers.logger import Logger
from airunner.enums import SignalCode, CanvasToolName, GeneratorSection, EngineResponseCode
from airunner.mediator_mixin import MediatorMixin
from airunner.settings import VALID_IMAGE_FILES
@@ -30,7 +29,6 @@ class CustomScene(
):
def __init__(self, canvas_type: str):
self.canvas_type = canvas_type
- self.logger = Logger(prefix=self.__class__.__name__)
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
self.image_backup = None
diff --git a/src/airunner/widgets/canvas/custom_view.py b/src/airunner/widgets/canvas/custom_view.py
index 46345708c..95d9521f9 100644
--- a/src/airunner/widgets/canvas/custom_view.py
+++ b/src/airunner/widgets/canvas/custom_view.py
@@ -9,7 +9,6 @@
from airunner.enums import CanvasToolName, SignalCode, CanvasType
from airunner.mediator_mixin import MediatorMixin
from airunner.utils.convert_image_to_base64 import convert_image_to_base64
-from airunner.utils.create_worker import create_worker
from airunner.utils.snap_to_grid import snap_to_grid
from airunner.widgets.canvas.brush_scene import BrushScene
from airunner.widgets.canvas.custom_scene import CustomScene
diff --git a/src/airunner/widgets/controlnet/controlnet_settings_widget.py b/src/airunner/widgets/controlnet/controlnet_settings_widget.py
index 9f7ac2fe2..027835751 100644
--- a/src/airunner/widgets/controlnet/controlnet_settings_widget.py
+++ b/src/airunner/widgets/controlnet/controlnet_settings_widget.py
@@ -23,11 +23,10 @@ def _load_controlnet_models(self):
if self._version is None or self._version != self.generator_settings.version:
self._version = self.generator_settings.version
current_index = 0
- session = self.db_handler.get_db_session()
- controlnet_models = session.query(ControlnetModel).filter_by(
+
+ controlnet_models = self.session.query(ControlnetModel).filter_by(
version=self.generator_settings.version
).all()
- session.close()
self.ui.controlnet.blockSignals(True)
self.ui.controlnet.clear()
for index, item in enumerate(controlnet_models):
diff --git a/src/airunner/widgets/embeddings/embedding_widget.py b/src/airunner/widgets/embeddings/embedding_widget.py
index 8ba4a4ec9..ce660a707 100644
--- a/src/airunner/widgets/embeddings/embedding_widget.py
+++ b/src/airunner/widgets/embeddings/embedding_widget.py
@@ -40,7 +40,7 @@ def action_clicked_button_deleted(self):
)
def update_embedding(self, embedding: Embedding):
- self.db_handler.save_object(embedding)
+ self.save_object(embedding)
@Slot(bool)
def action_toggled_embedding(self, val, emit_signal=True):
diff --git a/src/airunner/widgets/embeddings/embeddings_container_widget.py b/src/airunner/widgets/embeddings/embeddings_container_widget.py
index eca6789ce..38bce8795 100644
--- a/src/airunner/widgets/embeddings/embeddings_container_widget.py
+++ b/src/airunner/widgets/embeddings/embeddings_container_widget.py
@@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs):
self.register(SignalCode.EMBEDDING_UPDATED_SIGNAL, self.on_embedding_updated_signal)
self.register(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal)
self.register(SignalCode.EMBEDDING_STATUS_CHANGED, self.on_embedding_modified)
- self.register(SignalCode.EMBEDDING_DELETE_SIGNAL, self.delete_embedding)
+ self.register(SignalCode.EMBEDDING_DELETE_SIGNAL, self._delete_embedding)
self.ui.loading_icon.hide()
self.ui.loading_icon.set_size(spinner_size=QSize(30, 30), label_size=QSize(24, 24))
self._apply_button_enabled = False
@@ -92,7 +92,7 @@ def on_application_settings_changed_signal(self):
def on_embedding_updated_signal(self):
self._enable_form()
- def delete_embedding(self, data):
+ def _delete_embedding(self, data):
self._deleting = True
embedding_widget = data["embedding_widget"]
@@ -113,10 +113,10 @@ def delete_embedding(self, data):
break
# Remove lora from database
- session = self.db_handler.get_db_session()
- session.delete(embedding_widget.embedding)
- session.commit()
- session.close()
+
+ self.session.delete(embedding_widget.embedding)
+ self.session.commit()
+
self._apply_button_enabled = True
self.ui.apply_embeddings_button.setEnabled(self._apply_button_enabled)
@@ -143,7 +143,7 @@ def load_embeddings(self, force_reload:bool=False):
if self.search_filter.lower() in embedding.name.lower()
]
for embedding in filtered_embeddings:
- self.add_embedding(embedding)
+ self._add_embedding(embedding)
self.add_spacer()
def remove_spacer(self):
@@ -160,7 +160,7 @@ def add_spacer(self):
self.remove_spacer()
self.ui.scrollAreaWidgetContents.layout().addWidget(self.spacer)
- def add_embedding(self, embedding):
+ def _add_embedding(self, embedding):
if embedding is None:
return
embedding_widget = EmbeddingWidget(embedding=embedding)
diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py
index 0c78df71f..1357a2d0a 100644
--- a/src/airunner/widgets/generator_form/generator_form_widget.py
+++ b/src/airunner/widgets/generator_form/generator_form_widget.py
@@ -483,13 +483,13 @@ def stop_progress_bar(self, do_clear=False):
progressbar.setFormat("Complete")
def _set_keyboard_shortcuts(self):
- session = self.db_handler.get_db_session()
- generate_image_key = session.query(ShortcutKeys).filter_by(display_name="Generate Image").first()
- interrupt_key = session.query(ShortcutKeys).filter_by(display_name="Interrupt").first()
+
+ generate_image_key = self.session.query(ShortcutKeys).filter_by(display_name="Generate Image").first()
+ interrupt_key = self.session.query(ShortcutKeys).filter_by(display_name="Interrupt").first()
if generate_image_key:
self.ui.generate_button.setShortcut(generate_image_key.key)
self.ui.generate_button.setToolTip(f"{generate_image_key.display_name} ({generate_image_key.text})")
if interrupt_key:
self.ui.interrupt_button.setShortcut(interrupt_key.key)
self.ui.interrupt_button.setToolTip(f"{interrupt_key.display_name} ({interrupt_key.text})")
- session.close()
+
diff --git a/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py b/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py
index e3d25ac03..4c65c945c 100644
--- a/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py
+++ b/src/airunner/widgets/keyboard_shortcuts/keyboard_shortcuts_widget.py
@@ -69,12 +69,12 @@ def get_shortcut(self, shortcut_key: ShortcutKeys, line_edit, event, index):
shortcut_key.key = event.key()
shortcut_key.modifiers = event.modifiers().value
- session = self.db_handler.get_db_session()
- session.add(shortcut_key)
+
+ self.session.add(shortcut_key)
# clear existing key if it exists
- existing_keys = session.query(ShortcutKeys).filter(
+ existing_keys = self.session.query(ShortcutKeys).filter(
ShortcutKeys.text == shortcut_key.text,
ShortcutKeys.id != shortcut_key.id
).all()
@@ -82,7 +82,7 @@ def get_shortcut(self, shortcut_key: ShortcutKeys, line_edit, event, index):
existing_key.text = ""
existing_key.key = 0
existing_key.modifiers = 0
- session.add(existing_key)
+ self.session.add(existing_key)
for i, widget in enumerate(self.shortcut_key_widgets):
if i == index:
@@ -91,8 +91,8 @@ def get_shortcut(self, shortcut_key: ShortcutKeys, line_edit, event, index):
widget.line_edit.setText("")
line_edit.setText(shortcut_key.text)
- session.commit()
- session.close()
+ self.session.commit()
+
self.pressed_keys.clear()
self.emit_signal(SignalCode.KEYBOARD_SHORTCUTS_UPDATED)
@@ -102,18 +102,18 @@ def get_key_text(self, event):
return key_sequence.toString(QtGui.QKeySequence.SequenceFormat.NativeText)
def save_shortcuts(self):
- session = self.db_handler.get_db_session()
+
for k, v in enumerate(self.shortcut_keys):
# Ensure v.modifiers is a list
if not isinstance(v.modifiers, list):
v.modifiers = []
- session.query(ShortcutKeys).filter(ShortcutKeys.id == v.id).update({
+ self.session.query(ShortcutKeys).filter(ShortcutKeys.id == v.id).update({
"text": v.text,
"key": v.key,
"modifiers": ",".join(v.modifiers) # Convert list to comma-separated string
})
- session.commit()
- session.close()
+ self.session.commit()
+
def clear_shortcut_setting(self, key=""):
for index, v in enumerate(self.shortcut_keys):
diff --git a/src/airunner/widgets/llm/bot_preferences.py b/src/airunner/widgets/llm/bot_preferences.py
index 523616332..2742ee960 100644
--- a/src/airunner/widgets/llm/bot_preferences.py
+++ b/src/airunner/widgets/llm/bot_preferences.py
@@ -94,10 +94,10 @@ def create_new_chatbot_clicked(self):
self.load_saved_chatbots()
def saved_chatbots_changed(self, val):
- session = self.db_handler.get_db_session()
- chatbot = session.query(Chatbot).filter(Chatbot.name == val).first()
+
+ chatbot = self.session.query(Chatbot).filter(Chatbot.name == val).first()
chatbot_id = chatbot.id
- session.close()
+
self.update_llm_generator_settings("current_chatbot", chatbot_id)
self.load_form_elements()
self.emit_signal(SignalCode.CHATBOT_CHANGED)
@@ -176,10 +176,10 @@ def load_documents(self):
layout.addWidget(widget)
def delete_document(self, target_file:TargetFiles):
- session = self.db_handler.get_db_session()
- session.delete(target_file)
- session.commit()
- session.close()
+
+ self.session.delete(target_file)
+ self.session.commit()
+
self.load_documents()
self.emit_signal(SignalCode.RAG_RELOAD_INDEX_SIGNAL)
@@ -190,4 +190,4 @@ def update_chatbot(self, key, val):
except TypeError:
self.logger.error(f"Attribute {key} does not exist in Chatbot")
return
- self.db_handler.save_object(chatbot)
+ self.save_object(chatbot)
diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py
index 2503cd5f2..371551d5d 100644
--- a/src/airunner/widgets/llm/chat_prompt_widget.py
+++ b/src/airunner/widgets/llm/chat_prompt_widget.py
@@ -1,4 +1,4 @@
-from PySide6.QtCore import Slot, QTimer
+from PySide6.QtCore import Slot, QTimer, QPropertyAnimation
from PySide6.QtWidgets import QSpacerItem, QSizePolicy
from PySide6.QtCore import Qt
@@ -14,6 +14,8 @@ class ChatPromptWidget(BaseWidget):
def __init__(self, *args, **kwargs):
super().__init__()
+ self.scroll_bar = None
+ self._loading_widget = None
self.conversation = None
self.is_modal = True
self.generating = False
@@ -29,11 +31,13 @@ def __init__(self, *args, **kwargs):
self.messages_spacer = None
self.chat_loaded = False
self.conversation_id = None
+ self._has_loading_widget = False
self.ui.action.blockSignals(True)
self.ui.action.addItem("Auto")
self.ui.action.addItem("Chat")
self.ui.action.addItem("Image")
+ self.ui.action.addItem("RAG")
action = LLMActionType[self.action]
if action is LLMActionType.APPLICATION_COMMAND:
self.ui.action.setCurrentIndex(0)
@@ -41,6 +45,8 @@ def __init__(self, *args, **kwargs):
self.ui.action.setCurrentIndex(1)
elif action is LLMActionType.GENERATE_IMAGE:
self.ui.action.setCurrentIndex(2)
+ elif action is LLMActionType.PERFORM_RAG_SEARCH:
+ self.ui.action.setCurrentIndex(3)
self.ui.action.blockSignals(False)
self.originalKeyPressEvent = None
self.originalKeyPressEvent = self.ui.prompt.keyPressEvent
@@ -51,6 +57,7 @@ def __init__(self, *args, **kwargs):
self.register(SignalCode.CONVERSATION_DELETED, self.on_conversation_deleted)
self.held_message = None
self._disabled = False
+ self.scroll_animation = None
@Slot(str)
def handle_token_signal(self, val: str):
@@ -95,6 +102,7 @@ def _set_conversation_widgets(self, messages):
first_message=True,
use_loading_widget=False
)
+ self.scroll_to_bottom()
def on_hear_signal(self, data: dict):
transcription = data["transcription"]
@@ -122,7 +130,8 @@ def on_add_bot_message_to_conversation(self, data: dict):
name=name,
message=message,
is_bot=True,
- first_message=is_first_message
+ first_message=is_first_message,
+ action=data["action"]
)
if is_end_of_message:
@@ -143,9 +152,10 @@ def _clear_conversation(self):
self.conversation_history = []
self._clear_conversation_widgets()
self._create_conversation()
+ self.remove_loading_widget()
def _create_conversation(self):
- conversation_id = self.db_handler.create_conversation()
+ conversation_id = self.create_conversation()
self.emit_signal(SignalCode.LLM_CLEAR_HISTORY_SIGNAL, {
"conversation_id": conversation_id
})
@@ -163,6 +173,7 @@ def interrupt_button_clicked(self):
self.emit_signal(SignalCode.INTERRUPT_PROCESS_SIGNAL)
self.stop_progress_bar()
self.generating = False
+ self.remove_loading_widget()
self.enable_send_button()
@property
@@ -203,6 +214,7 @@ def do_generate(self, image_override=None, prompt_override=None, callback=None,
}
}
)
+ self.scroll_to_bottom()
def on_token_signal(self, val):
self.handle_token_signal(val)
@@ -243,6 +255,8 @@ def llm_action_changed(self, val: str):
llm_action_value = LLMActionType.CHAT
elif val == "Image":
llm_action_value = LLMActionType.GENERATE_IMAGE
+ elif val == "RAG":
+ llm_action_value = LLMActionType.PERFORM_RAG_SEARCH
else:
llm_action_value = LLMActionType.APPLICATION_COMMAND
self.update_llm_generator_settings("action", llm_action_value.name)
@@ -302,18 +316,21 @@ def describe_image(self, image, callback):
)
def add_loading_widget(self):
- self.ui.scrollAreaWidgetContents.layout().addWidget(
- LoadingWidget()
- )
+ if not self._has_loading_widget:
+ self._has_loading_widget = True
+ if self._loading_widget is None:
+ self._loading_widget = LoadingWidget()
+ self.ui.scrollAreaWidgetContents.layout().addWidget(
+ self._loading_widget
+ )
def remove_loading_widget(self):
- # remove the last LoadingWidget from scrollAreaWidgetContents.layout()
- for i in range(self.ui.scrollAreaWidgetContents.layout().count()):
- current_widget = self.ui.scrollAreaWidgetContents.layout().itemAt(i).widget()
- if isinstance(current_widget, LoadingWidget):
- self.ui.scrollAreaWidgetContents.layout().removeWidget(current_widget)
- current_widget.deleteLater()
- break
+ if self._has_loading_widget:
+ try:
+ self.ui.scrollAreaWidgetContents.layout().removeWidget(self._loading_widget)
+ except RuntimeError:
+ pass
+ self._has_loading_widget = False
def add_message_to_conversation(
self,
@@ -321,7 +338,8 @@ def add_message_to_conversation(
message,
is_bot,
first_message=True,
- use_loading_widget=True
+ use_loading_widget=True,
+ action:LLMActionType=LLMActionType.CHAT
):
if not first_message:
# get the last widget from the scrollAreaWidgetContents.layout()
@@ -370,6 +388,17 @@ def action_button_clicked_generate_characters(self):
pass
def scroll_to_bottom(self):
- self.ui.chat_container.verticalScrollBar().setValue(
- self.ui.chat_container.verticalScrollBar().maximum()
- )
+ if self.scroll_bar is None:
+ self.scroll_bar = self.ui.chat_container.verticalScrollBar()
+
+ if self.scroll_animation is None:
+ self.scroll_animation = QPropertyAnimation(self.scroll_bar, b"value")
+ self.scroll_animation.setDuration(500)
+
+ # Stop any ongoing animation
+ if self.scroll_animation and self.scroll_animation.state() == QPropertyAnimation.State.Running:
+ self.scroll_animation.stop()
+
+ self.scroll_animation.setStartValue(self.scroll_bar.value())
+ self.scroll_animation.setEndValue(self.scroll_bar.maximum())
+ self.scroll_animation.start()
diff --git a/src/airunner/widgets/llm/llm_history_widget.py b/src/airunner/widgets/llm/llm_history_widget.py
index 1d4fece2b..e814b72c1 100644
--- a/src/airunner/widgets/llm/llm_history_widget.py
+++ b/src/airunner/widgets/llm/llm_history_widget.py
@@ -19,7 +19,7 @@ def showEvent(self, event):
self.load_conversations()
def load_conversations(self):
- conversations = self.db_handler.get_all_conversations()
+ conversations = self.get_all_conversations()
layout = self.ui.gridLayout_2
if layout is None:
@@ -41,17 +41,16 @@ def load_conversations(self):
layout = QVBoxLayout(self.ui.scrollAreaWidgetContents)
self.ui.scrollAreaWidgetContents.setLayout(layout)
- session = self.db_handler.get_db_session()
for conversation in conversations:
h_layout = QHBoxLayout()
button = QPushButton(conversation.title)
button.clicked.connect(lambda _, c=conversation: self.on_conversation_click(c))
# Extract chatbot_id from the first message of the conversation
- first_message = session.query(Message).filter_by(conversation_id=conversation.id).first()
+ first_message = self.session.query(Message).filter_by(conversation_id=conversation.id).first()
chatbot_name = "Unknown"
if first_message and first_message.chatbot_id:
- chatbot = self.db_handler.get_chatbot_by_id(first_message.chatbot_id)
+ chatbot = self.get_chatbot_by_id(first_message.chatbot_id)
if chatbot:
chatbot_name = chatbot.name
@@ -65,27 +64,25 @@ def load_conversations(self):
container_widget = QWidget()
container_widget.setLayout(h_layout)
layout.addWidget(container_widget)
- session.close()
# Add a vertical spacer at the end
layout.addItem(self.spacer)
- self.ui.conversations_scroll_area.setLayout(layout)
+ self.ui.scrollAreaWidgetContents.setLayout(layout)
def on_conversation_click(self, conversation):
- session = self.db_handler.get_db_session()
- first_message = session.query(Message).filter_by(conversation_id=conversation.id).first()
+ first_message = self.session.query(Message).filter_by(conversation_id=conversation.id).first()
chatbot_id = first_message.chatbot_id
- session.query(LLMGeneratorSettings).update({"current_chatbot": chatbot_id})
- session.commit()
- session.close()
+ self.session.query(LLMGeneratorSettings).update({"current_chatbot": chatbot_id})
+ self.session.commit()
self.emit_signal(SignalCode.LOAD_CONVERSATION, {
- "conversation_id": conversation.id
+ "conversation_id": conversation.id,
+ "chatbot_id": chatbot_id
})
def on_delete_conversation(self, layout, conversation):
conversation_id = conversation.id
- self.db_handler.delete_conversation(conversation_id)
+ self.delete_conversation(conversation_id)
for i in reversed(range(layout.count())):
widget = layout.itemAt(i).widget()
if widget:
diff --git a/src/airunner/widgets/llm/llm_settings_widget.py b/src/airunner/widgets/llm/llm_settings_widget.py
index f437eeb1b..7949ab7a3 100644
--- a/src/airunner/widgets/llm/llm_settings_widget.py
+++ b/src/airunner/widgets/llm/llm_settings_widget.py
@@ -184,4 +184,4 @@ def update_chatbot(self, key, val):
except TypeError:
self.logger.error(f"Attribute {key} does not exist in Chatbot")
return
- self.db_handler.save_object(chatbot)
+ self.save_object(chatbot)
diff --git a/src/airunner/widgets/llm/templates/message.ui b/src/airunner/widgets/llm/templates/message.ui
index e84caf44f..57deb697f 100644
--- a/src/airunner/widgets/llm/templates/message.ui
+++ b/src/airunner/widgets/llm/templates/message.ui
@@ -6,8 +6,8 @@
0
0
- 418
- 902
+ 464
+ 433
@@ -26,15 +26,36 @@
Form
+
+ 0
+
+
+ 0
+
+
+ 0
+
+
+ 0
+
+
+ 0
+
-
+
+ 10
+
+
+ 10
+
-
TextLabel
- Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft
+ Qt::AlignmentFlag::AlignBottom|Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft
@@ -56,16 +77,16 @@
border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff;
- QFrame::NoFrame
+ QFrame::Shape::NoFrame
- QFrame::Plain
+ QFrame::Shadow::Plain
- Qt::ScrollBarAlwaysOff
+ Qt::ScrollBarPolicy::ScrollBarAlwaysOff
- Qt::ScrollBarAlwaysOff
+ Qt::ScrollBarPolicy::ScrollBarAlwaysOff
@@ -75,7 +96,7 @@
TextLabel
- Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft
+ Qt::AlignmentFlag::AlignBottom|Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft
diff --git a/src/airunner/widgets/llm/templates/message_ui.py b/src/airunner/widgets/llm/templates/message_ui.py
index 9012162b2..25657dc5b 100644
--- a/src/airunner/widgets/llm/templates/message_ui.py
+++ b/src/airunner/widgets/llm/templates/message_ui.py
@@ -22,7 +22,7 @@ class Ui_message(object):
def setupUi(self, message):
if not message.objectName():
message.setObjectName(u"message")
- message.resize(418, 902)
+ message.resize(464, 433)
sizePolicy = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.MinimumExpanding)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
@@ -30,12 +30,16 @@ def setupUi(self, message):
message.setSizePolicy(sizePolicy)
message.setMinimumSize(QSize(0, 40))
self.gridLayout = QGridLayout(message)
+ self.gridLayout.setSpacing(0)
self.gridLayout.setObjectName(u"gridLayout")
+ self.gridLayout.setContentsMargins(0, 0, 0, 0)
self.horizontalLayout = QHBoxLayout()
+ self.horizontalLayout.setSpacing(10)
self.horizontalLayout.setObjectName(u"horizontalLayout")
+ self.horizontalLayout.setContentsMargins(-1, -1, -1, 10)
self.user_name = QLabel(message)
self.user_name.setObjectName(u"user_name")
- self.user_name.setAlignment(Qt.AlignBottom|Qt.AlignLeading|Qt.AlignLeft)
+ self.user_name.setAlignment(Qt.AlignmentFlag.AlignBottom|Qt.AlignmentFlag.AlignLeading|Qt.AlignmentFlag.AlignLeft)
self.horizontalLayout.addWidget(self.user_name)
@@ -48,16 +52,16 @@ def setupUi(self, message):
self.content.setSizePolicy(sizePolicy1)
self.content.setMinimumSize(QSize(0, 40))
self.content.setStyleSheet(u"border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff;")
- self.content.setFrameShape(QFrame.NoFrame)
- self.content.setFrameShadow(QFrame.Plain)
- self.content.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
- self.content.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+ self.content.setFrameShape(QFrame.Shape.NoFrame)
+ self.content.setFrameShadow(QFrame.Shadow.Plain)
+ self.content.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
+ self.content.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.horizontalLayout.addWidget(self.content)
self.bot_name = QLabel(message)
self.bot_name.setObjectName(u"bot_name")
- self.bot_name.setAlignment(Qt.AlignBottom|Qt.AlignLeading|Qt.AlignLeft)
+ self.bot_name.setAlignment(Qt.AlignmentFlag.AlignBottom|Qt.AlignmentFlag.AlignLeading|Qt.AlignmentFlag.AlignLeft)
self.horizontalLayout.addWidget(self.bot_name)
diff --git a/src/airunner/widgets/lora/lora_container_widget.py b/src/airunner/widgets/lora/lora_container_widget.py
index 2e6625880..7cc2fc328 100644
--- a/src/airunner/widgets/lora/lora_container_widget.py
+++ b/src/airunner/widgets/lora/lora_container_widget.py
@@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs):
self.register(SignalCode.LORA_UPDATED_SIGNAL, self.on_lora_updated_signal)
self.register(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal)
self.register(SignalCode.LORA_STATUS_CHANGED, self.on_lora_modified)
- self.register(SignalCode.LORA_DELETE_SIGNAL, self.delete_lora)
+ self.register(SignalCode.LORA_DELETE_SIGNAL, self._delete_lora)
self.ui.loading_icon.hide()
self.ui.loading_icon.set_size(spinner_size=QSize(30, 30), label_size=QSize(24, 24))
self._apply_button_enabled = False
@@ -44,13 +44,13 @@ def _scan_path_for_lora(self, path) -> bool:
@Slot(bool)
def on_scan_completed(self, force_reload: bool):
- self.load_lora(force_reload=force_reload)
+ self._load_lora(force_reload=force_reload)
@Slot()
def scan_for_lora(self):
# clear all lora widgets
force_reload = scan_path_for_lora(self.path_settings.base_path)
- self.load_lora(force_reload=force_reload)
+ self._load_lora(force_reload=force_reload)
@Slot()
def apply_lora(self):
@@ -94,7 +94,7 @@ def _toggle_lora_widgets(self, enable: bool):
lora_widget.disable_lora_widget()
def on_application_settings_changed_signal(self):
- self.load_lora()
+ self._load_lora()
def on_lora_updated_signal(self):
self._enable_form()
@@ -114,9 +114,9 @@ def showEvent(self, event):
if not self.initialized:
self.scan_for_lora()
self.initialized = True
- self.load_lora()
+ self._load_lora()
- def load_lora(self, force_reload=False):
+ def _load_lora(self, force_reload=False):
version = self.generator_settings.version
if self._version is None or self._version != version or force_reload:
@@ -130,7 +130,7 @@ def load_lora(self, force_reload=False):
if self.search_filter.lower() in lora.name.lower()
]
for lora in filtered_loras:
- self.add_lora(lora)
+ self._add_lora(lora)
self.add_spacer()
def remove_spacer(self):
@@ -147,13 +147,13 @@ def add_spacer(self):
self.remove_spacer()
self.ui.scrollAreaWidgetContents.layout().addWidget(self.spacer)
- def add_lora(self, lora):
+ def _add_lora(self, lora):
if lora is None:
return
lora_widget = LoraWidget(lora=lora)
self.ui.scrollAreaWidgetContents.layout().addWidget(lora_widget)
- def delete_lora(self, data: dict):
+ def _delete_lora(self, data: dict):
self._deleting = True
lora_widget = data["lora_widget"]
@@ -174,14 +174,14 @@ def delete_lora(self, data: dict):
break
# Remove lora from database
- session = self.db_handler.get_db_session()
- session.delete(lora_widget.current_lora)
- session.commit()
- session.close()
+
+ self.session.delete(lora_widget.current_lora)
+ self.session.commit()
+
self._apply_button_enabled = True
self.ui.apply_lora_button.setEnabled(self._apply_button_enabled)
- self.load_lora(force_reload=True)
+ self._load_lora(force_reload=True)
self._deleting = False
def available_lora(self, action):
@@ -235,7 +235,7 @@ def handle_lora_spinbox(self, lora, lora_widget, value, tab_name):
def search_text_changed(self, val):
self.search_filter = val
- self.load_lora(force_reload=True)
+ self._load_lora(force_reload=True)
def clear_lora_widgets(self):
if self.spacer:
diff --git a/src/airunner/widgets/model_manager/import_widget.py b/src/airunner/widgets/model_manager/import_widget.py
index 3fc52dd41..5a1e110c0 100644
--- a/src/airunner/widgets/model_manager/import_widget.py
+++ b/src/airunner/widgets/model_manager/import_widget.py
@@ -123,7 +123,7 @@ def download_model(self):
self.create_lora(new_lora)
elif model_type == "TextualInversion":
# name = file_path.split("/")[-1].split(".")[0]
- # embedding_exists = session.query(Embedding).filter_by(
+ # embedding_exists = self.session.query(Embedding).filter_by(
# name=name,
# path=file_path,
# ).first()
@@ -134,7 +134,7 @@ def download_model(self):
# active=True,
# tags=trained_words,
# )
- # session.add(new_embedding)
+ # self.session.add(new_embedding)
# TODO: handle textual inversion
pass
elif model_type == "VAE":
diff --git a/src/airunner/widgets/slider/slider_widget.py b/src/airunner/widgets/slider/slider_widget.py
index c6e207da5..166ab927b 100644
--- a/src/airunner/widgets/slider/slider_widget.py
+++ b/src/airunner/widgets/slider/slider_widget.py
@@ -139,11 +139,11 @@ def init(self, **kwargs):
divide_by = self.property("divide_by") or 1.0
if self.table_id is not None and self.table_name is not None and self.table_column is not None:
- session = self.db_handler.get_db_session()
+
if self.table_name == "lora":
- self.table_item = session.query(Lora).filter_by(id=self.table_id).first()
+ self.table_item = self.session.query(Lora).filter_by(id=self.table_id).first()
current_value = getattr(self.table_item, self.table_column)
- session.close()
+
elif current_value is None:
if settings_property is not None:
current_value = self.get_settings_value(settings_property)
@@ -215,11 +215,11 @@ def get_settings_value(self, settings_property):
def set_settings_value(self, settings_property: str, val: Any):
if self.table_item is not None:
- session = self.db_handler.get_db_session()
+
setattr(self.table_item, self.table_column, val)
- session.add(self.table_item)
- session.commit()
- session.close()
+ self.session.add(self.table_item)
+ self.session.commit()
+
elif settings_property is not None:
keys = settings_property.split(".")
self.update_settings_by_name(keys[0], keys[1], val)
diff --git a/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py b/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py
index 300d19f8c..60ea48a50 100644
--- a/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py
+++ b/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py
@@ -68,15 +68,15 @@ def handle_pipeline_changed(self, val):
val = GeneratorSection.TXT2IMG.value
elif val == f"{GeneratorSection.INPAINT.value} / {GeneratorSection.OUTPAINT.value}":
val = GeneratorSection.INPAINT.value
- session = self.db_handler.get_db_session()
- generator_settings = session.query(GeneratorSettings).first()
+
+ generator_settings = self.session.query(GeneratorSettings).first()
do_reload = False
if val == GeneratorSection.TXT2IMG.value:
- model = session.query(AIModels).filter(
+ model = self.session.query(AIModels).filter(
AIModels.id == generator_settings.model
).first()
if model.pipeline_action == GeneratorSection.INPAINT.value:
- model = session.query(AIModels).filter(
+ model = self.session.query(AIModels).filter(
AIModels.version == generator_settings.version,
AIModels.pipeline_action == val,
AIModels.enabled == True,
@@ -88,8 +88,8 @@ def handle_pipeline_changed(self, val):
generator_settings.model = None
do_reload = True
generator_settings.pipeline_action = val
- session.commit()
- session.close()
+ self.session.commit()
+
self.load_versions()
self.load_models()
if do_reload:
@@ -100,9 +100,9 @@ def handle_pipeline_changed(self, val):
def handle_version_changed(self, val):
self.update_generator_settings("version", val)
- session = self.db_handler.get_db_session()
- generator_settings = session.query(GeneratorSettings).first()
- model = session.query(AIModels).filter(
+
+ generator_settings = self.session.query(GeneratorSettings).first()
+ model = self.session.query(AIModels).filter(
AIModels.version == val,
AIModels.pipeline_action == generator_settings.pipeline_action,
AIModels.enabled == True,
@@ -110,15 +110,15 @@ def handle_version_changed(self, val):
).first()
generator_settings.version = val
generator_settings.model = model.id
- session.commit()
- session.close()
+ self.session.commit()
+
self.load_models()
if self.application_settings.sd_enabled:
self.emit_signal(SignalCode.SD_LOAD_SIGNAL, {
"do_reload": True
})
- def load_pipelines(self):
+ def _load_pipelines(self):
self.ui.pipeline.blockSignals(True)
self.ui.pipeline.clear()
pipeline_names = [
@@ -148,7 +148,7 @@ def load_versions(self):
def on_models_changed_signal(self):
try:
- self.load_pipelines()
+ self._load_pipelines()
self.load_versions()
self.load_models()
self.load_schedulers()
@@ -162,14 +162,14 @@ def load_models(self):
self.ui.model.blockSignals(True)
self.clear_models()
image_generator = ImageGenerator.STABLEDIFFUSION.value
- session = self.db_handler.get_db_session()
- generator_settings = session.query(GeneratorSettings).first()
+
+ generator_settings = self.session.query(GeneratorSettings).first()
pipeline = generator_settings.pipeline_action
version = generator_settings.version
pipeline_actions = [GeneratorSection.TXT2IMG.value]
if pipeline == GeneratorSection.INPAINT.value:
pipeline_actions.append(GeneratorSection.INPAINT.value)
- models = session.query(AIModels).filter(
+ models = self.session.query(AIModels).filter(
AIModels.category == image_generator,
AIModels.pipeline_action.in_(pipeline_actions),
AIModels.version == version,
@@ -180,10 +180,10 @@ def load_models(self):
if model_id is None and len(models) > 0:
current_model = models[0]
generator_settings.model = current_model.id
- session.commit()
+ self.session.commit()
for model in models:
self.ui.model.addItem(model.name, model.id)
- session.close()
+
if model_id:
index = self.ui.model.findData(model_id)
if index != -1:
diff --git a/src/airunner/windows/filter_window.py b/src/airunner/windows/filter_window.py
index cca995734..ec0bce6c9 100644
--- a/src/airunner/windows/filter_window.py
+++ b/src/airunner/windows/filter_window.py
@@ -24,7 +24,6 @@ def __init__(self, image_filter_id):
"""
super().__init__(exec=False)
- self.session = self.db_handler.get_db_session()
self.image_filter = self.session.query(ImageFilter).options(joinedload(ImageFilter.image_filter_values)).get(
image_filter_id)
self.image_filter_model_name = self.image_filter.name
diff --git a/src/airunner/windows/installer/installer_window.py b/src/airunner/windows/installer/installer_window.py
deleted file mode 100644
index 910f508af..000000000
--- a/src/airunner/windows/installer/installer_window.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from PySide6.QtWidgets import QWizard
-from airunner.mediator_mixin import MediatorMixin
-from airunner.windows.installer.completion_page import CompletionPage
-from airunner.windows.installer.confirmation_page import ConfirmationPage
-from airunner.windows.installer.download_page import DownloadPage
-from airunner.windows.main.settings_mixin import SettingsMixin
-
-
-class InstallerWindow(
- QWizard,
- MediatorMixin,
- SettingsMixin
-):
- def __init__(self, *args):
- MediatorMixin.__init__(self)
- SettingsMixin.__init__(self)
- super(InstallerWindow, self).__init__(*args)
-
- self.page_ids = {}
-
- self.do_download_sd_models = True
- self.do_download_controlnet_models = True
- self.do_download_llm = True
- self.do_download_tts_models = True
- self.do_download_stt_models = True
-
- self.pages = {
- "confirmation_page": ConfirmationPage(self),
- "download_page": DownloadPage(self),
- "completion_page": CompletionPage(self),
- }
-
- for page_name, page in self.pages.items():
- self.addPage(page)
-
- self.setWindowTitle("AI Runner Setup Wizard")
-
- @property
- def download_settings(self):
- return {
- "compile_with_pyinstaller": False,
- "download_ai_runner": True,
- "download_sd": self.do_download_sd_models,
- "download_controlnet": self.do_download_controlnet_models,
- "download_llm": self.do_download_llm,
- "download_tts": self.do_download_tts_models,
- "download_stt": self.do_download_stt_models,
- }
-
- def addPage(self, page):
- page_id = super().addPage(page)
- self.page_ids[page] = page_id
- return page_id
\ No newline at end of file
diff --git a/src/airunner/windows/main/ai_model_mixin.py b/src/airunner/windows/main/ai_model_mixin.py
index 0e93baaeb..0d4ac1aeb 100644
--- a/src/airunner/windows/main/ai_model_mixin.py
+++ b/src/airunner/windows/main/ai_model_mixin.py
@@ -4,9 +4,6 @@
class AIModelMixin:
- def ai_model_get_by_filter(self, filter_dict):
- return [item for item in self.ai_models if all(item.get(k) == v for k, v in filter_dict.items())]
-
def on_ai_model_delete_signal(self, item: dict):
self.ai_models = [existing_item for existing_item in self.ai_models if existing_item.name != item.name]
self.update_settings("ai_models", self.ai_models)
diff --git a/src/airunner/windows/main/embedding_mixin.py b/src/airunner/windows/main/embedding_mixin.py
index 67eec3b11..f10b389eb 100644
--- a/src/airunner/windows/main/embedding_mixin.py
+++ b/src/airunner/windows/main/embedding_mixin.py
@@ -30,12 +30,12 @@ def delete_missing_embeddings(self):
embeddings = self.get_embeddings()
for embedding in embeddings:
if not os.path.exists(embedding["path"]):
- self.delete_embedding(embedding)
+ self._delete_embedding(embedding)
- def delete_embedding(self, embedding):
+ def _delete_embedding(self, embedding):
for index, _embedding in enumerate(self.embeddings):
if _embedding.name == embedding.name and _embedding.path == embedding.path:
- self.delete_embedding(embedding)
+ self._delete_embedding(embedding)
return
def scan_for_embeddings(self):
diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py
index 33a7f3bbc..200b94d47 100644
--- a/src/airunner/windows/main/main_window.py
+++ b/src/airunner/windows/main/main_window.py
@@ -4,14 +4,14 @@
import urllib
import webbrowser
from functools import partial
-from pathlib import Path
import requests
from PIL import Image
from PySide6 import QtGui
from PySide6.QtCore import (
Slot,
- Signal, QProcess, QSettings
+ Signal,
+ QProcess
)
from PySide6.QtGui import QGuiApplication, QKeySequence
from PySide6.QtWidgets import (
@@ -24,7 +24,6 @@
from airunner.handlers.llm.agent.actions.bash_execute import bash_execute
from airunner.handlers.llm.agent.actions.show_path import show_path
-from airunner.handlers.logger import Logger
from airunner.data.models.settings_models import ShortcutKeys, ImageFilter, DrawingPadSettings
from airunner.app_installer import AppInstaller
from airunner.settings import (
@@ -32,8 +31,6 @@
STATUS_NORMAL_COLOR_LIGHT,
STATUS_NORMAL_COLOR_DARK,
NSFW_CONTENT_DETECTED_MESSAGE,
- ORGANIZATION,
- APPLICATION_NAME, DARK_THEME_NAME, LIGHT_THEME_NAME
)
from airunner.enums import (
SignalCode,
@@ -151,10 +148,9 @@ def __init__(
self.listening = False
self.initialized = False
self._model_status = {model_type: ModelStatus.UNLOADED for model_type in ModelType}
- self.logger = Logger(prefix=self.__class__.__name__)
- self.logger.debug("Starting AI Runnner")
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
+ self.logger.debug("Starting AI Runnner")
super().__init__(*args, **kwargs)
self._updating_settings = True
PipelineMixin.__init__(self)
@@ -951,12 +947,12 @@ def on_history_updated(self, data):
self.ui.actionRedo.setEnabled(data["redo"] != 0)
def _set_keyboard_shortcuts(self):
- session = self.db_handler.get_db_session()
- quit_key = session.query(ShortcutKeys).filter_by(display_name="Quit").first()
- brush_key = session.query(ShortcutKeys).filter_by(display_name="Brush").first()
- eraser_key = session.query(ShortcutKeys).filter_by(display_name="Eraser").first()
- move_tool_key = session.query(ShortcutKeys).filter_by(display_name="Move Tool").first()
- select_tool_key = session.query(ShortcutKeys).filter_by(display_name="Select Tool").first()
+
+ quit_key = self.session.query(ShortcutKeys).filter_by(display_name="Quit").first()
+ brush_key = self.session.query(ShortcutKeys).filter_by(display_name="Brush").first()
+ eraser_key = self.session.query(ShortcutKeys).filter_by(display_name="Eraser").first()
+ move_tool_key = self.session.query(ShortcutKeys).filter_by(display_name="Move Tool").first()
+ select_tool_key = self.session.query(ShortcutKeys).filter_by(display_name="Select Tool").first()
if quit_key is not None:
key_sequence = QKeySequence(quit_key.key | quit_key.modifiers)
@@ -983,7 +979,7 @@ def _set_keyboard_shortcuts(self):
self.ui.actionToggle_Selection.setShortcut(key_sequence)
self.ui.actionToggle_Selection.setToolTip(f"{select_tool_key.display_name} ({select_tool_key.text})")
- session.close()
+
def _initialize_workers(self):
self.logger.debug("Initializing worker manager")
@@ -997,12 +993,12 @@ def _initialize_workers(self):
def _initialize_filter_actions(self):
# add more filters:
- session = self.db_handler.get_db_session()
- image_filters = session.query(ImageFilter).all()
+
+ image_filters = self.session.query(ImageFilter).all()
for image_filter in image_filters:
action = self.ui.menuFilters.addAction(image_filter.display_name)
action.triggered.connect(partial(self.display_filter_window, image_filter))
- session.close()
+
def display_filter_window(self, image_filter):
FilterWindow(image_filter.id)
@@ -1115,8 +1111,8 @@ def _generate_drawingpad_mask(self):
height = self.active_grid_settings.height
img = Image.new("RGB", (width, height), (0, 0, 0))
base64_image = convert_image_to_base64(img)
- session = self.db_handler.get_db_session()
- drawing_pad_settings = session.query(DrawingPadSettings).first()
+
+ drawing_pad_settings = self.session.query(DrawingPadSettings).first()
drawing_pad_settings.mask = base64_image
- session.commit()
- session.close()
+ self.session.commit()
+
diff --git a/src/airunner/windows/main/settings_mixin.py b/src/airunner/windows/main/settings_mixin.py
index 3fd715cc6..46b121119 100644
--- a/src/airunner/windows/main/settings_mixin.py
+++ b/src/airunner/windows/main/settings_mixin.py
@@ -1,39 +1,69 @@
import logging
-from typing import List
-
-from sqlalchemy.orm import joinedload
-
-from airunner.data.models.settings_db_handler import SettingsDBHandler
-from airunner.data.models.settings_models import ApplicationSettings, LLMGeneratorSettings, GeneratorSettings, \
- ControlnetSettings, BrushSettings, DrawingPadSettings, GridSettings, ActiveGridSettings, \
- ImageToImageSettings, OutpaintSettings, PathSettings, MemorySettings, Chatbot, \
- AIModels, Schedulers, Lora, ShortcutKeys, SavedPrompt, SpeechT5Settings, TTSSettings, EspeakSettings, \
- MetadataSettings, Embedding, STTSettings, PromptTemplate, ControlnetModel, FontSetting, PipelineModel, TargetFiles, \
- ImageFilterValue, WhisperSettings
+import datetime
+import os
+from typing import List, Type
+
+from sqlalchemy import create_engine
+from sqlalchemy.orm import joinedload, sessionmaker
+
+from airunner.data.models.settings_models import Chatbot, AIModels, Schedulers, Lora, PathSettings, SavedPrompt, \
+ Embedding, PromptTemplate, ControlnetModel, FontSetting, PipelineModel, ShortcutKeys, \
+ GeneratorSettings, WindowSettings, ApplicationSettings, ActiveGridSettings, ControlnetSettings, \
+ ImageToImageSettings, OutpaintSettings, DrawingPadSettings, MetadataSettings, \
+ LLMGeneratorSettings, TTSSettings, SpeechT5Settings, EspeakSettings, STTSettings, BrushSettings, GridSettings, \
+ MemorySettings, Message, Conversation, Summary, ImageFilterValue, TargetFiles, WhisperSettings, Base
from airunner.enums import SignalCode
+from airunner.handlers.logger import Logger
+from airunner.settings import LOG_LEVEL
from airunner.utils.convert_base64_to_image import convert_base64_to_image
class SettingsMixin:
def __init__(self):
logging.debug("Initializing SettingsMixin instance")
- self.db_handler = SettingsDBHandler()
+ self.db_path = os.path.expanduser(
+ os.path.join(
+ "~",
+ ".local",
+ "share",
+ "airunner",
+ "data",
+ "airunner.db"
+ )
+ )
+ self.engine = create_engine(f'sqlite:///{self.db_path}')
+ Base.metadata.create_all(self.engine)
+ self._session = None
+ self.Session = sessionmaker(bind=self.engine)
+ self.conversation_id = None
+ self.logger = Logger(prefix=self.__class__.__name__, log_level=LOG_LEVEL)
+
+ @property
+ def session(self):
+ if self._session is None:
+ self._session = self.Session()
+ return self._session
+
+ def close_session(self):
+ if self._session is not None:
+ self._session.close()
+ self._session = None
@property
def stt_settings(self) -> STTSettings:
- return self.db_handler.load_settings_from_db(STTSettings)
+ return self.load_settings_from_db(STTSettings)
@property
def application_settings(self) -> ApplicationSettings:
- return self.db_handler.load_settings_from_db(ApplicationSettings)
+ return self.load_settings_from_db(ApplicationSettings)
@property
def whisper_settings(self) -> WhisperSettings:
- return self.db_handler.load_settings_from_db(WhisperSettings)
+ return self.load_settings_from_db(WhisperSettings)
@property
def llm_generator_settings(self) -> LLMGeneratorSettings:
- settings = self.db_handler.load_settings_from_db(LLMGeneratorSettings)
+ settings = self.load_settings_from_db(LLMGeneratorSettings)
if settings.current_chatbot == 0:
chatbots = self.chatbots
if len(chatbots) > 0:
@@ -43,103 +73,103 @@ def llm_generator_settings(self) -> LLMGeneratorSettings:
@property
def generator_settings(self) -> GeneratorSettings:
- return self.db_handler.load_settings_from_db(GeneratorSettings)
+ return self.load_settings_from_db(GeneratorSettings)
@property
def controlnet_settings(self) -> ControlnetSettings:
- return self.db_handler.load_settings_from_db(ControlnetSettings)
+ return self.load_settings_from_db(ControlnetSettings)
@property
def image_to_image_settings(self) -> ImageToImageSettings:
- return self.db_handler.load_settings_from_db(ImageToImageSettings)
+ return self.load_settings_from_db(ImageToImageSettings)
@property
def outpaint_settings(self) -> OutpaintSettings:
- return self.db_handler.load_settings_from_db(OutpaintSettings)
+ return self.load_settings_from_db(OutpaintSettings)
@property
def drawing_pad_settings(self) -> DrawingPadSettings:
- return self.db_handler.load_settings_from_db(DrawingPadSettings)
+ return self.load_settings_from_db(DrawingPadSettings)
@property
def brush_settings(self) -> BrushSettings:
- return self.db_handler.load_settings_from_db(BrushSettings)
+ return self.load_settings_from_db(BrushSettings)
@property
def grid_settings(self) -> GridSettings:
- return self.db_handler.load_settings_from_db(GridSettings)
+ return self.load_settings_from_db(GridSettings)
@property
def active_grid_settings(self) -> ActiveGridSettings:
- return self.db_handler.load_settings_from_db(ActiveGridSettings)
+ return self.load_settings_from_db(ActiveGridSettings)
@property
def path_settings(self) -> PathSettings:
- return self.db_handler.load_settings_from_db(PathSettings)
+ return self.load_settings_from_db(PathSettings)
@property
def memory_settings(self) -> MemorySettings:
- return self.db_handler.load_settings_from_db(MemorySettings)
+ return self.load_settings_from_db(MemorySettings)
@property
- def chatbots(self) -> List[Chatbot]:
- return self.db_handler.load_chatbots()
+ def chatbots(self) -> List[Type[Chatbot]]:
+ return self.load_chatbots()
@property
- def ai_models(self) -> List[AIModels]:
- return self.db_handler.load_ai_models()
+ def ai_models(self) -> List[Type[AIModels]]:
+ return self.load_ai_models()
@property
- def schedulers(self) -> List[Schedulers]:
- return self.db_handler.load_schedulers()
+ def schedulers(self) -> List[Type[Schedulers]]:
+ return self.load_schedulers()
@property
- def lora(self) -> List[Lora]:
- return self.db_handler.load_lora()
+ def lora(self) -> List[Type[Lora]]:
+ return self.load_lora()
@property
- def shortcut_keys(self) -> List[ShortcutKeys]:
- return self.db_handler.load_shortcut_keys()
+ def shortcut_keys(self) -> List[Type[ShortcutKeys]]:
+ return self.load_shortcut_keys()
@property
def speech_t5_settings(self) -> SpeechT5Settings:
- return self.db_handler.load_settings_from_db(SpeechT5Settings)
+ return self.load_settings_from_db(SpeechT5Settings)
@property
def tts_settings(self) -> TTSSettings:
- return self.db_handler.load_settings_from_db(TTSSettings)
+ return self.load_settings_from_db(TTSSettings)
@property
def espeak_settings(self) -> EspeakSettings:
- return self.db_handler.load_settings_from_db(EspeakSettings)
+ return self.load_settings_from_db(EspeakSettings)
@property
def metadata_settings(self) -> MetadataSettings:
- return self.db_handler.load_settings_from_db(MetadataSettings)
+ return self.load_settings_from_db(MetadataSettings)
@property
- def embeddings(self) -> List[Embedding]:
- return self.db_handler.load_embeddings()
+ def embeddings(self) -> List[Type[Embedding]]:
+ return self.session.query(Embedding).all()
@property
- def prompt_templates(self) -> List[PromptTemplate]:
- return self.db_handler.load_prompt_templates()
+ def prompt_templates(self) -> List[Type[PromptTemplate]]:
+ return self.load_prompt_templates()
@property
def controlnet_models(self):
- return self.db_handler.load_controlnet_models()
+ return self.load_controlnet_models()
@property
- def saved_prompts(self) -> List[SavedPrompt]:
- return self.db_handler.load_saved_prompts()
+ def saved_prompts(self) -> List[Type[SavedPrompt]]:
+ return self.load_saved_prompts()
@property
- def font_settings(self) -> List[FontSetting]:
- return self.db_handler.load_font_settings()
+ def font_settings(self) -> List[Type[FontSetting]]:
+ return self.load_font_settings()
@property
- def pipelines(self) -> List[PipelineModel]:
- return self.db_handler.load_pipelines()
+ def pipelines(self) -> List[Type[PipelineModel]]:
+ return self.load_pipelines()
@property
def drawing_pad_image(self):
@@ -191,103 +221,31 @@ def outpaint_mask(self):
@property
def image_filter_values(self):
- session = self.db_handler.get_db_session()
- try:
- return session.query(ImageFilterValue).all()
- finally:
- session.close()
-
- #######################################
- ### LORA ###
- #######################################
+ return self.session.query(ImageFilterValue).all()
def get_lora_by_version(self, version):
- session = self.db_handler.get_db_session()
- try:
- return session.query(Lora).filter_by(version=version).all()
- finally:
- session.close()
+ return self.session.query(Lora).filter_by(version=version).all()
- def delete_lora_by_name(self, name, version):
- self.db_handler.delete_lora_by_name(name, version)
-
- def create_lora(self, lora: Lora):
- self.db_handler.create_lora(lora)
-
- def update_lora(self, lora: Lora):
- self.db_handler.update_lora(lora)
- self.__settings_updated()
-
- def update_loras(self, loras: List[Lora]):
- self.db_handler.update_loras(loras)
- self.__settings_updated()
-
- #######################################
- ### EMBEDDINGS ###
- #######################################
def get_embeddings_by_version(self, version):
return [embedding for embedding in self.embeddings if embedding.version == version]
- def delete_embedding(self, embedding: Embedding):
- self.db_handler.delete_embedding(embedding)
-
- def update_embeddings(self, embeddings: List[Embedding]):
- self.db_handler.update_embeddings(embeddings)
- self.__settings_updated()
-
- #######################################
- ### CHATBOT ###
- #######################################
@property
- def chatbot(self) -> Chatbot:
+ def chatbot(self) -> Type[Chatbot]:
return self.get_chatbot_by_id(
self.llm_generator_settings.current_chatbot
)
- def get_chatbot_by_id(self, chatbot_id) -> Chatbot:
- chatbot = None
- session = self.db_handler.get_db_session()
- try:
- chatbot = session.query(Chatbot).filter_by(id=chatbot_id).options(joinedload(Chatbot.target_files)).first()
- if chatbot is None:
- chatbot = session.query(Chatbot).options(joinedload(Chatbot.target_files)).first()
- finally:
- session.close()
- return chatbot
-
- def delete_chatbot_by_name(self, chatbot_name):
- self.db_handler.delete_chatbot_by_name(chatbot_name)
-
- def create_chatbot(self, chatbot_name):
- self.db_handler.create_chatbot(chatbot_name)
+ @property
+ def window_settings(self):
+ return self.load_window_settings()
def add_chatbot_document_to_chatbot(self, chatbot, file_path):
- session = self.db_handler.get_db_session()
- try:
- document = session.query(TargetFiles).filter_by(chatbot_id=chatbot.id, file_path=file_path).first()
- if document is None:
- document = TargetFiles(file_path=file_path, chatbot_id=chatbot.id)
- session.merge(document) # Use merge instead of add
- session.commit()
- finally:
- session.close()
-
- #######################################
- ### SAVED PROMPTS ###
- #######################################
- def create_saved_prompt(self, data: dict):
- self.db_handler.create_saved_prompt(data)
+ document = self.session.query(TargetFiles).filter_by(chatbot_id=chatbot.id, file_path=file_path).first()
+ if document is None:
+ document = TargetFiles(file_path=file_path, chatbot_id=chatbot.id)
+ self.session.merge(document) # Use merge instead of add
+ self.session.commit()
- def update_saved_prompt(self, saved_prompt: SavedPrompt):
- self.db_handler.update_saved_prompt(saved_prompt)
- self.__settings_updated()
-
- def get_saved_prompt_by_id(self, prompt_id) -> SavedPrompt:
- self.db_handler.get_saved_prompt_by_id(prompt_id)
-
- #######################################
- ### SETTINGS ###
- #######################################
def update_settings_by_name(self, setting_name, column_name, val):
if setting_name == "application_settings":
self.update_application_settings(column_name, val)
@@ -321,158 +279,450 @@ def update_settings_by_name(self, setting_name, column_name, val):
logging.error(f"Invalid setting name: {setting_name}")
def update_application_settings(self, column_name, val):
- self.db_handler.update_setting(ApplicationSettings, column_name, val)
+ self.update_setting(ApplicationSettings, column_name, val)
self.__settings_updated()
- #######################################
- ### TTS Settings ###
- #######################################
def update_espeak_settings(self, column_name, val):
- self.db_handler.update_setting(EspeakSettings, column_name, val)
+ self.update_setting(EspeakSettings, column_name, val)
self.__settings_updated()
def update_tts_settings(self, column_name, val):
- self.db_handler.update_setting(TTSSettings, column_name, val)
+ self.update_setting(TTSSettings, column_name, val)
self.__settings_updated()
def update_speech_t5_settings(self, column_name, val):
- self.db_handler.update_setting(SpeechT5Settings, column_name, val)
+ self.update_setting(SpeechT5Settings, column_name, val)
self.__settings_updated()
-
- #######################################
- ### CONTROLNET ###
- #######################################
def update_controlnet_settings(self, column_name, val):
- self.db_handler.update_setting(ControlnetSettings, column_name, val)
+ self.update_setting(ControlnetSettings, column_name, val)
self.__settings_updated()
def update_brush_settings(self, column_name, val):
- self.db_handler.update_setting(BrushSettings, column_name, val)
+ self.update_setting(BrushSettings, column_name, val)
self.__settings_updated()
def update_image_to_image_settings(self, column_name, val):
- self.db_handler.update_setting(ImageToImageSettings, column_name, val)
+ self.update_setting(ImageToImageSettings, column_name, val)
self.__settings_updated()
def update_outpaint_settings(self, column_name, val):
- self.db_handler.update_setting(OutpaintSettings, column_name, val)
+ self.update_setting(OutpaintSettings, column_name, val)
self.__settings_updated()
def update_drawing_pad_settings(self, column_name, val):
- self.db_handler.update_setting(DrawingPadSettings, column_name, val)
+ self.update_setting(DrawingPadSettings, column_name, val)
self.__settings_updated()
def update_grid_settings(self, column_name, val):
- self.db_handler.update_setting(GridSettings, column_name, val)
+ self.update_setting(GridSettings, column_name, val)
self.__settings_updated()
def update_active_grid_settings(self, column_name, val):
- self.db_handler.update_setting(ActiveGridSettings, column_name, val)
+ self.update_setting(ActiveGridSettings, column_name, val)
self.__settings_updated()
- #######################################
- ### PATH ###
- #######################################
def update_path_settings(self, column_name, val):
- self.db_handler.update_setting(PathSettings, column_name, val)
+ self.update_setting(PathSettings, column_name, val)
self.__settings_updated()
- def reset_path_settings(self):
- self.db_handler.reset_path_settings()
-
def update_memory_settings(self, column_name, val):
- self.db_handler.update_setting(MemorySettings, column_name, val)
+ self.update_setting(MemorySettings, column_name, val)
self.__settings_updated()
def update_metadata_settings(self, column_name, val):
- self.db_handler.update_setting(MetadataSettings, column_name, val)
+ self.update_setting(MetadataSettings, column_name, val)
self.__settings_updated()
def update_llm_generator_settings(self, column_name, val):
- self.db_handler.update_setting(LLMGeneratorSettings, column_name, val)
+ self.update_setting(LLMGeneratorSettings, column_name, val)
self.__settings_updated()
def update_whisper_settings(self, column_name, val):
- self.db_handler.update_setting(WhisperSettings, column_name, val)
+ self.update_setting(WhisperSettings, column_name, val)
self.__settings_updated()
- def __settings_updated(self):
- self.emit_signal(SignalCode.APPLICATION_SETTINGS_CHANGED_SIGNAL)
+ def update_ai_models(self, models: List[AIModels]):
+ for model in models:
+ self.update_ai_model(model)
+ self.__settings_updated()
+
+ def update_ai_model(self, model: AIModels):
+ query = self.session.query(AIModels).filter_by(
+ name=model.name,
+ path=model.path,
+ branch=model.branch,
+ version=model.version,
+ category=model.category,
+ pipeline_action=model.pipeline_action,
+ enabled=model.enabled,
+ model_type=model.model_type,
+ is_default=model.is_default
+ ).first()
+ if query:
+ for key in model.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(model, key))
+ else:
+ self.session.add(model)
+ self.session.commit()
+ self.__settings_updated()
+
+ def update_generator_settings(self, column_name, val):
+ generator_settings = self.generator_settings
+ setattr(generator_settings, column_name, val)
+ self.save_generator_settings(generator_settings)
+
+ def update_controlnet_image_settings(self, column_name, val):
+ controlnet_settings = self.controlnet_settings
+ setattr(controlnet_settings, column_name, val)
+ self.update_controlnet_settings(column_name, val)
+
+ def load_schedulers(self) -> list[Type[Schedulers]]:
+ return self.session.query(Schedulers).all()
+
+ def load_settings_from_db(self, model_class_):
+ settings = self.session.query(model_class_).first()
+ if settings is None:
+ settings = self.create_new_settings(model_class_)
+ return settings
+
+ def update_setting(self, model_class_, name, value):
+ setting = self.session.query(model_class_).order_by(model_class_.id.desc()).first()
+ if setting:
+ setattr(setting, name, value)
+ self.session.commit()
+
+ def save_generator_settings(self, generator_settings: GeneratorSettings):
+ query = self.session.query(GeneratorSettings).filter_by(
+ id=generator_settings.id
+ ).first()
+ if query:
+ for key in generator_settings.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(generator_settings, key))
+ else:
+ self.session.add(generator_settings)
+ self.session.commit()
+ self.__settings_updated()
def reset_settings(self):
- self.db_handler.reset_settings()
+ # Delete all entries from the model class
+ self.session.query(ApplicationSettings).delete()
+ self.session.query(ActiveGridSettings).delete()
+ self.session.query(ControlnetSettings).delete()
+ self.session.query(ImageToImageSettings).delete()
+ self.session.query(OutpaintSettings).delete()
+ self.session.query(DrawingPadSettings).delete()
+ self.session.query(MetadataSettings).delete()
+ self.session.query(GeneratorSettings).delete()
+ self.session.query(LLMGeneratorSettings).delete()
+ self.session.query(TTSSettings).delete()
+ self.session.query(SpeechT5Settings).delete()
+ self.session.query(EspeakSettings).delete()
+ self.session.query(STTSettings).delete()
+ self.session.query(BrushSettings).delete()
+ self.session.query(GridSettings).delete()
+ self.session.query(PathSettings).delete()
+ self.session.query(MemorySettings).delete()
+ # Commit the changes
+ self.session.commit()
+
+ def create_new_settings(self, model_class_):
+ new_settings = model_class_()
+ self.session.add(new_settings)
+ self.session.commit()
+ self.session.refresh(new_settings)
+ return new_settings
+
+ def get_saved_prompt_by_id(self, prompt_id) -> Type[SavedPrompt]:
+ return self.session.query(SavedPrompt).filter_by(id=prompt_id).first()
- #######################################
- ### AI MODELS ###
- #######################################
- def update_ai_models(self, models: List[AIModels]):
- self.db_handler.update_ai_models(models)
+ def update_saved_prompt(self, saved_prompt: SavedPrompt):
+ query = self.session.query(SavedPrompt).filter_by(
+ id=saved_prompt.id
+ ).first()
+ if query:
+ for key in saved_prompt.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(saved_prompt, key))
+ else:
+ self.session.add(saved_prompt)
+ self.session.commit()
self.__settings_updated()
- def update_ai_model(self, model: AIModels):
- self.db_handler.update_ai_model(model)
+ def create_saved_prompt(self, data: dict):
+ new_saved_prompt = SavedPrompt(**data)
+ self.session.add(new_saved_prompt)
+ self.session.commit()
+
+ def load_saved_prompts(self) -> List[Type[SavedPrompt]]:
+ return self.session.query(SavedPrompt).all()
+
+ def load_font_settings(self) -> List[Type[FontSetting]]:
+ return self.session.query(FontSetting).all()
+
+ def get_font_setting_by_name(self, name) -> Type[FontSetting]:
+ return self.session.query(FontSetting).filter_by(name=name).first()
+
+ def update_font_setting(self, font_setting: Type[FontSetting]):
+ query = self.session.query(FontSetting).filter_by(
+ name=font_setting.name
+ ).first()
+ if query:
+ for key in font_setting.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(font_setting, key))
+ else:
+ self.session.add(font_setting)
+ self.session.commit()
self.__settings_updated()
- #######################################
- ### PROMPT TEMPLATES ###
- #######################################
- def get_prompt_template_by_name(self, name) -> PromptTemplate:
- return self.db_handler.get_prompt_template_by_name(name)
+ def load_ai_models(self) -> List[Type[AIModels]]:
+ return self.session.query(AIModels).all()
- #######################################
- ### CONTROLNET MODELS ###
- #######################################
- def controlnet_model_by_name(self, name) -> ControlnetModel:
- return self.db_handler.controlnet_model_by_name(name)
+ def load_chatbots(self) -> List[Type[Chatbot]]:
+ settings = self.session.query(Chatbot).all()
+ return settings
- def get_font_setting_by_name(self, name) -> FontSetting:
- return self.db_handler.get_font_setting_by_name(name)
+ def delete_chatbot_by_name(self, chatbot_name):
+ self.session.query(Chatbot).filter_by(name=chatbot_name).delete()
+ self.session.commit()
- def update_font_setting(self, font_setting: FontSetting):
- self.db_handler.update_font_setting(font_setting)
+ def create_chatbot(self, chatbot_name):
+ new_chatbot = Chatbot(name=chatbot_name)
+ self.session.add(new_chatbot)
+ self.session.commit()
+
+ def reset_path_settings(self):
+ self.session.query(PathSettings).delete()
+ self.set_default_values(PathSettings)
+ self.session.commit()
+
+ def set_default_values(self, model_name_):
+ default_values = {}
+ for column in model_name_.__table__.columns:
+ if column.default is not None:
+ default_values[column.name] = column.default.arg
+ self.session.execute(
+ model_name_.__table__.insert(),
+ [default_values]
+ )
+ self.session.commit()
+
+ def load_lora(self) -> List[Type[Lora]]:
+ return self.session.query(Lora).all()
+
+ def get_lora_by_name(self, name):
+ return self.session.query(Lora).filter_by(name=name).first()
+
+ def add_lora(self, lora: Lora):
+ self.session.add(lora)
+ self.session.commit()
+
+ def delete_lora(self, lora: Lora):
+ self.session.query(Lora).filter_by(name=lora.name).delete()
+ self.session.commit()
+
+ def update_lora(self, lora: Lora):
+ query = self.session.query(Lora).filter_by(name=lora.name).first()
+ if query:
+ for key in lora.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(lora, key))
+ else:
+ self.session.add(lora)
+ self.session.commit()
self.__settings_updated()
- def ai_model_get_by_filter(self, filter_dict: dict) -> List[AIModels]:
- results = []
- models = self.db_handler.load_ai_models()
- for item in models:
- match = True
- for k, v in filter_dict.items():
- if isinstance(item, dict):
- if item.get(k, "") != v:
- match = False
- break
- else:
- if not hasattr(item, k) or getattr(item, k) != v:
- match = False
- break
- if match:
- results.append(item)
- return results
+ def update_loras(self, loras: List[Lora]):
+ for lora in loras:
+ query = self.session.query(Lora).filter_by(name=lora.name).first()
+ if query:
+ for key in lora.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(lora, key))
+ else:
+ self.session.add(lora)
+ self.session.commit()
+ self.__settings_updated()
- #######################################
- ### GENERATOR SETTINGS ###
- #######################################
- def update_generator_settings(self, column_name, val):
- generator_settings = self.generator_settings
- setattr(generator_settings, column_name, val)
- self.save_generator_settings(generator_settings)
+ def create_lora(self, lora: Lora):
+ self.session.add(lora)
+ self.session.commit()
+
+ def delete_lora_by_name(self, lora_name, version):
+ self.session.query(Lora).filter_by(name=lora_name, version=version).delete()
+ self.session.commit()
- def save_generator_settings(self, generator_settings):
- self.db_handler.save_generator_settings(generator_settings)
+ def delete_embedding(self, embedding: Embedding):
+ self.session.query(Embedding).filter_by(
+ name=embedding.name,
+ path=embedding.path,
+ branch=embedding.branch,
+ version=embedding.version,
+ category=embedding.category,
+ pipeline_action=embedding.pipeline_action,
+ enabled=embedding.enabled,
+ model_type=embedding.model_type,
+ is_default=embedding.is_default
+ ).delete()
+ self.session.commit()
+
+ def update_embeddings(self, embeddings: List[Embedding]):
+ for embedding in embeddings:
+ query = self.session.query(Embedding).filter_by(
+ name=embedding.name,
+ path=embedding.path,
+ branch=embedding.branch,
+ version=embedding.version,
+ category=embedding.category,
+ pipeline_action=embedding.pipeline_action,
+ enabled=embedding.enabled,
+ model_type=embedding.model_type,
+ is_default=embedding.is_default
+ ).first()
+ if query:
+ for key in embedding.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(embedding, key))
+ else:
+ self.session.add(embedding)
+ self.session.commit()
self.__settings_updated()
+ def get_embedding_by_name(self, name):
+ return self.session.query(Embedding).filter_by(name=name).first()
- #######################################
- ### WINDOW SETTINGS ###
- #######################################
- @property
- def window_settings(self):
- return self.db_handler.load_window_settings()
+ def add_embedding(self, embedding: Embedding):
+ self.session.add(embedding)
+ self.session.commit()
+
+ def load_prompt_templates(self) -> List[Type[PromptTemplate]]:
+ return self.session.query(PromptTemplate).all()
+
+ def get_prompt_template_by_name(self, name) -> Type[PromptTemplate]:
+ return self.session.query(PromptTemplate).filter_by(template_name=name).first()
+
+ def load_controlnet_models(self) -> List[Type[ControlnetModel]]:
+ return self.session.query(ControlnetModel).all()
+
+ def controlnet_model_by_name(self, name) -> Type[ControlnetModel]:
+ return self.session.query(ControlnetModel).filter_by(name=name).first()
+ def load_pipelines(self) -> List[Type[PipelineModel]]:
+ return self.session.query(PipelineModel).all()
+
+ def load_shortcut_keys(self) -> List[Type[ShortcutKeys]]:
+ return self.session.query(ShortcutKeys).all()
+
+ def load_window_settings(self) -> Type[WindowSettings]:
+ return self.session.query(WindowSettings).first()
def save_window_settings(self, column_name, val):
window_settings = self.window_settings
setattr(window_settings, column_name, val)
- self.db_handler.save_window_settings(window_settings)
+ query = self.session.query(WindowSettings).first()
+ if query:
+ for key in window_settings.__dict__.keys():
+ if key != "_sa_instance_state":
+ setattr(query, key, getattr(window_settings, key))
+ else:
+ self.session.add(window_settings)
+ self.session.commit()
+
+ def save_object(self, database_object):
+ self.session.add(database_object)
+ self.session.commit()
+
+ def load_history_from_db(self, conversation_id):
+ messages = self.session.query(Message).filter_by(
+ conversation_id=conversation_id
+ ).order_by(Message.timestamp).all()
+ results = [
+ {
+ "role": message.role,
+ "content": message.content,
+ "name": message.name,
+ "is_bot": message.is_bot,
+ "timestamp": message.timestamp,
+ "conversation_id": message.conversation_id
+ } for message in messages
+ ]
+ return results
+
+ def save_message(self, content, role, name, is_bot, conversation_id) -> Message:
+ timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object
+ llm_generator_settings = self.session.query(LLMGeneratorSettings).first()
+ message = Message(
+ role=role,
+ content=content,
+ name=name,
+ is_bot=is_bot,
+ timestamp=timestamp,
+ conversation_id=conversation_id,
+ chatbot_id=llm_generator_settings.current_chatbot
+ )
+ self.session.add(message)
+ self.session.commit()
+ return message
+
+ def get_chatbot_by_id(self, chatbot_id) -> Type[Chatbot]:
+ chatbot = self.session.query(Chatbot).filter_by(id=chatbot_id).options(joinedload(Chatbot.target_files)).first()
+ if chatbot is None:
+ chatbot = self.session.query(Chatbot).options(joinedload(Chatbot.target_files)).first()
+ return chatbot
+
+ def create_conversation(self):
+ conversation = Conversation(
+ timestamp=datetime.datetime.now(datetime.timezone.utc),
+ title=""
+ )
+ self.session.add(conversation)
+ self.session.commit()
+ return conversation.id
+
+ def update_conversation_title(self, conversation_id, title):
+ conversation = self.session.query(Conversation).filter_by(id=conversation_id).first()
+ if conversation:
+ conversation.title = title
+ self.session.commit()
+
+ def add_summary(self, content, conversation_id):
+ timestamp = datetime.datetime.now() # Ensure timestamp is a datetime object
+ summary = Summary(
+ content=content,
+ timestamp=timestamp,
+ conversation_id=conversation_id
+ )
+ self.session.add(summary)
+ self.session.commit()
+
+ def create_conversation_with_messages(self, messages):
+ conversation_id = self.create_conversation()
+ for message in messages:
+ self.add_message_to_history(
+ content=message["content"],
+ role=message["role"],
+ name=message["name"],
+ is_bot=message["is_bot"],
+ conversation_id=conversation_id
+ )
+ return conversation_id
+
+ def get_all_conversations(self):
+ conversations = self.session.query(Conversation).all()
+ return conversations
+
+ def delete_conversation(self, conversation_id):
+ self.session.query(Message).filter_by(conversation_id=conversation_id).delete()
+ self.session.query(Summary).filter_by(conversation_id=conversation_id).delete()
+ self.session.query(Conversation).filter_by(id=conversation_id).delete()
+ self.session.commit()
+
+ def get_most_recent_conversation_id(self):
+ conversation = self.session.query(Conversation).order_by(Conversation.timestamp.desc()).first()
+ return conversation.id if conversation else None
+
+ def __settings_updated(self):
+ self.emit_signal(SignalCode.APPLICATION_SETTINGS_CHANGED_SIGNAL)
diff --git a/src/airunner/windows/prompt_browser/prompt_widget.py b/src/airunner/windows/prompt_browser/prompt_widget.py
index 0f4e09005..da0aeb86d 100644
--- a/src/airunner/windows/prompt_browser/prompt_widget.py
+++ b/src/airunner/windows/prompt_browser/prompt_widget.py
@@ -34,10 +34,10 @@ def action_clicked_button_load(self):
})
def action_clicked_button_delete(self):
- session = self.db_handler.get_db_session()
- session.delete(self.saved_prompt)
- session.commit()
- session.close()
+
+ self.session.delete(self.saved_prompt)
+ self.session.commit()
+
self.deleteLater()
def save_prompt(self):
diff --git a/src/airunner/windows/setup_wizard/installation_settings/install_page.py b/src/airunner/windows/setup_wizard/installation_settings/install_page.py
index 3ce47066f..3097023bd 100644
--- a/src/airunner/windows/setup_wizard/installation_settings/install_page.py
+++ b/src/airunner/windows/setup_wizard/installation_settings/install_page.py
@@ -78,13 +78,13 @@ def download_stable_diffusion(self):
"label": "Downloading Stable Diffusion models..."
})
- session = self.db_handler.get_db_session()
- models = session.query(AIModels).filter(
+
+ models = self.session.query(AIModels).filter(
AIModels.category == "stablediffusion",
AIModels.is_default == 1,
AIModels.version != "SDXL Turbo"
).all()
- session.close()
+
self.total_models_in_current_step += len(models)
for model in models:
@@ -198,12 +198,12 @@ def download_controlnet_processors(self):
print(f"Error downloading {filename}: {e}")
def download_llms(self):
- session = self.db_handler.get_db_session()
- models = session.query(AIModels).filter(
+
+ models = self.session.query(AIModels).filter(
AIModels.category == "llm",
AIModels.is_default == 1
).all()
- session.close()
+
self.total_models_in_current_step += len(models)
for model in models:
files = LLM_FILE_BOOTSTRAP_DATA[model.path]["files"]
@@ -409,12 +409,12 @@ def __init__(self, parent):
if self.application_settings.stable_diffusion_agreement_checked:
self.total_steps += 1
- session = self.db_handler.get_db_session()
+
- controlnet_model_count = session.query(func.count(ControlnetModel.id.distinct())).scalar()
- controlnet_version_count = session.query(func.count(ControlnetModel.version.distinct())).scalar()
+ controlnet_model_count = self.session.query(func.count(ControlnetModel.id.distinct())).scalar()
+ controlnet_version_count = self.session.query(func.count(ControlnetModel.version.distinct())).scalar()
- llm_model_count = session.query(func.count(AIModels.id)).filter(AIModels.category == 'llm').scalar()
+ llm_model_count = self.session.query(func.count(AIModels.id)).filter(AIModels.category == 'llm').scalar()
self.total_steps += controlnet_model_count * controlnet_version_count
self.total_steps += llm_model_count
diff --git a/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py b/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py
index 692842fb6..5aa7b501c 100644
--- a/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py
+++ b/src/airunner/windows/setup_wizard/model_setup/choose_model_style.py
@@ -1,4 +1,4 @@
-from airunner.windows.setup_wizard.setup_wizard_window import BaseWizard
+from airunner.windows.setup_wizard.base_wizard import BaseWizard
from airunner.windows.setup_wizard.model_setup.stable_diffusion_setup.templates.choose_style_ui import Ui_choose_model_style
diff --git a/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py b/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py
index eff676599..258dd1e58 100644
--- a/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py
+++ b/src/airunner/windows/setup_wizard/model_setup/choose_model_version.py
@@ -1,5 +1,5 @@
from airunner.windows.setup_wizard.model_setup.stable_diffusion_setup.templates.choose_version_ui import Ui_choose_model_version
-from airunner.windows.setup_wizard.setup_wizard_window import BaseWizard
+from airunner.windows.setup_wizard.base_wizard import BaseWizard
class ChooseModelVersion(BaseWizard):
diff --git a/src/airunner/worker_manager.py b/src/airunner/worker_manager.py
deleted file mode 100644
index fe0d7dbea..000000000
--- a/src/airunner/worker_manager.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from PySide6.QtCore import QObject, Signal
-
-from airunner.mediator_mixin import MediatorMixin
-from airunner.windows.main.settings_mixin import SettingsMixin
-from airunner.handlers.logger import Logger
-from airunner.utils.create_worker import create_worker
-from airunner.workers.audio_capture_worker import AudioCaptureWorker
-from airunner.workers.audio_processor_worker import AudioProcessorWorker
-from airunner.workers.llm_generate_worker import LLMGenerateWorker
-from airunner.workers.mask_generator_worker import MaskGeneratorWorker
-from airunner.workers.sd_worker import SDWorker
-from airunner.workers.tts_generator_worker import TTSGeneratorWorker
-from airunner.workers.tts_vocalizer_worker import TTSVocalizerWorker
-
-
-class WorkerManager(QObject, MediatorMixin, SettingsMixin):
- """
- The engine is responsible for processing requests and offloading
- them to the appropriate AI model controller.
- """
- # Signals
- request_signal_status = Signal(str)
- image_generated_signal = Signal(dict)
-
- def __init__(self):
- MediatorMixin.__init__(self)
- SettingsMixin.__init__(self)
- super().__init__()
- self.logger = Logger(prefix=self.__class__.__name__)
- self._sd_worker = None
- self._llm_request_worker = None
- self._llm_generate_worker = None
- self._tts_generator_worker = None
- self._tts_vocalizer_worker = None
- self._stt_audio_capture_worker = None
- self._stt_audio_processor_worker = None
-
- self.register_sd_workers()
- self.register_llm_workers()
- self.register_tts_workers()
- self.register_stt_workers()
diff --git a/src/airunner/workers/worker.py b/src/airunner/workers/worker.py
index a0e31f838..291b152fa 100644
--- a/src/airunner/workers/worker.py
+++ b/src/airunner/workers/worker.py
@@ -5,7 +5,6 @@
from PySide6.QtCore import Signal, QThread, QObject
from airunner.enums import QueueType, SignalCode, WorkerState
-from airunner.handlers.logger import Logger
from airunner.mediator_mixin import MediatorMixin
from airunner.settings import SLEEP_TIME_IN_MS
from airunner.windows.main.settings_mixin import SettingsMixin
@@ -22,7 +21,6 @@ def __init__(self, signals=None):
SettingsMixin.__init__(self)
super().__init__()
self.state = WorkerState.HALTED
- self.logger = Logger(prefix=self.__class__.__name__)
self.running = False
self.queue = queue.Queue()
self.items = {}