Skip to content

Instantly share code, notes, and snippets.

@t-vi
Forked from simopal6/gan_failure_normalization.py
Last active December 1, 2017 09:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-vi/5fad0a6181eb9485b25b0935396f8687 to your computer and use it in GitHub Desktop.
Save t-vi/5fad0a6181eb9485b25b0935396f8687 to your computer and use it in GitHub Desktop.
import torch
from PIL import Image
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
opt_bad = True
opt_dataset = "parent_of_n02510455"
opt_batch_size = 16
opt_penalty = 10
# Imports
# Setup data transforms
load_size = 80
crop_size = 64
if opt_bad:
mean = (0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
else:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
train_transform = transforms.Compose([
transforms.Scale(load_size),
transforms.RandomCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# Load dataset
train_dataset = datasets.ImageFolder(root = opt_dataset, transform = train_transform)
# Create loader
loader = DataLoader(train_dataset, batch_size = opt_batch_size, shuffle = True, num_workers = 4, pin_memory = True, drop_last = True)
# Discriminator ("critic") -- Wasserstein
class WDiscriminator(nn.Module):
def __init__(self, isize, ndf, nc = 3):
super(WDiscriminator, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
main = nn.Sequential()
# input is nc x isize x isize
main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
main.add_module('initial.relu.{0}'.format(ndf), nn.LeakyReLU(0.2, inplace=True))
csize, cndf = isize / 2, ndf
# Reduce map size
while csize > 4:
in_feat = cndf
out_feat = cndf * 2
main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))
main.add_module('pyramid.{0}.batchnorm'.format(out_feat), nn.InstanceNorm2d(out_feat, affine=True))
main.add_module('pyramid.{0}.relu'.format(out_feat), nn.LeakyReLU(0.2, inplace=True))
cndf = cndf * 2
csize = csize / 2
# state size. K x 4 x 4
main.add_module('final.{0}-{1}.conv'.format(cndf, 1), nn.Conv2d(cndf, 1, 4, 1, 0, bias=False))
self.main = main
def forward(self, input):
output = self.main(input)
#output = output.mean(0)
return output.view(-1)
# Generator -- Wasserstein
class WGenerator(nn.Module):
def __init__(self, isize, nz, ngf, nc = 3):
super(WGenerator, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf//2, 4
while tisize != isize:
cngf = cngf * 2
tisize = tisize * 2
main = nn.Sequential()
# input is Z, going into a convolution
main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False))
main.add_module('initial.{0}.batchnorm'.format(cngf), nn.InstanceNorm2d(cngf, affine=True))
main.add_module('initial.{0}.relu'.format(cngf), nn.ReLU(True))
csize, cndf = 4, cngf
while csize < isize//2:
main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False))
main.add_module('pyramid.{0}.batchnorm'.format(cngf//2), nn.InstanceNorm2d(cngf//2, affine=True))
main.add_module('pyramid.{0}.relu'.format(cngf//2), nn.ReLU(True))
cngf = cngf // 2
csize = csize * 2
main.add_module('final.{0}-{1}.convt'.format(cngf, nc), nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
main.add_module('final.{0}.tanh'.format(nc), nn.Tanh())
self.main = main
def forward(self, input):
output = self.main(input)
return output
# Custom weight initialization
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('InstanceNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# Create generator model/optimizer
g_net = WGenerator(nz=100, isize=64, ngf=64)
g_net.apply(weights_init)
g_optimizer = torch.optim.RMSprop(g_net.parameters(), lr=1e-4)
# Create discriminator model/optimizer
d_net = WDiscriminator(isize=64, ndf=64)
d_net.apply(weights_init)
d_optimizer = torch.optim.RMSprop(d_net.parameters(), lr=5e-5)
# Setup CUDA
g_net.cuda()
d_net.cuda()
# Debug options
save_images_every = 20
cnt = 0
# Auxiliary variables
noise = torch.FloatTensor(opt_batch_size, 100, 1, 1)
one = torch.FloatTensor([1])
minus_one = one*-1
noise = noise.cuda()
one = one.cuda()
minus_one = minus_one.cuda()
for p in d_net.parameters():
torch.nn.init.normal(p.data, std=0.1)
# Start training
d_net.train()
g_net.train()
g_iterations = 0
for epoch in range(0, 10000):
# Keep track of losses
d_real_loss_sum = 0
d_fake_loss_sum = 0
d_loss_cnt = 0
g_loss_sum = 0
g_loss_cnt = 0
# Get data iterator
data_iter = iter(loader)
data_len = len(loader)
data_i = 0
# Process until data ends
while data_i < data_len:
# Compute gradients for discriminator
for p in d_net.parameters(): p.requires_grad = True
for p in g_net.parameters(): p.requires_grad = False
# Set number of discrimator iterations
d_iters = 500 if g_iterations <= 25 or g_iterations % 500 == 0 else 5
# Perform discriminator iterations
d_i = 0
while data_i < data_len and d_i < d_iters:
# Increase discriminator iterations
d_i += 1
# Clamp parameters to a cube
#for p in d_net.parameters():
# p.data.clamp_(-0.01, 0.01)
# Get data (keep reference to data on host)
(real_input_cpu, _) = data_iter.next()
real_input = real_input_cpu
data_i += 1
# Check CUDA
real_input = real_input.cuda(async = True)
# Wrap for autograd
real_input = Variable(real_input)
# Reset gradients
d_optimizer.zero_grad()
g_optimizer.zero_grad()
# Forward (discriminator, real)
d_real_loss_vec = d_net(real_input)
d_real_loss = d_real_loss_vec.mean(0).view(1)
d_real_loss_sum += d_real_loss.data[0]
# Backward (discriminator, real)
#d_real_loss.backward(one)
# Forward (discriminator, fake)
noise.normal_(0,1)
noise_v = Variable(noise, volatile = True)
g_output = Variable(g_net(noise_v).data)
d_fake_loss_vec = d_net(g_output)
d_fake_loss = d_fake_loss_vec.mean(0).view(1)
d_fake_loss_sum += d_fake_loss.data[0]
# Backward (discriminator, fake)
#d_fake_loss.backward(minus_one)
dist = (((g_output-real_input).view(g_output.size(0),-1)**2).sum(1)+1e-6)**0.5
lip_est = (d_fake_loss_vec-d_real_loss_vec).abs()/(dist+1e-6)
lip_loss = opt_penalty*((lip_est-1).clamp(min=0)**2).mean(0).view(1)
d_loss = d_real_loss-d_fake_loss+lip_loss
d_loss.backward()
# Update discriminator
d_optimizer.step()
# Update loss count
d_loss_cnt += 1
# Don't compute gradients w.r.t. parameters for discriminator
for p in d_net.parameters(): p.requires_grad = False
for p in g_net.parameters(): p.requires_grad = True
# Forward (generator)
noise.normal_(0,1)
noise_v = Variable(noise)
g_output = g_net(noise_v)
g_loss = d_net(g_output).mean(0).view(-1)
g_loss_sum += g_loss.data[0]
g_loss_cnt += 1
# Backward (generator)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# Increase generator iterations
g_iterations += 1
# Save images every once in a while
cnt += 1
if cnt % save_images_every == 0:
# Move generator output to host
g_output_cpu = g_output.data.cpu()
# Normalize images between 0 and 1
real_input_cpu = (real_input_cpu - real_input_cpu.min())/(real_input_cpu.max() - real_input_cpu.min())
g_output_cpu = (g_output_cpu - g_output_cpu.min())/(g_output_cpu.max() - g_output_cpu.min())
# Save images
Image.fromarray(torchvision.utils.make_grid(real_input_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("real_input_" + ("bad" if opt_bad else "good") + ".png")
Image.fromarray(torchvision.utils.make_grid(g_output_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("g_output_" + ("bad" if opt_bad else "good") + ".png")
# Print losses at the end of the epoch
print("Epoch {0}: GL={1:.4f}, DRL={2:.4f}, DFL={3:.4f}".format(epoch, g_loss_sum/g_loss_cnt, d_real_loss_sum/d_loss_cnt, d_fake_loss_sum/d_loss_cnt))
@t-vi
Copy link
Author

t-vi commented Dec 1, 2017

Uses SLOGAN Lipschitz penalty.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment