Skip to content

Commit

Permalink
Complete revamp of settings_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Jan 11, 2024
1 parent 24e537e commit 1fe1b48
Show file tree
Hide file tree
Showing 80 changed files with 2,555 additions and 4,851 deletions.
2 changes: 1 addition & 1 deletion src/airunner/aihandler/auto_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from airunner.aihandler.logger import Logger as logger
from airunner.aihandler.settings_manager import SettingsManager
from airunner.data.managers import SettingsManager


class AutoImport:
Expand Down
2 changes: 1 addition & 1 deletion src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from airunner.aihandler.llm import LLM
from airunner.aihandler.logger import Logger as logger
from airunner.aihandler.runner import SDRunner
from airunner.aihandler.settings_manager import SettingsManager
from airunner.data.managers import SettingsManager
from airunner.aihandler.tts import TTS


Expand Down
2 changes: 1 addition & 1 deletion src/airunner/aihandler/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from airunner.aihandler.mixins.scheduler_mixin import SchedulerMixin
from airunner.aihandler.mixins.txttovideo_mixin import TexttovideoMixin
from airunner.aihandler.settings import LOG_LEVEL, AIRUNNER_ENVIRONMENT
from airunner.aihandler.settings_manager import SettingsManager
from airunner.data.managers import SettingsManager
from airunner.prompt_builder.prompt_data import PromptData
from airunner.scripts.realesrgan.main import RealESRGAN
from airunner.aihandler.logger import Logger
Expand Down
180 changes: 97 additions & 83 deletions src/airunner/aihandler/settings_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from airunner.aihandler.qtvar import StringVar, IntVar, BooleanVar, FloatVar, DictVar
from airunner.data.models import LLMGenerator, Settings, GeneratorSetting, AIModel, Pipeline, ControlnetModel, ImageFilter, \
SavedPrompt, StandardImageWidgetSettings
from airunner.utils import save_session, get_session
from airunner.aihandler.logger import Logger as logger
from airunner.data.models import Document
from airunner.data.session_scope import session_scope, path_settings_scope


document = None
_app = None
Expand All @@ -17,7 +18,6 @@
"FLOAT": FloatVar,
"JSON": DictVar,
}
session = get_session()


class SettingsSignal(QObject):
Expand All @@ -40,7 +40,8 @@ def prompts(self):
Return Prompt objects from the database
:return:
"""
return session.query(SavedPrompt).all()
with session_scope() as session:
return session.query(SavedPrompt).all()

def create_saved_prompt(self, prompt, negative_prompt):
if prompt == "" and negative_prompt == "":
Expand All @@ -50,102 +51,110 @@ def create_saved_prompt(self, prompt, negative_prompt):
prompt=prompt,
negative_prompt=negative_prompt
)
session.add(saved_prompt)
if save_session(session):
with session_scope() as session:
session.add(saved_prompt)
self.changed_signal.emit("saved_prompt", saved_prompt, self)

def delete_prompt(self, saved_prompt):
session.delete(saved_prompt)
session.commit()
with session_scope() as session:
session.delete(saved_prompt)
self.save_and_emit("saved_prompt", saved_prompt)
self.changed_signal.emit("saved_prompt", saved_prompt, self)

def save_and_emit(self, key, value):
self.save()
self.changed_signal.emit(key, value, self)

def available_models_by_category(self, category):
categories = [category]
if category in ["img2img", "txt2vid"]:
categories.append("txt2img")
return session.query(AIModel).filter(
AIModel.category.in_(categories),
AIModel.enabled.is_(True)
).all()
with session_scope() as session:
return session.query(AIModel).filter(
AIModel.category.in_(categories),
AIModel.enabled.is_(True)
).all()

def set_model_enabled(self, key, model, enabled):
session.query(AIModel).filter_by(
name=model["name"],
path=model["path"],
branch=model["branch"],
version=model["version"],
category=model["category"]
).update({"enabled": enabled == 2})
with session_scope() as session:
session.query(AIModel).filter_by(
name=model["name"],
path=model["path"],
branch=model["branch"],
version=model["version"],
category=model["category"]
).update({"enabled": enabled == 2})
self.save_settings()

def available_pipeline_by_section(self, pipeline_action, version, category):
return session.query(Pipeline).filter_by(
category=category,
pipeline_action=pipeline_action,
version=version
).first()
with session_scope() as session:
return session.query(Pipeline).filter_by(
category=category,
pipeline_action=pipeline_action,
version=version
).first()

def available_model_names(self, pipeline_action, category):
# returns a list of names of models
# that match the pipeline_action and category
names = []
models = session.query(AIModel).filter_by(
pipeline_action=pipeline_action,
category=category,
enabled=True
).all()
for model in models:
if model.name not in names:
names.append(model.name)
return names
with session_scope() as session:
names = []
models = session.query(AIModel).filter_by(
pipeline_action=pipeline_action,
category=category,
enabled=True
).all()
for model in models:
if model.name not in names:
names.append(model.name)
return names

def add_model(self, model_data):
model = AIModel(**model_data)
session.add(model)
session.commit()
with session_scope() as session:
model = AIModel(**model_data)
session.add(model)

def delete_model(self, model):
session.delete(model)
session.commit()
with session_scope() as session:
session.delete(model)

def update_model(self, model):
session.add(model)
session.commit()
with session_scope() as session:
session.add(model)

def get_image_filter(self, name):
return session.query(ImageFilter).filter_by(name=name).first()
with session_scope() as session:
return session.query(ImageFilter).filter_by(name=name).first()

def get_image_filters(self):
return session.query(ImageFilter).all()
with session_scope() as session:
return session.query(ImageFilter).all()

@property
def standard_image_widget_settings(self):
standard_image_widget_settings = session.query(StandardImageWidgetSettings).first()
if standard_image_widget_settings is None:
standard_image_widget_settings = StandardImageWidgetSettings()
session.add(standard_image_widget_settings)
session.commit()
return standard_image_widget_settings
with session_scope() as session:
standard_image_widget_settings = session.query(StandardImageWidgetSettings).first()
if standard_image_widget_settings is None:
standard_image_widget_settings = StandardImageWidgetSettings()
session.add(standard_image_widget_settings)
return standard_image_widget_settings

@property
def pipelines(self):
return session.query(Settings).all()
with session_scope() as session:
return session.query(Settings).all()

@property
def models(self):
return session.query(AIModel).filter_by(enabled=True)
with session_scope() as session:
return session.query(AIModel).filter_by(enabled=True)

def models_by_pipeline_action(self, pipeline_action):
return self.models.filter_by(pipeline_action=pipeline_action).all()

@property
def controlnet_models(self):
return session.query(ControlnetModel).filter_by(enabled=True)
with session_scope() as session:
return session.query(ControlnetModel).filter_by(enabled=True)

def controlnet_model_by_name(self, name):
return self.controlnet_models.filter_by(name=name).first()
Expand All @@ -169,11 +178,12 @@ def model_categories(self):

def get_pipeline_classname(self, pipeline_action, version, category):
try:
return session.query(Pipeline).filter_by(
category=category,
pipeline_action=pipeline_action,
version=version
).first().classname
with session_scope() as session:
return session.query(Pipeline).filter_by(
category=category,
pipeline_action=pipeline_action,
version=version
).first().classname
except AttributeError:
logger.error(f"Unable to find pipeline classname for {pipeline_action} {version} {category}")
return None
Expand All @@ -200,34 +210,35 @@ def generator(self):

@property
def llm_generator_setting(self):
llm_generator = session.query(LLMGenerator).filter(LLMGenerator.name == self.current_llm_generator).first()
return llm_generator.generator_settings[0]
with session_scope() as session:
llm_generator = session.query(LLMGenerator).filter(LLMGenerator.name == self.current_llm_generator).first()
return llm_generator.generator_settings[0]

_generator = None

def find_generator(self, generator_section, generator_name):
# using sqlalchemy, query the document.settings.generator_settings column
# and find any with GeneratorSettings.section == self.generator_section and GeneratorSettings.generator_name == self.generator_name
# return the first result
if self.generator_settings_override_id:
generator_settings = session.query(GeneratorSetting).filter_by(
id=self.generator_settings_override_id
).first()
else:
generator_settings = session.query(GeneratorSetting).filter_by(
is_preset=0
).first()
if generator_settings is None:
if not generator_section or generator_section == "" or not generator_name or generator_name == "":
return None
# generator_settings = GeneratorSetting(
# section=generator_section,
# generator_name=generator_name,
# is_preset=False
# )
# session.add(generator_settings)
# session.commit()
return generator_settings
with session_scope() as session:
if self.generator_settings_override_id:
generator_settings = session.query(GeneratorSetting).filter_by(
id=self.generator_settings_override_id
).first()
else:
generator_settings = session.query(GeneratorSetting).filter_by(
is_preset=0
).first()
if generator_settings is None:
if not generator_section or generator_section == "" or not generator_name or generator_name == "":
return None
# generator_settings = GeneratorSetting(
# section=generator_section,
# generator_name=generator_name,
# is_preset=False
# )
# session.add(generator_settings)
return generator_settings

def __init__(self, app=None, *args, **kwargs):
global _app, document
Expand All @@ -237,8 +248,8 @@ def __init__(self, app=None, *args, **kwargs):
_app = app
document = _app.document
else:
session = get_session()
document = session.query(Document).first()
with session_scope() as session:
document = session.query(Document).first()

super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -273,8 +284,11 @@ def get_database_value(self, name):
return getattr(document.settings, name)

def __getattr__(self, name):
if document and hasattr(document.settings, name):
return getattr(document.settings, name)
with session_scope() as session:
session.add(document)
session.add(document.settings)
if document and hasattr(document.settings, name):
return getattr(document.settings, name)
return None

def __setattr__(self, name, value):
Expand Down Expand Up @@ -311,4 +325,4 @@ def current_tab_action(self):
return self.current_section_stablediffusion

def save(self):
session.commit()
pass
8 changes: 4 additions & 4 deletions src/airunner/aihandler/transformer_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from transformers import InstructBlipForConditionalGeneration
from transformers import InstructBlipProcessor

from airunner.aihandler.settings_manager import SettingsManager
from airunner.data.managers import SettingsManager
from airunner.data.models import LLMGenerator
from airunner.utils import get_session
from airunner.data.session_scope import session_scope
from airunner.aihandler.logger import Logger


Expand Down Expand Up @@ -66,10 +66,10 @@ class TransformerRunner(QObject):
@property
def generator(self):
try:
session = get_session()
if not self._generator or self.current_generator_name != self.requested_generator_name:
self.current_generator_name = self.requested_generator_name
self._generator = session.query(LLMGenerator).filter_by(name=self.current_generator_name).first()
with session_scope() as session:
self._generator = session.query(LLMGenerator).filter_by(name=self.current_generator_name).first()
return self._generator
except Exception as e:
Logger.error(e)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Renamed GridSettings size to cell_size
Revision ID: 1aa0b181a56e
Revises: 77bcb51efb33
Create Date: 2024-01-10 12:27:35.247277
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '1aa0b181a56e'
down_revision: Union[str, None] = '77bcb51efb33'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('grid_settings', sa.Column('cell_size', sa.Integer(), nullable=True))
op.drop_column('grid_settings', 'size')
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('grid_settings', sa.Column('size', sa.INTEGER(), nullable=True))
op.drop_column('grid_settings', 'cell_size')
# ### end Alembic commands ###
Loading

0 comments on commit 1fe1b48

Please sign in to comment.