Skip to content

Instantly share code, notes, and snippets.

@albusdemens
Last active January 23, 2020 15:58
Show Gist options
  • Save albusdemens/749b02cb184c06198d3b6a46dba4abe5 to your computer and use it in GitHub Desktop.
Save albusdemens/749b02cb184c06198d3b6a46dba4abe5 to your computer and use it in GitHub Desktop.
import argparse
import random
import math
from tqdm import tqdm
import numpy as np
from PIL import Image
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torch.utils.tensorboard import SummaryWriter
from dataset import MultiResolutionDataset
from model import StyledGenerator, Discriminator
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
def sample_data(dataset, batch_size, image_size=4):
dataset.resolution = image_size
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4, drop_last=True)
return loader
def adjust_lr(optimizer, lr):
for group in optimizer.param_groups:
mult = group.get('mult', 1)
group['lr'] = lr * mult
def cyclical_lr(stepsize, min_lr, max_lr):
# Scaler: we can adapt this if we do not want the triangular CLR
scaler = lambda x: 1.
# Lambda function to calculate the LR
lr_lambda = lambda it: min_lr + (max_lr - min_lr) * relative(it, stepsize)
# Additional function to see where on the cycle we are
def relative(it, stepsize):
cycle = math.floor(1 + it / (2 * stepsize))
x = abs(it / stepsize - 2 * cycle + 1)
return max(0, (1 - x)) * scaler(cycle)
return lr_lambda
def train(args, dataset, generator, discriminator):
step = int(math.log2(args.init_size)) - 2
resolution = 4 * 2 ** step
loader = sample_data(
dataset, args.batch.get(resolution, args.batch_default), resolution
)
data_loader = iter(loader)
adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))
pbar = tqdm(range(3_000_000))
requires_grad(generator, False)
requires_grad(discriminator, True)
disc_loss_val = 0
gen_loss_val = 0
grad_loss_val = 0
alpha = 0
used_sample = 0
max_step = int(math.log2(args.max_size)) - 2
final_progress = False
for i in pbar:
discriminator.zero_grad()
#print(resolution, args.lr.get(resolution))
#print(args.lr.get(resolution))
LR_start = args.lr.get(resolution)
alpha = min(1, 1 / args.phase * (used_sample + 1))
#idx_cycle = int((i - (250_000 * ((resolution - 8) / 8))) / 5000)
#LR_cyclical = step_size[idx_cycle]
#print(LR_cyclical)
if (resolution == args.init_size and args.ckpt is None) or final_progress:
alpha = 1
if used_sample > args.phase * 2:
used_sample = 0
step += 1
if step > max_step:
step = max_step
final_progress = True
ckpt_step = step + 1
else:
alpha = 0
ckpt_step = step
resolution = 4 * 2 ** step
loader = sample_data(
dataset, args.batch.get(resolution, args.batch_default), resolution
)
data_loader = iter(loader)
torch.save(
{
'generator': generator.module.state_dict(),
'discriminator': discriminator.module.state_dict(),
'g_optimizer': g_optimizer.state_dict(),
'd_optimizer': d_optimizer.state_dict(),
'g_running': g_running.state_dict(),
},
f'checkpoint/train_step-{ckpt_step}.model',
)
adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))
try:
real_image = next(data_loader)
except (OSError, StopIteration):
data_loader = iter(loader)
real_image = next(data_loader)
used_sample += real_image.shape[0]
b_size = real_image.size(0)
real_image = real_image.cuda()
if args.loss == 'wgan-gp':
real_predict = discriminator(real_image, step=step, alpha=alpha)
real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
(-real_predict).backward()
elif args.loss == 'r1':
real_image.requires_grad = True
real_scores = discriminator(real_image, step=step, alpha=alpha)
real_predict = F.softplus(-real_scores).mean()
real_predict.backward(retain_graph=True)
grad_real = grad(
outputs=real_scores.sum(), inputs=real_image, create_graph=True
)[0]
grad_penalty = (
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
).mean()
grad_penalty = 10 / 2 * grad_penalty
grad_penalty.backward()
if i%10 == 0:
grad_loss_val = grad_penalty.item()
if args.mixing and random.random() < 0.9:
gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
4, b_size, code_size, device='cuda'
).chunk(4, 0)
gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]
else:
gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk(
2, 0
)
gen_in1 = gen_in1.squeeze(0)
gen_in2 = gen_in2.squeeze(0)
fake_image = generator(gen_in1, step=step, alpha=alpha)
fake_predict = discriminator(fake_image, step=step, alpha=alpha)
if args.loss == 'wgan-gp':
fake_predict = fake_predict.mean()
fake_predict.backward()
eps = torch.rand(b_size, 1, 1, 1).cuda()
x_hat = eps * real_image.data + (1 - eps) * fake_image.data
x_hat.requires_grad = True
hat_predict = discriminator(x_hat, step=step, alpha=alpha)
grad_x_hat = grad(
outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
)[0]
grad_penalty = (
(grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
).mean()
grad_penalty = 10 * grad_penalty
grad_penalty.backward()
if i%10 == 0:
grad_loss_val = grad_penalty.item()
disc_loss_val = (-real_predict + fake_predict).item()
elif args.loss == 'r1':
fake_predict = F.softplus(fake_predict).mean()
fake_predict.backward()
if i%10 == 0:
disc_loss_val = (real_predict + fake_predict).item()
d_optimizer.step()
scheduler_d.step()
if (i + 1) % n_critic == 0:
generator.zero_grad()
requires_grad(generator, True)
requires_grad(discriminator, False)
fake_image = generator(gen_in2, step=step, alpha=alpha)
predict = discriminator(fake_image, step=step, alpha=alpha)
if args.loss == 'wgan-gp':
loss = -predict.mean()
elif args.loss == 'r1':
loss = F.softplus(-predict).mean()
if i%10 == 0:
gen_loss_val = loss.item()
loss.backward()
g_optimizer.step()
scheduler_g.step()
accumulate(g_running, generator.module)
requires_grad(generator, False)
requires_grad(discriminator, True)
if (i + 1) % 100 == 0:
images = []
gen_i, gen_j = args.gen_sample.get(resolution, (10, 5))
with torch.no_grad():
for _ in range(gen_i):
images.append(
g_running(
torch.randn(gen_j, code_size).cuda(), step=step, alpha=alpha
).data.cpu()
)
utils.save_image(
torch.cat(images, 0),
f'sample/{str(i + 1).zfill(6)}.png',
nrow=gen_i,
normalize=True,
range=(-1, 1),
)
if (i + 1) % 10000 == 0:
torch.save(
g_running.state_dict(), f'checkpoint/{str(i + 1).zfill(6)}.model'
)
state_msg = (
f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};'
f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}' #; LR: {LR_start:.8f}'
)
pbar.set_description(state_msg)
# ================================================================== #
# Tensorboard Logging #
# ================================================================== #
if (i + 1) % 100 == 0:
writer.add_scalar('Loss/G', gen_loss_val, i * args.batch.get(resolution))
writer.add_scalar('Loss/D', disc_loss_val, i * args.batch.get(resolution))
writer.add_scalar('Step/pixel_size', (4 * 2 ** step), i * args.batch.get(resolution))
print(args.batch.get(resolution))
if __name__ == '__main__':
code_size = 512
batch_size = 16
n_critic = 1
parser = argparse.ArgumentParser(description='Progressive Growing of GANs')
parser.add_argument('path', type=str, help='path of specified dataset')
parser.add_argument(
'--phase',
type=int,
default=250_000, # Original: 600_000
help='number of samples used for each training phases',
)
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--b_size', default=256, type=float, help='batch size')
parser.add_argument('--sched', action='store_true', help='use lr scheduling')
parser.add_argument('--init_size', default=8, type=int, help='initial image size')
parser.add_argument('--max_size', default=512, type=int, help='max image size')
parser.add_argument(
'--ckpt', default=None, type=str, help='load from previous checkpoints'
)
parser.add_argument(
'--no_from_rgb_activate',
action='store_true',
help='use activate in from_rgb (original implementation)',
)
parser.add_argument(
'--mixing', action='store_true', help='use mixing regularization'
)
parser.add_argument(
'--loss',
type=str,
default='wgan-gp',
choices=['wgan-gp', 'r1'],
help='class of gan loss',
)
args = parser.parse_args()
generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
discriminator = nn.DataParallel(
Discriminator(from_rgb_activate=not args.no_from_rgb_activate)
).cuda()
g_running = StyledGenerator(code_size).cuda()
g_running.train(False)
g_optimizer = optim.Adam(
generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)
# generator.module.generator.parameters(), lr = LR_cyclical, betas = (0.0, 0.99)
)
g_optimizer.add_param_group(
{
'params': generator.module.style.parameters(),
'lr': args.lr * 0.01,
# 'lr': LR_cyclical * 0.01,
'mult': 0.01,
}
)
step_size = 5*256
end_lr = 10**-1
factor = 10**5
clr = cyclical_lr(step_size, min_lr=end_lr / factor, max_lr=end_lr)
scheduler_g = torch.optim.lr_scheduler.LambdaLR(g_optimizer, [clr, clr])
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))
scheduler_d = torch.optim.lr_scheduler.LambdaLR(d_optimizer, [clr])
accumulate(g_running, generator.module, 0)
if args.ckpt is not None:
ckpt = torch.load(args.ckpt)
generator.module.load_state_dict(ckpt['generator'])
discriminator.module.load_state_dict(ckpt['discriminator'])
g_running.load_state_dict(ckpt['g_running'])
g_optimizer.load_state_dict(ckpt['g_optimizer'])
d_optimizer.load_state_dict(ckpt['d_optimizer'])
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
]
)
dataset = MultiResolutionDataset(args.path, transform)
if args.sched:
args.lr = {8: 0.001, 16: 0.001, 32: 0.001, 64: 0.001, 128: 0.0015, 256: 0.002, 512: 0.003}
args.batch = {8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32, 512: 16}
else:
args.lr = {}
args.batch = {}
args.gen_sample = {512: (8, 4), 1024: (4, 2)}
args.batch_default = 32
# Define the log directory where to write information
writer = SummaryWriter('runs/StyleGAN_training')
train(args, dataset, generator, discriminator)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment