Skip to content

Commit

Permalink
Merge pull request #336 from Capsize-Games/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
w4ffl35 authored Dec 20, 2023
2 parents 819df71 + c1c4bfe commit c8ccb1a
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 159 deletions.
9 changes: 8 additions & 1 deletion src/airunner/aihandler/download_civitai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import tqdm
import requests
from json.decoder import JSONDecodeError
from airunner.aihandler.logger import Logger


class DownloadCivitAI:
Expand All @@ -9,7 +11,12 @@ class DownloadCivitAI:
def get_json(model_id):
url = f"https://civitai.com/api/v1/models/{model_id}"
response = requests.get(url)
json = response.json()

try:
json = response.json()
except JSONDecodeError:
Logger.error(f"Failed to decode JSON from {url}")
print(response)
return json

def download_model(self, url, file_name, size_kb, callback):
Expand Down
34 changes: 24 additions & 10 deletions src/airunner/widgets/embeddings/embeddings_container_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self, *args, **kwargs):
self.settings_manager.changed_signal.connect(self.handle_changed_signal)
self.app.message_var.my_signal.connect(self.message_handler)

self.load_embeddings()
self.scan_for_embeddings()

def disable_embedding(self, name, model_name):
Expand Down Expand Up @@ -87,6 +86,8 @@ def message_handler(self, response: dict):
self.handle_embedding_load_failed(message)

def load_embeddings(self):
self.clear_embedding_widgets()

session = get_session()
embeddings = session.query(Embedding).filter(
Embedding.name.like(f"%{self.search_filter}%") if self.search_filter != "" else True).all()
Expand All @@ -105,24 +106,37 @@ def add_embedding(self, embedding):

def action_clicked_button_scan_for_embeddings(self):
self.scan_for_embeddings()

def check_saved_embeddings(self):
session = get_session()
embeddings = session.query(Embedding).all()
for embedding in embeddings:
if not os.path.exists(embedding.path):
session.delete(embedding)

def scan_for_embeddings(self):
# recursively scan for embedding model files in the embeddings path
# for each embedding model file, create an Embedding model
# add the Embedding model to the database
# add the Embedding model to the UI
self.check_saved_embeddings()

session = get_session()
embeddings_path = self.settings_manager.path_settings.embeddings_path
with os.scandir(embeddings_path) as dir_object:
for entry in dir_object:
if entry.is_file(): # ckpt or safetensors file
# check if entry.name is in ckpt, safetensors or pt files:
if entry.name.endswith(".ckpt") or entry.name.endswith(".safetensors") or entry.name.endswith(".pt"):
name = entry.name.replace(".ckpt", "").replace(".safetensors", "").replace(".pt", "")
embedding = Embedding(name=name, path=entry.path)
session.add(embedding)
save_session(session)

if os.path.exists(embeddings_path):
for root, dirs, _ in os.walk(embeddings_path):
for dir in dirs:
path = os.path.join(root, dir)
for entry in os.scandir(path):
if entry.is_file() and entry.name.endswith((".ckpt", ".safetensors", ".pt")):
name = os.path.splitext(entry.name)[0]
embedding = session.query(Embedding).filter_by(name=name).first()
if not embedding:
embedding = Embedding(name=name, path=entry.path)
session.add(embedding)
session.commit()
self.load_embeddings()

def toggle_all_toggled(self, checked):
for i in range(self.ui.embeddings.widget().layout().count()):
Expand Down
113 changes: 91 additions & 22 deletions src/airunner/widgets/model_manager/import_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading

from airunner.data.models import AIModel
from airunner.data.models import AIModel, Lora, Embedding
from airunner.utils import get_session
from airunner.widgets.base_widget import BaseWidget
from airunner.widgets.model_manager.templates.import_ui import Ui_import_model_widget
Expand Down Expand Up @@ -62,38 +62,81 @@ def download_model(self):
diffuser_model_version = model_version["baseModel"]
pipeline_class = self.settings_manager.get_pipeline_classname(pipeline_action, diffuser_model_version, category)
diffuser_model_versions = self.settings_manager.model_versions
file_path = self.download_path(file, diffuser_model_version) # path is the download path of the model
model_type = model_data["type"]
file_path = self.download_path(file, diffuser_model_version, pipeline_action, model_type) # path is the download path of the model

print("Name", name)
print("Path", file_path)
print("Branch", "main")
print("Version", diffuser_model_version)
print("Category", category)
print("Pipeline Action", pipeline_action)

trained_words = model_version.get("trained_words", [])
trained_words = ",".join(trained_words)

session = get_session()
model_exists = session.query(AIModel).filter_by(
name=name,
path=file_path,
branch="main",
version=diffuser_model_version,
category=category,
pipeline_action=pipeline_action,
).first()
if not model_exists:
new_model = AIModel(
if model_type == "Checkpoint":
model_exists = session.query(AIModel).filter_by(
name=name,
path=file_path,
branch="main",
version=diffuser_model_version,
category=category,
pipeline_action=pipeline_action,
enabled=True,
is_default=False
)
session.add(new_model)
session.commit()
).first()
if not model_exists:
new_model = AIModel(
name=name,
path=file_path,
branch="main",
version=diffuser_model_version,
category=category,
pipeline_action=pipeline_action,
enabled=True,
is_default=False
)
session.add(new_model)
session.commit()
elif model_type == "LORA":
lora_exists = session.query(Lora).filter_by(
name=name,
path=file_path,
).first()
if not lora_exists:
new_lora = Lora(
name=name,
path=file_path,
scale=1,
enabled=True,
loaded=False,
trigger_word=trained_words,
)
session.add(new_lora)
session.commit()
elif model_type == "TextualInversion":
embedding_exists = session.query(Embedding).filter_by(
name=name,
path=file_path,
).first()
if not embedding_exists:
new_embedding = Embedding(
name=name,
path=file_path,
enabled=True,
loaded=False,
trigger_word=trained_words,
)
session.add(new_embedding)
session.commit()
elif model_type == "VAE":
# todo save vae here
pass
elif model_type == "Controlnet":
# todo save controlnet here
pass
elif model_type == "Poses":
# todo save poses here
pass

print("starting download")
self.download_model_thread(download_url, file_path, size_kb)
Expand Down Expand Up @@ -163,8 +206,33 @@ def import_models(self):
def model_version_changed(self, index):
self.set_model_form_data()

def download_path(self, file, version):
path = self.settings_manager.path_settings.model_base_path
def download_path(self, file, version, pipeline_action, model_type):

if model_type == "LORA":
path = self.settings_manager.path_settings.lora_model_path
elif model_type == "Checkpoint":
if pipeline_action == "txt2img":
path = self.settings_manager.path_settings.txt2img_model_path
elif pipeline_action == "outpaint":
path = self.settings_manager.path_settings.outpaint_model_path
elif pipeline_action == "upscale":
path = self.settings_manager.path_settings.upscale_model_path
elif pipeline_action == "depth2img":
path = self.settings_manager.path_settings.depth2img_model_path
elif pipeline_action == "pix2pix":
path = self.settings_manager.path_settings.pix2pix_model_path
elif model_type == "TextualInversion":
path = self.settings_manager.path_settings.embeddings_model_path
elif model_type == "VAE":
# todo save vae here
pass
elif model_type == "Controlnet":
# todo save controlnet here
pass
elif model_type == "Poses":
# todo save poses here
pass

file_name = file["name"]
return f"{path}/{version}/{file_name}"

Expand All @@ -184,7 +252,7 @@ def set_model_form_data(self):
diffuser_model_version = model_version["baseModel"]
pipeline_class = self.settings_manager.get_pipeline_classname(pipeline_action, diffuser_model_version, category)
diffuser_model_versions = self.settings_manager.model_versions
path = self.download_path(file, diffuser_model_version) # path is the download path of the model
path = self.download_path(file, diffuser_model_version, pipeline_action, self.current_model_data["type"]) # path is the download path of the model

self.ui.model_form.set_model_form_data(
categories,
Expand All @@ -196,7 +264,8 @@ def set_model_form_data(self):
diffuser_model_version,
path,
self.current_model_data["name"],
model_data=self.current_model_data
model_data=self.current_model_data,
model_type=self.current_model_data["type"]
)

if self.is_civitai:
Expand Down
8 changes: 7 additions & 1 deletion src/airunner/widgets/model_manager/model_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def set_model_form_data(
diffuser_model_version,
path,
model_name,
model_data
model_data,
model_type
):
self.ui.category.clear()
self.ui.category.addItems(categories)
Expand All @@ -30,10 +31,15 @@ def set_model_form_data(
self.ui.diffuser_model_version.clear()
self.ui.diffuser_model_version.addItems(diffuser_model_versions)
self.ui.diffuser_model_version.setCurrentText(diffuser_model_version)
self.ui.model_type.clear()
self.ui.model_type.addItems(["Checkpoint", "LORA", "Embedding", "VAE", "Controlnet", "Pose"])
self.ui.pipeline_class_line_edit.setText(pipeline_class)
self.ui.enabled.setChecked(True)
self.ui.path_line_edit.setText(path)

# set current model type
self.ui.model_type.setCurrentText(model_type)

# clear the table
self.ui.model_data_table.clearContents()
self.ui.model_data_table.setRowCount(5)
Expand Down
Loading

0 comments on commit c8ccb1a

Please sign in to comment.