Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A "reverse" version of the k_euler sampler for Stable Diffusion, which finds the noise that will reconstruct the supplied image
import torch
import numpy as np
import k_diffusion as K
from PIL import Image
from torch import autocast
from einops import rearrange, repeat
def pil_img_to_torch(pil_img, half=False):
image = np.array(pil_img).astype(np.float32) / 255.0
image = rearrange(torch.from_numpy(image), 'h w c -> c h w')
if half:
image = image.half()
return (2.0 * image - 1.0).unsqueeze(0)
def pil_img_to_latent(model, img, batch_size=1, device='cuda', half=True):
init_image = pil_img_to_torch(img, half=half).to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
if half:
return model.get_first_stage_encoding(model.encode_first_stage(init_image.half()))
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
def find_noise_for_image(model, img, prompt, steps=200, cond_scale=0.0, verbose=False, normalize=True):
x = pil_img_to_latent(img, batch_size=1, device='cuda', half=True)
with torch.no_grad():
with autocast('cuda'):
uncond = model.get_learned_conditioning([''])
cond = model.get_learned_conditioning([prompt])
s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(model)
sigmas = dnw.get_sigmas(steps).flip(0)
if verbose:
print(sigmas)
with torch.no_grad():
with autocast('cuda'):
for i in trange(1, len(sigmas)):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
cond_in = torch.cat([uncond, cond])
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
if i == 1:
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
else:
t = dnw.sigma_to_t(sigma_in)
eps = model.apply_model(x_in * c_in, t, cond=cond_in)
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale
if i == 1:
d = (x - denoised) / (2 * sigmas[i])
else:
d = (x - denoised) / sigmas[i - 1]
dt = sigmas[i] - sigmas[i - 1]
x = x + d * dt
if normalize:
return (x / x.std()) * sigmas[-1]
else:
return x
@newsbubbles
Copy link

newsbubbles commented Sep 14, 2022

What is the k_diffusion library you are using for this? Can I get a github or pypi?

@parsec501
Copy link

parsec501 commented Sep 14, 2022

What is the k_diffusion library you are using for this? Can I get a github or pypi?

It's an alternative sampler

https://github.com/crowsonkb/k-diffusion

@newsbubbles
Copy link

newsbubbles commented Sep 14, 2022

What is the k_diffusion library you are using for this? Can I get a github or pypi?

It's an alternative sampler

https://github.com/crowsonkb/k-diffusion

Thanks a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment