Skip to content

Instantly share code, notes, and snippets.

@SimonKohl
Last active May 18, 2017 15:26
Show Gist options
  • Save SimonKohl/27e499681ec4f01c1d1d1909f3f4bd71 to your computer and use it in GitHub Desktop.
Save SimonKohl/27e499681ec4f01c1d1d1909f3f4bd71 to your computer and use it in GitHub Desktop.
Lasagne WGAN example employing "the improved training" strategy
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Example employing Lasagne for digit generation using the MNIST dataset and
Wasserstein Generative Adversarial Networks
(WGANs, see https://arxiv.org/abs/1701.07875 for the paper and
https://github.com/martinarjovsky/WassersteinGAN for the "official" code),
employing the gradient norm penalty to enforce a Lipschitz-1 critic
(WGAN with gradient penalty, see 'Improved Training of Wasserstein GANs',
https://arxiv.org/abs/1704.00028 and
https://github.com/igul222/improved_wgan_training/ for the "official code").
It is based on a WGAN example by Jan Schlüter:
https://gist.github.com/f0k/f3190ebba6c53887d598d03119ca2066
Which in turn is based on a DCGAN example:
https://gist.github.com/f0k/738fa2eedd9666b78404ed1751336f56
This, in turn, is based on the MNIST example in Lasagne:
https://lasagne.readthedocs.io/en/latest/user/tutorial.html
Simon Kohl, 2017-18-05
"""
from __future__ import print_function
import sys
import os
import time
import numpy as np
import theano
import theano.tensor as T
import lasagne
# ################## Download and prepare the MNIST dataset ##################
# This is just some way of getting the MNIST dataset from an online location
# and loading it into numpy arrays. It doesn't involve Lasagne at all.
def load_dataset():
# We first define a download function, supporting both Python 2 and 3.
if sys.version_info[0] == 2:
from urllib import urlretrieve
else:
from urllib.request import urlretrieve
def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
print("Downloading %s" % filename)
urlretrieve(source + filename, filename)
# We then define functions for loading MNIST images and labels.
# For convenience, they also download the requested files if needed.
import gzip
def load_mnist_images(filename):
if not os.path.exists(filename):
download(filename)
# Read the inputs in Yann LeCun's binary format.
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
# The inputs are vectors now, we reshape them to monochrome 2D images,
# following the shape convention: (examples, channels, rows, columns)
data = data.reshape(-1, 1, 28, 28)
# The inputs come as bytes, we convert them to float32 in range [0,1].
# (Actually to range [0, 255/256], for compatibility to the version
# provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
return data / np.float32(256)
def load_mnist_labels(filename):
if not os.path.exists(filename):
download(filename)
# Read the labels in Yann LeCun's binary format.
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
# The labels are vectors of integers now, that's exactly what we want.
return data
# We can now download and read the training and test set images and labels.
X_train = load_mnist_images('train-images-idx3-ubyte.gz')
y_train = load_mnist_labels('train-labels-idx1-ubyte.gz')
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz')
# We reserve the last 10000 training examples for validation.
X_train, X_val = X_train[:-10000], X_train[-10000:]
y_train, y_val = y_train[:-10000], y_train[-10000:]
# We just return all the arrays in order, as expected in main().
# (It doesn't matter how we do this as long as we can read them again.)
return X_train, y_train, X_val, y_val, X_test, y_test
# ##################### Build the neural network model #######################
# We create two models: The generator and the critic network.
# The models are the same as in the Lasagne DCGAN example, except that the
# discriminator is now a critic with linear output instead of sigmoid output.
def build_generator(input_var=None):
from lasagne.layers import InputLayer, ReshapeLayer, DenseLayer
try:
from lasagne.layers import TransposedConv2DLayer as Deconv2DLayer
except ImportError:
raise ImportError("Your Lasagne is too old. Try the bleeding-edge "
"version: http://lasagne.readthedocs.io/en/latest/"
"user/installation.html#bleeding-edge-version")
try:
from lasagne.layers.dnn import batch_norm_dnn as batch_norm
except ImportError:
from lasagne.layers import batch_norm
from lasagne.nonlinearities import sigmoid
# input: 100dim
layer = InputLayer(shape=(None, 100), input_var=input_var)
# fully-connected layer
layer = batch_norm(DenseLayer(layer, 1024))
# project and reshape
layer = batch_norm(DenseLayer(layer, 128*7*7))
layer = ReshapeLayer(layer, ([0], 128, 7, 7))
# two fractional-stride convolutions
layer = batch_norm(Deconv2DLayer(layer, 64, 5, stride=2, crop='same',
output_size=14))
layer = Deconv2DLayer(layer, 1, 5, stride=2, crop='same', output_size=28,
nonlinearity=sigmoid)
print ("Generator output:", layer.output_shape)
return layer
def build_critic(input_var=None):
from lasagne.layers import (InputLayer, Conv2DLayer, ReshapeLayer,
DenseLayer)
from lasagne.nonlinearities import LeakyRectify
lrelu = LeakyRectify(0.2)
# input: (None, 1, 28, 28)
layer = InputLayer(shape=(None, 1, 28, 28), input_var=input_var)
# two convolutions
layer = Conv2DLayer(layer, 64, 5, stride=2, pad='same',
nonlinearity=lrelu)
layer = Conv2DLayer(layer, 128, 5, stride=2, pad='same',
nonlinearity=lrelu)
# fully-connected layer
layer = DenseLayer(layer, 1024, nonlinearity=lrelu)
# output layer (linear and without bias)
layer = DenseLayer(layer, 1, nonlinearity=None, b=None)
print ("critic output:", layer.output_shape)
return layer
# ############################# Batch iterator ###############################
# This is just a simple helper function iterating over training data in
# mini-batches of a particular size, optionally in random order. It assumes
# data is available as numpy arrays. For big datasets, you could load numpy
# arrays as memory-mapped files (np.load(..., mmap_mode='r')), or write your
# own custom data iteration function. For small datasets, you can also copy
# them to GPU at once for slightly improved performance. This would involve
# several changes in the main program, though, and is not demonstrated here.
def iterate_minibatches(inputs, targets, batchsize, shuffle=False,
forever=False):
assert len(inputs) == len(targets)
if shuffle:
indices = np.arange(len(inputs))
while True:
if shuffle:
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt], targets[excerpt]
if not forever:
break
# ############################## Main program ################################
# Everything else will be handled in our main program now. We could pull out
# more functions to better separate the code, but it wouldn't make it any
# easier to read.
def main(num_epochs=1000, epochsize=100, batchsize=64, initial_eta=5e-5, alpha=10):
# Load the dataset
print("Loading data...")
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset()
# Prepare Theano variables for inputs and targets
noise_var = T.matrix('noise')
input_var = T.tensor4('inputs')
# Prepare theano random stream
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
srng = RandomStreams(seed=np.random.randint(2147462579, size=6))
# Create neural network model
print("Building model and compiling functions...")
generator = build_generator(noise_var)
critic = build_critic(input_var)
# Create expression for passing real data through the critic
real_out = lasagne.layers.get_output(critic)
# Create expression for passing fake data through the critic
gen_out = lasagne.layers.get_output(generator)
fake_out = lasagne.layers.get_output(critic, gen_out)
# Create score expressions to be maximized (i.e., negative losses)
generator_score = fake_out.mean()
critic_score = real_out.mean() - fake_out.mean()
# Penalize gradient norm of critic as a soft constraint to enforce Lipschitz-1 criterium
epsilon = srng.uniform((batchsize, 1, 1, 1), low=0., high=1.)
interpolates = epsilon * input_var + (1 - epsilon) * gen_out
interpolates_out_summed = lasagne.layers.get_output(critic, interpolates).sum()
gradients = theano.grad(interpolates_out_summed, wrt=interpolates)
slopes = T.sqrt(T.sum(T.sqr(gradients), axis=(1, 2, 3)))
gradient_penalty = T.mean((slopes-1.)**2)
full_critic_score = critic_score - alpha * gradient_penalty
# Create update expressions for training
generator_params = lasagne.layers.get_all_params(generator, trainable=True)
critic_params = lasagne.layers.get_all_params(critic, trainable=True)
eta = theano.shared(lasagne.utils.floatX(initial_eta))
generator_updates = lasagne.updates.rmsprop(
-generator_score, generator_params, learning_rate=eta)
critic_updates = lasagne.updates.rmsprop(
-full_critic_score, critic_params, learning_rate=eta)
# Instantiate a symbolic noise generator to use for training
noise = srng.uniform((batchsize, 100))
# Compile functions performing a training step on a mini-batch (according
# to the updates dictionary) and returning the corresponding score:
generator_train_fn = theano.function([], generator_score,
givens={noise_var: noise},
updates=generator_updates)
critic_train_fn = theano.function([input_var], [critic_score, critic_score - full_critic_score],
givens={noise_var: noise},
updates=critic_updates)
# Compile another function generating some data
gen_fn = theano.function([noise_var],
lasagne.layers.get_output(generator,
deterministic=True))
# Finally, launch the training loop.
print("Starting training...")
# We create an infinite supply of batches (as an iterable generator):
batches = iterate_minibatches(X_train, y_train, batchsize, shuffle=True,
forever=True)
# We iterate over epochs:
generator_updates = 0
for epoch in range(num_epochs):
start_time = time.time()
# In each epoch, we do `epochsize` generator updates. Usually, the
# critic is updated 5 times before every generator update. For the
# first 25 generator updates and every 500 generator updates, the
# critic is updated 100 times instead, following the authors' code.
critic_scores = []
gradient_penalty_scores = []
generator_scores = []
for _ in range(epochsize):
if (generator_updates < 25) or (generator_updates % 500 == 0):
critic_runs = 100
else:
critic_runs = 5
for _ in range(critic_runs):
batch = next(batches)
inputs, targets = batch
c_score, grad_penalty = critic_train_fn(inputs)
critic_scores.append(c_score)
gradient_penalty_scores.append(grad_penalty)
generator_scores.append(generator_train_fn())
generator_updates += 1
# Then we print the results for this epoch:
print("Epoch {} of {} took {:.3f}s".format(
epoch + 1, num_epochs, time.time() - start_time))
print(" generator score:\t\t{}".format(np.mean(generator_scores)))
print(" Wasserstein distance:\t\t{}".format(np.mean(critic_scores)))
print(" gradient norm penalty:\t\t{}".format(np.mean(gradient_penalty_scores)))
# And finally, we plot some generated data
samples = gen_fn(lasagne.utils.floatX(np.random.rand(42, 100)))
try:
import matplotlib.pyplot as plt
except ImportError:
pass
else:
plt.imsave('improved_wgan_mnist_samples.png',
(samples.reshape(6, 7, 28, 28)
.transpose(0, 2, 1, 3)
.reshape(6*28, 7*28)),
cmap='gray')
# After half the epochs, we start decaying the learn rate towards zero
if epoch >= num_epochs // 2:
progress = float(epoch) / num_epochs
eta.set_value(lasagne.utils.floatX(initial_eta*2*(1 - progress)))
# Optionally, you could now dump the network weights to a file like this:
np.savez('wgan_mnist_gen.npz', *lasagne.layers.get_all_param_values(generator))
np.savez('wgan_mnist_crit.npz', *lasagne.layers.get_all_param_values(critic))
#
# And load them again later on like this:
# with np.load('model.npz') as f:
# param_values = [f['arr_%d' % i] for i in range(len(f.files))]
# lasagne.layers.set_all_param_values(network, param_values)
if __name__ == '__main__':
if ('--help' in sys.argv) or ('-h' in sys.argv):
print("Trains a WGAN on MNIST using Lasagne.")
print("Usage: %s [EPOCHS [EPOCHSIZE]]" % sys.argv[0])
print()
print("EPOCHS: number of training epochs to perform (default: 1000)")
print("EPOCHSIZE: number of generator updates per epoch (default: 100)")
else:
kwargs = {}
if len(sys.argv) > 1:
kwargs['num_epochs'] = int(sys.argv[1])
if len(sys.argv) > 2:
kwargs['epochsize'] = int(sys.argv[2])
main(**kwargs)
@f0k
Copy link

f0k commented May 18, 2017

Cool, thanks for posting!

The original code uses a single set of generated samples for the adversarial loss (https://github.com/igul222/improved_wgan_training/blob/master/gan_mnist.py#L107) and the Lipschitz penalty (https://github.com/igul222/improved_wgan_training/blob/master/gan_mnist.py#L143), while you generate an additional set of samples (https://gist.github.com/SimonKohl/27e499681ec4f01c1d1d1909f3f4bd71#file-improved_wgan_mnist-py-L210). It should be faster to train if you reuse the data as well (i.e., do a single lasagne.layers.get_output(generator) call, assign it to a variable, and reuse it).

@SimonKohl
Copy link
Author

Thanks, good point! Changed that, should be fine now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment