Skip to content

Instantly share code, notes, and snippets.

@f0k
Created February 2, 2017 14:21
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save f0k/f3190ebba6c53887d598d03119ca2066 to your computer and use it in GitHub Desktop.
Save f0k/f3190ebba6c53887d598d03119ca2066 to your computer and use it in GitHub Desktop.
Lasagne WGAN example
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Example employing Lasagne for digit generation using the MNIST dataset and
Wasserstein Generative Adversarial Networks
(WGANs, see https://arxiv.org/abs/1701.07875 for the paper and
https://github.com/martinarjovsky/WassersteinGAN for the "official" code).
It is based on a DCGAN example:
https://gist.github.com/f0k/738fa2eedd9666b78404ed1751336f56
This, in turn, is based on the MNIST example in Lasagne:
https://lasagne.readthedocs.io/en/latest/user/tutorial.html
Jan Schlüter, 2017-02-02
"""
from __future__ import print_function
import sys
import os
import time
import numpy as np
import theano
import theano.tensor as T
import lasagne
# ################## Download and prepare the MNIST dataset ##################
# This is just some way of getting the MNIST dataset from an online location
# and loading it into numpy arrays. It doesn't involve Lasagne at all.
def load_dataset():
# We first define a download function, supporting both Python 2 and 3.
if sys.version_info[0] == 2:
from urllib import urlretrieve
else:
from urllib.request import urlretrieve
def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
print("Downloading %s" % filename)
urlretrieve(source + filename, filename)
# We then define functions for loading MNIST images and labels.
# For convenience, they also download the requested files if needed.
import gzip
def load_mnist_images(filename):
if not os.path.exists(filename):
download(filename)
# Read the inputs in Yann LeCun's binary format.
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
# The inputs are vectors now, we reshape them to monochrome 2D images,
# following the shape convention: (examples, channels, rows, columns)
data = data.reshape(-1, 1, 28, 28)
# The inputs come as bytes, we convert them to float32 in range [0,1].
# (Actually to range [0, 255/256], for compatibility to the version
# provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
return data / np.float32(256)
def load_mnist_labels(filename):
if not os.path.exists(filename):
download(filename)
# Read the labels in Yann LeCun's binary format.
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
# The labels are vectors of integers now, that's exactly what we want.
return data
# We can now download and read the training and test set images and labels.
X_train = load_mnist_images('train-images-idx3-ubyte.gz')
y_train = load_mnist_labels('train-labels-idx1-ubyte.gz')
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz')
# We reserve the last 10000 training examples for validation.
X_train, X_val = X_train[:-10000], X_train[-10000:]
y_train, y_val = y_train[:-10000], y_train[-10000:]
# We just return all the arrays in order, as expected in main().
# (It doesn't matter how we do this as long as we can read them again.)
return X_train, y_train, X_val, y_val, X_test, y_test
# ##################### Build the neural network model #######################
# We create two models: The generator and the critic network.
# The models are the same as in the Lasagne DCGAN example, except that the
# discriminator is now a critic with linear output instead of sigmoid output.
def build_generator(input_var=None):
from lasagne.layers import InputLayer, ReshapeLayer, DenseLayer
try:
from lasagne.layers import TransposedConv2DLayer as Deconv2DLayer
except ImportError:
raise ImportError("Your Lasagne is too old. Try the bleeding-edge "
"version: http://lasagne.readthedocs.io/en/latest/"
"user/installation.html#bleeding-edge-version")
try:
from lasagne.layers.dnn import batch_norm_dnn as batch_norm
except ImportError:
from lasagne.layers import batch_norm
from lasagne.nonlinearities import sigmoid
# input: 100dim
layer = InputLayer(shape=(None, 100), input_var=input_var)
# fully-connected layer
layer = batch_norm(DenseLayer(layer, 1024))
# project and reshape
layer = batch_norm(DenseLayer(layer, 128*7*7))
layer = ReshapeLayer(layer, ([0], 128, 7, 7))
# two fractional-stride convolutions
layer = batch_norm(Deconv2DLayer(layer, 64, 5, stride=2, crop='same',
output_size=14))
layer = Deconv2DLayer(layer, 1, 5, stride=2, crop='same', output_size=28,
nonlinearity=sigmoid)
print ("Generator output:", layer.output_shape)
return layer
def build_critic(input_var=None):
from lasagne.layers import (InputLayer, Conv2DLayer, ReshapeLayer,
DenseLayer)
try:
from lasagne.layers.dnn import batch_norm_dnn as batch_norm
except ImportError:
from lasagne.layers import batch_norm
from lasagne.nonlinearities import LeakyRectify
lrelu = LeakyRectify(0.2)
# input: (None, 1, 28, 28)
layer = InputLayer(shape=(None, 1, 28, 28), input_var=input_var)
# two convolutions
layer = batch_norm(Conv2DLayer(layer, 64, 5, stride=2, pad='same',
nonlinearity=lrelu))
layer = batch_norm(Conv2DLayer(layer, 128, 5, stride=2, pad='same',
nonlinearity=lrelu))
# fully-connected layer
layer = batch_norm(DenseLayer(layer, 1024, nonlinearity=lrelu))
# output layer (linear and without bias)
layer = DenseLayer(layer, 1, nonlinearity=None, b=None)
print ("critic output:", layer.output_shape)
return layer
# ############################# Batch iterator ###############################
# This is just a simple helper function iterating over training data in
# mini-batches of a particular size, optionally in random order. It assumes
# data is available as numpy arrays. For big datasets, you could load numpy
# arrays as memory-mapped files (np.load(..., mmap_mode='r')), or write your
# own custom data iteration function. For small datasets, you can also copy
# them to GPU at once for slightly improved performance. This would involve
# several changes in the main program, though, and is not demonstrated here.
def iterate_minibatches(inputs, targets, batchsize, shuffle=False,
forever=False):
assert len(inputs) == len(targets)
if shuffle:
indices = np.arange(len(inputs))
while True:
if shuffle:
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt], targets[excerpt]
if not forever:
break
# ############################## Main program ################################
# Everything else will be handled in our main program now. We could pull out
# more functions to better separate the code, but it wouldn't make it any
# easier to read.
def main(num_epochs=1000, epochsize=100, batchsize=64, initial_eta=5e-5,
clip=0.01):
# Load the dataset
print("Loading data...")
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset()
# Prepare Theano variables for inputs and targets
noise_var = T.matrix('noise')
input_var = T.tensor4('inputs')
# Create neural network model
print("Building model and compiling functions...")
generator = build_generator(noise_var)
critic = build_critic(input_var)
# Create expression for passing real data through the critic
real_out = lasagne.layers.get_output(critic)
# Create expression for passing fake data through the critic
fake_out = lasagne.layers.get_output(critic,
lasagne.layers.get_output(generator))
# Create score expressions to be maximized (i.e., negative losses)
generator_score = fake_out.mean()
critic_score = real_out.mean() - fake_out.mean()
# Create update expressions for training
generator_params = lasagne.layers.get_all_params(generator, trainable=True)
critic_params = lasagne.layers.get_all_params(critic, trainable=True)
eta = theano.shared(lasagne.utils.floatX(initial_eta))
generator_updates = lasagne.updates.rmsprop(
-generator_score, generator_params, learning_rate=eta)
critic_updates = lasagne.updates.rmsprop(
-critic_score, critic_params, learning_rate=eta)
# Clip critic parameters in a limited range around zero (except biases)
for param in lasagne.layers.get_all_params(critic, trainable=True,
regularizable=True):
critic_updates[param] = T.clip(critic_updates[param], -clip, clip)
# Instantiate a symbolic noise generator to use for training
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
srng = RandomStreams(seed=np.random.randint(2147462579, size=6))
noise = srng.uniform((batchsize, 100))
# Compile functions performing a training step on a mini-batch (according
# to the updates dictionary) and returning the corresponding score:
generator_train_fn = theano.function([], generator_score,
givens={noise_var: noise},
updates=generator_updates)
critic_train_fn = theano.function([input_var], critic_score,
givens={noise_var: noise},
updates=critic_updates)
# Compile another function generating some data
gen_fn = theano.function([noise_var],
lasagne.layers.get_output(generator,
deterministic=True))
# Finally, launch the training loop.
print("Starting training...")
# We create an infinite supply of batches (as an iterable generator):
batches = iterate_minibatches(X_train, y_train, batchsize, shuffle=True,
forever=True)
# We iterate over epochs:
generator_updates = 0
for epoch in range(num_epochs):
start_time = time.time()
# In each epoch, we do `epochsize` generator updates. Usually, the
# critic is updated 5 times before every generator update. For the
# first 25 generator updates and every 500 generator updates, the
# critic is updated 100 times instead, following the authors' code.
critic_scores = []
generator_scores = []
for _ in range(epochsize):
if (generator_updates < 25) or (generator_updates % 500 == 0):
critic_runs = 100
else:
critic_runs = 5
for _ in range(critic_runs):
batch = next(batches)
inputs, targets = batch
critic_scores.append(critic_train_fn(inputs))
generator_scores.append(generator_train_fn())
generator_updates += 1
# Then we print the results for this epoch:
print("Epoch {} of {} took {:.3f}s".format(
epoch + 1, num_epochs, time.time() - start_time))
print(" generator score:\t\t{}".format(np.mean(generator_scores)))
print(" Wasserstein distance:\t\t{}".format(np.mean(critic_scores)))
# And finally, we plot some generated data
samples = gen_fn(lasagne.utils.floatX(np.random.rand(42, 100)))
try:
import matplotlib.pyplot as plt
except ImportError:
pass
else:
plt.imsave('wgan_mnist_samples.png',
(samples.reshape(6, 7, 28, 28)
.transpose(0, 2, 1, 3)
.reshape(6*28, 7*28)),
cmap='gray')
# After half the epochs, we start decaying the learn rate towards zero
if epoch >= num_epochs // 2:
progress = float(epoch) / num_epochs
eta.set_value(lasagne.utils.floatX(initial_eta*2*(1 - progress)))
# Optionally, you could now dump the network weights to a file like this:
np.savez('wgan_mnist_gen.npz', *lasagne.layers.get_all_param_values(generator))
np.savez('wgan_mnist_crit.npz', *lasagne.layers.get_all_param_values(critic))
#
# And load them again later on like this:
# with np.load('model.npz') as f:
# param_values = [f['arr_%d' % i] for i in range(len(f.files))]
# lasagne.layers.set_all_param_values(network, param_values)
if __name__ == '__main__':
if ('--help' in sys.argv) or ('-h' in sys.argv):
print("Trains a WGAN on MNIST using Lasagne.")
print("Usage: %s [EPOCHS [EPOCHSIZE]]" % sys.argv[0])
print()
print("EPOCHS: number of training epochs to perform (default: 1000)")
print("EPOCHSIZE: number of generator updates per epoch (default: 100)")
else:
kwargs = {}
if len(sys.argv) > 1:
kwargs['num_epochs'] = int(sys.argv[1])
if len(sys.argv) > 2:
kwargs['epochsize'] = int(sys.argv[2])
main(**kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment