Skip to content

Commit

Permalink
Merge pull request #134 from shockz0rz/feature/add-vae-clip-skip
Browse files Browse the repository at this point in the history
Add VAE and CLIP Skip selections to the SD Common Options tab.
  • Loading branch information
Interpause authored Mar 25, 2023
2 parents 3b82fed + ff5a007 commit 272e184
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 0 deletions.
1 change: 1 addition & 0 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ async def get_state():
"scripts_img2img": get_scripts_metadata(True),
"face_restorers": [model.name() for model in shared.face_restorers],
"sd_models": modules.sd_models.checkpoint_tiles(), # yes internal API has spelling error
"sd_vaes": ["None", "Automatic" ] + (list(modules.sd_vae.vae_dict))
}


Expand Down
6 changes: 6 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class BaseOptions(BaseModel):
class GenerationOptions(BaseModel):
sd_model: str = "model.ckpt"
"""Model to use for generation."""
sd_vae: str = "Automatic"
"""VAE to use for generation."""

clip_skip: int = 1
"""CLIP layers to skip during generation."""

script: str = "None"
"""Which script to use."""
script_args: list = Field(default_factory=list)
Expand Down
2 changes: 2 additions & 0 deletions backend/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class ConfigResponse(PluginOptions):
"""List of available face restorers."""
sd_models: List[str]
"""List of available models."""
sd_vaes: List[str]
"""List of available VAEs."""


class ImageResponse(BaseModel):
Expand Down
7 changes: 7 additions & 0 deletions backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def prepare_backend(opt: BaseModel):
shared.opts.sd_model_checkpoint = opt.sd_model
modules.sd_models.reload_model_weights(shared.sd_model)

if hasattr(opt, "sd_vae"):
shared.opts.sd_vae = opt.sd_vae
modules.sd_vae.reload_vae_weights()

if hasattr(opt, "clip_skip"):
shared.opts.CLIP_stop_at_last_layers = opt.clip_skip

if hasattr(opt, "upscaler_name"):
shared.opts.upscaler_for_img2img = opt.upscaler_name

Expand Down
4 changes: 4 additions & 0 deletions frontends/krita/krita_diff/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def common_params(self, has_selection):
# its fine to stuff extra stuff here; pydantic will shave off irrelevant params
params = dict(
sd_model=self.cfg("sd_model", str),
sd_vae=self.cfg("sd_vae", str),
clip_skip=self.cfg("clip_skip", int),
batch_count=self.cfg("sd_batch_count", int),
batch_size=self.cfg("sd_batch_size", int),
base_size=self.cfg("sd_base_size", int),
Expand All @@ -246,6 +248,7 @@ def cb(obj):
assert len(obj["samplers_img2img"]) > 0
assert len(obj["face_restorers"]) > 0
assert len(obj["sd_models"]) > 0
assert len(obj["sd_vaes"]) > 0
assert len(obj["scripts_txt2img"]) > 0
assert len(obj["scripts_img2img"]) > 0
except:
Expand All @@ -269,6 +272,7 @@ def cb(obj):
self.cfg.set("inpaint_script_list", list(obj["scripts_img2img"].keys()))
self.cfg.set("face_restorer_model_list", obj["face_restorers"])
self.cfg.set("sd_model_list", obj["sd_models"])
self.cfg.set("sd_vae_list", ["Automatic", "None"] + obj["sd_vaes"])

# extension script cfg
obj["scripts_inpaint"] = obj["scripts_img2img"]
Expand Down
3 changes: 3 additions & 0 deletions frontends/krita/krita_diff/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class Defaults:

sd_model_list: List[str] = field(default_factory=lambda: [ERROR_MSG])
sd_model: str = "model.ckpt"
sd_vae_list: List[str] = field(default_factory=lambda: [ERROR_MSG])
sd_vae: str = "Automatic"
clip_skip: int = 1
sd_batch_size: int = 1
sd_batch_count: int = 1
sd_base_size: int = 512
Expand Down
16 changes: 16 additions & 0 deletions frontends/krita/krita_diff/pages/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def __init__(self, *args, **kwargs):
script.cfg, "sd_model_list", "sd_model", label="SD model:"
)

# VAE list
self.sd_vae_layout = QComboBoxLayout(
script.cfg, "sd_vae_list", "sd_vae", label="VAE:"
)

# Clip skip
self.clip_skip_layout = QSpinBoxLayout(
script.cfg, "clip_skip", label="Clip skip:", min=1, max=12, step=1
)

# batch size & count
self.batch_count_layout = QSpinBoxLayout(
script.cfg, "sd_batch_count", label="Batch count:", min=1, max=9999, step=1
Expand Down Expand Up @@ -83,6 +93,8 @@ def __init__(self, *args, **kwargs):
layout.addLayout(self.codeformer_weight_layout)
layout.addLayout(checkboxes_layout)
layout.addLayout(self.sd_model_layout)
layout.addLayout(self.sd_vae_layout)
layout.addLayout(self.clip_skip_layout)
layout.addLayout(batch_layout)
layout.addLayout(size_layout)
layout.addWidget(self.interrupt_btn)
Expand All @@ -92,6 +104,8 @@ def __init__(self, *args, **kwargs):

def cfg_init(self):
self.sd_model_layout.cfg_init()
self.sd_vae_layout.cfg_init()
self.clip_skip_layout.cfg_init()
self.batch_count_layout.cfg_init()
self.batch_size_layout.cfg_init()
self.base_size_layout.cfg_init()
Expand All @@ -106,6 +120,8 @@ def cfg_init(self):

def cfg_connect(self):
self.sd_model_layout.cfg_connect()
self.sd_vae_layout.cfg_connect()
self.clip_skip_layout.cfg_connect()
self.batch_count_layout.cfg_connect()
self.batch_size_layout.cfg_connect()
self.base_size_layout.cfg_connect()
Expand Down

0 comments on commit 272e184

Please sign in to comment.