Created
November 29, 2017 16:25
-
-
Save simopal6/c6484df00d5747dfe33f7ed67383c6fd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define options | |
import argparse | |
parser = argparse.ArgumentParser(description="WGAN") | |
# Dataset options | |
parser.add_argument('-d', '--dataset', default="parent_of_n02510455", help="dataset directory (for ImageFolder)") | |
# Training options | |
parser.add_argument('--bad', action="store_true", help="use \"bad\" normalization values") | |
parser.add_argument('-b', '--batch-size', default=16, type=int, help="batch size") | |
# Backend options | |
parser.add_argument('--no-cuda', help="disable CUDA", action="store_true") | |
# Read options | |
opt = parser.parse_args() | |
print(opt) | |
# Imports | |
import torch | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
import torchvision | |
from torchvision import transforms, datasets | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.optim | |
import torch.backends.cudnn as cudnn; cudnn.benchmark = True | |
# Setup data transforms | |
load_size = 80 | |
crop_size = 64 | |
if opt.bad: | |
mean = (0.485, 0.456, 0.406) | |
std=(0.229, 0.224, 0.225) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
train_transform = transforms.Compose([ | |
transforms.Scale(load_size), | |
transforms.RandomCrop(crop_size), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std) | |
]) | |
# Load dataset | |
train_dataset = datasets.ImageFolder(root = opt.dataset, transform = train_transform) | |
# Create loader | |
loader = DataLoader(train_dataset, batch_size = opt.batch_size, shuffle = True, num_workers = 4, pin_memory = False if opt.no_cuda else True, drop_last = True) | |
# Discriminator ("critic") -- Wasserstein | |
class WDiscriminator(nn.Module): | |
def __init__(self, isize, ndf, nc = 3): | |
super(WDiscriminator, self).__init__() | |
assert isize % 16 == 0, "isize has to be a multiple of 16" | |
main = nn.Sequential() | |
# input is nc x isize x isize | |
main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) | |
main.add_module('initial.relu.{0}'.format(ndf), nn.LeakyReLU(0.2, inplace=True)) | |
csize, cndf = isize / 2, ndf | |
# Reduce map size | |
while csize > 4: | |
in_feat = cndf | |
out_feat = cndf * 2 | |
main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) | |
main.add_module('pyramid.{0}.batchnorm'.format(out_feat), nn.BatchNorm2d(out_feat)) | |
main.add_module('pyramid.{0}.relu'.format(out_feat), nn.LeakyReLU(0.2, inplace=True)) | |
cndf = cndf * 2 | |
csize = csize / 2 | |
# state size. K x 4 x 4 | |
main.add_module('final.{0}-{1}.conv'.format(cndf, 1), nn.Conv2d(cndf, 1, 4, 1, 0, bias=False)) | |
self.main = main | |
def forward(self, input): | |
output = self.main(input) | |
output = output.mean(0) | |
return output.view(1) | |
# Generator -- Wasserstein | |
class WGenerator(nn.Module): | |
def __init__(self, isize, nz, ngf, nc = 3): | |
super(WGenerator, self).__init__() | |
assert isize % 16 == 0, "isize has to be a multiple of 16" | |
cngf, tisize = ngf//2, 4 | |
while tisize != isize: | |
cngf = cngf * 2 | |
tisize = tisize * 2 | |
main = nn.Sequential() | |
# input is Z, going into a convolution | |
main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) | |
main.add_module('initial.{0}.batchnorm'.format(cngf), nn.BatchNorm2d(cngf)) | |
main.add_module('initial.{0}.relu'.format(cngf), nn.ReLU(True)) | |
csize, cndf = 4, cngf | |
while csize < isize//2: | |
main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) | |
main.add_module('pyramid.{0}.batchnorm'.format(cngf//2), nn.BatchNorm2d(cngf//2)) | |
main.add_module('pyramid.{0}.relu'.format(cngf//2), nn.ReLU(True)) | |
cngf = cngf // 2 | |
csize = csize * 2 | |
main.add_module('final.{0}-{1}.convt'.format(cngf, nc), nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) | |
main.add_module('final.{0}.tanh'.format(nc), nn.Tanh()) | |
self.main = main | |
def forward(self, input): | |
output = self.main(input) | |
return output | |
# Custom weight initialization | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
# Create generator model/optimizer | |
g_net = WGenerator(nz=100, isize=64, ngf=64) | |
g_net.apply(weights_init) | |
g_optimizer = torch.optim.RMSprop(g_net.parameters(), lr=5e-5) | |
# Create discriminator model/optimizer | |
d_net = WDiscriminator(isize=64, ndf=64) | |
d_net.apply(weights_init) | |
d_optimizer = torch.optim.RMSprop(d_net.parameters(), lr=5e-5) | |
# Setup CUDA | |
if not opt.no_cuda: | |
g_net.cuda() | |
d_net.cuda() | |
print("Copied to CUDA") | |
# Debug options | |
save_images_every = 20 | |
cnt = 0 | |
# Auxiliary variables | |
noise = torch.FloatTensor(opt.batch_size, 100, 1, 1) | |
one = torch.FloatTensor([1]) | |
minus_one = one*-1 | |
if not opt.no_cuda: | |
noise = noise.cuda() | |
one = one.cuda() | |
minus_one = minus_one.cuda() | |
# Start training | |
d_net.train() | |
g_net.train() | |
g_iterations = 0 | |
for epoch in range(0, 10000): | |
# Keep track of losses | |
d_real_loss_sum = 0 | |
d_fake_loss_sum = 0 | |
d_loss_cnt = 0 | |
g_loss_sum = 0 | |
g_loss_cnt = 0 | |
# Get data iterator | |
data_iter = iter(loader) | |
data_len = len(loader) | |
data_i = 0 | |
# Process until data ends | |
while data_i < data_len: | |
# Compute gradients for discriminator | |
for p in d_net.parameters(): p.requires_grad = True | |
# Set number of discrimator iterations | |
d_iters = 100 if g_iterations <= 25 or g_iterations % 500 == 0 else 5 | |
# Perform discriminator iterations | |
d_i = 0 | |
while data_i < data_len and d_i < d_iters: | |
# Increase discriminator iterations | |
d_i += 1 | |
# Clamp parameters to a cube | |
for p in d_net.parameters(): | |
p.data.clamp_(-0.01, 0.01) | |
# Get data (keep reference to data on host) | |
(real_input_cpu, _) = data_iter.next() | |
real_input = real_input_cpu | |
data_i += 1 | |
# Check CUDA | |
if not opt.no_cuda: real_input = real_input.cuda(async = True) | |
# Wrap for autograd | |
real_input = Variable(real_input) | |
# Reset gradients | |
d_optimizer.zero_grad() | |
g_optimizer.zero_grad() | |
# Forward (discriminator, real) | |
d_real_loss = d_net(real_input) | |
d_real_loss_sum += d_real_loss.data[0] | |
# Backward (discriminator, real) | |
d_real_loss.backward(one) | |
# Forward (discriminator, fake) | |
noise.normal_(0,1) | |
noise_v = Variable(noise, volatile = True) | |
g_output = Variable(g_net(noise_v).data) | |
d_fake_loss = d_net(g_output) | |
d_fake_loss_sum += d_fake_loss.data[0] | |
# Backward (discriminator, fake) | |
d_fake_loss.backward(minus_one) | |
# Update discriminator | |
d_optimizer.step() | |
# Update loss count | |
d_loss_cnt += 1 | |
# Don't compute gradients w.r.t. parameters for discriminator | |
for p in d_net.parameters(): p.requires_grad = False | |
# Forward (generator) | |
noise.normal_(0,1) | |
noise_v = Variable(noise) | |
g_output = g_net(noise_v) | |
g_loss = d_net(g_output) | |
g_loss_sum += g_loss.data[0] | |
g_loss_cnt += 1 | |
# Backward (generator) | |
g_loss.backward(one) | |
g_optimizer.step() | |
# Increase generator iterations | |
g_iterations += 1 | |
# Save images every once in a while | |
cnt += 1 | |
if cnt % save_images_every == 0: | |
# Move generator output to host | |
g_output_cpu = g_output.data.cpu() | |
# Normalize images between 0 and 1 | |
real_input_cpu = (real_input_cpu - real_input_cpu.min())/(real_input_cpu.max() - real_input_cpu.min()) | |
g_output_cpu = (g_output_cpu - g_output_cpu.min())/(g_output_cpu.max() - g_output_cpu.min()) | |
# Save images | |
Image.fromarray(torchvision.utils.make_grid(real_input_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("real_input_" + ("bad" if opt.bad else "good") + ".png") | |
Image.fromarray(torchvision.utils.make_grid(g_output_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("g_output_" + ("bad" if opt.bad else "good") + ".png") | |
# Print losses at the end of the epoch | |
print("Epoch {0}: GL={1:.4f}, DRL={2:.4f}, DFL={3:.4f}".format(epoch, g_loss_sum/g_loss_cnt, d_real_loss_sum/d_loss_cnt, d_fake_loss_sum/d_loss_cnt)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment