Skip to content

Instantly share code, notes, and snippets.

@f0k
Created March 7, 2017 13:22
Show Gist options
  • Save f0k/9b0bb51040719eeafec7eba473a9e79b to your computer and use it in GitHub Desktop.
Save f0k/9b0bb51040719eeafec7eba473a9e79b to your computer and use it in GitHub Desktop.
Lasagne LSGAN example
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Example employing Lasagne for digit generation using the MNIST dataset and
Least Squares Generative Adversarial Networks
(LSGANs, see https://arxiv.org/abs/1611.04076 for the paper).
It is based on a WGAN example:
https://gist.github.com/f0k/f3190ebba6c53887d598d03119ca2066
This, in turn, is based on a DCGAN example:
https://gist.github.com/f0k/738fa2eedd9666b78404ed1751336f56
This, in turn, is based on the MNIST example in Lasagne:
https://lasagne.readthedocs.io/en/latest/user/tutorial.html
Jan Schlüter, 2017-03-07
"""
from __future__ import print_function
import sys
import os
import time
import numpy as np
import theano
import theano.tensor as T
import lasagne
# ################## Download and prepare the MNIST dataset ##################
# This is just some way of getting the MNIST dataset from an online location
# and loading it into numpy arrays. It doesn't involve Lasagne at all.
def load_dataset():
# We first define a download function, supporting both Python 2 and 3.
if sys.version_info[0] == 2:
from urllib import urlretrieve
else:
from urllib.request import urlretrieve
def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
print("Downloading %s" % filename)
urlretrieve(source + filename, filename)
# We then define functions for loading MNIST images and labels.
# For convenience, they also download the requested files if needed.
import gzip
def load_mnist_images(filename):
if not os.path.exists(filename):
download(filename)
# Read the inputs in Yann LeCun's binary format.
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
# The inputs are vectors now, we reshape them to monochrome 2D images,
# following the shape convention: (examples, channels, rows, columns)
data = data.reshape(-1, 1, 28, 28)
# The inputs come as bytes, we convert them to float32 in range [0,1].
# (Actually to range [0, 255/256], for compatibility to the version
# provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
return data / np.float32(256)
def load_mnist_labels(filename):
if not os.path.exists(filename):
download(filename)
# Read the labels in Yann LeCun's binary format.
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
# The labels are vectors of integers now, that's exactly what we want.
return data
# We can now download and read the training and test set images and labels.
X_train = load_mnist_images('train-images-idx3-ubyte.gz')
y_train = load_mnist_labels('train-labels-idx1-ubyte.gz')
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz')
# We reserve the last 10000 training examples for validation.
X_train, X_val = X_train[:-10000], X_train[-10000:]
y_train, y_val = y_train[:-10000], y_train[-10000:]
# We just return all the arrays in order, as expected in main().
# (It doesn't matter how we do this as long as we can read them again.)
return X_train, y_train, X_val, y_val, X_test, y_test
# ##################### Build the neural network model #######################
# We create two models: The generator and the critic network.
# The models are the same as in the Lasagne DCGAN example, except that the
# discriminator is now a critic with linear output instead of sigmoid output.
def build_generator(input_var=None):
from lasagne.layers import InputLayer, ReshapeLayer, DenseLayer
try:
from lasagne.layers import TransposedConv2DLayer as Deconv2DLayer
except ImportError:
raise ImportError("Your Lasagne is too old. Try the bleeding-edge "
"version: http://lasagne.readthedocs.io/en/latest/"
"user/installation.html#bleeding-edge-version")
try:
from lasagne.layers.dnn import batch_norm_dnn as batch_norm
except ImportError:
from lasagne.layers import batch_norm
from lasagne.nonlinearities import sigmoid
# input: 100dim
layer = InputLayer(shape=(None, 100), input_var=input_var)
# fully-connected layer
layer = batch_norm(DenseLayer(layer, 1024))
# project and reshape
layer = batch_norm(DenseLayer(layer, 128*7*7))
layer = ReshapeLayer(layer, ([0], 128, 7, 7))
# two fractional-stride convolutions
layer = batch_norm(Deconv2DLayer(layer, 64, 5, stride=2, crop='same',
output_size=14))
layer = Deconv2DLayer(layer, 1, 5, stride=2, crop='same', output_size=28,
nonlinearity=sigmoid)
print ("Generator output:", layer.output_shape)
return layer
def build_critic(input_var=None):
from lasagne.layers import (InputLayer, Conv2DLayer, ReshapeLayer,
DenseLayer)
try:
from lasagne.layers.dnn import batch_norm_dnn as batch_norm
except ImportError:
from lasagne.layers import batch_norm
from lasagne.nonlinearities import LeakyRectify
lrelu = LeakyRectify(0.2)
# input: (None, 1, 28, 28)
layer = InputLayer(shape=(None, 1, 28, 28), input_var=input_var)
# two convolutions
layer = batch_norm(Conv2DLayer(layer, 64, 5, stride=2, pad='same',
nonlinearity=lrelu))
layer = batch_norm(Conv2DLayer(layer, 128, 5, stride=2, pad='same',
nonlinearity=lrelu))
# fully-connected layer
layer = batch_norm(DenseLayer(layer, 1024, nonlinearity=lrelu))
# output layer (linear)
layer = DenseLayer(layer, 1, nonlinearity=None)
print ("critic output:", layer.output_shape)
return layer
# ############################# Batch iterator ###############################
# This is just a simple helper function iterating over training data in
# mini-batches of a particular size, optionally in random order. It assumes
# data is available as numpy arrays. For big datasets, you could load numpy
# arrays as memory-mapped files (np.load(..., mmap_mode='r')), or write your
# own custom data iteration function. For small datasets, you can also copy
# them to GPU at once for slightly improved performance. This would involve
# several changes in the main program, though, and is not demonstrated here.
def iterate_minibatches(inputs, targets, batchsize, shuffle=False,
forever=False):
assert len(inputs) == len(targets)
if shuffle:
indices = np.arange(len(inputs))
while True:
if shuffle:
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt], targets[excerpt]
if not forever:
break
# ############################## Main program ################################
# Everything else will be handled in our main program now. We could pull out
# more functions to better separate the code, but it wouldn't make it any
# easier to read.
def main(num_epochs=1000, epochsize=100, batchsize=64, initial_eta=1e-4):
# Load the dataset
print("Loading data...")
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset()
# Prepare Theano variables for inputs and targets
noise_var = T.matrix('noise')
input_var = T.tensor4('inputs')
# Create neural network model
print("Building model and compiling functions...")
generator = build_generator(noise_var)
critic = build_critic(input_var)
# Create expression for passing real data through the critic
real_out = lasagne.layers.get_output(critic)
# Create expression for passing fake data through the critic
fake_out = lasagne.layers.get_output(critic,
lasagne.layers.get_output(generator))
# Create loss expressions to be minimized
# a, b, c = -1, 1, 0 # Equation (8) in the paper
a, b, c = 0, 1, 1 # Equation (9) in the paper
generator_loss = lasagne.objectives.squared_error(fake_out, c).mean()
critic_loss = (lasagne.objectives.squared_error(real_out, b).mean() +
lasagne.objectives.squared_error(fake_out, a).mean())
# Create update expressions for training
generator_params = lasagne.layers.get_all_params(generator, trainable=True)
critic_params = lasagne.layers.get_all_params(critic, trainable=True)
eta = theano.shared(lasagne.utils.floatX(initial_eta))
generator_updates = lasagne.updates.rmsprop(
generator_loss, generator_params, learning_rate=eta)
critic_updates = lasagne.updates.rmsprop(
critic_loss, critic_params, learning_rate=eta)
# Instantiate a symbolic noise generator to use for training
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
srng = RandomStreams(seed=np.random.randint(2147462579, size=6))
noise = srng.uniform((batchsize, 100))
# Compile functions performing a training step on a mini-batch (according
# to the updates dictionary) and returning the corresponding score:
generator_train_fn = theano.function([], generator_loss,
givens={noise_var: noise},
updates=generator_updates)
critic_train_fn = theano.function([input_var], critic_loss,
givens={noise_var: noise},
updates=critic_updates)
# Compile another function generating some data
gen_fn = theano.function([noise_var],
lasagne.layers.get_output(generator,
deterministic=True))
# Finally, launch the training loop.
print("Starting training...")
# We create an infinite supply of batches (as an iterable generator):
batches = iterate_minibatches(X_train, y_train, batchsize, shuffle=True,
forever=True)
# We iterate over epochs:
generator_updates = 0
for epoch in range(num_epochs):
start_time = time.time()
# In each epoch, we do `epochsize` generator and critic updates.
critic_losses = []
generator_losses = []
for _ in range(epochsize):
inputs, targets = next(batches)
critic_losses.append(critic_train_fn(inputs))
generator_losses.append(generator_train_fn())
# Then we print the results for this epoch:
print("Epoch {} of {} took {:.3f}s".format(
epoch + 1, num_epochs, time.time() - start_time))
print(" generator loss: {}".format(np.mean(generator_losses)))
print(" critic loss: {}".format(np.mean(critic_losses)))
# And finally, we plot some generated data
samples = gen_fn(lasagne.utils.floatX(np.random.rand(42, 100)))
try:
import matplotlib.pyplot as plt
except ImportError:
pass
else:
plt.imsave('lsgan_mnist_samples.png',
(samples.reshape(6, 7, 28, 28)
.transpose(0, 2, 1, 3)
.reshape(6*28, 7*28)),
cmap='gray')
# After half the epochs, we start decaying the learn rate towards zero
if epoch >= num_epochs // 2:
progress = float(epoch) / num_epochs
eta.set_value(lasagne.utils.floatX(initial_eta*2*(1 - progress)))
# Optionally, you could now dump the network weights to a file like this:
np.savez('lsgan_mnist_gen.npz', *lasagne.layers.get_all_param_values(generator))
np.savez('lsgan_mnist_crit.npz', *lasagne.layers.get_all_param_values(critic))
#
# And load them again later on like this:
# with np.load('model.npz') as f:
# param_values = [f['arr_%d' % i] for i in range(len(f.files))]
# lasagne.layers.set_all_param_values(network, param_values)
if __name__ == '__main__':
if ('--help' in sys.argv) or ('-h' in sys.argv):
print("Trains a LSGAN on MNIST using Lasagne.")
print("Usage: %s [EPOCHS [EPOCHSIZE]]" % sys.argv[0])
print()
print("EPOCHS: number of training epochs to perform (default: 1000)")
print("EPOCHSIZE: number of network updates per epoch (default: 100)")
else:
kwargs = {}
if len(sys.argv) > 1:
kwargs['num_epochs'] = int(sys.argv[1])
if len(sys.argv) > 2:
kwargs['epochsize'] = int(sys.argv[2])
main(**kwargs)
@hma02
Copy link

hma02 commented Mar 20, 2017

@f0k Thanks for the example. I tried adapting the code to train on CIFAR-10 dataset but didn't got any luck yet. Just wondering if you have the example on CIFAR-10 as well. Thanks.

Copy link

ghost commented Mar 22, 2017

Hey @hma02, I'm trying to do something similar. Please let me know if you get it working. What exactly is going wrong ?

@hma02
Copy link

hma02 commented Mar 24, 2017

@lucascaccia yes, I managed to generate something smooth finally. See the lsgan_cifar10.py code adapted from this gist.

@pclucas14
Copy link

Thanks for sharing @hma02, looking foward to try it out! Also, is there any specific reason as to why rmsprop is used (and not adam) ? Maybe it's a leftover from WGAN ?

@hma02
Copy link

hma02 commented Apr 7, 2017

@pclucas14 Just tested with the lasagne.updates.adam. Within 200 epochs of training, both rmsprop and adam work for the lsgan_cifar10.py code. I just used the default hyperparams. The adam training seems slightly noisier at the beginning. According to the LSGAN paper, both should work as well.

@f0k
Copy link
Author

f0k commented May 16, 2017

Also, is there any specific reason as to why rmsprop is used (and not adam) ?

From p.12 of the paper, rmsprop is more stable than adam: "First, for BN_G with Adam, there is a chance for LSGANs to generate relatively good quality images. We test 10 times, and 5 of them succeeds to generate relatively good quality images. [...] Third, [...] for BN_G with RMSProp, both LSGANs and regular GANs learn the data distribution successfully, [...]."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment