Skip to content

Commit

Permalink
adds directory watcher to automatically update available lora
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Oct 4, 2024
1 parent b163d52 commit d0a0a77
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
22 changes: 17 additions & 5 deletions src/airunner/utils/models/scan_path_for_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from airunner.aihandler.models.settings_db_handler import SettingsDBHandler
from airunner.aihandler.models.settings_models import Lora, Embedding

def scan_path_for_lora(base_path) -> List[Lora]:
def scan_path_for_lora(base_path) -> bool:
lora_added = False
lora_deleted = False

db_handler = SettingsDBHandler()
items = []
for versionpath, versionnames, versionfiles in os.walk(os.path.expanduser(os.path.join(base_path, "art/models"))):
version = versionpath.split("/")[-1]
lora_path = os.path.expanduser(
Expand All @@ -19,6 +21,13 @@ def scan_path_for_lora(base_path) -> List[Lora]:
)
if not os.path.exists(lora_path):
continue

session = db_handler.get_db_session()
existing_lora = session.query(Lora).all()
for lora in existing_lora:
if not os.path.exists(lora.path):
session.delete(lora)
lora_deleted = True
for dirpath, dirnames, filenames in os.walk(lora_path):
for file in filenames:
if file.endswith(".ckpt") or file.endswith(".safetensors") or file.endswith(".pt"):
Expand All @@ -35,9 +44,12 @@ def scan_path_for_lora(base_path) -> List[Lora]:
trigger_word="",
version=version
)
db_handler.add_lora(item)
items.append(item)
return items
session.add(item)
lora_added = True
if lora_deleted or lora_added:
session.commit()
session.close()
return lora_deleted or lora_added

def scan_path_for_embeddings(base_path) -> List[Embedding]:
db_handler = SettingsDBHandler()
Expand Down
40 changes: 32 additions & 8 deletions src/airunner/widgets/lora/lora_container_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from PySide6.QtCore import Slot, QSize
from PySide6.QtCore import Slot, QSize, QObject, QThread, Signal
from PySide6.QtWidgets import QWidget, QSizePolicy

from airunner.enums import SignalCode, ModelType, ModelStatus
Expand All @@ -10,6 +10,21 @@
from airunner.widgets.lora.templates.lora_container_ui import Ui_lora_container


class DirectoryWatcher(QObject):
scan_completed = Signal(bool)

def __init__(self, base_path: str, scan_function: callable):
super().__init__()
self.base_path = base_path
self._scan_function = scan_function

def run(self):
while True:
force_reload = self._scan_function(self.base_path)
self.scan_completed.emit(force_reload)
QThread.sleep(1)


class LoraContainerWidget(BaseWidget):
widget_class_ = Ui_lora_container
lora_loaded = False
Expand All @@ -30,13 +45,22 @@ def __init__(self, *args, **kwargs):
self.ui.loading_icon.set_size(spinner_size=QSize(30, 30), label_size=QSize(24, 24))
self._apply_button_enabled = False
self.ui.apply_lora_button.setEnabled(self._apply_button_enabled)
self._scanner_worker = DirectoryWatcher(self.path_settings.base_path, scan_path_for_lora)
self._scanner_worker.scan_completed.connect(self.on_scan_completed)
self._scanner_thread = QThread()
self._scanner_worker.moveToThread(self._scanner_thread)
self._scanner_thread.started.connect(self._scanner_worker.run)
self._scanner_thread.start()

@Slot(bool)
def on_scan_completed(self, force_reload: bool):
self.load_lora(force_reload=force_reload)

@Slot()
def scan_for_lora(self):
# clear all lora widgets
loras = scan_path_for_lora(self.path_settings.base_path)
self.update_loras(loras)
self.load_lora()
force_reload = scan_path_for_lora(self.path_settings.base_path)
self.load_lora(force_reload=force_reload)

@Slot()
def apply_lora(self):
Expand Down Expand Up @@ -107,9 +131,10 @@ def showEvent(self, event):
self.initialized = True
self.load_lora()

def load_lora(self):
def load_lora(self, force_reload=False):
version = self.generator_settings.version
if self._version is None or self._version != version:
if self._version is None or self._version != version or force_reload:
print("LOAD LORA")
self._version = version
self.clear_lora_widgets()
loras = self.get_lora_by_version(self._version)
Expand Down Expand Up @@ -227,8 +252,7 @@ def handle_lora_spinbox(self, lora, lora_widget, value, tab_name):

def search_text_changed(self, val):
self.search_filter = val
self._version = None
self.load_lora()
self.load_lora(force_reload=True)

def clear_lora_widgets(self):
if self.spacer:
Expand Down

0 comments on commit d0a0a77

Please sign in to comment.