-
-
Save crowsonkb/a6aef1031a2712241d0c21426f9c2897 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
from collections import namedtuple | |
from copy import deepcopy | |
import io | |
import json | |
import math | |
import pickle | |
import random | |
import sys | |
import time | |
from einops import rearrange | |
import kornia.augmentation as K | |
from madgrad import MADGRAD | |
import numpy as np | |
import requests | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F | |
import torchvision.transforms.functional as TF | |
import torchvision.transforms as T | |
from tqdm import trange, tqdm | |
from CLIP import clip | |
sys.path.append('./deep-image-prior') | |
from models import * | |
from utils.sr_utils import * | |
sys.path.append('./ResizeRight') | |
from resize_right import resize, interp_methods | |
sys.path.append('./v-diffusion-pytorch') | |
from diffusion import sampling, utils as diff_utils | |
import clip_prior_1 | |
class ReplaceGrad(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x_forward, x_backward): | |
ctx.shape = x_backward.shape | |
return x_forward | |
@staticmethod | |
def backward(ctx, grad_in): | |
return None, grad_in.sum_to_size(ctx.shape) | |
replace_grad = ReplaceGrad.apply | |
def clamp(x, min_value, max_value): | |
return max(min(x, max_value), min_value) | |
class MakeCutouts(nn.Module): | |
def __init__(self, cut_size, cutn): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
# self.cut_pow = cut_pow | |
self.augs = T.Compose([ | |
K.RandomHorizontalFlip(p=0.5), | |
K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'), | |
K.RandomPerspective(0.4, p=0.7, resample='bilinear'), | |
K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7), | |
K.RandomGrayscale(p=0.15), | |
]) | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
long_size, short_size = max(sideX, sideY), min(sideX, sideY) | |
min_size = min(short_size, self.cut_size) | |
pad_x, pad_y = long_size - sideX, long_size - sideY | |
input_zero_padded = F.pad(input, (pad_x, pad_x, pad_y, pad_y), 'constant') | |
input_mask = F.pad(torch.zeros_like(input), (pad_x, pad_x, pad_y, pad_y), 'constant', 1.) | |
input_padded = input_zero_padded + input_mask * input.mean(dim=[2, 3], keepdim=True) | |
cutouts = [] | |
for cn in range(self.cutn): | |
if cn >= self.cutn - self.cutn // 4: | |
size = long_size | |
else: | |
size = clamp(int(short_size * torch.zeros([]).normal_(mean=.8, std=.3)), min_size, long_size) | |
# size = int(torch.rand([])**self.cut_pow * (short_size - min_size) + min_size) | |
offsetx = torch.randint(min(0, sideX - size), abs(sideX - size) + 1, ()) + pad_x | |
offsety = torch.randint(min(0, sideY - size), abs(sideY - size) + 1, ()) + pad_y | |
cutout = input_padded[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
# cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
cutouts.append(resize(cutout, out_shape=(self.cut_size, self.cut_size),by_convs=True, pad_mode='reflect')) | |
return self.augs(torch.cat(cutouts)) | |
class Prompt(nn.Module): | |
def __init__(self, embed, weight=1., stop=float('-inf')): | |
super().__init__() | |
self.register_buffer('embed', embed) | |
self.register_buffer('weight', torch.as_tensor(weight)) | |
self.register_buffer('stop', torch.as_tensor(stop)) | |
def forward(self, input): | |
input_normed = F.normalize(input.unsqueeze(1), dim=2) | |
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) | |
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
dists = dists * self.weight.sign() | |
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() | |
def fetch(url_or_path): | |
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): | |
r = requests.get(url_or_path) | |
r.raise_for_status() | |
fd = io.BytesIO() | |
fd.write(r.content) | |
fd.seek(0) | |
return fd | |
return open(url_or_path, 'rb') | |
def parse_prompt(prompt): | |
if prompt.startswith('http://') or prompt.startswith('https://'): | |
vals = prompt.rsplit(':', 3) | |
vals = [vals[0] + ':' + vals[1], *vals[2:]] | |
else: | |
vals = prompt.rsplit(':', 2) | |
vals = vals + ['', '1', '-inf'][len(vals):] | |
return vals[0], float(vals[1]), float(vals[2]) | |
def resize_image(image, out_size): | |
ratio = image.size[0] / image.size[1] | |
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) | |
size = round((area * ratio)**0.5), round((area / ratio)**0.5) | |
return image.resize(size, Image.LANCZOS) | |
class CaptureOutput: | |
"""Captures a layer's output activations using a forward hook.""" | |
def __init__(self, module): | |
self.output = None | |
self.handle = module.register_forward_hook(self) | |
def __call__(self, module, input, output): | |
self.output = output | |
def __del__(self): | |
self.handle.remove() | |
def get_output(self): | |
return self.output | |
def main(): | |
p = argparse.ArgumentParser(description=__doc__, | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
p.add_argument('prompt', type=str, | |
help='the text prompt') | |
# p.add_argument('--images', type=str, default=[], nargs='*', metavar='IMAGE', | |
# help='the image prompts') | |
p.add_argument('--clip-model', type=str, default='ViT-B/32', | |
help='the CLIP model to use') | |
p.add_argument('--cutn', type=int, default=32, | |
help='the number of random crops per iteration, per CLIP model') | |
p.add_argument('--iterations', '-i', type=int, default=500, | |
help='the number of iterations') | |
p.add_argument('--lr', type=float, default=1e-3, | |
help='the learning rate') | |
p.add_argument('--lr-decay', type=float, default=0.995, | |
help='the learning rate decay coefficient') | |
p.add_argument('--param-noise-strength', type=float, default=0., | |
help='the starting parameter noise strength') | |
p.add_argument('--input-noise-strength', type=float, default=0., | |
help='the starting input noise strength') | |
p.add_argument('--offset-type', type=str, choices=['full', '1x1', 'none'], default='none', | |
help='the offset layer for the deformable convolutions (none to disable them)') | |
p.add_argument('--offset-lr-fac', type=float, default=1., | |
help='the learning rate factor for the offset layers') | |
p.add_argument('--display-freq', type=int, default=25, | |
help='display every this many steps') | |
p.add_argument('--size', '-s', type=int, nargs=2, default=[512, 512], | |
help='the output image size') | |
p.add_argument('--num-scales', type=int, default=7, | |
help='the number of Deep Image Prior feature map scales') | |
p.add_argument('--seed', type=int, default=None, | |
help='the random seed') | |
args = p.parse_args() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
seed = random.randint(0, 2**32) if args.seed is None else args.seed | |
optimizer_type = 'MADGRAD' # Adam, MADGRAD | |
sideX, sideY = args.size # Resolution | |
clip_model = clip.load(args.clip_model, device=device)[0] | |
clip_model = clip_model.eval().requires_grad_(False) | |
clip_size = clip_model.visual.input_resolution | |
make_cutouts = MakeCutouts(clip_size, args.cutn) | |
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
itt = 0 | |
clip_prior_ckpt = 'clip_prior_1_0030000.pth' | |
clip_prior = clip_prior_1.CLIPDiffusionPrior(clip_model.visual.output_dim, 768, 12) | |
clip_prior.load_state_dict(torch.load(clip_prior_ckpt, map_location='cpu')['model_ema']) | |
clip_prior = clip_prior.to(device).eval().requires_grad_(False) | |
seed = random.randint(0, 2**32) if args.seed is None else args.seed | |
print('Using random seed:', seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
random.seed(seed) | |
# Initialize DIP skip network | |
input_depth = 32 | |
net = get_hq_skip_net( | |
input_depth, | |
skip_n33d=192, | |
skip_n33u=192, | |
skip_n11=4, | |
num_scales=args.num_scales, | |
offset_type=args.offset_type | |
).to(device) | |
print('Deep Image Prior parameters:', sum(p.numel() for p in net.parameters())) | |
# Initialize input noise | |
input_scale = 0.1 | |
net_input = torch.randn([1, input_depth, sideY, sideX], device=device) | |
# Encode prompts with CLIP | |
# prompts = [[] for _ in clip_models] | |
# prompts = [] | |
# weights_sum = abs(sum([parse_prompt(prompt)[1] for prompt in [*args.prompts, *args.images]])) | |
# if weights_sum < 1e-3: | |
# raise RuntimeError('The weights must not sum to 0.') | |
# for prompt in args.prompts: | |
# txt, weight, stop = parse_prompt(prompt) | |
# embed = clip_model.encode_text(clip.tokenize(txt).to(device)).float() | |
# prompts.append(Prompt(embed, weight / weights_sum, stop).to(device)) | |
# for prompt in args.images: | |
# path, weight, stop = parse_prompt(prompt) | |
# img = resize_image(Image.open(fetch(path)).convert('RGB'), (sideX, sideY)) | |
# img = TF.to_tensor(img)[None].to(device) | |
# batch = make_cutouts(img) | |
# embed = clip_model.encode_image(normalize(batch)).float() | |
# prompts.append(Prompt(embed, weight / weights_sum, stop).to(device)) | |
text_embed = clip_model.encode_text(clip.tokenize(args.prompt).to(device)).float() | |
noise = torch.randn_like(text_embed) | |
t = torch.linspace(1, 0, 1000 + 1)[:-1] | |
steps = diff_utils.get_ddpm_schedule(t) | |
target_embed = sampling.sample(clip_prior_1.pred_to_v(clip_prior), noise, steps, 1., {'cond': text_embed}) | |
prompts = [Prompt(target_embed)] | |
params = [{'params': get_non_offset_params(net), 'lr': args.lr}, | |
{'params': get_offset_params(net), 'lr': args.lr * args.offset_lr_fac}] | |
if optimizer_type == 'Adam': | |
opt = optim.Adam(params, args.lr) | |
elif optimizer_type == 'MADGRAD': | |
opt = MADGRAD(params, args.lr, momentum=0.9) | |
scaler = torch.cuda.amp.GradScaler() | |
try: | |
for _ in trange(args.iterations): | |
opt.zero_grad(set_to_none=True) | |
noise_ramp = 1 - min(1, itt / args.iterations) | |
net_input_noised = net_input | |
if args.input_noise_strength: | |
phi = min(1, noise_ramp * args.input_noise_strength) * math.pi / 2 | |
noise = torch.randn_like(net_input) | |
net_input_noised = net_input * math.cos(phi) + noise * math.sin(phi) | |
with torch.cuda.amp.autocast(): | |
out = net(net_input_noised * input_scale).float() | |
losses = [] | |
# for i, clip_model in enumerate(clip_models): | |
cutouts = normalize(make_cutouts(out)) | |
with torch.cuda.amp.autocast(False): | |
image_embeds = clip_model.encode_image(cutouts).float() | |
for prompt in prompts: | |
losses.append(prompt(image_embeds)) # * clip_model.weight) | |
loss = sum(losses, out.new_zeros([])) | |
scaler.scale(loss).backward() | |
scaler.step(opt) | |
scaler.update() | |
if args.param_noise_strength: | |
with torch.no_grad(): | |
noise_ramp = 1 - min(1, itt / args.iterations) | |
for group in opt.param_groups: | |
for param in group['params']: | |
param += torch.randn_like(param) * group['lr'] * args.param_noise_strength * noise_ramp | |
itt += 1 | |
if itt % args.display_freq == 0 : | |
with torch.inference_mode(): | |
image = TF.to_pil_image(out[0].clamp(0, 1)) | |
if itt % args.display_freq == 0: | |
losses_str = ', '.join([f'{loss.item():g}' for loss in losses]) | |
tqdm.write(f'i: {itt}, loss: {loss.item():g}, losses: {losses_str}') | |
image.save(f'out_{itt:05}.png') | |
for group in opt.param_groups: | |
group['lr'] = args.lr_decay * group['lr'] | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment