From d141fb103e2f9ad76e249e9b2e5f0e49bc73061e Mon Sep 17 00:00:00 2001 From: Andrew F <59752493+shockz0rz@users.noreply.github.com> Date: Thu, 2 Mar 2023 17:12:28 -0700 Subject: [PATCH 1/2] Add VAE selection support to SD Common Options tab. --- backend/app.py | 1 + backend/config.py | 2 ++ backend/structs.py | 2 ++ backend/utils.py | 4 ++++ frontends/krita/krita_diff/client.py | 3 +++ frontends/krita/krita_diff/defaults.py | 2 ++ frontends/krita/krita_diff/pages/common.py | 8 ++++++++ 7 files changed, 22 insertions(+) diff --git a/backend/app.py b/backend/app.py index 69e9a9a9..9c238ff2 100644 --- a/backend/app.py +++ b/backend/app.py @@ -95,6 +95,7 @@ async def get_state(): "scripts_img2img": get_scripts_metadata(True), "face_restorers": [model.name() for model in shared.face_restorers], "sd_models": modules.sd_models.checkpoint_tiles(), # yes internal API has spelling error + "sd_vaes": ["None", "Automatic" ] + (list(modules.sd_vae.vae_dict)) } diff --git a/backend/config.py b/backend/config.py index 38d56bbd..25d320a9 100644 --- a/backend/config.py +++ b/backend/config.py @@ -26,6 +26,8 @@ class BaseOptions(BaseModel): class GenerationOptions(BaseModel): sd_model: str = "model.ckpt" """Model to use for generation.""" + sd_vae: str = "Automatic" + """VAE to use for generation.""" script: str = "None" """Which script to use.""" script_args: list = Field(default_factory=list) diff --git a/backend/structs.py b/backend/structs.py index c4cbe191..8f6c2f95 100644 --- a/backend/structs.py +++ b/backend/structs.py @@ -71,6 +71,8 @@ class ConfigResponse(PluginOptions): """List of available face restorers.""" sd_models: List[str] """List of available models.""" + sd_vaes: List[str] + """List of available VAEs.""" class ImageResponse(BaseModel): diff --git a/backend/utils.py b/backend/utils.py index 89ede617..df52a70d 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -85,6 +85,10 @@ def prepare_backend(opt: BaseModel): shared.opts.sd_model_checkpoint = opt.sd_model modules.sd_models.reload_model_weights(shared.sd_model) + if hasattr(opt, "sd_vae"): + shared.opts.sd_vae = opt.sd_vae + modules.sd_vae.reload_vae_weights() + if hasattr(opt, "upscaler_name"): shared.opts.upscaler_for_img2img = opt.upscaler_name diff --git a/frontends/krita/krita_diff/client.py b/frontends/krita/krita_diff/client.py index 0d3062d5..7ccd7bbd 100644 --- a/frontends/krita/krita_diff/client.py +++ b/frontends/krita/krita_diff/client.py @@ -220,6 +220,7 @@ def common_params(self, has_selection): # its fine to stuff extra stuff here; pydantic will shave off irrelevant params params = dict( sd_model=self.cfg("sd_model", str), + sd_vae=self.cfg("sd_vae", str), batch_count=self.cfg("sd_batch_count", int), batch_size=self.cfg("sd_batch_size", int), base_size=self.cfg("sd_base_size", int), @@ -246,6 +247,7 @@ def cb(obj): assert len(obj["samplers_img2img"]) > 0 assert len(obj["face_restorers"]) > 0 assert len(obj["sd_models"]) > 0 + assert len(obj["sd_vaes"]) > 0 assert len(obj["scripts_txt2img"]) > 0 assert len(obj["scripts_img2img"]) > 0 except: @@ -267,6 +269,7 @@ def cb(obj): self.cfg.set("inpaint_script_list", list(obj["scripts_img2img"].keys())) self.cfg.set("face_restorer_model_list", obj["face_restorers"]) self.cfg.set("sd_model_list", obj["sd_models"]) + self.cfg.set("sd_vae_list", ["Automatic", "None"] + obj["sd_vaes"]) # extension script cfg obj["scripts_inpaint"] = obj["scripts_img2img"] diff --git a/frontends/krita/krita_diff/defaults.py b/frontends/krita/krita_diff/defaults.py index 7b7090ac..6fdb4c3f 100644 --- a/frontends/krita/krita_diff/defaults.py +++ b/frontends/krita/krita_diff/defaults.py @@ -63,6 +63,8 @@ class Defaults: sd_model_list: List[str] = field(default_factory=lambda: [ERROR_MSG]) sd_model: str = "model.ckpt" + sd_vae_list: List[str] = field(default_factory=lambda: [ERROR_MSG]) + sd_vae: str = "Automatic" sd_batch_size: int = 1 sd_batch_count: int = 1 sd_base_size: int = 512 diff --git a/frontends/krita/krita_diff/pages/common.py b/frontends/krita/krita_diff/pages/common.py index 2cf0c108..b7898cc5 100644 --- a/frontends/krita/krita_diff/pages/common.py +++ b/frontends/krita/krita_diff/pages/common.py @@ -21,6 +21,11 @@ def __init__(self, *args, **kwargs): script.cfg, "sd_model_list", "sd_model", label="SD model:" ) + # VAE list + self.sd_vae_layout = QComboBoxLayout( + script.cfg, "sd_vae_list", "sd_vae", label="VAE:" + ) + # batch size & count self.batch_count_layout = QSpinBoxLayout( script.cfg, "sd_batch_count", label="Batch count:", min=1, max=9999, step=1 @@ -83,6 +88,7 @@ def __init__(self, *args, **kwargs): layout.addLayout(self.codeformer_weight_layout) layout.addLayout(checkboxes_layout) layout.addLayout(self.sd_model_layout) + layout.addLayout(self.sd_vae_layout) layout.addLayout(batch_layout) layout.addLayout(size_layout) layout.addWidget(self.interrupt_btn) @@ -92,6 +98,7 @@ def __init__(self, *args, **kwargs): def cfg_init(self): self.sd_model_layout.cfg_init() + self.sd_vae_layout.cfg_init() self.batch_count_layout.cfg_init() self.batch_size_layout.cfg_init() self.base_size_layout.cfg_init() @@ -106,6 +113,7 @@ def cfg_init(self): def cfg_connect(self): self.sd_model_layout.cfg_connect() + self.sd_vae_layout.cfg_connect() self.batch_count_layout.cfg_connect() self.batch_size_layout.cfg_connect() self.base_size_layout.cfg_connect() From ff5a007671864ef0e24e2c7ce8aa132b803590dd Mon Sep 17 00:00:00 2001 From: Andrew F <59752493+shockz0rz@users.noreply.github.com> Date: Thu, 2 Mar 2023 21:04:56 -0700 Subject: [PATCH 2/2] Add CLIP skip spin box to Common Options tab. (partial fix for #16) --- backend/config.py | 4 ++++ backend/utils.py | 3 +++ frontends/krita/krita_diff/client.py | 1 + frontends/krita/krita_diff/defaults.py | 1 + frontends/krita/krita_diff/pages/common.py | 8 ++++++++ 5 files changed, 17 insertions(+) diff --git a/backend/config.py b/backend/config.py index 25d320a9..a9944529 100644 --- a/backend/config.py +++ b/backend/config.py @@ -28,6 +28,10 @@ class GenerationOptions(BaseModel): """Model to use for generation.""" sd_vae: str = "Automatic" """VAE to use for generation.""" + + clip_skip: int = 1 + """CLIP layers to skip during generation.""" + script: str = "None" """Which script to use.""" script_args: list = Field(default_factory=list) diff --git a/backend/utils.py b/backend/utils.py index df52a70d..cff7c6d5 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -89,6 +89,9 @@ def prepare_backend(opt: BaseModel): shared.opts.sd_vae = opt.sd_vae modules.sd_vae.reload_vae_weights() + if hasattr(opt, "clip_skip"): + shared.opts.CLIP_stop_at_last_layers = opt.clip_skip + if hasattr(opt, "upscaler_name"): shared.opts.upscaler_for_img2img = opt.upscaler_name diff --git a/frontends/krita/krita_diff/client.py b/frontends/krita/krita_diff/client.py index 7ccd7bbd..98b3e4c7 100644 --- a/frontends/krita/krita_diff/client.py +++ b/frontends/krita/krita_diff/client.py @@ -221,6 +221,7 @@ def common_params(self, has_selection): params = dict( sd_model=self.cfg("sd_model", str), sd_vae=self.cfg("sd_vae", str), + clip_skip=self.cfg("clip_skip", int), batch_count=self.cfg("sd_batch_count", int), batch_size=self.cfg("sd_batch_size", int), base_size=self.cfg("sd_base_size", int), diff --git a/frontends/krita/krita_diff/defaults.py b/frontends/krita/krita_diff/defaults.py index 6fdb4c3f..02508df2 100644 --- a/frontends/krita/krita_diff/defaults.py +++ b/frontends/krita/krita_diff/defaults.py @@ -65,6 +65,7 @@ class Defaults: sd_model: str = "model.ckpt" sd_vae_list: List[str] = field(default_factory=lambda: [ERROR_MSG]) sd_vae: str = "Automatic" + clip_skip: int = 1 sd_batch_size: int = 1 sd_batch_count: int = 1 sd_base_size: int = 512 diff --git a/frontends/krita/krita_diff/pages/common.py b/frontends/krita/krita_diff/pages/common.py index b7898cc5..f0463a95 100644 --- a/frontends/krita/krita_diff/pages/common.py +++ b/frontends/krita/krita_diff/pages/common.py @@ -26,6 +26,11 @@ def __init__(self, *args, **kwargs): script.cfg, "sd_vae_list", "sd_vae", label="VAE:" ) + # Clip skip + self.clip_skip_layout = QSpinBoxLayout( + script.cfg, "clip_skip", label="Clip skip:", min=1, max=12, step=1 + ) + # batch size & count self.batch_count_layout = QSpinBoxLayout( script.cfg, "sd_batch_count", label="Batch count:", min=1, max=9999, step=1 @@ -89,6 +94,7 @@ def __init__(self, *args, **kwargs): layout.addLayout(checkboxes_layout) layout.addLayout(self.sd_model_layout) layout.addLayout(self.sd_vae_layout) + layout.addLayout(self.clip_skip_layout) layout.addLayout(batch_layout) layout.addLayout(size_layout) layout.addWidget(self.interrupt_btn) @@ -99,6 +105,7 @@ def __init__(self, *args, **kwargs): def cfg_init(self): self.sd_model_layout.cfg_init() self.sd_vae_layout.cfg_init() + self.clip_skip_layout.cfg_init() self.batch_count_layout.cfg_init() self.batch_size_layout.cfg_init() self.base_size_layout.cfg_init() @@ -114,6 +121,7 @@ def cfg_init(self): def cfg_connect(self): self.sd_model_layout.cfg_connect() self.sd_vae_layout.cfg_connect() + self.clip_skip_layout.cfg_connect() self.batch_count_layout.cfg_connect() self.batch_size_layout.cfg_connect() self.base_size_layout.cfg_connect()