Skip to content

Instantly share code, notes, and snippets.

@simopal6
Created November 29, 2017 16:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save simopal6/c6484df00d5747dfe33f7ed67383c6fd to your computer and use it in GitHub Desktop.
Save simopal6/c6484df00d5747dfe33f7ed67383c6fd to your computer and use it in GitHub Desktop.
# Define options
import argparse
parser = argparse.ArgumentParser(description="WGAN")
# Dataset options
parser.add_argument('-d', '--dataset', default="parent_of_n02510455", help="dataset directory (for ImageFolder)")
# Training options
parser.add_argument('--bad', action="store_true", help="use \"bad\" normalization values")
parser.add_argument('-b', '--batch-size', default=16, type=int, help="batch size")
# Backend options
parser.add_argument('--no-cuda', help="disable CUDA", action="store_true")
# Read options
opt = parser.parse_args()
print(opt)
# Imports
import torch
from PIL import Image
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
# Setup data transforms
load_size = 80
crop_size = 64
if opt.bad:
mean = (0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
else:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
train_transform = transforms.Compose([
transforms.Scale(load_size),
transforms.RandomCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# Load dataset
train_dataset = datasets.ImageFolder(root = opt.dataset, transform = train_transform)
# Create loader
loader = DataLoader(train_dataset, batch_size = opt.batch_size, shuffle = True, num_workers = 4, pin_memory = False if opt.no_cuda else True, drop_last = True)
# Discriminator ("critic") -- Wasserstein
class WDiscriminator(nn.Module):
def __init__(self, isize, ndf, nc = 3):
super(WDiscriminator, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
main = nn.Sequential()
# input is nc x isize x isize
main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
main.add_module('initial.relu.{0}'.format(ndf), nn.LeakyReLU(0.2, inplace=True))
csize, cndf = isize / 2, ndf
# Reduce map size
while csize > 4:
in_feat = cndf
out_feat = cndf * 2
main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))
main.add_module('pyramid.{0}.batchnorm'.format(out_feat), nn.BatchNorm2d(out_feat))
main.add_module('pyramid.{0}.relu'.format(out_feat), nn.LeakyReLU(0.2, inplace=True))
cndf = cndf * 2
csize = csize / 2
# state size. K x 4 x 4
main.add_module('final.{0}-{1}.conv'.format(cndf, 1), nn.Conv2d(cndf, 1, 4, 1, 0, bias=False))
self.main = main
def forward(self, input):
output = self.main(input)
output = output.mean(0)
return output.view(1)
# Generator -- Wasserstein
class WGenerator(nn.Module):
def __init__(self, isize, nz, ngf, nc = 3):
super(WGenerator, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf//2, 4
while tisize != isize:
cngf = cngf * 2
tisize = tisize * 2
main = nn.Sequential()
# input is Z, going into a convolution
main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False))
main.add_module('initial.{0}.batchnorm'.format(cngf), nn.BatchNorm2d(cngf))
main.add_module('initial.{0}.relu'.format(cngf), nn.ReLU(True))
csize, cndf = 4, cngf
while csize < isize//2:
main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False))
main.add_module('pyramid.{0}.batchnorm'.format(cngf//2), nn.BatchNorm2d(cngf//2))
main.add_module('pyramid.{0}.relu'.format(cngf//2), nn.ReLU(True))
cngf = cngf // 2
csize = csize * 2
main.add_module('final.{0}-{1}.convt'.format(cngf, nc), nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
main.add_module('final.{0}.tanh'.format(nc), nn.Tanh())
self.main = main
def forward(self, input):
output = self.main(input)
return output
# Custom weight initialization
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# Create generator model/optimizer
g_net = WGenerator(nz=100, isize=64, ngf=64)
g_net.apply(weights_init)
g_optimizer = torch.optim.RMSprop(g_net.parameters(), lr=5e-5)
# Create discriminator model/optimizer
d_net = WDiscriminator(isize=64, ndf=64)
d_net.apply(weights_init)
d_optimizer = torch.optim.RMSprop(d_net.parameters(), lr=5e-5)
# Setup CUDA
if not opt.no_cuda:
g_net.cuda()
d_net.cuda()
print("Copied to CUDA")
# Debug options
save_images_every = 20
cnt = 0
# Auxiliary variables
noise = torch.FloatTensor(opt.batch_size, 100, 1, 1)
one = torch.FloatTensor([1])
minus_one = one*-1
if not opt.no_cuda:
noise = noise.cuda()
one = one.cuda()
minus_one = minus_one.cuda()
# Start training
d_net.train()
g_net.train()
g_iterations = 0
for epoch in range(0, 10000):
# Keep track of losses
d_real_loss_sum = 0
d_fake_loss_sum = 0
d_loss_cnt = 0
g_loss_sum = 0
g_loss_cnt = 0
# Get data iterator
data_iter = iter(loader)
data_len = len(loader)
data_i = 0
# Process until data ends
while data_i < data_len:
# Compute gradients for discriminator
for p in d_net.parameters(): p.requires_grad = True
# Set number of discrimator iterations
d_iters = 100 if g_iterations <= 25 or g_iterations % 500 == 0 else 5
# Perform discriminator iterations
d_i = 0
while data_i < data_len and d_i < d_iters:
# Increase discriminator iterations
d_i += 1
# Clamp parameters to a cube
for p in d_net.parameters():
p.data.clamp_(-0.01, 0.01)
# Get data (keep reference to data on host)
(real_input_cpu, _) = data_iter.next()
real_input = real_input_cpu
data_i += 1
# Check CUDA
if not opt.no_cuda: real_input = real_input.cuda(async = True)
# Wrap for autograd
real_input = Variable(real_input)
# Reset gradients
d_optimizer.zero_grad()
g_optimizer.zero_grad()
# Forward (discriminator, real)
d_real_loss = d_net(real_input)
d_real_loss_sum += d_real_loss.data[0]
# Backward (discriminator, real)
d_real_loss.backward(one)
# Forward (discriminator, fake)
noise.normal_(0,1)
noise_v = Variable(noise, volatile = True)
g_output = Variable(g_net(noise_v).data)
d_fake_loss = d_net(g_output)
d_fake_loss_sum += d_fake_loss.data[0]
# Backward (discriminator, fake)
d_fake_loss.backward(minus_one)
# Update discriminator
d_optimizer.step()
# Update loss count
d_loss_cnt += 1
# Don't compute gradients w.r.t. parameters for discriminator
for p in d_net.parameters(): p.requires_grad = False
# Forward (generator)
noise.normal_(0,1)
noise_v = Variable(noise)
g_output = g_net(noise_v)
g_loss = d_net(g_output)
g_loss_sum += g_loss.data[0]
g_loss_cnt += 1
# Backward (generator)
g_loss.backward(one)
g_optimizer.step()
# Increase generator iterations
g_iterations += 1
# Save images every once in a while
cnt += 1
if cnt % save_images_every == 0:
# Move generator output to host
g_output_cpu = g_output.data.cpu()
# Normalize images between 0 and 1
real_input_cpu = (real_input_cpu - real_input_cpu.min())/(real_input_cpu.max() - real_input_cpu.min())
g_output_cpu = (g_output_cpu - g_output_cpu.min())/(g_output_cpu.max() - g_output_cpu.min())
# Save images
Image.fromarray(torchvision.utils.make_grid(real_input_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("real_input_" + ("bad" if opt.bad else "good") + ".png")
Image.fromarray(torchvision.utils.make_grid(g_output_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("g_output_" + ("bad" if opt.bad else "good") + ".png")
# Print losses at the end of the epoch
print("Epoch {0}: GL={1:.4f}, DRL={2:.4f}, DFL={3:.4f}".format(epoch, g_loss_sum/g_loss_cnt, d_real_loss_sum/d_loss_cnt, d_fake_loss_sum/d_loss_cnt))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment