Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created May 4, 2022 22:01
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crowsonkb/a6aef1031a2712241d0c21426f9c2897 to your computer and use it in GitHub Desktop.
Save crowsonkb/a6aef1031a2712241d0c21426f9c2897 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
from collections import namedtuple
from copy import deepcopy
import io
import json
import math
import pickle
import random
import sys
import time
from einops import rearrange
import kornia.augmentation as K
from madgrad import MADGRAD
import numpy as np
import requests
import torch
from torch import nn, optim
from torch.nn import functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as T
from tqdm import trange, tqdm
from CLIP import clip
sys.path.append('./deep-image-prior')
from models import *
from utils.sr_utils import *
sys.path.append('./ResizeRight')
from resize_right import resize, interp_methods
sys.path.append('./v-diffusion-pytorch')
from diffusion import sampling, utils as diff_utils
import clip_prior_1
class ReplaceGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward
@staticmethod
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
replace_grad = ReplaceGrad.apply
def clamp(x, min_value, max_value):
return max(min(x, max_value), min_value)
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
# self.cut_pow = cut_pow
self.augs = T.Compose([
K.RandomHorizontalFlip(p=0.5),
K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'),
K.RandomPerspective(0.4, p=0.7, resample='bilinear'),
K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7),
K.RandomGrayscale(p=0.15),
])
def forward(self, input):
sideY, sideX = input.shape[2:4]
long_size, short_size = max(sideX, sideY), min(sideX, sideY)
min_size = min(short_size, self.cut_size)
pad_x, pad_y = long_size - sideX, long_size - sideY
input_zero_padded = F.pad(input, (pad_x, pad_x, pad_y, pad_y), 'constant')
input_mask = F.pad(torch.zeros_like(input), (pad_x, pad_x, pad_y, pad_y), 'constant', 1.)
input_padded = input_zero_padded + input_mask * input.mean(dim=[2, 3], keepdim=True)
cutouts = []
for cn in range(self.cutn):
if cn >= self.cutn - self.cutn // 4:
size = long_size
else:
size = clamp(int(short_size * torch.zeros([]).normal_(mean=.8, std=.3)), min_size, long_size)
# size = int(torch.rand([])**self.cut_pow * (short_size - min_size) + min_size)
offsetx = torch.randint(min(0, sideX - size), abs(sideX - size) + 1, ()) + pad_x
offsety = torch.randint(min(0, sideY - size), abs(sideY - size) + 1, ()) + pad_y
cutout = input_padded[:, :, offsety:offsety + size, offsetx:offsetx + size]
# cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
cutouts.append(resize(cutout, out_shape=(self.cut_size, self.cut_size),by_convs=True, pad_mode='reflect'))
return self.augs(torch.cat(cutouts))
class Prompt(nn.Module):
def __init__(self, embed, weight=1., stop=float('-inf')):
super().__init__()
self.register_buffer('embed', embed)
self.register_buffer('weight', torch.as_tensor(weight))
self.register_buffer('stop', torch.as_tensor(stop))
def forward(self, input):
input_normed = F.normalize(input.unsqueeze(1), dim=2)
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
dists = dists * self.weight.sign()
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
def fetch(url_or_path):
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
r = requests.get(url_or_path)
r.raise_for_status()
fd = io.BytesIO()
fd.write(r.content)
fd.seek(0)
return fd
return open(url_or_path, 'rb')
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 3)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 2)
vals = vals + ['', '1', '-inf'][len(vals):]
return vals[0], float(vals[1]), float(vals[2])
def resize_image(image, out_size):
ratio = image.size[0] / image.size[1]
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
size = round((area * ratio)**0.5), round((area / ratio)**0.5)
return image.resize(size, Image.LANCZOS)
class CaptureOutput:
"""Captures a layer's output activations using a forward hook."""
def __init__(self, module):
self.output = None
self.handle = module.register_forward_hook(self)
def __call__(self, module, input, output):
self.output = output
def __del__(self):
self.handle.remove()
def get_output(self):
return self.output
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('prompt', type=str,
help='the text prompt')
# p.add_argument('--images', type=str, default=[], nargs='*', metavar='IMAGE',
# help='the image prompts')
p.add_argument('--clip-model', type=str, default='ViT-B/32',
help='the CLIP model to use')
p.add_argument('--cutn', type=int, default=32,
help='the number of random crops per iteration, per CLIP model')
p.add_argument('--iterations', '-i', type=int, default=500,
help='the number of iterations')
p.add_argument('--lr', type=float, default=1e-3,
help='the learning rate')
p.add_argument('--lr-decay', type=float, default=0.995,
help='the learning rate decay coefficient')
p.add_argument('--param-noise-strength', type=float, default=0.,
help='the starting parameter noise strength')
p.add_argument('--input-noise-strength', type=float, default=0.,
help='the starting input noise strength')
p.add_argument('--offset-type', type=str, choices=['full', '1x1', 'none'], default='none',
help='the offset layer for the deformable convolutions (none to disable them)')
p.add_argument('--offset-lr-fac', type=float, default=1.,
help='the learning rate factor for the offset layers')
p.add_argument('--display-freq', type=int, default=25,
help='display every this many steps')
p.add_argument('--size', '-s', type=int, nargs=2, default=[512, 512],
help='the output image size')
p.add_argument('--num-scales', type=int, default=7,
help='the number of Deep Image Prior feature map scales')
p.add_argument('--seed', type=int, default=None,
help='the random seed')
args = p.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
seed = random.randint(0, 2**32) if args.seed is None else args.seed
optimizer_type = 'MADGRAD' # Adam, MADGRAD
sideX, sideY = args.size # Resolution
clip_model = clip.load(args.clip_model, device=device)[0]
clip_model = clip_model.eval().requires_grad_(False)
clip_size = clip_model.visual.input_resolution
make_cutouts = MakeCutouts(clip_size, args.cutn)
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
itt = 0
clip_prior_ckpt = 'clip_prior_1_0030000.pth'
clip_prior = clip_prior_1.CLIPDiffusionPrior(clip_model.visual.output_dim, 768, 12)
clip_prior.load_state_dict(torch.load(clip_prior_ckpt, map_location='cpu')['model_ema'])
clip_prior = clip_prior.to(device).eval().requires_grad_(False)
seed = random.randint(0, 2**32) if args.seed is None else args.seed
print('Using random seed:', seed)
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
# Initialize DIP skip network
input_depth = 32
net = get_hq_skip_net(
input_depth,
skip_n33d=192,
skip_n33u=192,
skip_n11=4,
num_scales=args.num_scales,
offset_type=args.offset_type
).to(device)
print('Deep Image Prior parameters:', sum(p.numel() for p in net.parameters()))
# Initialize input noise
input_scale = 0.1
net_input = torch.randn([1, input_depth, sideY, sideX], device=device)
# Encode prompts with CLIP
# prompts = [[] for _ in clip_models]
# prompts = []
# weights_sum = abs(sum([parse_prompt(prompt)[1] for prompt in [*args.prompts, *args.images]]))
# if weights_sum < 1e-3:
# raise RuntimeError('The weights must not sum to 0.')
# for prompt in args.prompts:
# txt, weight, stop = parse_prompt(prompt)
# embed = clip_model.encode_text(clip.tokenize(txt).to(device)).float()
# prompts.append(Prompt(embed, weight / weights_sum, stop).to(device))
# for prompt in args.images:
# path, weight, stop = parse_prompt(prompt)
# img = resize_image(Image.open(fetch(path)).convert('RGB'), (sideX, sideY))
# img = TF.to_tensor(img)[None].to(device)
# batch = make_cutouts(img)
# embed = clip_model.encode_image(normalize(batch)).float()
# prompts.append(Prompt(embed, weight / weights_sum, stop).to(device))
text_embed = clip_model.encode_text(clip.tokenize(args.prompt).to(device)).float()
noise = torch.randn_like(text_embed)
t = torch.linspace(1, 0, 1000 + 1)[:-1]
steps = diff_utils.get_ddpm_schedule(t)
target_embed = sampling.sample(clip_prior_1.pred_to_v(clip_prior), noise, steps, 1., {'cond': text_embed})
prompts = [Prompt(target_embed)]
params = [{'params': get_non_offset_params(net), 'lr': args.lr},
{'params': get_offset_params(net), 'lr': args.lr * args.offset_lr_fac}]
if optimizer_type == 'Adam':
opt = optim.Adam(params, args.lr)
elif optimizer_type == 'MADGRAD':
opt = MADGRAD(params, args.lr, momentum=0.9)
scaler = torch.cuda.amp.GradScaler()
try:
for _ in trange(args.iterations):
opt.zero_grad(set_to_none=True)
noise_ramp = 1 - min(1, itt / args.iterations)
net_input_noised = net_input
if args.input_noise_strength:
phi = min(1, noise_ramp * args.input_noise_strength) * math.pi / 2
noise = torch.randn_like(net_input)
net_input_noised = net_input * math.cos(phi) + noise * math.sin(phi)
with torch.cuda.amp.autocast():
out = net(net_input_noised * input_scale).float()
losses = []
# for i, clip_model in enumerate(clip_models):
cutouts = normalize(make_cutouts(out))
with torch.cuda.amp.autocast(False):
image_embeds = clip_model.encode_image(cutouts).float()
for prompt in prompts:
losses.append(prompt(image_embeds)) # * clip_model.weight)
loss = sum(losses, out.new_zeros([]))
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
if args.param_noise_strength:
with torch.no_grad():
noise_ramp = 1 - min(1, itt / args.iterations)
for group in opt.param_groups:
for param in group['params']:
param += torch.randn_like(param) * group['lr'] * args.param_noise_strength * noise_ramp
itt += 1
if itt % args.display_freq == 0 :
with torch.inference_mode():
image = TF.to_pil_image(out[0].clamp(0, 1))
if itt % args.display_freq == 0:
losses_str = ', '.join([f'{loss.item():g}' for loss in losses])
tqdm.write(f'i: {itt}, loss: {loss.item():g}, losses: {losses_str}')
image.save(f'out_{itt:05}.png')
for group in opt.param_groups:
group['lr'] = args.lr_decay * group['lr']
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment