-
-
Save nousr/bafb0a417efceb4a9ced4e07f3acadef to your computer and use it in GitHub Desktop.
modification of @crowsonkb 's script for use with dalle2-pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ====================================================================== | |
# NOTE: You will need to install some stuff before you begin... | |
# ====================================================================== | |
# clone https://github.com/crowsonkb/deep-image-prior | |
# pip install lucid's package of resize-right | |
# pip install madgrad | |
# ====================================================================== | |
import sys | |
sys.path.append('./deep-image-prior') | |
from models import * | |
from resize_right import resize | |
from madgrad import MADGRAD | |
# ====================================================================== | |
from utils.sr_utils import * | |
import argparse | |
import math | |
import random | |
import clip | |
import kornia.augmentation as K | |
import numpy as np | |
import torch | |
import torchvision.transforms as T | |
import torchvision.transforms.functional as TF | |
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from tqdm import tqdm, trange | |
class ReplaceGrad(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x_forward, x_backward): | |
ctx.shape = x_backward.shape | |
return x_forward | |
@staticmethod | |
def backward(ctx, grad_in): | |
return None, grad_in.sum_to_size(ctx.shape) | |
replace_grad = ReplaceGrad.apply | |
class MakeCutouts(nn.Module): | |
def __init__(self, cut_size, cutn): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
# self.cut_pow = cut_pow | |
self.augs = T.Compose([ | |
K.RandomHorizontalFlip(p=0.5), | |
K.RandomAffine(degrees=15, translate=0.1, p=0.8, | |
padding_mode='border', resample='bilinear'), | |
K.RandomPerspective(0.4, p=0.7, resample='bilinear'), | |
K.ColorJitter(brightness=0.1, contrast=0.1, | |
saturation=0.1, hue=0.1, p=0.7), | |
K.RandomGrayscale(p=0.15), | |
]) | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
long_size, short_size = max(sideX, sideY), min(sideX, sideY) | |
min_size = min(short_size, self.cut_size) | |
pad_x, pad_y = long_size - sideX, long_size - sideY | |
input_zero_padded = F.pad( | |
input, (pad_x, pad_x, pad_y, pad_y), 'constant') | |
input_mask = F.pad(torch.zeros_like( | |
input), (pad_x, pad_x, pad_y, pad_y), 'constant', 1.) | |
input_padded = input_zero_padded + input_mask * \ | |
input.mean(dim=[2, 3], keepdim=True) | |
cutouts = [] | |
for cn in range(self.cutn): | |
if cn >= self.cutn - self.cutn // 4: | |
size = long_size | |
else: | |
size = clamp( | |
int(short_size * torch.zeros([]).normal_(mean=.8, std=.3)), min_size, long_size) | |
# size = int(torch.rand([])**self.cut_pow * (short_size - min_size) + min_size) | |
offsetx = torch.randint( | |
min(0, sideX - size), abs(sideX - size) + 1, ()) + pad_x | |
offsety = torch.randint( | |
min(0, sideY - size), abs(sideY - size) + 1, ()) + pad_y | |
cutout = input_padded[:, :, offsety:offsety + | |
size, offsetx:offsetx + size] | |
# cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
cutouts.append(resize(cutout, out_shape=( | |
self.cut_size, self.cut_size), by_convs=True, pad_mode='reflect')) | |
return self.augs(torch.cat(cutouts)) | |
class Prompt(nn.Module): | |
def __init__(self, embed, weight=1., stop=float('-inf')): | |
super().__init__() | |
self.register_buffer('embed', embed) | |
self.register_buffer('weight', torch.as_tensor(weight)) | |
self.register_buffer('stop', torch.as_tensor(stop)) | |
def forward(self, input): | |
input_normed = F.normalize(input.unsqueeze(1), dim=2) | |
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) | |
dists = input_normed.sub(embed_normed).norm( | |
dim=2).div(2).arcsin().pow(2).mul(2) | |
dists = dists * self.weight.sign() | |
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() | |
def clamp(x, min_value, max_value): | |
return max(min(x, max_value), min_value) | |
def load_dip(input_depth, num_scales, offset_type, device): | |
dip_net = get_hq_skip_net( | |
input_depth, | |
skip_n33d=192, | |
skip_n33u=192, | |
skip_n11=4, | |
num_scales=num_scales, | |
offset_type=offset_type | |
).to(device) | |
return dip_net | |
def load_prior(path, device): | |
""" | |
You may need to tweak this manually for now if your model is different | |
""" | |
# load in the diffusion prior | |
state_dict = torch.load(path, map_location=device)['model'] | |
prior_network = DiffusionPriorNetwork( | |
dim=768, | |
depth=6, | |
dim_head=64, | |
heads=8, | |
normformer=False | |
).to(device) | |
diffusion_prior = DiffusionPrior( | |
net=prior_network, | |
clip=None, | |
image_embed_dim=768, | |
timesteps=100, | |
cond_drop_prob=0.1, | |
loss_type="l2", | |
).to(device) | |
diffusion_prior.load_state_dict(state_dict, strict=True) | |
return diffusion_prior | |
def main(): | |
p = argparse.ArgumentParser(description=__doc__, | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
p.add_argument('prompt', type=str, | |
help='the text prompt') | |
p.add_argument('--clip-model', type=str, default='ViT-L/14', | |
help='the CLIP model to use') | |
p.add_argument('--cutn', type=int, default=32, | |
help='the number of random crops per iteration, per CLIP model') | |
p.add_argument('--iterations', '-i', type=int, default=500, | |
help='the number of iterations') | |
p.add_argument('--lr', type=float, default=1e-3, | |
help='the learning rate') | |
p.add_argument('--lr-decay', type=float, default=0.995, | |
help='the learning rate decay coefficient') | |
p.add_argument('--param-noise-strength', type=float, default=0., | |
help='the starting parameter noise strength') | |
p.add_argument('--input-noise-strength', type=float, default=0., | |
help='the starting input noise strength') | |
p.add_argument('--offset-type', type=str, choices=['full', '1x1', 'none'], default='none', | |
help='the offset layer for the deformable convolutions (none to disable them)') | |
p.add_argument('--offset-lr-fac', type=float, default=1., | |
help='the learning rate factor for the offset layers') | |
p.add_argument('--display-freq', type=int, default=25, | |
help='display every this many steps') | |
p.add_argument('--size', '-s', type=int, nargs=2, default=[512, 512], | |
help='the output image size') | |
p.add_argument('--num-scales', type=int, default=7, | |
help='the number of Deep Image Prior feature map scales') | |
p.add_argument('--seed', type=int, default=None, | |
help='the random seed') | |
p.add_argument('--gpu', type=int, default=0, help="Which gpu to use") | |
p.add_argument('--model', type=str, default='model.pth', | |
help="Path to your diffusion prior checkpoint.") | |
p.add_argument('--optimizer-type', default="MADGRAD", type=str, choices=['MADGRAD', 'ADAM'], | |
help="Choose optimizer for DIP network. [MADGRAD | ADAM]") | |
args = p.parse_args() | |
# ================================================================== | |
# Set some stuff up for later | |
# ================================================================== | |
# get device | |
device = torch.device( | |
f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
# random seed | |
seed = random.randint(0, 2**32) if args.seed is None else args.seed | |
print('Using random seed:', seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
random.seed(seed) | |
# set resolution | |
sideX, sideY = args.size # Resolution | |
# load the clip model for guidance | |
clip_model = clip.load(args.clip_model, device=device)[0] | |
clip_model = clip_model.eval().requires_grad_(False) | |
clip_size = clip_model.visual.input_resolution | |
make_cutouts = MakeCutouts(clip_size, args.cutn) | |
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
# iterations | |
iterations = 0 | |
# ================================================================== | |
# Model Creation | |
diffusion_prior = load_prior(path=args.model, device=device) | |
# Initialize DIP skip network | |
input_depth = 32 | |
net = load_dip(input_depth=input_depth, | |
num_scales=args.num_scales, offset_type=args.offset_type, device=device) | |
# Initialize input noise | |
input_scale = 0.1 | |
net_input = torch.randn([1, input_depth, sideY, sideX], device=device) | |
# encode text | |
text_embed = clip_model.encode_text( | |
clip.tokenize(args.prompt).to(device)).float() | |
# get some noise | |
noise = torch.randn_like(text_embed) | |
# sample from the diffusion prior | |
text_cond = dict(text_embed=text_embed) | |
# NOTE: dimension is hard coded here, your model may vary | |
target_embed = diffusion_prior.p_sample_loop( | |
(1, 768), text_cond=text_cond) | |
# store prompts for use while sampling | |
prompts = [Prompt(target_embed)] | |
# make a param config for the DIP net | |
params = [{'params': get_non_offset_params(net), 'lr': args.lr}, | |
{'params': get_offset_params(net), 'lr': args.lr * args.offset_lr_fac}] | |
# select optimizer | |
if args.optimizer_type == 'Adam': | |
opt = optim.Adam(params, args.lr) | |
elif args.optimizer_type == 'MADGRAD': | |
opt = MADGRAD(params, args.lr, momentum=0.9) | |
# define scaler for DIP network | |
scaler = torch.cuda.amp.GradScaler() | |
# attempt to get the money | |
try: | |
for _ in trange(args.iterations): | |
opt.zero_grad(set_to_none=True) | |
noise_ramp = 1 - min(1, iterations / args.iterations) | |
net_input_noised = net_input | |
if args.input_noise_strength: | |
phi = min(1, noise_ramp * | |
args.input_noise_strength) * math.pi / 2 | |
noise = torch.randn_like(net_input) | |
net_input_noised = net_input * \ | |
math.cos(phi) + noise * math.sin(phi) | |
with torch.cuda.amp.autocast(): | |
out = net(net_input_noised * input_scale).float() | |
losses = [] | |
cutouts = normalize(make_cutouts(out)) | |
with torch.cuda.amp.autocast(False): | |
image_embeds = clip_model.encode_image(cutouts).float() | |
for prompt in prompts: | |
losses.append(prompt(image_embeds)) # * clip_model.weight) | |
loss = sum(losses, out.new_zeros([])) | |
scaler.scale(loss).backward() | |
scaler.step(opt) | |
scaler.update() | |
if args.param_noise_strength: | |
with torch.no_grad(): | |
noise_ramp = 1 - min(1, iterations / args.iterations) | |
for group in opt.param_groups: | |
for param in group['params']: | |
param += torch.randn_like( | |
param) * group['lr'] * args.param_noise_strength * noise_ramp | |
iterations += 1 | |
if iterations % args.display_freq == 0: | |
with torch.inference_mode(): | |
image = TF.to_pil_image(out[0].clamp(0, 1)) | |
if iterations % args.display_freq == 0: | |
losses_str = ', '.join( | |
[f'{loss.item():g}' for loss in losses]) | |
tqdm.write( | |
f'i: {iterations}, loss: {loss.item():g}, losses: {losses_str}') | |
if iterations < args.iterations: | |
image.save(f'samples/out_{iterations:05}.png') | |
else: | |
# money got | |
image.save(f"results/{args.prompt}.png") | |
for group in opt.param_groups: | |
group['lr'] = args.lr_decay * group['lr'] | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment