Last active
January 23, 2020 15:58
-
-
Save albusdemens/749b02cb184c06198d3b6a46dba4abe5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import random | |
import math | |
from tqdm import tqdm | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from torch.autograd import Variable, grad | |
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms, utils | |
from torch.utils.tensorboard import SummaryWriter | |
from dataset import MultiResolutionDataset | |
from model import StyledGenerator, Discriminator | |
def requires_grad(model, flag=True): | |
for p in model.parameters(): | |
p.requires_grad = flag | |
def accumulate(model1, model2, decay=0.999): | |
par1 = dict(model1.named_parameters()) | |
par2 = dict(model2.named_parameters()) | |
for k in par1.keys(): | |
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) | |
def sample_data(dataset, batch_size, image_size=4): | |
dataset.resolution = image_size | |
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4, drop_last=True) | |
return loader | |
def adjust_lr(optimizer, lr): | |
for group in optimizer.param_groups: | |
mult = group.get('mult', 1) | |
group['lr'] = lr * mult | |
def cyclical_lr(stepsize, min_lr, max_lr): | |
# Scaler: we can adapt this if we do not want the triangular CLR | |
scaler = lambda x: 1. | |
# Lambda function to calculate the LR | |
lr_lambda = lambda it: min_lr + (max_lr - min_lr) * relative(it, stepsize) | |
# Additional function to see where on the cycle we are | |
def relative(it, stepsize): | |
cycle = math.floor(1 + it / (2 * stepsize)) | |
x = abs(it / stepsize - 2 * cycle + 1) | |
return max(0, (1 - x)) * scaler(cycle) | |
return lr_lambda | |
def train(args, dataset, generator, discriminator): | |
step = int(math.log2(args.init_size)) - 2 | |
resolution = 4 * 2 ** step | |
loader = sample_data( | |
dataset, args.batch.get(resolution, args.batch_default), resolution | |
) | |
data_loader = iter(loader) | |
adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) | |
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) | |
pbar = tqdm(range(3_000_000)) | |
requires_grad(generator, False) | |
requires_grad(discriminator, True) | |
disc_loss_val = 0 | |
gen_loss_val = 0 | |
grad_loss_val = 0 | |
alpha = 0 | |
used_sample = 0 | |
max_step = int(math.log2(args.max_size)) - 2 | |
final_progress = False | |
for i in pbar: | |
discriminator.zero_grad() | |
#print(resolution, args.lr.get(resolution)) | |
#print(args.lr.get(resolution)) | |
LR_start = args.lr.get(resolution) | |
alpha = min(1, 1 / args.phase * (used_sample + 1)) | |
#idx_cycle = int((i - (250_000 * ((resolution - 8) / 8))) / 5000) | |
#LR_cyclical = step_size[idx_cycle] | |
#print(LR_cyclical) | |
if (resolution == args.init_size and args.ckpt is None) or final_progress: | |
alpha = 1 | |
if used_sample > args.phase * 2: | |
used_sample = 0 | |
step += 1 | |
if step > max_step: | |
step = max_step | |
final_progress = True | |
ckpt_step = step + 1 | |
else: | |
alpha = 0 | |
ckpt_step = step | |
resolution = 4 * 2 ** step | |
loader = sample_data( | |
dataset, args.batch.get(resolution, args.batch_default), resolution | |
) | |
data_loader = iter(loader) | |
torch.save( | |
{ | |
'generator': generator.module.state_dict(), | |
'discriminator': discriminator.module.state_dict(), | |
'g_optimizer': g_optimizer.state_dict(), | |
'd_optimizer': d_optimizer.state_dict(), | |
'g_running': g_running.state_dict(), | |
}, | |
f'checkpoint/train_step-{ckpt_step}.model', | |
) | |
adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) | |
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) | |
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) | |
try: | |
real_image = next(data_loader) | |
except (OSError, StopIteration): | |
data_loader = iter(loader) | |
real_image = next(data_loader) | |
used_sample += real_image.shape[0] | |
b_size = real_image.size(0) | |
real_image = real_image.cuda() | |
if args.loss == 'wgan-gp': | |
real_predict = discriminator(real_image, step=step, alpha=alpha) | |
real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean() | |
(-real_predict).backward() | |
elif args.loss == 'r1': | |
real_image.requires_grad = True | |
real_scores = discriminator(real_image, step=step, alpha=alpha) | |
real_predict = F.softplus(-real_scores).mean() | |
real_predict.backward(retain_graph=True) | |
grad_real = grad( | |
outputs=real_scores.sum(), inputs=real_image, create_graph=True | |
)[0] | |
grad_penalty = ( | |
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2 | |
).mean() | |
grad_penalty = 10 / 2 * grad_penalty | |
grad_penalty.backward() | |
if i%10 == 0: | |
grad_loss_val = grad_penalty.item() | |
if args.mixing and random.random() < 0.9: | |
gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn( | |
4, b_size, code_size, device='cuda' | |
).chunk(4, 0) | |
gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)] | |
gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)] | |
else: | |
gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk( | |
2, 0 | |
) | |
gen_in1 = gen_in1.squeeze(0) | |
gen_in2 = gen_in2.squeeze(0) | |
fake_image = generator(gen_in1, step=step, alpha=alpha) | |
fake_predict = discriminator(fake_image, step=step, alpha=alpha) | |
if args.loss == 'wgan-gp': | |
fake_predict = fake_predict.mean() | |
fake_predict.backward() | |
eps = torch.rand(b_size, 1, 1, 1).cuda() | |
x_hat = eps * real_image.data + (1 - eps) * fake_image.data | |
x_hat.requires_grad = True | |
hat_predict = discriminator(x_hat, step=step, alpha=alpha) | |
grad_x_hat = grad( | |
outputs=hat_predict.sum(), inputs=x_hat, create_graph=True | |
)[0] | |
grad_penalty = ( | |
(grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2 | |
).mean() | |
grad_penalty = 10 * grad_penalty | |
grad_penalty.backward() | |
if i%10 == 0: | |
grad_loss_val = grad_penalty.item() | |
disc_loss_val = (-real_predict + fake_predict).item() | |
elif args.loss == 'r1': | |
fake_predict = F.softplus(fake_predict).mean() | |
fake_predict.backward() | |
if i%10 == 0: | |
disc_loss_val = (real_predict + fake_predict).item() | |
d_optimizer.step() | |
scheduler_d.step() | |
if (i + 1) % n_critic == 0: | |
generator.zero_grad() | |
requires_grad(generator, True) | |
requires_grad(discriminator, False) | |
fake_image = generator(gen_in2, step=step, alpha=alpha) | |
predict = discriminator(fake_image, step=step, alpha=alpha) | |
if args.loss == 'wgan-gp': | |
loss = -predict.mean() | |
elif args.loss == 'r1': | |
loss = F.softplus(-predict).mean() | |
if i%10 == 0: | |
gen_loss_val = loss.item() | |
loss.backward() | |
g_optimizer.step() | |
scheduler_g.step() | |
accumulate(g_running, generator.module) | |
requires_grad(generator, False) | |
requires_grad(discriminator, True) | |
if (i + 1) % 100 == 0: | |
images = [] | |
gen_i, gen_j = args.gen_sample.get(resolution, (10, 5)) | |
with torch.no_grad(): | |
for _ in range(gen_i): | |
images.append( | |
g_running( | |
torch.randn(gen_j, code_size).cuda(), step=step, alpha=alpha | |
).data.cpu() | |
) | |
utils.save_image( | |
torch.cat(images, 0), | |
f'sample/{str(i + 1).zfill(6)}.png', | |
nrow=gen_i, | |
normalize=True, | |
range=(-1, 1), | |
) | |
if (i + 1) % 10000 == 0: | |
torch.save( | |
g_running.state_dict(), f'checkpoint/{str(i + 1).zfill(6)}.model' | |
) | |
state_msg = ( | |
f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};' | |
f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}' #; LR: {LR_start:.8f}' | |
) | |
pbar.set_description(state_msg) | |
# ================================================================== # | |
# Tensorboard Logging # | |
# ================================================================== # | |
if (i + 1) % 100 == 0: | |
writer.add_scalar('Loss/G', gen_loss_val, i * args.batch.get(resolution)) | |
writer.add_scalar('Loss/D', disc_loss_val, i * args.batch.get(resolution)) | |
writer.add_scalar('Step/pixel_size', (4 * 2 ** step), i * args.batch.get(resolution)) | |
print(args.batch.get(resolution)) | |
if __name__ == '__main__': | |
code_size = 512 | |
batch_size = 16 | |
n_critic = 1 | |
parser = argparse.ArgumentParser(description='Progressive Growing of GANs') | |
parser.add_argument('path', type=str, help='path of specified dataset') | |
parser.add_argument( | |
'--phase', | |
type=int, | |
default=250_000, # Original: 600_000 | |
help='number of samples used for each training phases', | |
) | |
parser.add_argument('--lr', default=0.001, type=float, help='learning rate') | |
parser.add_argument('--b_size', default=256, type=float, help='batch size') | |
parser.add_argument('--sched', action='store_true', help='use lr scheduling') | |
parser.add_argument('--init_size', default=8, type=int, help='initial image size') | |
parser.add_argument('--max_size', default=512, type=int, help='max image size') | |
parser.add_argument( | |
'--ckpt', default=None, type=str, help='load from previous checkpoints' | |
) | |
parser.add_argument( | |
'--no_from_rgb_activate', | |
action='store_true', | |
help='use activate in from_rgb (original implementation)', | |
) | |
parser.add_argument( | |
'--mixing', action='store_true', help='use mixing regularization' | |
) | |
parser.add_argument( | |
'--loss', | |
type=str, | |
default='wgan-gp', | |
choices=['wgan-gp', 'r1'], | |
help='class of gan loss', | |
) | |
args = parser.parse_args() | |
generator = nn.DataParallel(StyledGenerator(code_size)).cuda() | |
discriminator = nn.DataParallel( | |
Discriminator(from_rgb_activate=not args.no_from_rgb_activate) | |
).cuda() | |
g_running = StyledGenerator(code_size).cuda() | |
g_running.train(False) | |
g_optimizer = optim.Adam( | |
generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99) | |
# generator.module.generator.parameters(), lr = LR_cyclical, betas = (0.0, 0.99) | |
) | |
g_optimizer.add_param_group( | |
{ | |
'params': generator.module.style.parameters(), | |
'lr': args.lr * 0.01, | |
# 'lr': LR_cyclical * 0.01, | |
'mult': 0.01, | |
} | |
) | |
step_size = 5*256 | |
end_lr = 10**-1 | |
factor = 10**5 | |
clr = cyclical_lr(step_size, min_lr=end_lr / factor, max_lr=end_lr) | |
scheduler_g = torch.optim.lr_scheduler.LambdaLR(g_optimizer, [clr, clr]) | |
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99)) | |
scheduler_d = torch.optim.lr_scheduler.LambdaLR(d_optimizer, [clr]) | |
accumulate(g_running, generator.module, 0) | |
if args.ckpt is not None: | |
ckpt = torch.load(args.ckpt) | |
generator.module.load_state_dict(ckpt['generator']) | |
discriminator.module.load_state_dict(ckpt['discriminator']) | |
g_running.load_state_dict(ckpt['g_running']) | |
g_optimizer.load_state_dict(ckpt['g_optimizer']) | |
d_optimizer.load_state_dict(ckpt['d_optimizer']) | |
transform = transforms.Compose( | |
[ | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), | |
] | |
) | |
dataset = MultiResolutionDataset(args.path, transform) | |
if args.sched: | |
args.lr = {8: 0.001, 16: 0.001, 32: 0.001, 64: 0.001, 128: 0.0015, 256: 0.002, 512: 0.003} | |
args.batch = {8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32, 512: 16} | |
else: | |
args.lr = {} | |
args.batch = {} | |
args.gen_sample = {512: (8, 4), 1024: (4, 2)} | |
args.batch_default = 32 | |
# Define the log directory where to write information | |
writer = SummaryWriter('runs/StyleGAN_training') | |
train(args, dataset, generator, discriminator) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment