From bbd5764ab6f1315bd0110d8151ba27b6bf6228bf Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Tue, 9 Jan 2024 14:49:29 -0700 Subject: [PATCH] Adds version to embedding and lora files fixes #358 --- ...95f_adds_version_to_embedding_and_lora_.py | 32 +++++++++++++++++++ src/airunner/data/models.py | 2 ++ .../embeddings/embeddings_container_widget.py | 3 +- .../widgets/lora/lora_container_widget.py | 5 +-- 4 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 src/airunner/alembic/versions/6d98d892995f_adds_version_to_embedding_and_lora_.py diff --git a/src/airunner/alembic/versions/6d98d892995f_adds_version_to_embedding_and_lora_.py b/src/airunner/alembic/versions/6d98d892995f_adds_version_to_embedding_and_lora_.py new file mode 100644 index 000000000..1e7eb19b3 --- /dev/null +++ b/src/airunner/alembic/versions/6d98d892995f_adds_version_to_embedding_and_lora_.py @@ -0,0 +1,32 @@ +"""Adds version to embedding and lora tables + +Revision ID: 6d98d892995f +Revises: 55c1dff62eba +Create Date: 2024-01-09 14:40:48.810380 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '6d98d892995f' +down_revision: Union[str, None] = '55c1dff62eba' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('embeddings', sa.Column('version', sa.String(), nullable=True)) + op.add_column('loras', sa.Column('version', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('loras', 'version') + op.drop_column('embeddings', 'version') + # ### end Alembic commands ### diff --git a/src/airunner/data/models.py b/src/airunner/data/models.py index fa10c5f11..5ac73cdd0 100644 --- a/src/airunner/data/models.py +++ b/src/airunner/data/models.py @@ -103,6 +103,7 @@ class Embedding(BaseModel): path = Column(String) tags = Column(String) active = Column(Boolean, default=True) + version = Column(String, default="SD 1.5") __table_args__ = ( UniqueConstraint('name', 'path', name='name_path_unique'), @@ -302,6 +303,7 @@ class Lora(BaseModel): enabled = Column(Boolean, default=True) loaded = Column(Boolean, default=False) trigger_word = Column(String, default="") + version = Column(String, default="SD 1.5") @classmethod def get_all(cls, session): diff --git a/src/airunner/widgets/embeddings/embeddings_container_widget.py b/src/airunner/widgets/embeddings/embeddings_container_widget.py index b11e4a11c..3d9b5f8fc 100644 --- a/src/airunner/widgets/embeddings/embeddings_container_widget.py +++ b/src/airunner/widgets/embeddings/embeddings_container_widget.py @@ -127,13 +127,14 @@ def scan_for_embeddings(self): if os.path.exists(embeddings_path): for root, dirs, _ in os.walk(embeddings_path): for dir in dirs: + version = dir.split("/")[-1] path = os.path.join(root, dir) for entry in os.scandir(path): if entry.is_file() and entry.name.endswith((".ckpt", ".safetensors", ".pt")): name = os.path.splitext(entry.name)[0] embedding = session.query(Embedding).filter_by(name=name).first() if not embedding: - embedding = Embedding(name=name, path=entry.path) + embedding = Embedding(name=name, path=entry.path, version=version) session.add(embedding) session.commit() self.load_embeddings() diff --git a/src/airunner/widgets/lora/lora_container_widget.py b/src/airunner/widgets/lora/lora_container_widget.py index b516499a8..6d8e0fc4d 100644 --- a/src/airunner/widgets/lora/lora_container_widget.py +++ b/src/airunner/widgets/lora/lora_container_widget.py @@ -58,11 +58,12 @@ def scan_for_lora(self): session = get_session() lora_path = self.settings_manager.path_settings.lora_path for dirpath, dirnames, filenames in os.walk(lora_path): + # get version from dirpath + version = dirpath.split("/")[-1] for file in filenames: if file.endswith(".ckpt") or file.endswith(".safetensors") or file.endswith(".pt"): - print("adding lora to session") name = file.replace(".ckpt", "").replace(".safetensors", "").replace(".pt", "") - lora = Lora(name=name, path=os.path.join(dirpath, file), enabled=True, scale=100.0) + lora = Lora(name=name, path=os.path.join(dirpath, file), enabled=True, scale=100.0, version=version) session.add(lora) save_session(session)