Skip to content

Commit

Permalink
add remote vae
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Feb 22, 2025
1 parent f8f987f commit 1b2d428
Show file tree
Hide file tree
Showing 21 changed files with 134 additions and 71 deletions.
24 changes: 15 additions & 9 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Change Log for SD.Next

## Update for 2025-02-20

Quick release refresh:
- remove ui splash screen on auth fail
- add `--extensions-dir` cli arg and `SD_EXTENSIONSDIR` env variable to specify extensions directory
- log full path when reading/saving `config.json`
- log full path to `sdnext.log`
- log system hostname in `sdnext.log`
- log extensions path in `sdnext.log`
## Update for 2025-02-22

- **Decode**
- Final step of image generate, VAE decode, is by far the most memory intensive operation and can easily result in out-of-memory errors
What can be done? Well, *Huggingface* is now providing *free-of-charge* **remote-VAE-decode** service!
- How to use? Previous *Full quality* option in UI is replace it with VAE type selector: Full, Tiny, Remote
Currently supports SD15, SDXL and FLUX.1 with more models expected in the near future
Availability is limited, so if remote processing fails SD.Next will fallback to using normal VAE decode process
- **Other**
- add `--extensions-dir` cli arg and `SD_EXTENSIONSDIR` env variable to specify extensions directory
- **Fixes**
- remove ui splash screen on auth fail
- log full config path, full log path, system name, extensions path
- zluda update
- fix zluda with pulid

## Update for 2025-02-18

Expand Down
2 changes: 1 addition & 1 deletion cli/run-benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def main():
"sampler_name": args.sampler,
"width": args.width,
"height": args.height,
"full_quality": not args.taesd,
"vae_type": 'Tiny' if args.taesd else 'Full',
"cfg_scale": 0,
"batch_size": 1,
"n_iter": 1,
Expand Down
4 changes: 2 additions & 2 deletions modules/control/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def control_run(state: str = '',
steps: int = 20, sampler_index: int = None,
seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1,
cfg_scale: float = 6.0, clip_skip: float = 1.0, image_cfg_scale: float = 6.0, diffusers_guidance_rescale: float = 0.7, pag_scale: float = 0.0, pag_adaptive: float = 0.5, cfg_end: float = 1.0,
full_quality: bool = True, tiling: bool = False, hidiffusion: bool = False,
vae_type: str = 'Full', tiling: bool = False, hidiffusion: bool = False,
detailer_enabled: bool = True, detailer_prompt: str = '', detailer_negative: str = '', detailer_steps: int = 10, detailer_strength: float = 0.3,
hdr_mode: int = 0, hdr_brightness: float = 0, hdr_color: float = 0, hdr_sharpen: float = 0, hdr_clamp: bool = False, hdr_boundary: float = 4.0, hdr_threshold: float = 0.95,
hdr_maximize: bool = False, hdr_max_center: float = 0.6, hdr_max_boundry: float = 1.0, hdr_color_picker: str = None, hdr_tint_ratio: float = 0,
Expand Down Expand Up @@ -292,7 +292,7 @@ def control_run(state: str = '',
diffusers_guidance_rescale = diffusers_guidance_rescale,
pag_scale = pag_scale,
pag_adaptive = pag_adaptive,
full_quality = full_quality,
vae_type = vae_type,
tiling = tiling,
hidiffusion = hidiffusion,
# detailer
Expand Down
4 changes: 2 additions & 2 deletions modules/images_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def latent(im, scale: float, selected_upscaler: upscaler.UpscalerData):
return im
else:
from modules.processing_vae import vae_encode, vae_decode
latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO resize image: enable full VAE mode for resize-latent
latents = vae_encode(im, shared.sd_model, vae_type='Tiny') # TODO resize image: enable full VAE mode for resize-latent
latents = selected_upscaler.scaler.upscale(latents, scale, selected_upscaler.name)
im = vae_decode(latents, shared.sd_model, output_type='pil', full_quality=False)[0]
im = vae_decode(latents, shared.sd_model, output_type='pil', vae_type='Tiny')[0]
return im

def resize(im: Union[Image.Image, torch.Tensor], w, h):
Expand Down
4 changes: 2 additions & 2 deletions modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def img2img(id_task: str, state: str, mode: int,
sampler_index,
mask_blur, mask_alpha,
inpainting_fill,
full_quality, tiling, hidiffusion,
vae_type, tiling, hidiffusion,
detailer_enabled, detailer_prompt, detailer_negative, detailer_steps, detailer_strength,
n_iter, batch_size,
cfg_scale, image_cfg_scale,
Expand Down Expand Up @@ -241,7 +241,7 @@ def img2img(id_task: str, state: str, mode: int,
clip_skip=clip_skip,
width=width,
height=height,
full_quality=full_quality,
vae_type=vae_type,
tiling=tiling,
hidiffusion=hidiffusion,
detailer_enabled=detailer_enabled,
Expand Down
2 changes: 1 addition & 1 deletion modules/infotext.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def parse(infotext):
elif val == "False":
params[key] = False
elif key == 'VAE' and val == 'TAESD':
params["Full quality"] = False
params["VAE type"] = 'Tiny'
elif size is not None:
params[f"{key}-1"] = int(size.group(1))
params[f"{key}-2"] = int(size.group(2))
Expand Down
2 changes: 1 addition & 1 deletion modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
sd_models.reload_model_weights()
if p.override_settings.get('sd_vae', None) is not None:
if p.override_settings.get('sd_vae', None) == 'TAESD':
p.full_quality = False
p.vae_type = 'Tiny'
p.override_settings.pop('sd_vae', None)
if p.override_settings.get('Hires upscaler', None) is not None:
p.enable_hr = True
Expand Down
2 changes: 1 addition & 1 deletion modules/processing_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def task_specific_kwargs(p, model):
}
if model.__class__.__name__ == 'LatentConsistencyModelPipeline' and hasattr(p, 'init_images') and len(p.init_images) > 0:
p.ops.append('lcm')
init_latents = [processing_vae.vae_encode(image, model=shared.sd_model, full_quality=p.full_quality).squeeze(dim=0) for image in p.init_images]
init_latents = [processing_vae.vae_encode(image, model=shared.sd_model, vae_type=p.vae_type).squeeze(dim=0) for image in p.init_images]
init_latent = torch.stack(init_latents, dim=0).to(shared.device)
init_noise = p.denoising_strength * processing.create_random_tensors(init_latent.shape[1:], seeds=p.all_seeds, subseeds=p.all_subseeds, subseed_strength=p.subseed_strength, p=p)
init_latent = (1 - p.denoising_strength) * init_latent + init_noise
Expand Down
4 changes: 2 additions & 2 deletions modules/processing_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self,
styles: List[str] = [],
# vae
tiling: bool = False,
full_quality: bool = True,
vae_type: str = 'Full',
# other
hidiffusion: bool = False,
do_not_reload_embeddings: bool = False,
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self,
self.negative_prompt = negative_prompt
self.styles = styles
self.tiling = tiling
self.full_quality = full_quality
self.vae_type = vae_type
self.hidiffusion = hidiffusion
self.do_not_reload_embeddings = do_not_reload_embeddings
self.detailer_enabled = detailer_enabled
Expand Down
10 changes: 5 additions & 5 deletions modules/processing_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
if p.hr_force:
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
if 'Upscale' in shared.sd_model.__class__.__name__ or 'Flux' in shared.sd_model.__class__.__name__ or 'Kandinsky' in shared.sd_model.__class__.__name__:
output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height)
output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height)
if p.is_control and hasattr(p, 'task_args') and p.task_args.get('image', None) is not None:
if hasattr(shared.sd_model, "vae") and output.images is not None and len(output.images) > 0:
output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.hr_upscale_to_x, height=p.hr_upscale_to_y) # controlnet cannnot deal with latent input
output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.hr_upscale_to_x, height=p.hr_upscale_to_y) # controlnet cannnot deal with latent input
update_sampler(p, shared.sd_model, second_pass=True)
orig_denoise = p.denoising_strength
p.denoising_strength = strength
Expand Down Expand Up @@ -289,7 +289,7 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
noise_level = round(350 * p.denoising_strength)
output_type='latent'
if 'Upscale' in shared.sd_refiner.__class__.__name__ or 'Flux' in shared.sd_refiner.__class__.__name__ or 'Kandinsky' in shared.sd_refiner.__class__.__name__:
image = processing_vae.vae_decode(latents=image, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height)
image = processing_vae.vae_decode(latents=image, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height)
p.extra_generation_params['Noise level'] = noise_level
output_type = 'np'
update_sampler(p, shared.sd_refiner, second_pass=True)
Expand Down Expand Up @@ -370,7 +370,7 @@ def process_decode(p: processing.StableDiffusionProcessing, output):
result_batch = processing_vae.vae_decode(
latents = output.images[i],
model = model,
full_quality = p.full_quality,
vae_type = p.vae_type,
width = width,
height = height,
frames = frames,
Expand All @@ -381,7 +381,7 @@ def process_decode(p: processing.StableDiffusionProcessing, output):
results = processing_vae.vae_decode(
latents = output.images,
model = model,
full_quality = p.full_quality,
vae_type = p.vae_type,
width = width,
height = height,
frames = frames,
Expand Down
28 changes: 11 additions & 17 deletions modules/processing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x


def decode_first_stage(model, x, full_quality=True):
def decode_first_stage(model, x):
if not shared.opts.keep_incomplete and (shared.state.skipped or shared.state.interrupted):
shared.log.debug(f'Decode VAE: skipped={shared.state.skipped} interrupted={shared.state.interrupted}')
x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
Expand All @@ -210,20 +210,14 @@ def decode_first_stage(model, x, full_quality=True):
shared.state.job = 'VAE'
with devices.autocast(disable = x.dtype==devices.dtype_vae):
try:
if full_quality:
if hasattr(model, 'decode_first_stage'):
# x_sample = model.decode_first_stage(x) * 0.5 + 0.5
x_sample = model.decode_first_stage(x)
elif hasattr(model, 'vae'):
x_sample = processing_vae.vae_decode(latents=x, model=model, output_type='np', full_quality=full_quality)
else:
x_sample = x
shared.log.error('Decode VAE unknown model')
if hasattr(model, 'decode_first_stage'):
# x_sample = model.decode_first_stage(x) * 0.5 + 0.5
x_sample = model.decode_first_stage(x)
elif hasattr(model, 'vae'):
x_sample = processing_vae.vae_decode(latents=x, model=model, output_type='np')
else:
from modules import sd_vae_taesd
x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
for i in range(len(x_sample)):
x_sample[i] = sd_vae_taesd.decode(x[i]) * 0.5 + 0.5
x_sample = x
shared.log.error('Decode VAE unknown model')
except Exception as e:
x_sample = x
shared.log.error(f'Decode VAE: {e}')
Expand Down Expand Up @@ -407,7 +401,7 @@ def resize_init_images(p):
def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler else latent
if not torch.is_tensor(latents):
shared.log.warning('Hires: input is not tensor')
first_pass_images = processing_vae.vae_decode(latents=latents, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height)
first_pass_images = processing_vae.vae_decode(latents=latents, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height)
return first_pass_images

if (p.hr_upscale_to_x == 0 or p.hr_upscale_to_y == 0) and hasattr(p, 'init_hr'):
Expand All @@ -418,7 +412,7 @@ def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler
resized_image = images.resize_image(p.hr_resize_mode, latents, p.hr_upscale_to_x, p.hr_upscale_to_y, upscaler_name=p.hr_upscaler, context=p.hr_resize_context)
return resized_image

first_pass_images = processing_vae.vae_decode(latents=latents, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height)
first_pass_images = processing_vae.vae_decode(latents=latents, model=shared.sd_model, vae_type=p.vae_type, output_type='pil', width=p.width, height=p.height)
resized_images = []
for img in first_pass_images:
resized_image = images.resize_image(p.hr_resize_mode, img, p.hr_upscale_to_x, p.hr_upscale_to_y, upscaler_name=p.hr_upscaler, context=p.hr_resize_context)
Expand Down Expand Up @@ -561,7 +555,7 @@ def save_intermediate(p, latents, suffix):
for i in range(len(latents)):
from modules.processing import create_infotext
info=create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i)
decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, output_type='pil', full_quality=p.full_quality, width=p.width, height=p.height)
decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, output_type='pil', vae_type=p.vae_type, width=p.width, height=p.height)
for j in range(len(decoded)):
images.save_image(decoded[j], path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix=suffix)

Expand Down
5 changes: 4 additions & 1 deletion modules/processing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
"Batch": f'{p.n_iter}x{p.batch_size}' if p.n_iter > 1 or p.batch_size > 1 else None,
"Model": None if (not shared.opts.add_model_name_to_info) or (not shared.sd_model.sd_checkpoint_info.model_name) else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', ''),
"Model hash": getattr(p, 'sd_model_hash', None if (not shared.opts.add_model_hash_to_info) or (not shared.sd_model.sd_model_hash) else shared.sd_model.sd_model_hash),
"VAE": (None if not shared.opts.add_model_name_to_info or sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0]) if p.full_quality else 'TAESD',
"Refiner prompt": p.refiner_prompt if len(p.refiner_prompt) > 0 else None,
"Refiner negative": p.refiner_negative if len(p.refiner_negative) > 0 else None,
"Styles": "; ".join(p.styles) if p.styles is not None and len(p.styles) > 0 else None,
Expand All @@ -71,6 +70,10 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
"Comment": comment,
"Operations": '; '.join(ops).replace('"', '') if len(p.ops) > 0 else 'none',
}
if p.vae_type == 'Full':
args["VAE"] = (None if not shared.opts.add_model_name_to_info or sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0])
elif p.vae_type == 'Tiny':
args["VAE"] = 'TAESD'
if shared.opts.add_model_name_to_info and getattr(shared.sd_model, 'sd_checkpoint_info', None) is not None:
args["Model"] = shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')
if shared.opts.add_model_hash_to_info and getattr(shared.sd_model, 'sd_model_hash', None) is not None:
Expand Down
Loading

0 comments on commit 1b2d428

Please sign in to comment.