Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and enhancements for alternate img2img script for stable diffusion XL #16761

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 108 additions & 59 deletions scripts/img2imgalt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
import torch
import k_diffusion as K

def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
x = p.init_latent
# Debugging notes - the original method apply_model is being called for sd1.5 is in modules.sd_hijack_utils and is ldm.models.diffusion.ddpm.LatentDiffusion
# For sdxl - OpenAIWrapper will be called, which will call the underlying diffusion_model
# When controlnet is enabled, the underlying model is not available to use, therefore we skip

@torch.no_grad()
def find_noise_for_image(p, cond, uncond, cfg_scale, steps, skip_sdxl_vector):
x = p.init_latent.clone()

s_in = x.new_ones([x.shape[0]])
if shared.sd_model.parameterization == "v":
Expand All @@ -30,41 +35,30 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):

x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i] * s_in] * 2)
cond_in = torch.cat([uncond, cond])

image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
if shared.sd_model.is_sdxl:
cond_in = {"crossattn": [torch.cat([uncond['crossattn'], cond['crossattn']])], "vector": [torch.cat([uncond['vector'], cond['vector']])]}
else:
cond_in = {"c_concat": [torch.cat([p.image_conditioning] * 2)], "c_crossattn": [torch.cat([uncond, cond])]}

c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
t = dnw.sigma_to_t(sigma_in)

eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)

denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale

d = (x - denoised) / sigmas[i]
dt = sigmas[i] - sigmas[i - 1]

x = x + d * dt
x += noise_from_model(x, t, dt, sigma_in, cond_in, cfg_scale, dnw, skip, skip_sdxl_vector)

sd_samplers_common.store_latent(x)

# This shouldn't be necessary, but solved some VRAM issues
del x_in, sigma_in, cond_in, c_out, c_in, t,
del eps, denoised_uncond, denoised_cond, denoised, d, dt
del x_in, sigma_in, cond_in, t, dt

shared.state.nextjob()

return x / x.std()


Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
return x, sigmas[-1]


# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
x = p.init_latent
@torch.no_grad()
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps, correction_factor, skip_sdxl_vector):
x = p.init_latent.clone()

s_in = x.new_ones([x.shape[0]])
if shared.sd_model.parameterization == "v":
Expand All @@ -79,43 +73,78 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):

for i in trange(1, len(sigmas)):
shared.state.sampling_step += 1

x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
cond_in = torch.cat([uncond, cond])

image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}

c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
if shared.sd_model.is_sdxl:
cond_in = {"crossattn": [torch.cat([uncond['crossattn'], cond['crossattn']])], "vector": [torch.cat([uncond['vector'], cond['vector']])]}
else:
cond_in = {"c_concat": [torch.cat([p.image_conditioning] * 2)], "c_crossattn": [torch.cat([uncond, cond])]}

if i == 1:
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
dt = (sigmas[i] - sigmas[i - 1]) / (2 * sigmas[i])
else:
t = dnw.sigma_to_t(sigma_in)
dt = (sigmas[i] - sigmas[i - 1]) / sigmas[i - 1]

eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
noise = noise_from_model(x, t, dt, sigma_in, cond_in, cfg_scale, dnw, skip, skip_sdxl_vector)

denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
if correction_factor > 0: # runs model with previously calculated noise
recalculated_noise = noise_from_model(x + noise, t, dt, sigma_in, cond_in, cfg_scale, dnw, skip, skip_sdxl_vector)
noise = recalculated_noise * correction_factor + noise * (1 - correction_factor)

if i == 1:
d = (x - denoised) / (2 * sigmas[i])
x += noise

sd_samplers_common.store_latent(x)

shared.state.nextjob()

return x, sigmas[-1]

@torch.no_grad()
def noise_from_model(x, t, dt, sigma_in, cond_in, cfg_scale, dnw, skip, skip_sdxl_vector):

if cfg_scale == 1: # Case where denoised_uncond should not be calculated - 50% speedup, also good for sdxl in experiments
x_in = x
sigma_in = sigma_in[1:2]
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
cond_in = {k:[v[0][1:2]] for k, v in cond_in.items()}
if shared.sd_model.is_sdxl:
num_classes_hack = shared.sd_model.model.diffusion_model.num_classes
if skip_sdxl_vector:
shared.sd_model.model.diffusion_model.num_classes = None
cond_in["vector"][0] = None
try:
eps = shared.sd_model.model(x_in * c_in, t[1:2], {"crossattn": cond_in["crossattn"][0], "vector": cond_in["vector"][0]})
finally:
shared.sd_model.model.diffusion_model.num_classes = num_classes_hack
else:
d = (x - denoised) / sigmas[i - 1]
eps = shared.sd_model.apply_model(x_in * c_in, t[1:2], cond=cond_in)

dt = sigmas[i] - sigmas[i - 1]
x = x + d * dt
return -eps * c_out* dt
else :
x_in = torch.cat([x] * 2)

sd_samplers_common.store_latent(x)
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]

# This shouldn't be necessary, but solved some VRAM issues
del x_in, sigma_in, cond_in, c_out, c_in, t,
del eps, denoised_uncond, denoised_cond, denoised, d, dt
if shared.sd_model.is_sdxl:
num_classes_hack = shared.sd_model.model.diffusion_model.num_classes
if skip_sdxl_vector:
shared.sd_model.model.diffusion_model.num_classes = None
cond_in["vector"][0] = None
try:
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["crossattn"][0], "vector": cond_in["vector"][0]} )
finally:
shared.sd_model.model.diffusion_model.num_classes = num_classes_hack
else:
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)

shared.state.nextjob()
denoised_uncond, denoised_cond = (eps * c_out).chunk(2)

denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale

return -denoised * dt

return x / sigmas[-1]
Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment", "second_order_correction", "skip_sdxl_vector"])


class Script(scripts.Script):
Expand All @@ -133,31 +162,38 @@ def ui(self, is_img2img):
* `CFG Scale` should be 2 or lower.
''')

override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler"))
override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=False, elem_id=self.elem_id("override_sampler"))

override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt"))
override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=False, elem_id=self.elem_id("override_prompt"))
original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt"))
original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt"))

override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps"))
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st"))
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=20, elem_id=self.elem_id("st"))

override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength"))

cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg"))
randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness"))
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=True, elem_id=self.elem_id("sigma_adjustment"))
second_order_correction = gr.Slider(label="Correct noise by running model again", minimum=0.0, maximum=1.0, step=0.01, value=0.5, elem_id=self.elem_id("second_order_correction"),
info="use 0 (disabled) for original script behaviour, 0.5 reccomended value. Runs the model again to recalculate noise and correct it by given factor. Higher adheres to original image more.")
noise_sigma_intensity = gr.Slider(label="Weight scaling std vs sigma based", minimum=-1.0, maximum=2.0, step=0.01, value=0.5, elem_id=self.elem_id("noise_sigma_intensity"),
info="use 1 for original script behaviour, 0.5 reccomended value. Decides whether to use fixed sigma value or dynamic standard deviation to scale noise. Lower gives softer images.")
skip_sdxl_vector = gr.Checkbox(label="Skip sdxl vectors", info="may cause distortion if false", value=True, elem_id=self.elem_id("skip_sdxl_vector"))

return [
info,
override_sampler,
override_prompt, original_prompt, original_negative_prompt,
override_steps, st,
override_strength,
cfg, randomness, sigma_adjustment,
cfg, randomness, sigma_adjustment, second_order_correction,
noise_sigma_intensity, skip_sdxl_vector
]

def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
@torch.no_grad()
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment, second_order_correction, noise_sigma_intensity, skip_sdxl_vector):
# Override
if override_sampler:
p.sampler_name = "Euler"
Expand All @@ -175,33 +211,46 @@ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subs
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
and self.cache.original_prompt == original_prompt \
and self.cache.original_negative_prompt == original_negative_prompt \
and self.cache.sigma_adjustment == sigma_adjustment
and self.cache.sigma_adjustment == sigma_adjustment \
and self.cache.second_order_correction == second_order_correction \
and self.cache.skip_sdxl_vector == skip_sdxl_vector

same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100

rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)

if same_everything:
rec_noise = self.cache.noise
rec_noise, sigma_val = self.cache.noise
else:
# This prevents a crash, because I don't know how to access the underlying .diffusion_model yet when controlnet is enabled.
# modules.sd_unet -> we're good
# scripts.hook -> we're cooked
if "scripts.hook" in str(shared.sd_model.model.diffusion_model.forward.__module__):
print("turn off any controlnets, do 1 pass and then turn controlnet back on to cache noise")
p.steps = 1
return sd_samplers.create_sampler(p.sampler_name, p.sd_model).sample_img2img(p, p.init_latent, rand_noise, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)

shared.state.job_count += 1
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
if sigma_adjustment:
rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
rec_noise, sigma_val = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st, second_order_correction, skip_sdxl_vector)
else:
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rec_noise, sigma_val = find_noise_for_image(p, cond, uncond, cfg, st, skip_sdxl_vector)
self.cache = Cached((rec_noise, sigma_val), cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment, second_order_correction, skip_sdxl_vector)

rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
rec_noise = rec_noise / (rec_noise.std()*(1 - noise_sigma_intensity) + sigma_val*noise_sigma_intensity)

combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)

sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)

p.seed = p.seed + 1

sigmas = sampler.model_wrap.get_sigmas(p.steps)

noise_dt = combined_noise - (p.init_latent / sigmas[0])

p.seed = p.seed + 1

return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)

p.sample = sample_extra
Expand Down
Loading