Skip to content

Instantly share code, notes, and snippets.

@lotabout

lotabout/DCWGAN.py

Created Mar 29, 2018
Embed
What would you like to do?
WGAN implementation
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable, grad
import torch.optim as optim
import torchvision.utils as vutils
from path import Path
class Config(object):
NUM_OF_GPU = 2
USE_CUDA = torch.cuda.is_available()
Z_CHANNELS = 100 # number of noise
G_FEATURES = 64 # number of features used in GNet
D_FEATURES = 64 # number of features used in DNet
OUTPUT_CHANNELS = 3
DATA_ROOT = './data'
IMAGE_SIZE = 96
BATCH_SIZE = 64
NUM_WORKERS = 8
LR = 0.0002 # learning rate
BETA1 = 0.5 # beta1 for adam optimizer
EPOCHES = 5000
EPOCHES_TO_SAVE = 20
DEBUG_FOLDER = './debug'
LAMBDA = 10
CRITIC_ITERS = 5 # only update generate every CRITIC_ITERS
PRE_TRAINED_G = None
PRE_TRAINED_D = None
import os
try:
os.makedirs(Config.DEBUG_FOLDER)
except OSError:
pass
# custom weights initialization called on netG and netD
def weights_init(module):
classname = module.__class__.__name__
if classname.find('Conv') != -1:
module.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
module.weight.data.normal_(1.0, 0.02)
module.bias.data.fill_(0)
class AverageMeter(object):
def __init__(self):
self.reset() # __init__():reset parameters
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class _NetG(nn.Module):
"""docstring for _NetG"""
def __init__(self, config):
super(_NetG, self).__init__()
self.num_of_gpu = config.NUM_OF_GPU
Z_CHANNELS = config.Z_CHANNELS
G_FEATURES = config.G_FEATURES
OUTPUT_CHANNELS = config.OUTPUT_CHANNELS
self.net = nn.Sequential(
#input is noise Z, going into a convolution (Z_CHANNELS * 1 * 1)
nn.ConvTranspose2d(Z_CHANNELS, G_FEATURES*8, kernel_size=4, bias=False),
nn.BatchNorm2d(G_FEATURES*8),
nn.ReLU(inplace=True),
# state size: (G_FEATURES*8) x 4 x 4
nn.ConvTranspose2d(G_FEATURES*8, G_FEATURES*4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(G_FEATURES*4),
nn.ReLU(inplace=True),
# state size: (G_FEATURES*4) x 8 x 8
nn.ConvTranspose2d(G_FEATURES*4, G_FEATURES*2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(G_FEATURES*2),
nn.ReLU(inplace=True),
# state size: (G_FEATURES*4) x 16 x 16
nn.ConvTranspose2d(G_FEATURES*2, G_FEATURES, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(G_FEATURES),
nn.ReLU(inplace=True),
# state size: (G_FEATURES*4) x 32 x 32
nn.ConvTranspose2d(G_FEATURES, OUTPUT_CHANNELS, kernel_size=5, stride=3, padding=1, bias=False),
nn.Tanh()
# state size: (OUTPUT_CHANNLE) x 96 x 96
)
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.num_of_gpu > 1:
output = nn.parallel.data_parallel(self.net, input, range(self.num_of_gpu))
else:
output = self.net(input)
return output
class _NetD(nn.Module):
def __init__(self, config):
super(_NetD, self).__init__()
self.num_of_gpu = config.NUM_OF_GPU
Z_CHANNELS = config.Z_CHANNELS
D_FEATURES = config.D_FEATURES
OUTPUT_CHANNELS = config.OUTPUT_CHANNELS
self.net = nn.Sequential(
# input is: (OUTPUT_CHANNLE) x 96 x 96
nn.Conv2d(OUTPUT_CHANNELS, D_FEATURES, kernel_size=5, stride=3, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size: (D_FEATURES) x 32 x 32
nn.Conv2d(D_FEATURES, D_FEATURES*2, kernel_size=4, stride=2, padding=1, bias=False),
# nn.BatchNorm2d(D_FEATURES*2),
nn.LeakyReLU(0.2, inplace=True),
# state size: (D_FEATURES*2) x 16 x 16
nn.Conv2d(D_FEATURES*2, D_FEATURES*4, kernel_size=4, stride=2, padding=1, bias=False),
# nn.BatchNorm2d(D_FEATURES*4),
nn.LeakyReLU(0.2, inplace=True),
# state size: (D_FEATURES*4) x 8 x 8
nn.Conv2d(D_FEATURES*4, D_FEATURES*8, kernel_size=4, stride=2, padding=1, bias=False),
# nn.BatchNorm2d(D_FEATURES*8),
nn.LeakyReLU(0.2, inplace=True),
# state size: (D_FEATURES*8) x 4 x 4
nn.Conv2d(D_FEATURES*8, 1, kernel_size=4, stride=1, padding=0, bias=False),
)
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.num_of_gpu > 1:
output = nn.parallel.data_parallel(self.net, input, range(self.num_of_gpu))
else:
output = self.net(input)
return output.view(-1, 1).squeeze(1)
def cal_gradient_penalty(netD, real_data, fake_data, config=Config()):
alpha = torch.rand(config.BATCH_SIZE, 1)
alpha = alpha.expand(config.BATCH_SIZE, real_data.nelement()//config.BATCH_SIZE).contiguous().view(real_data.size())
alpha = alpha.cuda() if config.USE_CUDA else alpha
interpolates = alpha * real_data + ((1-alpha) * fake_data)
if config.USE_CUDA:
interpolates = interpolates.cuda()
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
grad_outputs = torch.ones(disc_interpolates.size())
grad_outputs = grad_outputs.cuda() if config.USE_CUDA else grad_outputs
gradients = grad(disc_interpolates, interpolates, grad_outputs=grad_outputs,
create_graph=True, retain_graph=True, only_inputs=True)[0]
return ((gradients.norm(2, dim=1)-1) ** 2).mean()
def train(config=Config(), **kwargs):
transformers = transforms.Compose([
transforms.Resize(config.IMAGE_SIZE),
transforms.CenterCrop(config.IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = datasets.ImageFolder(root=config.DATA_ROOT, transform=transformers)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True,
num_workers=config.NUM_WORKERS)
netG = _NetG(config)
if config.PRE_TRAINED_G is not None:
netG.load_state_dict(torch.load(config.PRE_TRAINED_G))
else:
netG.apply(weights_init)
netD = _NetD(config)
if config.PRE_TRAINED_D is not None:
netD.load_state_dict(torch.load(CONFID.PRE_TRAINED_D))
else:
netD.apply(weights_init)
# prepare input for GNet
input = torch.FloatTensor(config.BATCH_SIZE, 3, config.IMAGE_SIZE, config.IMAGE_SIZE)
noise = torch.FloatTensor(config.BATCH_SIZE, config.Z_CHANNELS, 1, 1)
fixed_noise = torch.FloatTensor(config.BATCH_SIZE, config.Z_CHANNELS, 1, 1).normal_(0, 1)
one = torch.FloatTensor([1])
mone = one * -1
if config.USE_CUDA:
netD.cuda()
netG.cuda()
input = input.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
one, mone = one.cuda(), mone.cuda()
fixed_noise = Variable(fixed_noise)
# setup optimizer
optimizerG = optim.Adam(netG.parameters(), lr=config.LR, betas=(config.BETA1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=config.LR, betas=(config.BETA1, 0.999))
for epoch in range(config.EPOCHES):
Wasserstein_Ds = AverageMeter()
D_costs = AverageMeter()
for batch_idx, (images, _) in enumerate(dataloader):
#==================================================
# update D network
# train with real
netD.zero_grad()
batch_size = images.size(0)
if config.USE_CUDA:
images = images.cuda()
input.resize_as_(images).copy_(images)
inputv = Variable(input)
D_real = netD(inputv)
D_real = D_real.mean()
# train with fake
noise.resize_(batch_size, config.Z_CHANNELS, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
D_fake = netD(fake.detach()) # detach so that netG won't be affected by D_fake.backward()
D_fake = D_fake.mean()
penalty = cal_gradient_penalty(netD, images, fake.data, config=config) * config.LAMBDA
D_cost = D_fake - D_real + penalty
D_cost.backward()
D_costs.update(D_cost.data[0])
Wasserstein_D = D_real - D_fake
Wasserstein_Ds.update(Wasserstein_D.data[0])
optimizerD.step()
#==================================================
# update G network: maximize log(D(G(z)))
if (batch_idx + epoch * len(dataloader)) % config.CRITIC_ITERS == 0:
# only update generator every CRITIC_ITERS
netG.zero_grad()
noise.resize_(batch_size, config.Z_CHANNELS, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
G = netD(fake)
G = -G.mean()
G.backward()
optimizerG.step()
if batch_idx % config.CRITIC_ITERS == 0:
print(f'[{epoch}/{config.EPOCHES}] [{batch_idx}/{len(dataloader)}] '
f'Wasserstein_D: {Wasserstein_D.data[0]:.4f}/{Wasserstein_Ds.avg:.4f} '
f'Loss: {-D_cost.data[0]:.4f}/{-D_costs.avg:.4f}')
if batch_idx % 100 == 0:
vutils.save_image(images[:64], f'{config.DEBUG_FOLDER}/real_samples.png', normalize=True)
fake = netG(fixed_noise)
vutils.save_image(fake.data[:64], f'{config.DEBUG_FOLDER}/fake_samples_epoch_{epoch:03d}.png', normalize=True)
# do checkpointing
if (epoch+1) % config.EPOCHES_TO_SAVE == 0:
torch.save(netG.state_dict(), f'{config.DEBUG_FOLDER}/netG_epoch_{epoch}.pth')
torch.save(netD.state_dict(), f'{config.DEBUG_FOLDER}/netD_epoch_{epoch}.pth')
config = Config()
train(config=config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment