From d0a0a773e07689d0c32c9718b533b63f0edd1687 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:15:02 -0600 Subject: [PATCH] adds directory watcher to automatically update available lora --- .../utils/models/scan_path_for_items.py | 22 +++++++--- .../widgets/lora/lora_container_widget.py | 40 +++++++++++++++---- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/src/airunner/utils/models/scan_path_for_items.py b/src/airunner/utils/models/scan_path_for_items.py index 5ae4b3ac1..821792794 100644 --- a/src/airunner/utils/models/scan_path_for_items.py +++ b/src/airunner/utils/models/scan_path_for_items.py @@ -4,9 +4,11 @@ from airunner.aihandler.models.settings_db_handler import SettingsDBHandler from airunner.aihandler.models.settings_models import Lora, Embedding -def scan_path_for_lora(base_path) -> List[Lora]: +def scan_path_for_lora(base_path) -> bool: + lora_added = False + lora_deleted = False + db_handler = SettingsDBHandler() - items = [] for versionpath, versionnames, versionfiles in os.walk(os.path.expanduser(os.path.join(base_path, "art/models"))): version = versionpath.split("/")[-1] lora_path = os.path.expanduser( @@ -19,6 +21,13 @@ def scan_path_for_lora(base_path) -> List[Lora]: ) if not os.path.exists(lora_path): continue + + session = db_handler.get_db_session() + existing_lora = session.query(Lora).all() + for lora in existing_lora: + if not os.path.exists(lora.path): + session.delete(lora) + lora_deleted = True for dirpath, dirnames, filenames in os.walk(lora_path): for file in filenames: if file.endswith(".ckpt") or file.endswith(".safetensors") or file.endswith(".pt"): @@ -35,9 +44,12 @@ def scan_path_for_lora(base_path) -> List[Lora]: trigger_word="", version=version ) - db_handler.add_lora(item) - items.append(item) - return items + session.add(item) + lora_added = True + if lora_deleted or lora_added: + session.commit() + session.close() + return lora_deleted or lora_added def scan_path_for_embeddings(base_path) -> List[Embedding]: db_handler = SettingsDBHandler() diff --git a/src/airunner/widgets/lora/lora_container_widget.py b/src/airunner/widgets/lora/lora_container_widget.py index 7708ad04f..5147060bc 100644 --- a/src/airunner/widgets/lora/lora_container_widget.py +++ b/src/airunner/widgets/lora/lora_container_widget.py @@ -1,6 +1,6 @@ import os -from PySide6.QtCore import Slot, QSize +from PySide6.QtCore import Slot, QSize, QObject, QThread, Signal from PySide6.QtWidgets import QWidget, QSizePolicy from airunner.enums import SignalCode, ModelType, ModelStatus @@ -10,6 +10,21 @@ from airunner.widgets.lora.templates.lora_container_ui import Ui_lora_container +class DirectoryWatcher(QObject): + scan_completed = Signal(bool) + + def __init__(self, base_path: str, scan_function: callable): + super().__init__() + self.base_path = base_path + self._scan_function = scan_function + + def run(self): + while True: + force_reload = self._scan_function(self.base_path) + self.scan_completed.emit(force_reload) + QThread.sleep(1) + + class LoraContainerWidget(BaseWidget): widget_class_ = Ui_lora_container lora_loaded = False @@ -30,13 +45,22 @@ def __init__(self, *args, **kwargs): self.ui.loading_icon.set_size(spinner_size=QSize(30, 30), label_size=QSize(24, 24)) self._apply_button_enabled = False self.ui.apply_lora_button.setEnabled(self._apply_button_enabled) + self._scanner_worker = DirectoryWatcher(self.path_settings.base_path, scan_path_for_lora) + self._scanner_worker.scan_completed.connect(self.on_scan_completed) + self._scanner_thread = QThread() + self._scanner_worker.moveToThread(self._scanner_thread) + self._scanner_thread.started.connect(self._scanner_worker.run) + self._scanner_thread.start() + + @Slot(bool) + def on_scan_completed(self, force_reload: bool): + self.load_lora(force_reload=force_reload) @Slot() def scan_for_lora(self): # clear all lora widgets - loras = scan_path_for_lora(self.path_settings.base_path) - self.update_loras(loras) - self.load_lora() + force_reload = scan_path_for_lora(self.path_settings.base_path) + self.load_lora(force_reload=force_reload) @Slot() def apply_lora(self): @@ -107,9 +131,10 @@ def showEvent(self, event): self.initialized = True self.load_lora() - def load_lora(self): + def load_lora(self, force_reload=False): version = self.generator_settings.version - if self._version is None or self._version != version: + if self._version is None or self._version != version or force_reload: + print("LOAD LORA") self._version = version self.clear_lora_widgets() loras = self.get_lora_by_version(self._version) @@ -227,8 +252,7 @@ def handle_lora_spinbox(self, lora, lora_widget, value, tab_name): def search_text_changed(self, val): self.search_filter = val - self._version = None - self.load_lora() + self.load_lora(force_reload=True) def clear_lora_widgets(self): if self.spacer: