diff --git a/setup.py b/setup.py index a1742d29d..599e9b46b 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "tensorflow==2.17.0", "DeepCache==0.1.1", "tensorflow==2.17.0", + "alembic==1.13.3", # LLM Dependencies "transformers==4.43.4", diff --git a/src/airunner/aihandler/base_handler.py b/src/airunner/aihandler/base_handler.py index bffdf1047..82ff2717c 100644 --- a/src/airunner/aihandler/base_handler.py +++ b/src/airunner/aihandler/base_handler.py @@ -27,13 +27,23 @@ def __init__(self, *args, **kwargs): MediatorMixin.__init__(self) SettingsMixin.__init__(self) super().__init__(*args, **kwargs) - self._requested_action = None + self._requested_action:ModelAction = ModelAction.NONE + + @property + def requested_action(self): + return self._requested_action + + @requested_action.setter + def requested_action(self, value): + self._requested_action = value def handle_requested_action(self): if self._requested_action is ModelAction.LOAD: self.load() + self._requested_action = ModelAction.NONE if self._requested_action is ModelAction.CLEAR: self.unload() + self._requested_action = ModelAction.NONE def load(self): pass diff --git a/src/airunner/aihandler/models/settings_models.py b/src/airunner/aihandler/models/settings_models.py index 82dbad687..048342480 100644 --- a/src/airunner/aihandler/models/settings_models.py +++ b/src/airunner/aihandler/models/settings_models.py @@ -183,6 +183,8 @@ class GeneratorSettings(Base): negative_original_size = Column(JSON, default={"width": 512, "height": 512}) negative_target_size = Column(JSON, default={"width": 512, "height": 512}) + lora_scale = Column(Integer, default=100) + class ControlnetImageSettings(Base): __tablename__ = 'controlnet_image_settings' @@ -468,7 +470,7 @@ class Lora(Base): __tablename__ = 'lora' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String, nullable=False) - scale = Column(Float, nullable=False) + scale = Column(Integer, nullable=False) enabled = Column(Boolean, nullable=False) loaded = Column(Boolean, default=False, nullable=False) trigger_word = Column(String, nullable=True) diff --git a/src/airunner/aihandler/stablediffusion/sd_handler.py b/src/airunner/aihandler/stablediffusion/sd_handler.py index f213761de..b0def6b07 100644 --- a/src/airunner/aihandler/stablediffusion/sd_handler.py +++ b/src/airunner/aihandler/stablediffusion/sd_handler.py @@ -1,5 +1,5 @@ import os -from typing import Any, List +from typing import Any, List, Dict import diffusers import numpy as np @@ -24,9 +24,10 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizerFast from airunner.aihandler.base_handler import BaseHandler +from airunner.aihandler.models.settings_models import Schedulers, Lora from airunner.enums import ( SDMode, StableDiffusionVersion, GeneratorSection, ModelStatus, ModelType, SignalCode, HandlerState, - EngineResponseCode + EngineResponseCode, ModelAction ) from airunner.exceptions import PipeNotLoadedException, InterruptedException from airunner.settings import MIN_NUM_INFERENCE_STEPS_IMG2IMG @@ -73,11 +74,11 @@ def __init__(self, *args, **kwargs): self._current_prompt_2: str = "" self._current_negative_prompt_2: str = "" self._tokenizer: CLIPTokenizerFast = None - self._generator: torch.Generator = None + self._generator = None self._latents = None self._textual_inversion_manager: DiffusersTextualInversionManager = None self._compel_proc: Compel = None - self._loaded_lora: List = [] + self._loaded_lora: Dict = {} self._disabled_lora: List = [] self._loaded_embeddings: List = [] self._current_state: HandlerState = HandlerState.UNINITIALIZED @@ -238,6 +239,21 @@ def model_path(self) -> str: ) ) + @property + def lora_base_path(self) -> str: + return os.path.expanduser( + os.path.join( + self.path_settings_cached.base_path, + "art/models", + self.generator_settings_cached.version, + "lora" + ) + ) + + @property + def lora_scale(self) -> float: + return self.generator_settings_cached.lora_scale / 100.0 + @property def data_type(self) -> torch.dtype: return torch.float16 @@ -352,10 +368,15 @@ def unload_controlnet(self): return self._unload_controlnet() - def load_stable_diffusion(self): + def reload(self): + self.logger.debug("Reloading stable diffusion") + self.unload() + self.load() + + def load(self): if self.sd_is_loading or self.sd_is_loaded: return - self.unload_stable_diffusion() + self.unload() self.change_model_status(ModelType.SD, ModelStatus.LOADING) self._load_safety_checker() self._load_tokenizer() @@ -370,14 +391,23 @@ def load_stable_diffusion(self): self._make_memory_efficient() self._finalize_load_stable_diffusion() - def unload_stable_diffusion(self): - if self.sd_is_loading or self.sd_is_unloaded: + def unload(self): + if ( + self.sd_is_loading or + self.sd_is_unloaded + ): return + elif self._current_state in ( + HandlerState.PREPARING_TO_GENERATE, + HandlerState.GENERATING + ): + self.interrupt_image_generation() + self.requested_action = ModelAction.CLEAR self.change_model_status(ModelType.SD, ModelStatus.LOADING) self._unload_safety_checker() self._unload_scheduler() self._unload_controlnet() - self._unload_lora() + self._unload_loras() self._unload_emebeddings() self._unload_compel() self._unload_tokenizer() @@ -389,7 +419,7 @@ def unload_stable_diffusion(self): self.change_model_status(ModelType.SD, ModelStatus.UNLOADED) def handle_generate_signal(self, message: dict=None): - self.load_stable_diffusion() + self.load() self._clear_cached_properties() self._swap_pipeline() if self._current_state not in ( @@ -414,15 +444,29 @@ def handle_generate_signal(self, message: dict=None): 'message': response }) self._current_state = HandlerState.READY + clear_memory() + self.handle_requested_action() - def load_lora(self): + def reload_lora(self): + if self.model_status[ModelType.SD] is not ModelStatus.LOADED or self._current_state in ( + HandlerState.PREPARING_TO_GENERATE, + HandlerState.GENERATING + ): + return + self.change_model_status(ModelType.SD, ModelStatus.LOADING) + self._unload_loras() self._load_lora() + self.emit_signal(SignalCode.LORA_UPDATED_SIGNAL) + self.change_model_status(ModelType.SD, ModelStatus.LOADED) def load_embeddings(self): self._load_embeddings() def interrupt_image_generation(self): - if self._current_state == HandlerState.GENERATING: + if self._current_state in ( + HandlerState.PREPARING_TO_GENERATE, + HandlerState.GENERATING + ): self.do_interrupt_image_generation = True def _swap_pipeline(self): @@ -437,8 +481,7 @@ def _generate(self): model = self.generator_settings_cached.model if self._current_model != model: if self._pipe is not None: - self.unload_stable_diffusion() - self.load_stable_diffusion() + self.reload() if self._pipe is None: raise PipeNotLoadedException() self._load_prompt_embeds() @@ -611,13 +654,12 @@ def _load_tokenizer(self): self.logger.error(f"Failed to load tokenizer") self.logger.error(e) - def _load_generator(self, seed=None): + def _load_generator(self): self.logger.debug("Loading generator") - if not self._generator is None: - seed = seed or int(self.generator_settings_cached.seed) + if self._generator is None: + seed = int(self.generator_settings_cached.seed) self._generator = torch.Generator(device=self._device) self._generator.manual_seed(seed) - return self._generator def _load_controlnet(self): if not self.controlnet_enabled or self.controlnet_is_loading: @@ -687,27 +729,28 @@ def _load_scheduler(self): "txt2img/scheduler/scheduler_config.json" ) ) - for scheduler in self.schedulers: - if scheduler.display_name == scheduler_name: - scheduler_name = scheduler.display_name - scheduler_class_name = scheduler.name - scheduler_class = getattr(diffusers, scheduler_class_name) - try: - self.scheduler = scheduler_class.from_pretrained( - scheduler_path, - subfolder="scheduler", - local_files_only=True - ) - self.change_model_status(ModelType.SCHEDULER, ModelStatus.LOADED) - self.current_scheduler_name = scheduler_name - self.logger.debug(f"Loaded scheduler {scheduler_name}") - except Exception as e: - self.logger.error(f"Failed to load scheduler {scheduler_name}: {e}") - self.change_model_status(ModelType.SCHEDULER, ModelStatus.FAILED) - return - if self._pipe: - self._pipe.scheduler = self.scheduler - return scheduler + session = self.db_handler.get_db_session() + scheduler = session.query(Schedulers).filter_by(display_name=scheduler_name).first() + if not scheduler: + self.logger.error(f"Failed to find scheduler {scheduler_name}") + return None + scheduler_class_name = scheduler.name + scheduler_class = getattr(diffusers, scheduler_class_name) + try: + self.scheduler = scheduler_class.from_pretrained( + scheduler_path, + subfolder="scheduler", + local_files_only=True + ) + self.change_model_status(ModelType.SCHEDULER, ModelStatus.LOADED) + self.current_scheduler_name = scheduler_name + self.logger.debug(f"Loaded scheduler {scheduler_name}") + except Exception as e: + self.logger.error(f"Failed to load scheduler {scheduler_name}: {e}") + self.change_model_status(ModelType.SCHEDULER, ModelStatus.FAILED) + return + if self._pipe: + self._pipe.scheduler = self.scheduler def _load_pipe(self): self._current_model = self.generator_settings_cached.model @@ -759,77 +802,54 @@ def _load_pipe(self): self.logger.error(f"Failed to load model to device: {e}") def _load_lora(self): - self._loaded_lora = [] - available_lora = self.get_lora_by_version(self.generator_settings_cached.version) - if len(available_lora) == 0: + session = self.db_handler.get_db_session() + enabled_lora = session.query(Lora).filter_by( + version=self.generator_settings_cached.version, + enabled=True + ).all() + for lora in enabled_lora: + self._load_lora_weights(lora) + + def _load_lora_weights(self, lora: Lora): + if lora in self._disabled_lora or lora.path in self._loaded_lora: return - self._remove_lora_from_pipe() - if self._pipe is not None: - for lora in available_lora: - if lora.enabled: - self._apply_lora(lora, available_lora) - if len(self.loaded_lora): - self.logger.debug("LoRA loaded") - - def _apply_lora(self, lora, available_lora): - if lora.path in self._disabled_lora: - return - - if not self._has_lora_changed(available_lora): - return - do_disable_lora = False - + filename = os.path.basename(lora.path) try: - filename = lora.path.split("/")[-1] - for _lora in self.loaded_lora: - if _lora.name == lora.name and _lora.path == lora.path: - return - base_path:str = self.path_settings_cached.base_path - version:str = self.generator_settings_cached.version - lora_path = os.path.expanduser( - os.path.join( - base_path, - "art/models", - version, - "lora" - ) - ) + lora_base_path = self.lora_base_path + self.logger.info(f"Loading LORA weights from {lora_base_path}/{filename}") + adapter_name = os.path.splitext(filename)[0] self._pipe.load_lora_weights( - lora_path, - weight_name=filename + lora_base_path, + weight_name=filename, + adapter_name=adapter_name ) - self.loaded_lora.append(lora) + self._loaded_lora[lora.path] = lora except AttributeError as _e: self.logger.warning("This model does not support LORA") do_disable_lora = True except RuntimeError: - self.logger.warning("LORA could not be loaded") + self.logger.warning(f"LORA {filename} could not be loaded") do_disable_lora = True except ValueError: - self.logger.warning("LORA could not be loaded") + self.logger.warning(f"LORA {filename} could not be loaded") do_disable_lora = True - if do_disable_lora: - self._disabled_lora.append(lora.path) - - def _has_lora_changed(self, available_lora): - """ - Check if there are any changes in the available LORA compared to the loaded LORA. - Return True if there are changes, otherwise False. - """ - loaded_lora_paths = {lora.path for lora in self.loaded_lora} - - for lora in available_lora: - if lora.enabled and lora.path not in loaded_lora_paths: - return True - - return False - - def _remove_lora_from_pipe(self): - self.loaded_lora = [] - if self._pipe is not None: - self._pipe.unload_lora_weights() + self._disabled_lora.append(lora) + + def _set_lora_adapters(self): + self.logger.debug("Setting LORA adapters") + session = self.db_handler.get_db_session() + loaded_lora_id = [l.id for l in self._loaded_lora.values()] + enabled_lora = session.query(Lora).filter(Lora.id.in_(loaded_lora_id)).all() + adapter_weights = [] + adapter_names = [] + for lora in enabled_lora: + adapter_weights.append(lora.scale / 100.0) + adapter_names.append(os.path.splitext(os.path.basename(lora.path))[0]) + if len(adapter_weights) > 0: + self._pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) + self.logger.debug("LORA adapters set") def _load_embeddings(self): if not self._pipe: @@ -939,7 +959,7 @@ def _finalize_load_stable_diffusion(self): else: self.logger.error("Something went wrong with Stable Diffusion loading") self.change_model_status(ModelType.SD, ModelStatus.FAILED) - self.unload_stable_diffusion() + self.unload() if ( self._controlnet is not None @@ -1097,14 +1117,23 @@ def _unload_controlnet_processor(self): del self._controlnet_processor self._controlnet_processor = None - def _unload_lora(self): + def _unload_loras(self): self.logger.debug("Unloading lora") - self._remove_lora_from_pipe() if self._pipe is not None: self._pipe.unload_lora_weights() - self._loaded_lora = [] + self._loaded_lora = {} self._disabled_lora = [] + def _unload_lora(self, lora:Lora): + if lora.path in self._loaded_lora: + self.logger.debug(f"Unloading LORA {lora.path}") + del self._loaded_lora[lora.path] + if len(self._loaded_lora) > 0: + self._set_lora_adapters() + else: + self._unload_loras() + clear_memory() + def _unload_emebeddings(self): self.logger.debug("Unloading embeddings") self._loaded_embeddings = [] @@ -1248,6 +1277,9 @@ def _prepare_data(self) -> dict: -self.canvas_settings.pos_x, -self.canvas_settings.pos_y ) + + self._set_seed() + args = dict( outpaint_box_rect=active_rect, width=int(self.application_settings_cached.working_width), @@ -1260,6 +1292,12 @@ def _prepare_data(self) -> dict: callback_on_step_end=self.__interrupt_callback, ) + if len(self._loaded_lora) > 0: + args.update(cross_attention_kwargs=dict( + scale=self.lora_scale, + )) + self._set_lora_adapters() + if self.generator_settings_cached.use_compel: args.update(dict( prompt_embeds=self._prompt_embeds, @@ -1341,8 +1379,6 @@ def _prepare_data(self) -> dict: def _set_seed(self): seed = self.generator_settings_cached.seed self._generator.manual_seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) def _callback(self, step: int, _time_step, latents): self.emit_signal(SignalCode.SD_PROGRESS_SIGNAL, { diff --git a/src/airunner/alembic/versions/4626ae0d0601_convert_lora_scale_column_to_integer.py b/src/airunner/alembic/versions/4626ae0d0601_convert_lora_scale_column_to_integer.py new file mode 100644 index 000000000..a1b9f3ac0 --- /dev/null +++ b/src/airunner/alembic/versions/4626ae0d0601_convert_lora_scale_column_to_integer.py @@ -0,0 +1,107 @@ +"""convert lora scale column to integer + +Revision ID: 4626ae0d0601 +Revises: 7e744b48e075 +Create Date: 2024-10-04 08:04:43.962151 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import sqlite +from sqlalchemy.sql import table, column + +# revision identifiers, used by Alembic. +revision: str = '4626ae0d0601' +down_revision: Union[str, None] = '7e744b48e075' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + # Check if the scale_temp column already exists + conn = op.get_bind() + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('lora')] + + if 'scale_temp' not in columns: + # Create a temporary column to store the integer values + op.add_column('lora', sa.Column('scale_temp', sa.Integer(), nullable=True)) + + # Convert existing float values to integers and store them in the temporary column + lora_table = table('lora', column('scale', sa.Float), column('scale_temp', sa.Integer)) + op.execute( + lora_table.update().values( + scale_temp=(lora_table.c.scale * 100).cast(sa.Integer) + ) + ) + + # Create a new table with the desired schema + op.create_table( + 'lora_new', + sa.Column('id', sa.Integer, primary_key=True, autoincrement=True), + sa.Column('name', sa.String, nullable=False), + sa.Column('scale', sa.Integer, nullable=False), + sa.Column('enabled', sa.Boolean, nullable=False), + sa.Column('loaded', sa.Boolean, default=False, nullable=False), + sa.Column('trigger_word', sa.String, nullable=True), + sa.Column('path', sa.String, nullable=True), + sa.Column('version', sa.String, nullable=True) + ) + + # Copy data from the old table to the new table + op.execute( + 'INSERT INTO lora_new (id, name, scale, enabled, loaded, trigger_word, path, version) ' + 'SELECT id, name, scale_temp, enabled, loaded, trigger_word, path, version FROM lora' + ) + + # Drop the old table + op.drop_table('lora') + + # Rename the new table to the original table name + op.rename_table('lora_new', 'lora') + + +def downgrade(): + # Check if the scale_temp column already exists + conn = op.get_bind() + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('lora')] + + if 'scale_temp' not in columns: + # Create a temporary column to store the float values + op.add_column('lora', sa.Column('scale_temp', sa.Float(), nullable=True)) + + # Convert existing integer values to floats and store them in the temporary column + lora_table = table('lora', column('scale', sa.Integer), column('scale_temp', sa.Float)) + op.execute( + lora_table.update().values( + scale_temp=(lora_table.c.scale / 100).cast(sa.Float) + ) + ) + + # Create a new table with the desired schema + op.create_table( + 'lora_new', + sa.Column('id', sa.Integer, primary_key=True, autoincrement=True), + sa.Column('name', sa.String, nullable=False), + sa.Column('scale', sa.Float, nullable=False), + sa.Column('enabled', sa.Boolean, nullable=False), + sa.Column('loaded', sa.Boolean, default=False, nullable=False), + sa.Column('trigger_word', sa.String, nullable=True), + sa.Column('path', sa.String, nullable=True), + sa.Column('version', sa.String, nullable=True) + ) + + # Copy data from the old table to the new table + op.execute( + 'INSERT INTO lora_new (id, name, scale, enabled, loaded, trigger_word, path, version) ' + 'SELECT id, name, scale_temp, enabled, loaded, trigger_word, path, version FROM lora' + ) + + # Drop the old table + op.drop_table('lora') + + # Rename the new table to the original table name + op.rename_table('lora_new', 'lora') diff --git a/src/airunner/alembic/versions/8ebcece37db8_add_lora_scale_to_generator_settings.py b/src/airunner/alembic/versions/8ebcece37db8_add_lora_scale_to_generator_settings.py new file mode 100644 index 000000000..4a75ec95b --- /dev/null +++ b/src/airunner/alembic/versions/8ebcece37db8_add_lora_scale_to_generator_settings.py @@ -0,0 +1,26 @@ +"""Add lora_scale to generator_settings + +Revision ID: 8ebcece37db8 +Revises: 4626ae0d0601 +Create Date: 2024-10-04 10:16:43.172811 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import sqlite + +# revision identifiers, used by Alembic. +revision: str = '8ebcece37db8' +down_revision: Union[str, None] = '4626ae0d0601' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + op.add_column('generator_settings', sa.Column('lora_scale', sa.Integer, default=100)) + +def downgrade(): + op.drop_column('generator_settings', 'lora_scale') + # ### end Alembic commands ### diff --git a/src/airunner/enums.py b/src/airunner/enums.py index b4979c217..cdfd163a4 100644 --- a/src/airunner/enums.py +++ b/src/airunner/enums.py @@ -152,6 +152,7 @@ class SignalCode(Enum): LLM_PROCESS_STT_AUDIO_SIGNAL = "llm_process_stt_audio" LORA_ADD_SIGNAL = "add_lora_signal" LORA_UPDATE_SIGNAL = "update_lora_signal" + LORA_UPDATED_SIGNAL = "lora_updated_signal" LORA_DELETE_SIGNAL = "delete_lora_signal" SET_CANVAS_COLOR_SIGNAL = "set_canvas_color_signal" UPDATE_SCENE_SIGNAL = "update_scene_signal" @@ -234,6 +235,7 @@ class SignalCode(Enum): CONVERSATION_DELETED = enum.auto() KEYBOARD_SHORTCUTS_UPDATED = enum.auto() + LORA_STATUS_CHANGED = enum.auto() class EngineResponseCode(Enum): STATUS = 100 diff --git a/src/airunner/styles/dark_theme/styles.qss b/src/airunner/styles/dark_theme/styles.qss index 2819a3184..8a426444a 100644 --- a/src/airunner/styles/dark_theme/styles.qss +++ b/src/airunner/styles/dark_theme/styles.qss @@ -94,6 +94,14 @@ QCheckBox:indicator:checked { height: 12px; } +QCheckBox:disabled { + color: #333; +} + +QCheckBox:indicator:disabled { + border: 1px solid #333; +} + QPushButton { border-radius: none; padding: 4px 4px 3px 4px; diff --git a/src/airunner/widgets/llm/loading_widget.py b/src/airunner/widgets/llm/loading_widget.py index 5e1228582..0fc74d224 100644 --- a/src/airunner/widgets/llm/loading_widget.py +++ b/src/airunner/widgets/llm/loading_widget.py @@ -16,4 +16,10 @@ def __init__(self, *args, **kwargs): movie = QMovie(os.path.join(HERE, "../../icons/dark/Spinner-1s-200px.gif")) movie.setScaledSize(QSize(64, 64)) # Resize the GIF self.ui.label.setMovie(movie) # Set the QMovie object to the label - movie.start() # Start the animation \ No newline at end of file + movie.start() # Start the animation + + def set_size(self, spinner_size: QSize, label_size: QSize): + self.ui.label.movie().setScaledSize(spinner_size) + self.ui.label.setFixedSize(label_size) + self.ui.label.update() + self.ui.label.repaint() diff --git a/src/airunner/widgets/llm/templates/loading.ui b/src/airunner/widgets/llm/templates/loading.ui index accbdfcde..2546b6647 100644 --- a/src/airunner/widgets/llm/templates/loading.ui +++ b/src/airunner/widgets/llm/templates/loading.ui @@ -14,13 +14,28 @@ Form + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + - Qt::AlignCenter + Qt::AlignmentFlag::AlignCenter diff --git a/src/airunner/widgets/llm/templates/loading_ui.py b/src/airunner/widgets/llm/templates/loading_ui.py index ea2905cdb..2eb946005 100644 --- a/src/airunner/widgets/llm/templates/loading_ui.py +++ b/src/airunner/widgets/llm/templates/loading_ui.py @@ -24,10 +24,12 @@ def setupUi(self, loading_message): loading_message.setObjectName(u"loading_message") loading_message.resize(605, 707) self.gridLayout = QGridLayout(loading_message) + self.gridLayout.setSpacing(0) self.gridLayout.setObjectName(u"gridLayout") + self.gridLayout.setContentsMargins(0, 0, 0, 0) self.label = QLabel(loading_message) self.label.setObjectName(u"label") - self.label.setAlignment(Qt.AlignCenter) + self.label.setAlignment(Qt.AlignmentFlag.AlignCenter) self.gridLayout.addWidget(self.label, 0, 0, 1, 1) diff --git a/src/airunner/widgets/lora/lora_container_widget.py b/src/airunner/widgets/lora/lora_container_widget.py index 963342410..5f5808c12 100644 --- a/src/airunner/widgets/lora/lora_container_widget.py +++ b/src/airunner/widgets/lora/lora_container_widget.py @@ -1,9 +1,9 @@ import os -from PySide6.QtCore import Slot +from PySide6.QtCore import Slot, QSize from PySide6.QtWidgets import QWidget, QSizePolicy -from airunner.enums import SignalCode +from airunner.enums import SignalCode, ModelType, ModelStatus from airunner.utils.models.scan_path_for_items import scan_path_for_lora from airunner.widgets.base_widget import BaseWidget from airunner.widgets.lora.lora_widget import LoraWidget @@ -13,21 +13,83 @@ class LoraContainerWidget(BaseWidget): widget_class_ = Ui_lora_container lora_loaded = False - total_lora_by_section = {} + # total_lora_by_section = {} search_filter = "" spacer = None def __init__(self, *args, **kwargs): + self._version = None super().__init__(*args, **kwargs) - self.loras = None self.initialized = False - self._version = None self.register(SignalCode.APPLICATION_SETTINGS_CHANGED_SIGNAL, self.on_application_settings_changed_signal) + self.register(SignalCode.LORA_UPDATED_SIGNAL, self.on_lora_updated_signal) + self.register(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal) + self.register(SignalCode.LORA_STATUS_CHANGED, self.on_lora_modified) + self.register(SignalCode.LORA_DELETE_SIGNAL, self.on_lora_modified) + self.ui.loading_icon.hide() + self.ui.loading_icon.set_size(spinner_size=QSize(30, 30), label_size=QSize(24, 24)) + self._apply_button_enabled = False + self.ui.apply_lora_button.setEnabled(self._apply_button_enabled) + + @Slot() + def scan_for_lora(self): + # clear all lora widgets + loras = scan_path_for_lora(self.path_settings.base_path) + self.update_loras(loras) + self.load_lora() + + @Slot() + def apply_lora(self): + self._apply_button_enabled = False + self.ui.apply_lora_button.setEnabled(self._apply_button_enabled) + self.emit_signal(SignalCode.LORA_UPDATE_SIGNAL) + self._disable_form() + + def on_lora_modified(self): + self._apply_button_enabled = True + self.ui.apply_lora_button.setEnabled(self._apply_button_enabled) + + def on_model_status_changed_signal(self, data): + model = data["model"] + status = data["status"] + if model is ModelType.SD: + if status is ModelStatus.LOADING: + self.ui.loading_icon.show() + self._disable_form() + else: + self.ui.loading_icon.hide() + self._enable_form() + + def _disable_form(self): + self.ui.apply_lora_button.setEnabled(self._apply_button_enabled) + self.ui.lora_scale_slider.setEnabled(False) + self.ui.toggleAllLora.setEnabled(False) + self.ui.loading_icon.show() + self._toggle_lora_widgets(False) + + def _enable_form(self): + self.ui.apply_lora_button.setEnabled(self._apply_button_enabled) + self.ui.lora_scale_slider.setEnabled(True) + self.ui.toggleAllLora.setEnabled(True) + self.ui.loading_icon.hide() + self._toggle_lora_widgets(True) + + def _toggle_lora_widgets(self, enable: bool): + for i in range(self.ui.lora_scroll_area.widget().layout().count()): + lora_widget = self.ui.lora_scroll_area.widget().layout().itemAt(i).widget() + if isinstance(lora_widget, LoraWidget): + if enable: + lora_widget.enable_lora_widget() + else: + lora_widget.disable_lora_widget() def on_application_settings_changed_signal(self): self.load_lora() + def on_lora_updated_signal(self): + self._enable_form() + def toggle_all(self, val): lora_widgets = [ self.ui.lora_scroll_area.widget().layout().itemAt(i).widget() @@ -36,18 +98,15 @@ def toggle_all(self, val): ] for lora_widget in lora_widgets: lora_widget.ui.enabledCheckbox.blockSignals(True) - lora_widget.action_toggled_lora_enabled(val, emit_signal=False) + lora_widget.action_toggled_lora_enabled(val) lora_widget.ui.enabledCheckbox.blockSignals(False) - for lora in self.lora: - lora.enabled = val - self.update_lora(lora) - self.emit_signal(SignalCode.LORA_UPDATE_SIGNAL) def showEvent(self, event): if not self.initialized: self.register(SignalCode.LORA_DELETE_SIGNAL, self.delete_lora) self.scan_for_lora() self.initialized = True + self.load_lora() def load_lora(self): version = self.generator_settings.version @@ -111,13 +170,6 @@ def delete_lora(self, data: dict): os.remove(os.path.join(dirpath, file)) break - @Slot() - def scan_for_lora(self): - # clear all lora widgets - loras = scan_path_for_lora(self.path_settings.base_path) - self.update_loras(loras) - self.load_lora() - def available_lora(self, action): available_lora = [] for lora in self.lora: @@ -125,83 +177,81 @@ def available_lora(self, action): available_lora.append(lora) return available_lora - def get_available_loras(self, tab_name): - lora_path = os.path.expanduser( - os.path.join( - self.path_settings.base_path, - "art/models", - self._version, - "lora" - ) - ) - if not os.path.exists(lora_path): - return [] - available_lora = self.get_list_of_available_loras(tab_name, lora_path, lora_names=self.lora) - return available_lora - - def get_list_of_available_loras(self, tab_name, lora_path, lora_names=None): - self.total_lora_by_section = { - "total": 0, - "enabled": 0 - } - - if lora_names is None: - lora_names = [] - if not os.path.exists(lora_path): - return lora_names - possible_line_endings = ["ckpt", "safetensors", "bin"] - new_loras = [] - from airunner.aihandler.models.settings_models import Lora - - for lora_file in os.listdir(lora_path): - if os.path.isdir(os.path.join(lora_path, lora_file)): - lora_names = self.get_list_of_available_loras(tab_name, os.path.join(lora_path, lora_file), lora_names) - if lora_file.split(".")[-1] in possible_line_endings: - name = lora_file.split(".")[0] - scale = 100.0 - enabled = True - trigger_word = "" - for lora in self.lora: - if lora.name == name: - scale = lora.scale - enabled = lora.enabled - trigger_word = lora.trigger_word if trigger_word in lora else "" - self.total_lora_by_section["total"] += 1 - if enabled: - self.total_lora_by_section["enabled"] += 1 - break - new_lora = Lora( - name=name, - scale=scale, - enabled=enabled, - loaded=False, - trigger_word=trigger_word - ) - self.create_lora(new_lora) - - # check if name already in lora_names: - for old_lora in lora_names: - name = old_lora["name"] - found = False - for new_lora in new_loras: - if new_lora.name == name: - found = True - break - if not found: - lora_names.remove(old_lora) - merge_lora = [] - for new_lora in new_loras: - name = new_lora.name - found = False - for current_lora in lora_names: - if current_lora["name"] == name: - found = True - if not found: - merge_lora.append(new_lora) - lora_names.extend(merge_lora) - return lora_names - - lora_tab_container = None + # def get_available_loras(self, tab_name): + # lora_path = os.path.expanduser( + # os.path.join( + # self.path_settings.base_path, + # "art/models", + # self._version, + # "lora" + # ) + # ) + # if not os.path.exists(lora_path): + # return [] + # available_lora = self.get_list_of_available_loras(tab_name, lora_path, lora_names=self.lora) + # return available_lora + + # def get_list_of_available_loras(self, tab_name, lora_path, lora_names=None): + # self.total_lora_by_section = { + # "total": 0, + # "enabled": 0 + # } + # + # if lora_names is None: + # lora_names = [] + # if not os.path.exists(lora_path): + # return lora_names + # possible_line_endings = ["ckpt", "safetensors", "bin"] + # new_loras = [] + # from airunner.aihandler.models.settings_models import Lora + # + # for lora_file in os.listdir(lora_path): + # if os.path.isdir(os.path.join(lora_path, lora_file)): + # lora_names = self.get_list_of_available_loras(tab_name, os.path.join(lora_path, lora_file), lora_names) + # if lora_file.split(".")[-1] in possible_line_endings: + # name = lora_file.split(".")[0] + # scale = 100.0 + # enabled = True + # trigger_word = "" + # for lora in self.lora: + # if lora.name == name: + # scale = lora.scale + # enabled = lora.enabled + # trigger_word = lora.trigger_word if trigger_word in lora else "" + # self.total_lora_by_section["total"] += 1 + # if enabled: + # self.total_lora_by_section["enabled"] += 1 + # break + # new_lora = Lora( + # name=name, + # scale=scale, + # enabled=enabled, + # loaded=False, + # trigger_word=trigger_word + # ) + # self.create_lora(new_lora) + # + # # check if name already in lora_names: + # for old_lora in lora_names: + # name = old_lora["name"] + # found = False + # for new_lora in new_loras: + # if new_lora.name == name: + # found = True + # break + # if not found: + # lora_names.remove(old_lora) + # merge_lora = [] + # for new_lora in new_loras: + # name = new_lora.name + # found = False + # for current_lora in lora_names: + # if current_lora["name"] == name: + # found = True + # if not found: + # merge_lora.append(new_lora) + # lora_names.extend(merge_lora) + # return lora_names def initialize_lora_trigger_words(self): for lora in self.lora: @@ -234,20 +284,10 @@ def toggle_lora(self, lora, value, tab_name): lora_object = self.loras[n] lora_object.enabled = value == 2 self.update_lora(lora_object) - if value == 2: - self.total_lora_by_section["enabled"] += 1 - else: - self.total_lora_by_section["enabled"] -= 1 - self.update_lora_tab_name(tab_name) - - def update_lora_tab_name(self, tab_name): - # if tab_name not in self.total_lora_by_section: - # self.total_lora_by_section[tab_name] = {"total": 0, "enabled": 0} - # self.tabs[tab_name].PromptTabsSection.setTabText( - # 2, - # f'LoRA ({self.total_lora_by_section[tab_name]["enabled"]}/{self.total_lora_by_section[tab_name]["total"]})' - # ) - pass + # if value == 2: + # self.total_lora_by_section["enabled"] += 1 + # else: + # self.total_lora_by_section["enabled"] -= 1 def handle_lora_slider(self, lora, lora_widget, value, tab_name): float_val = value / 100 @@ -268,6 +308,7 @@ def handle_lora_spinbox(self, lora, lora_widget, value, tab_name): def search_text_changed(self, val): self.search_filter = val + self._version = None self.load_lora() def clear_lora_widgets(self): diff --git a/src/airunner/widgets/lora/lora_widget.py b/src/airunner/widgets/lora/lora_widget.py index b1241864f..7ea53e2b5 100644 --- a/src/airunner/widgets/lora/lora_widget.py +++ b/src/airunner/widgets/lora/lora_widget.py @@ -1,5 +1,6 @@ -from PySide6.QtCore import QTimer +from PySide6.QtCore import QTimer, Slot +from airunner.aihandler.models.settings_models import Lora from airunner.enums import SignalCode from airunner.widgets.base_widget import BaseWidget from airunner.widgets.lora.lora_trigger_word_widget import LoraTriggerWordWidget @@ -18,7 +19,7 @@ def __init__(self, *args, **kwargs): self.icons = [ ("recycle-bin-line-icon", "delete_button"), ] - self.current_lora = kwargs.pop("lora", None) + self.current_lora:Lora = kwargs.pop("lora") super().__init__(*args, **kwargs) name = self.current_lora.name enabled = self.current_lora.enabled @@ -28,7 +29,7 @@ def __init__(self, *args, **kwargs): self.ui.enabledCheckbox.blockSignals(True) self.ui.trigger_word_edit.blockSignals(True) - self.ui.enabledCheckbox.setTitle(name) + self.ui.enabledCheckbox.setText(name) self.ui.enabledCheckbox.setChecked(enabled) self.ui.trigger_word_edit.setText(trigger_word) @@ -39,6 +40,41 @@ def __init__(self, *args, **kwargs): # Defer the creation of trigger word widgets self.create_trigger_word_widgets(self.current_lora, defer=True) + self.ui.scale_slider.setProperty("table_id", self.current_lora.id) + + def disable_lora_widget(self): + self.ui.enabledCheckbox.setEnabled(False) + self.ui.delete_button.setEnabled(False) + self.ui.scale_slider.setEnabled(False) + + def enable_lora_widget(self): + self.ui.enabledCheckbox.setEnabled(True) + self.ui.delete_button.setEnabled(True) + self.ui.scale_slider.setEnabled(True) + + @Slot(bool) + def action_toggled_lora_enabled(self, val): + self.current_lora.enabled = val + self.ui.enabledCheckbox.blockSignals(True) + self.ui.enabledCheckbox.setChecked(val) + self.ui.enabledCheckbox.blockSignals(False) + self.update_lora(self.current_lora) + self.emit_signal(SignalCode.LORA_STATUS_CHANGED) + + @Slot(str) + def action_text_changed_trigger_word(self, val): + self.current_lora.trigger_word = val + self.update_lora(self.current_lora) + + @Slot() + def action_clicked_button_deleted(self): + self.emit_signal( + SignalCode.LORA_DELETE_SIGNAL, + { + "lora_widget": self + } + ) + def create_trigger_word_widgets(self, lora, defer=False): if defer: # Defer the creation of trigger word widgets @@ -47,36 +83,17 @@ def create_trigger_word_widgets(self, lora, defer=False): self._create_trigger_word_widgets(lora) def _create_trigger_word_widgets(self, lora): - for i in reversed(range(self.ui.enabledCheckbox.layout().count())): - widget = self.ui.enabledCheckbox.layout().itemAt(i).widget() + for i in reversed(range(self.ui.lora_container.layout().count())): + widget = self.ui.lora_container.layout().itemAt(i).widget() if isinstance(widget, LoraTriggerWordWidget): widget.deleteLater() for word in lora.trigger_word.split(","): if word.strip() == "": continue widget = LoraTriggerWordWidget(trigger_word=word) - self.ui.enabledCheckbox.layout().addWidget(widget) + self.ui.lora_container.layout().addWidget(widget) def action_changed_trigger_words(self, val): self.current_lora.trigger_word = val self.update_lora(self.current_lora) self.create_trigger_word_widgets(self.current_lora) - - def action_toggled_lora_enabled(self, val, emit_signal=True): - self.ui.enabledCheckbox.setChecked(val) - self.current_lora.enabled = val - self.update_lora(self.current_lora) - if emit_signal: - self.emit_signal(SignalCode.LORA_UPDATE_SIGNAL) - - def action_text_changed_trigger_word(self, val): - self.current_lora.trigger_word = val - self.update_lora(self.current_lora) - - def action_clicked_button_deleted(self): - self.emit_signal( - SignalCode.LORA_DELETE_SIGNAL, - { - "lora_widget": self - } - ) diff --git a/src/airunner/widgets/lora/templates/lora.ui b/src/airunner/widgets/lora/templates/lora.ui index 4fe4363cc..926dae820 100644 --- a/src/airunner/widgets/lora/templates/lora.ui +++ b/src/airunner/widgets/lora/templates/lora.ui @@ -6,8 +6,8 @@ 0 0 - 437 - 49 + 518 + 80 @@ -52,13 +52,7 @@ 0 - - - LoRA name here - - - true - + 0 @@ -72,10 +66,20 @@ 0 - - 0 + + 5 + + + 10 + + + enabledCheckbox + + + + <html><head/><body><p>Some LoRA require a trigger word to activate.</p><p>Make a note here for your records.</p></body></html> @@ -85,7 +89,7 @@ - + @@ -114,6 +118,58 @@ + + + + + 0 + 0 + + + + + 0 + 0 + + + + 0 + + + 100 + + + 0.000000000000000 + + + 1.000000000000000 + + + true + + + lora.scale + + + 1 + + + 5 + + + 0.010000000000000 + + + 0.100000000000000 + + + Scale + + + 0 + + + @@ -121,6 +177,14 @@ + + + SliderWidget + QWidget +
airunner/widgets/slider/slider_widget
+ 1 +
+
@@ -132,12 +196,12 @@ action_changed_trigger_words(QString) - 137 - 66 + 401 + 43 - 127 - 0 + 404 + -14 @@ -148,47 +212,15 @@ action_toggled_lora_enabled(bool) - 47 - 33 + 99 + 2 - 45 + 88 0 - - trigger_word_edit - textEdited(QString) - lora - action_text_changed_trigger_word(QString) - - - 215 - 66 - - - 412 - 0 - - - - - delete_button - clicked() - lora - action_clicked_button_deleted() - - - 239 - 54 - - - 261 - -6 - - - action_toggled_lora_enabled(bool) diff --git a/src/airunner/widgets/lora/templates/lora_container.ui b/src/airunner/widgets/lora/templates/lora_container.ui index 4ec9c4cd3..be50a859d 100644 --- a/src/airunner/widgets/lora/templates/lora_container.ui +++ b/src/airunner/widgets/lora/templates/lora_container.ui @@ -32,61 +32,14 @@ 10 - - - - - true - - - - Lora - - - - - - - 10 - - - 10 - - - - - Search - - - - - - - - 9 - - - - Toggle all - - - false - - - false - - - - - - + Scan for LoRA - + @@ -108,7 +61,7 @@ 0 0 583 - 719 + 665 @@ -134,16 +87,200 @@ - - - - Qt::Orientation::Horizontal + + + + 10 - + + 0 + + + 0 + + + 10 + + + 10 + + + + + 10 + + + + + + true + + + + Lora + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + 0 + 0 + + + + + 0 + 0 + + + + + + + + Apply Lora + + + + + + + + + Qt::Orientation::Horizontal + + + + + + + + 0 + 0 + + + + + 0 + 0 + + + + 0 + + + 100 + + + 0.000000000000000 + + + 1.000000000000000 + + + true + + + generator_settings.lora_scale + + + 1 + + + 5 + + + 0.010000000000000 + + + 0.100000000000000 + + + Scale + + + + + + + 10 + + + 0 + + + 10 + + + + + Search + + + + + + + + 9 + + + + Toggle all + + + false + + + false + + + + + + + + + Qt::Orientation::Horizontal + + + + - + + + SliderWidget + QWidget +
airunner/widgets/slider/slider_widget
+ 1 +
+ + LoadingWidget + QWidget +
airunner/widgets/llm/loading_widget
+ 1 +
+
+ + + toggleAllLora @@ -152,8 +289,8 @@ toggle_all(bool) - 547 - 17 + 562 + 91 52 @@ -168,8 +305,8 @@ search_text_changed(QString) - 190 - 20 + 192 + 92 277 @@ -185,7 +322,7 @@ 399 - 752 + 830 486 @@ -193,10 +330,27 @@ + + apply_lora_button + clicked() + lora_container + apply_lora() + + + 532 + 17 + + + 474 + -7 + + + toggle_all(bool) search_text_changed(QString) scan_for_lora() + apply_lora() diff --git a/src/airunner/widgets/lora/templates/lora_container_ui.py b/src/airunner/widgets/lora/templates/lora_container_ui.py index eb19f5087..6aa62b495 100644 --- a/src/airunner/widgets/lora/templates/lora_container_ui.py +++ b/src/airunner/widgets/lora/templates/lora_container_ui.py @@ -17,7 +17,12 @@ QPalette, QPixmap, QRadialGradient, QTransform) from PySide6.QtWidgets import (QApplication, QCheckBox, QFrame, QGridLayout, QHBoxLayout, QLabel, QLineEdit, QPushButton, - QScrollArea, QSizePolicy, QWidget) + QScrollArea, QSizePolicy, QSpacerItem, QVBoxLayout, + QWidget) + +from airunner.widgets.llm.loading_widget import LoadingWidget +from airunner.widgets.slider.slider_widget import SliderWidget +import airunner.resources_light_rc class Ui_lora_container(object): def setupUi(self, lora_container): @@ -29,50 +34,22 @@ def setupUi(self, lora_container): self.gridLayout.setHorizontalSpacing(0) self.gridLayout.setVerticalSpacing(10) self.gridLayout.setContentsMargins(0, 0, 0, 0) - self.label = QLabel(lora_container) - self.label.setObjectName(u"label") - font = QFont() - font.setBold(True) - self.label.setFont(font) - - self.gridLayout.addWidget(self.label, 0, 0, 1, 1) - - self.horizontalLayout = QHBoxLayout() - self.horizontalLayout.setSpacing(10) - self.horizontalLayout.setObjectName(u"horizontalLayout") - self.horizontalLayout.setContentsMargins(-1, -1, 10, -1) - self.lineEdit = QLineEdit(lora_container) - self.lineEdit.setObjectName(u"lineEdit") - - self.horizontalLayout.addWidget(self.lineEdit) - - self.toggleAllLora = QCheckBox(lora_container) - self.toggleAllLora.setObjectName(u"toggleAllLora") - font1 = QFont() - font1.setPointSize(9) - self.toggleAllLora.setFont(font1) - self.toggleAllLora.setChecked(False) - self.toggleAllLora.setTristate(False) - - self.horizontalLayout.addWidget(self.toggleAllLora) - - - self.gridLayout.addLayout(self.horizontalLayout, 2, 0, 1, 1) - self.pushButton = QPushButton(lora_container) self.pushButton.setObjectName(u"pushButton") - self.gridLayout.addWidget(self.pushButton, 4, 0, 1, 1) + self.gridLayout.addWidget(self.pushButton, 8, 0, 1, 1) self.lora_scroll_area = QScrollArea(lora_container) self.lora_scroll_area.setObjectName(u"lora_scroll_area") - self.lora_scroll_area.setFont(font1) + font = QFont() + font.setPointSize(9) + self.lora_scroll_area.setFont(font) self.lora_scroll_area.setFrameShape(QFrame.Shape.NoFrame) self.lora_scroll_area.setFrameShadow(QFrame.Shadow.Plain) self.lora_scroll_area.setWidgetResizable(True) self.scrollAreaWidgetContents = QWidget() self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents") - self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 583, 719)) + self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 583, 665)) self.gridLayout_2 = QGridLayout(self.scrollAreaWidgetContents) self.gridLayout_2.setObjectName(u"gridLayout_2") self.gridLayout_2.setHorizontalSpacing(0) @@ -80,29 +57,118 @@ def setupUi(self, lora_container): self.gridLayout_2.setContentsMargins(0, 0, 10, 0) self.lora_scroll_area.setWidget(self.scrollAreaWidgetContents) - self.gridLayout.addWidget(self.lora_scroll_area, 3, 0, 1, 1) + self.gridLayout.addWidget(self.lora_scroll_area, 7, 0, 1, 1) + + self.verticalLayout = QVBoxLayout() + self.verticalLayout.setSpacing(10) + self.verticalLayout.setObjectName(u"verticalLayout") + self.verticalLayout.setContentsMargins(0, 0, 10, 10) + self.horizontalLayout_2 = QHBoxLayout() + self.horizontalLayout_2.setObjectName(u"horizontalLayout_2") + self.horizontalLayout_2.setContentsMargins(-1, -1, 10, -1) + self.label = QLabel(lora_container) + self.label.setObjectName(u"label") + font1 = QFont() + font1.setBold(True) + self.label.setFont(font1) + + self.horizontalLayout_2.addWidget(self.label) + + self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.horizontalLayout_2.addItem(self.horizontalSpacer) + + self.loading_icon = LoadingWidget(lora_container) + self.loading_icon.setObjectName(u"loading_icon") + sizePolicy = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.loading_icon.sizePolicy().hasHeightForWidth()) + self.loading_icon.setSizePolicy(sizePolicy) + self.loading_icon.setMinimumSize(QSize(0, 0)) + + self.horizontalLayout_2.addWidget(self.loading_icon) + + self.apply_lora_button = QPushButton(lora_container) + self.apply_lora_button.setObjectName(u"apply_lora_button") + + self.horizontalLayout_2.addWidget(self.apply_lora_button) + + + self.verticalLayout.addLayout(self.horizontalLayout_2) self.line = QFrame(lora_container) self.line.setObjectName(u"line") self.line.setFrameShape(QFrame.Shape.HLine) self.line.setFrameShadow(QFrame.Shadow.Sunken) - self.gridLayout.addWidget(self.line, 1, 0, 1, 1) + self.verticalLayout.addWidget(self.line) + + self.lora_scale_slider = SliderWidget(lora_container) + self.lora_scale_slider.setObjectName(u"lora_scale_slider") + sizePolicy.setHeightForWidth(self.lora_scale_slider.sizePolicy().hasHeightForWidth()) + self.lora_scale_slider.setSizePolicy(sizePolicy) + self.lora_scale_slider.setMinimumSize(QSize(0, 0)) + self.lora_scale_slider.setProperty("slider_minimum", 0) + self.lora_scale_slider.setProperty("slider_maximum", 100) + self.lora_scale_slider.setProperty("spinbox_minimum", 0.000000000000000) + self.lora_scale_slider.setProperty("spinbox_maximum", 1.000000000000000) + self.lora_scale_slider.setProperty("display_as_float", True) + self.lora_scale_slider.setProperty("slider_single_step", 1) + self.lora_scale_slider.setProperty("slider_page_step", 5) + self.lora_scale_slider.setProperty("spinbox_single_step", 0.010000000000000) + self.lora_scale_slider.setProperty("spinbox_page_step", 0.100000000000000) + + self.verticalLayout.addWidget(self.lora_scale_slider) + + self.horizontalLayout = QHBoxLayout() + self.horizontalLayout.setSpacing(10) + self.horizontalLayout.setObjectName(u"horizontalLayout") + self.horizontalLayout.setContentsMargins(0, -1, 10, -1) + self.lineEdit = QLineEdit(lora_container) + self.lineEdit.setObjectName(u"lineEdit") + + self.horizontalLayout.addWidget(self.lineEdit) + + self.toggleAllLora = QCheckBox(lora_container) + self.toggleAllLora.setObjectName(u"toggleAllLora") + self.toggleAllLora.setFont(font) + self.toggleAllLora.setChecked(False) + self.toggleAllLora.setTristate(False) + + self.horizontalLayout.addWidget(self.toggleAllLora) + + + self.verticalLayout.addLayout(self.horizontalLayout) + + self.line_2 = QFrame(lora_container) + self.line_2.setObjectName(u"line_2") + self.line_2.setFrameShape(QFrame.Shape.HLine) + self.line_2.setFrameShadow(QFrame.Shadow.Sunken) + + self.verticalLayout.addWidget(self.line_2) + + + self.gridLayout.addLayout(self.verticalLayout, 2, 0, 1, 1) self.retranslateUi(lora_container) self.toggleAllLora.toggled.connect(lora_container.toggle_all) self.lineEdit.textEdited.connect(lora_container.search_text_changed) self.pushButton.clicked.connect(lora_container.scan_for_lora) + self.apply_lora_button.clicked.connect(lora_container.apply_lora) QMetaObject.connectSlotsByName(lora_container) # setupUi def retranslateUi(self, lora_container): lora_container.setWindowTitle(QCoreApplication.translate("lora_container", u"Form", None)) + self.pushButton.setText(QCoreApplication.translate("lora_container", u"Scan for LoRA", None)) self.label.setText(QCoreApplication.translate("lora_container", u"Lora", None)) + self.apply_lora_button.setText(QCoreApplication.translate("lora_container", u"Apply Lora", None)) + self.lora_scale_slider.setProperty("settings_property", QCoreApplication.translate("lora_container", u"generator_settings.lora_scale", None)) + self.lora_scale_slider.setProperty("label_text", QCoreApplication.translate("lora_container", u"Scale", None)) self.lineEdit.setPlaceholderText(QCoreApplication.translate("lora_container", u"Search", None)) self.toggleAllLora.setText(QCoreApplication.translate("lora_container", u"Toggle all", None)) - self.pushButton.setText(QCoreApplication.translate("lora_container", u"Scan for LoRA", None)) # retranslateUi diff --git a/src/airunner/widgets/lora/templates/lora_ui.py b/src/airunner/widgets/lora/templates/lora_ui.py index 446d8bcf4..c0a643f39 100644 --- a/src/airunner/widgets/lora/templates/lora_ui.py +++ b/src/airunner/widgets/lora/templates/lora_ui.py @@ -15,15 +15,17 @@ QFont, QFontDatabase, QGradient, QIcon, QImage, QKeySequence, QLinearGradient, QPainter, QPalette, QPixmap, QRadialGradient, QTransform) -from PySide6.QtWidgets import (QApplication, QGridLayout, QGroupBox, QLineEdit, +from PySide6.QtWidgets import (QApplication, QCheckBox, QGridLayout, QLineEdit, QPushButton, QSizePolicy, QWidget) + +from airunner.widgets.slider.slider_widget import SliderWidget import airunner.resources_light_rc class Ui_lora(object): def setupUi(self, lora): if not lora.objectName(): lora.setObjectName(u"lora") - lora.resize(437, 49) + lora.resize(518, 80) sizePolicy = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -40,19 +42,24 @@ def setupUi(self, lora): self.gridLayout = QGridLayout() self.gridLayout.setObjectName(u"gridLayout") self.gridLayout.setHorizontalSpacing(0) - self.enabledCheckbox = QGroupBox(lora) - self.enabledCheckbox.setObjectName(u"enabledCheckbox") - self.enabledCheckbox.setCheckable(True) - self.gridLayout_3 = QGridLayout(self.enabledCheckbox) - self.gridLayout_3.setSpacing(0) + self.lora_container = QWidget(lora) + self.lora_container.setObjectName(u"lora_container") + self.gridLayout_3 = QGridLayout(self.lora_container) self.gridLayout_3.setObjectName(u"gridLayout_3") + self.gridLayout_3.setHorizontalSpacing(5) + self.gridLayout_3.setVerticalSpacing(10) self.gridLayout_3.setContentsMargins(0, 0, 0, 0) - self.trigger_word_edit = QLineEdit(self.enabledCheckbox) + self.enabledCheckbox = QCheckBox(self.lora_container) + self.enabledCheckbox.setObjectName(u"enabledCheckbox") + + self.gridLayout_3.addWidget(self.enabledCheckbox, 0, 0, 1, 1) + + self.trigger_word_edit = QLineEdit(self.lora_container) self.trigger_word_edit.setObjectName(u"trigger_word_edit") - self.gridLayout_3.addWidget(self.trigger_word_edit, 0, 0, 1, 1) + self.gridLayout_3.addWidget(self.trigger_word_edit, 2, 0, 1, 1) - self.delete_button = QPushButton(self.enabledCheckbox) + self.delete_button = QPushButton(self.lora_container) self.delete_button.setObjectName(u"delete_button") self.delete_button.setMinimumSize(QSize(24, 24)) self.delete_button.setMaximumSize(QSize(24, 24)) @@ -61,10 +68,28 @@ def setupUi(self, lora): icon.addFile(u":/icons/light/recycle-bin-line-icon.svg", QSize(), QIcon.Normal, QIcon.Off) self.delete_button.setIcon(icon) - self.gridLayout_3.addWidget(self.delete_button, 0, 1, 1, 1) + self.gridLayout_3.addWidget(self.delete_button, 2, 2, 1, 1) + + self.scale_slider = SliderWidget(self.lora_container) + self.scale_slider.setObjectName(u"scale_slider") + sizePolicy.setHeightForWidth(self.scale_slider.sizePolicy().hasHeightForWidth()) + self.scale_slider.setSizePolicy(sizePolicy) + self.scale_slider.setMinimumSize(QSize(0, 0)) + self.scale_slider.setProperty("slider_minimum", 0) + self.scale_slider.setProperty("slider_maximum", 100) + self.scale_slider.setProperty("spinbox_minimum", 0.000000000000000) + self.scale_slider.setProperty("spinbox_maximum", 1.000000000000000) + self.scale_slider.setProperty("display_as_float", True) + self.scale_slider.setProperty("slider_single_step", 1) + self.scale_slider.setProperty("slider_page_step", 5) + self.scale_slider.setProperty("spinbox_single_step", 0.010000000000000) + self.scale_slider.setProperty("spinbox_page_step", 0.100000000000000) + self.scale_slider.setProperty("table_id", 0) + + self.gridLayout_3.addWidget(self.scale_slider, 1, 0, 1, 3) - self.gridLayout.addWidget(self.enabledCheckbox, 0, 0, 2, 1) + self.gridLayout.addWidget(self.lora_container, 0, 0, 2, 1) self.gridLayout_2.addLayout(self.gridLayout, 0, 0, 1, 1) @@ -73,15 +98,13 @@ def setupUi(self, lora): self.retranslateUi(lora) self.trigger_word_edit.textChanged.connect(lora.action_changed_trigger_words) self.enabledCheckbox.toggled.connect(lora.action_toggled_lora_enabled) - self.trigger_word_edit.textEdited.connect(lora.action_text_changed_trigger_word) - self.delete_button.clicked.connect(lora.action_clicked_button_deleted) QMetaObject.connectSlotsByName(lora) # setupUi def retranslateUi(self, lora): lora.setWindowTitle(QCoreApplication.translate("lora", u"Form", None)) - self.enabledCheckbox.setTitle(QCoreApplication.translate("lora", u"LoRA name here", None)) + self.enabledCheckbox.setText(QCoreApplication.translate("lora", u"enabledCheckbox", None)) #if QT_CONFIG(tooltip) self.trigger_word_edit.setToolTip(QCoreApplication.translate("lora", u"

Some LoRA require a trigger word to activate.

Make a note here for your records.

", None)) #endif // QT_CONFIG(tooltip) @@ -90,5 +113,7 @@ def retranslateUi(self, lora): self.delete_button.setToolTip(QCoreApplication.translate("lora", u"Delete model", None)) #endif // QT_CONFIG(tooltip) self.delete_button.setText("") + self.scale_slider.setProperty("settings_property", QCoreApplication.translate("lora", u"lora.scale", None)) + self.scale_slider.setProperty("label_text", QCoreApplication.translate("lora", u"Scale", None)) # retranslateUi diff --git a/src/airunner/widgets/slider/slider_widget.py b/src/airunner/widgets/slider/slider_widget.py index be882a0fb..227947d34 100644 --- a/src/airunner/widgets/slider/slider_widget.py +++ b/src/airunner/widgets/slider/slider_widget.py @@ -1,6 +1,8 @@ from typing import Any, List from PySide6.QtCore import Slot from PySide6.QtWidgets import QDoubleSpinBox + +from airunner.aihandler.models.settings_models import Lora from airunner.enums import SignalCode from airunner.widgets.base_widget import BaseWidget from airunner.widgets.slider.templates.slider_ui import Ui_slider_widget @@ -15,8 +17,10 @@ class SliderWidget(BaseWidget): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.settings_property = None - self.register(SignalCode.APPLICATION_MAIN_WINDOW_LOADED_SIGNAL, self.on_main_window_loaded_signal) - self.register(SignalCode.WINDOW_LOADED_SIGNAL, self.on_main_window_loaded_signal) + self.table_id = None + self.table_name = None + self.table_column = None + self.table_item = None self.ui.slider.sliderReleased.connect(self.handle_slider_release) # Connect valueChanged signal self.ui.slider_spinbox.valueChanged.connect(self.handle_spinbox_change) # Connect valueChanged signal self._callback = None @@ -102,11 +106,9 @@ def spinbox_minimum(self): def spinbox_minimum(self, val): self.ui.slider_spinbox.setMinimum(val) - def on_main_window_loaded_signal(self): - try: - self.init() - except RuntimeError as e: - self.logger.error(f"Error initializing SliderWidget: {e}") + def showEvent(self, event): + self.init() + super().showEvent(event) def init(self, **kwargs): self.is_loading = True @@ -119,6 +121,9 @@ def init(self, **kwargs): spinbox_maximum = kwargs.get("spinbox_maximum", self.property("spinbox_maximum") or 100.0) current_value = None settings_property = kwargs.get("settings_property", self.property("settings_property") or None) + self.table_id = self.property("table_id") or None + if self.table_id is not None: + self.table_name, self.table_column = settings_property.split(".") label_text = kwargs.get("label_text", self.property("label_text") or "") display_as_float = kwargs.get("display_as_float", self.property("display_as_float") or False) @@ -134,7 +139,13 @@ def init(self, **kwargs): divide_by = self.property("divide_by") or 1.0 - if current_value is None: + if self.table_id is not None and self.table_name is not None and self.table_column is not None: + session = self.db_handler.get_db_session() + if self.table_name == "lora": + self.table_item = session.query(Lora).filter_by(id=self.table_id).first() + current_value = getattr(self.table_item, self.table_column) + session.close() + elif current_value is None: if settings_property is not None: current_value = self.get_settings_value(settings_property) else: @@ -189,6 +200,8 @@ def slider_callback(self, attr_name, value=None, widget=None): self.set_settings_value(attr_name, value) def get_settings_value(self, settings_property): + if self.table_item is not None: + return getattr(self.table_item, self.table_column) keys = settings_property.split(".") if len(keys) == 1: @@ -202,10 +215,15 @@ def get_settings_value(self, settings_property): return getattr(obj, keys[1]) def set_settings_value(self, settings_property: str, val: Any): - if settings_property is None: - return - keys = settings_property.split(".") - self.update_settings_by_name(keys[0], keys[1], val) + if self.table_item is not None: + session = self.db_handler.get_db_session() + setattr(self.table_item, self.table_column, val) + session.add(self.table_item) + session.commit() + session.close() + elif settings_property is not None: + keys = settings_property.split(".") + self.update_settings_by_name(keys[0], keys[1], val) def _update_dict_recursively(self, data: dict, keys: List[str], val: Any) -> dict: if len(keys) == 1: diff --git a/src/airunner/widgets/slider/templates/slider.ui b/src/airunner/widgets/slider/templates/slider.ui index 14a3221f2..52aa0bd3e 100644 --- a/src/airunner/widgets/slider/templates/slider.ui +++ b/src/airunner/widgets/slider/templates/slider.ui @@ -7,7 +7,7 @@ 0 0 548 - 50 + 38 @@ -80,21 +80,18 @@ - 10 + 0 - 10 + 0 - 10 + 0 - 10 - - - 10 + 0 - + 0 diff --git a/src/airunner/widgets/slider/templates/slider_ui.py b/src/airunner/widgets/slider/templates/slider_ui.py index d7e004d58..a80f394ce 100644 --- a/src/airunner/widgets/slider/templates/slider_ui.py +++ b/src/airunner/widgets/slider/templates/slider_ui.py @@ -22,7 +22,7 @@ class Ui_slider_widget(object): def setupUi(self, slider_widget): if not slider_widget.objectName(): slider_widget.setObjectName(u"slider_widget") - slider_widget.resize(548, 50) + slider_widget.resize(548, 38) sizePolicy = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -45,10 +45,9 @@ def setupUi(self, slider_widget): font.setPointSize(9) self.groupBox.setFont(font) self.gridLayout_2 = QGridLayout(self.groupBox) + self.gridLayout_2.setSpacing(0) self.gridLayout_2.setObjectName(u"gridLayout_2") - self.gridLayout_2.setHorizontalSpacing(10) - self.gridLayout_2.setVerticalSpacing(0) - self.gridLayout_2.setContentsMargins(10, 10, 10, 10) + self.gridLayout_2.setContentsMargins(0, 0, 0, 0) self.slider = QSlider(self.groupBox) self.slider.setObjectName(u"slider") sizePolicy.setHeightForWidth(self.slider.sizePolicy().hasHeightForWidth()) diff --git a/src/airunner/windows/main/settings_mixin.py b/src/airunner/windows/main/settings_mixin.py index 7e6b1f74e..a1186479f 100644 --- a/src/airunner/windows/main/settings_mixin.py +++ b/src/airunner/windows/main/settings_mixin.py @@ -190,7 +190,11 @@ def outpaint_mask(self): ####################################### def get_lora_by_version(self, version): - return [lora for lora in self.lora if lora.version == version] + session = self.db_handler.get_db_session() + try: + return session.query(Lora).filter_by(version=version).all() + finally: + session.close() def delete_lora_by_name(self, name, version): self.db_handler.delete_lora_by_name(name, version) diff --git a/src/airunner/workers/sd_worker.py b/src/airunner/workers/sd_worker.py index 56fe557e2..e5405cc98 100644 --- a/src/airunner/workers/sd_worker.py +++ b/src/airunner/workers/sd_worker.py @@ -92,8 +92,12 @@ def get_embeddings(self, message): self.sd.get_embeddings(message) def on_update_lora_signal(self): + thread = threading.Thread(target=self._reload_lora) + thread.start() + + def _reload_lora(self): if self.sd: - self.sd.load_lora() + self.sd.reload_lora() def on_update_embeddings_signal(self): if self.sd: @@ -124,14 +128,14 @@ def on_unload_stablediffusion_signal(self, data=None): thread.start() def _load_sd(self, data:dict=None): - self.sd.load_stable_diffusion() + self.sd.load() if data: callback = data.get("callback", None) if callback is not None: callback(data) def _unload_sd(self, data:dict=None): - self.sd.unload_stable_diffusion() + self.sd.unload() if data: callback = data.get("callback", None) if callback is not None: @@ -157,7 +161,7 @@ def start_worker_thread(self): from airunner.aihandler.stablediffusion.sd_handler import SDHandler self.sd = SDHandler() if self.application_settings.sd_enabled: - self.sd.load_stable_diffusion() + self.sd.load() def handle_message(self, message): if self.sd: