Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Forked from trygvebw/find_noise.py
Created September 12, 2022 00:00
Show Gist options
  • Save kastnerkyle/0092c1321df080bdb4467db8adfe2453 to your computer and use it in GitHub Desktop.
Save kastnerkyle/0092c1321df080bdb4467db8adfe2453 to your computer and use it in GitHub Desktop.
A "reverse" version of the k_euler sampler for Stable Diffusion, which finds the noise that will reconstruct the supplied image
import torch
import k_diffusion as K
from PIL import Image
from torch import autocast
from einops import rearrange, repeat
def pil_img_to_latent(model, img, batch_size=1, device='cuda', half=True):
init_image = pil_img_to_torch(img, half=half).to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
if half:
return model.get_first_stage_encoding(model.encode_first_stage(init_image.half()))
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
def find_noise_for_image(model, pil_img, prompt, steps=150, cond_scale=1.0, verbose=False, half=True, normalize=False):
x = pil_img_to_latent(model, pil_img, batch_size=1, device='cuda', half=half)
with torch.no_grad():
with autocast('cuda'):
uncond = model.get_learned_conditioning([''])
cond = model.get_learned_conditioning([prompt])
s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(model)
sigmas = dnw.get_sigmas(steps).flip(0)
if verbose:
print(sigmas)
with torch.no_grad():
with autocast('cuda'):
for i in trange(1, len(sigmas)):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i] * s_in] * 2)
cond_in = torch.cat([uncond, cond])
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
t = dnw.sigma_to_t(sigma_in)
eps = model.apply_model(x_in * c_in, t, cond=cond_in)
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale
d = (x - denoised) / sigmas[i]
dt = sigmas[i] - sigmas[i - 1]
x = x + d * dt
if normalize:
return (x / x.std()) * sigmas[-1]
else:
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment