Last active
May 18, 2017 15:26
-
-
Save SimonKohl/27e499681ec4f01c1d1d1909f3f4bd71 to your computer and use it in GitHub Desktop.
Lasagne WGAN example employing "the improved training" strategy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Example employing Lasagne for digit generation using the MNIST dataset and | |
Wasserstein Generative Adversarial Networks | |
(WGANs, see https://arxiv.org/abs/1701.07875 for the paper and | |
https://github.com/martinarjovsky/WassersteinGAN for the "official" code), | |
employing the gradient norm penalty to enforce a Lipschitz-1 critic | |
(WGAN with gradient penalty, see 'Improved Training of Wasserstein GANs', | |
https://arxiv.org/abs/1704.00028 and | |
https://github.com/igul222/improved_wgan_training/ for the "official code"). | |
It is based on a WGAN example by Jan Schlüter: | |
https://gist.github.com/f0k/f3190ebba6c53887d598d03119ca2066 | |
Which in turn is based on a DCGAN example: | |
https://gist.github.com/f0k/738fa2eedd9666b78404ed1751336f56 | |
This, in turn, is based on the MNIST example in Lasagne: | |
https://lasagne.readthedocs.io/en/latest/user/tutorial.html | |
Simon Kohl, 2017-18-05 | |
""" | |
from __future__ import print_function | |
import sys | |
import os | |
import time | |
import numpy as np | |
import theano | |
import theano.tensor as T | |
import lasagne | |
# ################## Download and prepare the MNIST dataset ################## | |
# This is just some way of getting the MNIST dataset from an online location | |
# and loading it into numpy arrays. It doesn't involve Lasagne at all. | |
def load_dataset(): | |
# We first define a download function, supporting both Python 2 and 3. | |
if sys.version_info[0] == 2: | |
from urllib import urlretrieve | |
else: | |
from urllib.request import urlretrieve | |
def download(filename, source='http://yann.lecun.com/exdb/mnist/'): | |
print("Downloading %s" % filename) | |
urlretrieve(source + filename, filename) | |
# We then define functions for loading MNIST images and labels. | |
# For convenience, they also download the requested files if needed. | |
import gzip | |
def load_mnist_images(filename): | |
if not os.path.exists(filename): | |
download(filename) | |
# Read the inputs in Yann LeCun's binary format. | |
with gzip.open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=16) | |
# The inputs are vectors now, we reshape them to monochrome 2D images, | |
# following the shape convention: (examples, channels, rows, columns) | |
data = data.reshape(-1, 1, 28, 28) | |
# The inputs come as bytes, we convert them to float32 in range [0,1]. | |
# (Actually to range [0, 255/256], for compatibility to the version | |
# provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) | |
return data / np.float32(256) | |
def load_mnist_labels(filename): | |
if not os.path.exists(filename): | |
download(filename) | |
# Read the labels in Yann LeCun's binary format. | |
with gzip.open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=8) | |
# The labels are vectors of integers now, that's exactly what we want. | |
return data | |
# We can now download and read the training and test set images and labels. | |
X_train = load_mnist_images('train-images-idx3-ubyte.gz') | |
y_train = load_mnist_labels('train-labels-idx1-ubyte.gz') | |
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz') | |
y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz') | |
# We reserve the last 10000 training examples for validation. | |
X_train, X_val = X_train[:-10000], X_train[-10000:] | |
y_train, y_val = y_train[:-10000], y_train[-10000:] | |
# We just return all the arrays in order, as expected in main(). | |
# (It doesn't matter how we do this as long as we can read them again.) | |
return X_train, y_train, X_val, y_val, X_test, y_test | |
# ##################### Build the neural network model ####################### | |
# We create two models: The generator and the critic network. | |
# The models are the same as in the Lasagne DCGAN example, except that the | |
# discriminator is now a critic with linear output instead of sigmoid output. | |
def build_generator(input_var=None): | |
from lasagne.layers import InputLayer, ReshapeLayer, DenseLayer | |
try: | |
from lasagne.layers import TransposedConv2DLayer as Deconv2DLayer | |
except ImportError: | |
raise ImportError("Your Lasagne is too old. Try the bleeding-edge " | |
"version: http://lasagne.readthedocs.io/en/latest/" | |
"user/installation.html#bleeding-edge-version") | |
try: | |
from lasagne.layers.dnn import batch_norm_dnn as batch_norm | |
except ImportError: | |
from lasagne.layers import batch_norm | |
from lasagne.nonlinearities import sigmoid | |
# input: 100dim | |
layer = InputLayer(shape=(None, 100), input_var=input_var) | |
# fully-connected layer | |
layer = batch_norm(DenseLayer(layer, 1024)) | |
# project and reshape | |
layer = batch_norm(DenseLayer(layer, 128*7*7)) | |
layer = ReshapeLayer(layer, ([0], 128, 7, 7)) | |
# two fractional-stride convolutions | |
layer = batch_norm(Deconv2DLayer(layer, 64, 5, stride=2, crop='same', | |
output_size=14)) | |
layer = Deconv2DLayer(layer, 1, 5, stride=2, crop='same', output_size=28, | |
nonlinearity=sigmoid) | |
print ("Generator output:", layer.output_shape) | |
return layer | |
def build_critic(input_var=None): | |
from lasagne.layers import (InputLayer, Conv2DLayer, ReshapeLayer, | |
DenseLayer) | |
from lasagne.nonlinearities import LeakyRectify | |
lrelu = LeakyRectify(0.2) | |
# input: (None, 1, 28, 28) | |
layer = InputLayer(shape=(None, 1, 28, 28), input_var=input_var) | |
# two convolutions | |
layer = Conv2DLayer(layer, 64, 5, stride=2, pad='same', | |
nonlinearity=lrelu) | |
layer = Conv2DLayer(layer, 128, 5, stride=2, pad='same', | |
nonlinearity=lrelu) | |
# fully-connected layer | |
layer = DenseLayer(layer, 1024, nonlinearity=lrelu) | |
# output layer (linear and without bias) | |
layer = DenseLayer(layer, 1, nonlinearity=None, b=None) | |
print ("critic output:", layer.output_shape) | |
return layer | |
# ############################# Batch iterator ############################### | |
# This is just a simple helper function iterating over training data in | |
# mini-batches of a particular size, optionally in random order. It assumes | |
# data is available as numpy arrays. For big datasets, you could load numpy | |
# arrays as memory-mapped files (np.load(..., mmap_mode='r')), or write your | |
# own custom data iteration function. For small datasets, you can also copy | |
# them to GPU at once for slightly improved performance. This would involve | |
# several changes in the main program, though, and is not demonstrated here. | |
def iterate_minibatches(inputs, targets, batchsize, shuffle=False, | |
forever=False): | |
assert len(inputs) == len(targets) | |
if shuffle: | |
indices = np.arange(len(inputs)) | |
while True: | |
if shuffle: | |
np.random.shuffle(indices) | |
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): | |
if shuffle: | |
excerpt = indices[start_idx:start_idx + batchsize] | |
else: | |
excerpt = slice(start_idx, start_idx + batchsize) | |
yield inputs[excerpt], targets[excerpt] | |
if not forever: | |
break | |
# ############################## Main program ################################ | |
# Everything else will be handled in our main program now. We could pull out | |
# more functions to better separate the code, but it wouldn't make it any | |
# easier to read. | |
def main(num_epochs=1000, epochsize=100, batchsize=64, initial_eta=5e-5, alpha=10): | |
# Load the dataset | |
print("Loading data...") | |
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset() | |
# Prepare Theano variables for inputs and targets | |
noise_var = T.matrix('noise') | |
input_var = T.tensor4('inputs') | |
# Prepare theano random stream | |
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams | |
srng = RandomStreams(seed=np.random.randint(2147462579, size=6)) | |
# Create neural network model | |
print("Building model and compiling functions...") | |
generator = build_generator(noise_var) | |
critic = build_critic(input_var) | |
# Create expression for passing real data through the critic | |
real_out = lasagne.layers.get_output(critic) | |
# Create expression for passing fake data through the critic | |
gen_out = lasagne.layers.get_output(generator) | |
fake_out = lasagne.layers.get_output(critic, gen_out) | |
# Create score expressions to be maximized (i.e., negative losses) | |
generator_score = fake_out.mean() | |
critic_score = real_out.mean() - fake_out.mean() | |
# Penalize gradient norm of critic as a soft constraint to enforce Lipschitz-1 criterium | |
epsilon = srng.uniform((batchsize, 1, 1, 1), low=0., high=1.) | |
interpolates = epsilon * input_var + (1 - epsilon) * gen_out | |
interpolates_out_summed = lasagne.layers.get_output(critic, interpolates).sum() | |
gradients = theano.grad(interpolates_out_summed, wrt=interpolates) | |
slopes = T.sqrt(T.sum(T.sqr(gradients), axis=(1, 2, 3))) | |
gradient_penalty = T.mean((slopes-1.)**2) | |
full_critic_score = critic_score - alpha * gradient_penalty | |
# Create update expressions for training | |
generator_params = lasagne.layers.get_all_params(generator, trainable=True) | |
critic_params = lasagne.layers.get_all_params(critic, trainable=True) | |
eta = theano.shared(lasagne.utils.floatX(initial_eta)) | |
generator_updates = lasagne.updates.rmsprop( | |
-generator_score, generator_params, learning_rate=eta) | |
critic_updates = lasagne.updates.rmsprop( | |
-full_critic_score, critic_params, learning_rate=eta) | |
# Instantiate a symbolic noise generator to use for training | |
noise = srng.uniform((batchsize, 100)) | |
# Compile functions performing a training step on a mini-batch (according | |
# to the updates dictionary) and returning the corresponding score: | |
generator_train_fn = theano.function([], generator_score, | |
givens={noise_var: noise}, | |
updates=generator_updates) | |
critic_train_fn = theano.function([input_var], [critic_score, critic_score - full_critic_score], | |
givens={noise_var: noise}, | |
updates=critic_updates) | |
# Compile another function generating some data | |
gen_fn = theano.function([noise_var], | |
lasagne.layers.get_output(generator, | |
deterministic=True)) | |
# Finally, launch the training loop. | |
print("Starting training...") | |
# We create an infinite supply of batches (as an iterable generator): | |
batches = iterate_minibatches(X_train, y_train, batchsize, shuffle=True, | |
forever=True) | |
# We iterate over epochs: | |
generator_updates = 0 | |
for epoch in range(num_epochs): | |
start_time = time.time() | |
# In each epoch, we do `epochsize` generator updates. Usually, the | |
# critic is updated 5 times before every generator update. For the | |
# first 25 generator updates and every 500 generator updates, the | |
# critic is updated 100 times instead, following the authors' code. | |
critic_scores = [] | |
gradient_penalty_scores = [] | |
generator_scores = [] | |
for _ in range(epochsize): | |
if (generator_updates < 25) or (generator_updates % 500 == 0): | |
critic_runs = 100 | |
else: | |
critic_runs = 5 | |
for _ in range(critic_runs): | |
batch = next(batches) | |
inputs, targets = batch | |
c_score, grad_penalty = critic_train_fn(inputs) | |
critic_scores.append(c_score) | |
gradient_penalty_scores.append(grad_penalty) | |
generator_scores.append(generator_train_fn()) | |
generator_updates += 1 | |
# Then we print the results for this epoch: | |
print("Epoch {} of {} took {:.3f}s".format( | |
epoch + 1, num_epochs, time.time() - start_time)) | |
print(" generator score:\t\t{}".format(np.mean(generator_scores))) | |
print(" Wasserstein distance:\t\t{}".format(np.mean(critic_scores))) | |
print(" gradient norm penalty:\t\t{}".format(np.mean(gradient_penalty_scores))) | |
# And finally, we plot some generated data | |
samples = gen_fn(lasagne.utils.floatX(np.random.rand(42, 100))) | |
try: | |
import matplotlib.pyplot as plt | |
except ImportError: | |
pass | |
else: | |
plt.imsave('improved_wgan_mnist_samples.png', | |
(samples.reshape(6, 7, 28, 28) | |
.transpose(0, 2, 1, 3) | |
.reshape(6*28, 7*28)), | |
cmap='gray') | |
# After half the epochs, we start decaying the learn rate towards zero | |
if epoch >= num_epochs // 2: | |
progress = float(epoch) / num_epochs | |
eta.set_value(lasagne.utils.floatX(initial_eta*2*(1 - progress))) | |
# Optionally, you could now dump the network weights to a file like this: | |
np.savez('wgan_mnist_gen.npz', *lasagne.layers.get_all_param_values(generator)) | |
np.savez('wgan_mnist_crit.npz', *lasagne.layers.get_all_param_values(critic)) | |
# | |
# And load them again later on like this: | |
# with np.load('model.npz') as f: | |
# param_values = [f['arr_%d' % i] for i in range(len(f.files))] | |
# lasagne.layers.set_all_param_values(network, param_values) | |
if __name__ == '__main__': | |
if ('--help' in sys.argv) or ('-h' in sys.argv): | |
print("Trains a WGAN on MNIST using Lasagne.") | |
print("Usage: %s [EPOCHS [EPOCHSIZE]]" % sys.argv[0]) | |
print() | |
print("EPOCHS: number of training epochs to perform (default: 1000)") | |
print("EPOCHSIZE: number of generator updates per epoch (default: 100)") | |
else: | |
kwargs = {} | |
if len(sys.argv) > 1: | |
kwargs['num_epochs'] = int(sys.argv[1]) | |
if len(sys.argv) > 2: | |
kwargs['epochsize'] = int(sys.argv[2]) | |
main(**kwargs) |
Thanks, good point! Changed that, should be fine now.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Cool, thanks for posting!
The original code uses a single set of generated samples for the adversarial loss (https://github.com/igul222/improved_wgan_training/blob/master/gan_mnist.py#L107) and the Lipschitz penalty (https://github.com/igul222/improved_wgan_training/blob/master/gan_mnist.py#L143), while you generate an additional set of samples (https://gist.github.com/SimonKohl/27e499681ec4f01c1d1d1909f3f4bd71#file-improved_wgan_mnist-py-L210). It should be faster to train if you reuse the data as well (i.e., do a single
lasagne.layers.get_output(generator)
call, assign it to a variable, and reuse it).