Skip to content

Instantly share code, notes, and snippets.

@nousr
Created May 10, 2022 04:08
Show Gist options
  • Save nousr/bafb0a417efceb4a9ced4e07f3acadef to your computer and use it in GitHub Desktop.
Save nousr/bafb0a417efceb4a9ced4e07f3acadef to your computer and use it in GitHub Desktop.
modification of @crowsonkb 's script for use with dalle2-pytorch
# ======================================================================
# NOTE: You will need to install some stuff before you begin...
# ======================================================================
# clone https://github.com/crowsonkb/deep-image-prior
# pip install lucid's package of resize-right
# pip install madgrad
# ======================================================================
import sys
sys.path.append('./deep-image-prior')
from models import *
from resize_right import resize
from madgrad import MADGRAD
# ======================================================================
from utils.sr_utils import *
import argparse
import math
import random
import clip
import kornia.augmentation as K
import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from torch import nn, optim
from torch.nn import functional as F
from tqdm import tqdm, trange
class ReplaceGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward
@staticmethod
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
replace_grad = ReplaceGrad.apply
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
# self.cut_pow = cut_pow
self.augs = T.Compose([
K.RandomHorizontalFlip(p=0.5),
K.RandomAffine(degrees=15, translate=0.1, p=0.8,
padding_mode='border', resample='bilinear'),
K.RandomPerspective(0.4, p=0.7, resample='bilinear'),
K.ColorJitter(brightness=0.1, contrast=0.1,
saturation=0.1, hue=0.1, p=0.7),
K.RandomGrayscale(p=0.15),
])
def forward(self, input):
sideY, sideX = input.shape[2:4]
long_size, short_size = max(sideX, sideY), min(sideX, sideY)
min_size = min(short_size, self.cut_size)
pad_x, pad_y = long_size - sideX, long_size - sideY
input_zero_padded = F.pad(
input, (pad_x, pad_x, pad_y, pad_y), 'constant')
input_mask = F.pad(torch.zeros_like(
input), (pad_x, pad_x, pad_y, pad_y), 'constant', 1.)
input_padded = input_zero_padded + input_mask * \
input.mean(dim=[2, 3], keepdim=True)
cutouts = []
for cn in range(self.cutn):
if cn >= self.cutn - self.cutn // 4:
size = long_size
else:
size = clamp(
int(short_size * torch.zeros([]).normal_(mean=.8, std=.3)), min_size, long_size)
# size = int(torch.rand([])**self.cut_pow * (short_size - min_size) + min_size)
offsetx = torch.randint(
min(0, sideX - size), abs(sideX - size) + 1, ()) + pad_x
offsety = torch.randint(
min(0, sideY - size), abs(sideY - size) + 1, ()) + pad_y
cutout = input_padded[:, :, offsety:offsety +
size, offsetx:offsetx + size]
# cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
cutouts.append(resize(cutout, out_shape=(
self.cut_size, self.cut_size), by_convs=True, pad_mode='reflect'))
return self.augs(torch.cat(cutouts))
class Prompt(nn.Module):
def __init__(self, embed, weight=1., stop=float('-inf')):
super().__init__()
self.register_buffer('embed', embed)
self.register_buffer('weight', torch.as_tensor(weight))
self.register_buffer('stop', torch.as_tensor(stop))
def forward(self, input):
input_normed = F.normalize(input.unsqueeze(1), dim=2)
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(
dim=2).div(2).arcsin().pow(2).mul(2)
dists = dists * self.weight.sign()
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
def clamp(x, min_value, max_value):
return max(min(x, max_value), min_value)
def load_dip(input_depth, num_scales, offset_type, device):
dip_net = get_hq_skip_net(
input_depth,
skip_n33d=192,
skip_n33u=192,
skip_n11=4,
num_scales=num_scales,
offset_type=offset_type
).to(device)
return dip_net
def load_prior(path, device):
"""
You may need to tweak this manually for now if your model is different
"""
# load in the diffusion prior
state_dict = torch.load(path, map_location=device)['model']
prior_network = DiffusionPriorNetwork(
dim=768,
depth=6,
dim_head=64,
heads=8,
normformer=False
).to(device)
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=None,
image_embed_dim=768,
timesteps=100,
cond_drop_prob=0.1,
loss_type="l2",
).to(device)
diffusion_prior.load_state_dict(state_dict, strict=True)
return diffusion_prior
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('prompt', type=str,
help='the text prompt')
p.add_argument('--clip-model', type=str, default='ViT-L/14',
help='the CLIP model to use')
p.add_argument('--cutn', type=int, default=32,
help='the number of random crops per iteration, per CLIP model')
p.add_argument('--iterations', '-i', type=int, default=500,
help='the number of iterations')
p.add_argument('--lr', type=float, default=1e-3,
help='the learning rate')
p.add_argument('--lr-decay', type=float, default=0.995,
help='the learning rate decay coefficient')
p.add_argument('--param-noise-strength', type=float, default=0.,
help='the starting parameter noise strength')
p.add_argument('--input-noise-strength', type=float, default=0.,
help='the starting input noise strength')
p.add_argument('--offset-type', type=str, choices=['full', '1x1', 'none'], default='none',
help='the offset layer for the deformable convolutions (none to disable them)')
p.add_argument('--offset-lr-fac', type=float, default=1.,
help='the learning rate factor for the offset layers')
p.add_argument('--display-freq', type=int, default=25,
help='display every this many steps')
p.add_argument('--size', '-s', type=int, nargs=2, default=[512, 512],
help='the output image size')
p.add_argument('--num-scales', type=int, default=7,
help='the number of Deep Image Prior feature map scales')
p.add_argument('--seed', type=int, default=None,
help='the random seed')
p.add_argument('--gpu', type=int, default=0, help="Which gpu to use")
p.add_argument('--model', type=str, default='model.pth',
help="Path to your diffusion prior checkpoint.")
p.add_argument('--optimizer-type', default="MADGRAD", type=str, choices=['MADGRAD', 'ADAM'],
help="Choose optimizer for DIP network. [MADGRAD | ADAM]")
args = p.parse_args()
# ==================================================================
# Set some stuff up for later
# ==================================================================
# get device
device = torch.device(
f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
# random seed
seed = random.randint(0, 2**32) if args.seed is None else args.seed
print('Using random seed:', seed)
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
# set resolution
sideX, sideY = args.size # Resolution
# load the clip model for guidance
clip_model = clip.load(args.clip_model, device=device)[0]
clip_model = clip_model.eval().requires_grad_(False)
clip_size = clip_model.visual.input_resolution
make_cutouts = MakeCutouts(clip_size, args.cutn)
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
# iterations
iterations = 0
# ==================================================================
# Model Creation
diffusion_prior = load_prior(path=args.model, device=device)
# Initialize DIP skip network
input_depth = 32
net = load_dip(input_depth=input_depth,
num_scales=args.num_scales, offset_type=args.offset_type, device=device)
# Initialize input noise
input_scale = 0.1
net_input = torch.randn([1, input_depth, sideY, sideX], device=device)
# encode text
text_embed = clip_model.encode_text(
clip.tokenize(args.prompt).to(device)).float()
# get some noise
noise = torch.randn_like(text_embed)
# sample from the diffusion prior
text_cond = dict(text_embed=text_embed)
# NOTE: dimension is hard coded here, your model may vary
target_embed = diffusion_prior.p_sample_loop(
(1, 768), text_cond=text_cond)
# store prompts for use while sampling
prompts = [Prompt(target_embed)]
# make a param config for the DIP net
params = [{'params': get_non_offset_params(net), 'lr': args.lr},
{'params': get_offset_params(net), 'lr': args.lr * args.offset_lr_fac}]
# select optimizer
if args.optimizer_type == 'Adam':
opt = optim.Adam(params, args.lr)
elif args.optimizer_type == 'MADGRAD':
opt = MADGRAD(params, args.lr, momentum=0.9)
# define scaler for DIP network
scaler = torch.cuda.amp.GradScaler()
# attempt to get the money
try:
for _ in trange(args.iterations):
opt.zero_grad(set_to_none=True)
noise_ramp = 1 - min(1, iterations / args.iterations)
net_input_noised = net_input
if args.input_noise_strength:
phi = min(1, noise_ramp *
args.input_noise_strength) * math.pi / 2
noise = torch.randn_like(net_input)
net_input_noised = net_input * \
math.cos(phi) + noise * math.sin(phi)
with torch.cuda.amp.autocast():
out = net(net_input_noised * input_scale).float()
losses = []
cutouts = normalize(make_cutouts(out))
with torch.cuda.amp.autocast(False):
image_embeds = clip_model.encode_image(cutouts).float()
for prompt in prompts:
losses.append(prompt(image_embeds)) # * clip_model.weight)
loss = sum(losses, out.new_zeros([]))
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
if args.param_noise_strength:
with torch.no_grad():
noise_ramp = 1 - min(1, iterations / args.iterations)
for group in opt.param_groups:
for param in group['params']:
param += torch.randn_like(
param) * group['lr'] * args.param_noise_strength * noise_ramp
iterations += 1
if iterations % args.display_freq == 0:
with torch.inference_mode():
image = TF.to_pil_image(out[0].clamp(0, 1))
if iterations % args.display_freq == 0:
losses_str = ', '.join(
[f'{loss.item():g}' for loss in losses])
tqdm.write(
f'i: {iterations}, loss: {loss.item():g}, losses: {losses_str}')
if iterations < args.iterations:
image.save(f'samples/out_{iterations:05}.png')
else:
# money got
image.save(f"results/{args.prompt}.png")
for group in opt.param_groups:
group['lr'] = args.lr_decay * group['lr']
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment