Skip to content

Instantly share code, notes, and snippets.

@tencia
Last active February 25, 2020 13:09
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save tencia/8eccd0cc150a9e0bdef9 to your computer and use it in GitHub Desktop.
Save tencia/8eccd0cc150a9e0bdef9 to your computer and use it in GitHub Desktop.
Variational Auto-Encoder using Lasagne
import sys
import os
import numpy as np
import theano
import theano.tensor as T
import lasagne as nn
import time
from PIL import Image
from scipy.stats import norm
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
# ############################################################################
# Tencia Lee
# Some code borrowed from:
# https://github.com/Lasagne/Lasagne/blob/master/examples/mnist.py
#
# Implementation of variational autoencoder (AEVB) algorithm as in:
# [1] arXiv:1312.6114 [stat.ML] (Diederik P Kingma, Max Welling 2013)
# ################## Download and prepare the MNIST dataset ##################
# For the linked MNIST data, the autoencoder learns well only in binary mode.
# This is most likely due to the distribution of the values. Most pixels are
# either very close to 0, or very close to 1.
#
# Running this code with default settings should produce a manifold similar
# to the example in this directory. An animation of the manifold's evolution
# can be found here: https://youtu.be/pgmnCU_DxzM
def load_dataset():
if sys.version_info[0] == 2:
from urllib import urlretrieve
else:
from urllib.request import urlretrieve
def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
print("Downloading %s" % filename)
urlretrieve(source + filename, filename)
import gzip
def load_mnist_images(filename):
if not os.path.exists(filename):
download(filename)
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, 1, 28, 28).transpose(0,1,3,2)
return data / np.float32(255)
X_train = load_mnist_images('train-images-idx3-ubyte.gz')
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
X_train, X_val = X_train[:-10000], X_train[-10000:]
return X_train, X_val, X_test
# ############################# Output images ################################
# image processing using PIL
def get_image_array(X, index, shp=(28,28), channels=1):
ret = (X[index] * 255.).reshape(channels,shp[0],shp[1]) \
.transpose(2,1,0).clip(0,255).astype(np.uint8)
if channels == 1:
ret = ret.reshape(shp[1], shp[0])
return ret
def get_image_pair(X, Xpr, channels=1, idx=-1):
mode = 'RGB' if channels == 3 else 'L'
shp=X[0][0].shape
i = np.random.randint(X.shape[0]) if idx == -1 else idx
orig = Image.fromarray(get_image_array(X, i, shp, channels), mode=mode)
ret = Image.new(mode, (orig.size[0], orig.size[1]*2))
ret.paste(orig, (0,0))
new = Image.fromarray(get_image_array(Xpr, i, shp, channels), mode=mode)
ret.paste(new, (0, orig.size[1]))
return ret
# ############################# Batch iterator ###############################
def iterate_minibatches(inputs, batchsize, shuffle=False):
if shuffle:
indices = np.arange(len(inputs))
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt]
# ##################### Custom layer for middle of VCAE ######################
# This layer takes the mu and sigma (both DenseLayers) and combines them with
# a random vector epsilon to sample values for a multivariate Gaussian
class GaussianSampleLayer(nn.layers.MergeLayer):
def __init__(self, mu, logsigma, rng=None, **kwargs):
self.rng = rng if rng else RandomStreams(nn.random.get_rng().randint(1,2147462579))
super(GaussianSampleLayer, self).__init__([mu, logsigma], **kwargs)
def get_output_shape_for(self, input_shapes):
return input_shapes[0]
def get_output_for(self, inputs, withNoise=True, **kwargs):
mu, logsigma = inputs
shape=(self.input_shapes[0][0] or inputs[0].shape[0],
self.input_shapes[0][1] or inputs[0].shape[1])
if withNoise:
return mu + T.exp(logsigma) * self.rng.normal(shape)
return mu
# ############################## Build Model #################################
# encoder has 1 hidden layer, where we get mu and sigma for Z given an inp X
# continuous decoder has 1 hidden layer, where we get mu and sigma for X given code Z
# binary decoder has 1 hidden layer, where we calculate p(X=1)
# once we have (mu, sigma) for Z, we sample L times
# Then L separate outputs are constructed and the final layer averages them
def build_vae(inputvar, L=2, binary=True, imgshape=(28,28), channels=1, z_dim=2, n_hid=1024):
x_dim = imgshape[0] * imgshape[1] * channels
l_input = nn.layers.InputLayer(shape=(None,channels,imgshape[0], imgshape[1]),
input_var=inputvar, name='input')
l_enc_hid = nn.layers.DenseLayer(l_input, num_units=n_hid,
nonlinearity=nn.nonlinearities.tanh if binary else T.nnet.softplus,
name='enc_hid')
l_enc_mu = nn.layers.DenseLayer(l_enc_hid, num_units=z_dim,
nonlinearity = None, name='enc_mu')
l_enc_logsigma = nn.layers.DenseLayer(l_enc_hid, num_units=z_dim,
nonlinearity = None, name='enc_logsigma')
l_dec_mu_list = []
l_dec_logsigma_list = []
l_output_list = []
# tie the weights of all L versions so they are the "same" layer
W_dec_hid = None
b_dec_hid = None
W_dec_mu = None
b_dec_mu = None
W_dec_ls = None
b_dec_ls = None
for i in xrange(L):
l_Z = GaussianSampleLayer(l_enc_mu, l_enc_logsigma, name='Z')
l_dec_hid = nn.layers.DenseLayer(l_Z, num_units=n_hid,
nonlinearity = nn.nonlinearities.tanh if binary else T.nnet.softplus,
W=nn.init.GlorotUniform() if W_dec_hid is None else W_dec_hid,
b=nn.init.Constant(0.) if b_dec_hid is None else b_dec_hid,
name='dec_hid')
if binary:
l_output = nn.layers.DenseLayer(l_dec_hid, num_units = x_dim,
nonlinearity = nn.nonlinearities.sigmoid,
W = nn.init.GlorotUniform() if W_dec_mu is None else W_dec_mu,
b = nn.init.Constant(0.) if b_dec_mu is None else b_dec_mu,
name = 'dec_output')
l_output_list.append(l_output)
if W_dec_hid is None:
W_dec_hid = l_dec_hid.W
b_dec_hid = l_dec_hid.b
W_dec_mu = l_output.W
b_dec_mu = l_output.b
else:
l_dec_mu = nn.layers.DenseLayer(l_dec_hid, num_units=x_dim,
nonlinearity = None,
W = nn.init.GlorotUniform() if W_dec_mu is None else W_dec_mu,
b = nn.init.Constant(0) if b_dec_mu is None else b_dec_mu,
name = 'dec_mu')
# relu_shift is for numerical stability - if training data has any
# dimensions where stdev=0, allowing logsigma to approach -inf
# will cause the loss function to become NAN. So we set the limit
# stdev >= exp(-1 * relu_shift)
relu_shift = 10
l_dec_logsigma = nn.layers.DenseLayer(l_dec_hid, num_units=x_dim,
W = nn.init.GlorotUniform() if W_dec_ls is None else W_dec_ls,
b = nn.init.Constant(0) if b_dec_ls is None else b_dec_ls,
nonlinearity = lambda a: T.nnet.relu(a+relu_shift)-relu_shift,
name='dec_logsigma')
l_output = GaussianSampleLayer(l_dec_mu, l_dec_logsigma,
name='dec_output')
l_dec_mu_list.append(l_dec_mu)
l_dec_logsigma_list.append(l_dec_logsigma)
l_output_list.append(l_output)
if W_dec_hid is None:
W_dec_hid = l_dec_hid.W
b_dec_hid = l_dec_hid.b
W_dec_mu = l_dec_mu.W
b_dec_mu = l_dec_mu.b
W_dec_ls = l_dec_logsigma.W
b_dec_ls = l_dec_logsigma.b
l_output = nn.layers.ElemwiseSumLayer(l_output_list, coeffs=1./L, name='output')
return l_enc_mu, l_enc_logsigma, l_dec_mu_list, l_dec_logsigma_list, l_output_list, l_output
# ############################## Main program ################################
def log_likelihood(tgt, mu, ls):
return T.sum(-(np.float32(0.5 * np.log(2 * np.pi)) + ls)
- 0.5 * T.sqr(tgt - mu) / T.exp(2 * ls))
def main(L=2, z_dim=2, n_hid=1024, num_epochs=300, binary=True):
print("Loading data...")
X_train, X_val, X_test = load_dataset()
width, height = X_train.shape[2], X_train.shape[3]
input_var = T.tensor4('inputs')
# Create VAE model
print("Building model and compiling functions...")
print("L = {}, z_dim = {}, n_hid = {}, binary={}".format(L, z_dim, n_hid, binary))
x_dim = width * height
l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
build_vae(input_var, L=L, binary=binary, z_dim=z_dim, n_hid=n_hid)
def build_loss(deterministic, withNoise=True):
layer_outputs = nn.layers.get_output([l_z_mu, l_z_ls] + l_x_mu_list + l_x_ls_list
+ l_x_list, deterministic=deterministic)
z_mu = layer_outputs[0]
z_ls = layer_outputs[1]
x_mu = [] if binary else layer_outputs[2:2+L]
x_ls = [] if binary else layer_outputs[2+L:2+2*L]
x_list = layer_outputs[2:2+L] if binary else layer_outputs[2+2*L:2+3*L]
# Loss expression has two parts as specified in [1]
# kl_div = KL divergence between p_theta(z) and p(z|x)
# - divergence between prior distr and approx posterior of z given x
# - or how likely we are to see this z when accounting for Gaussian prior
# logpxz = log p(x|z)
# - log-likelihood of x given z
# - in binary case logpxz = cross-entropy
# - in continuous case, is log-likelihood of seeing the target x under the
# Gaussian distribution parameterized by dec_mu, sigma = exp(dec_logsigma)
kl_div = 0.5 * T.sum(1 + 2*z_ls - T.sqr(z_mu) - T.exp(2 * z_ls))
if binary:
logpxz = sum(nn.objectives.binary_crossentropy(x,
input_var.flatten(2)).sum() for x in x_list) * (-1./L)
prediction = nn.layers.get_output(l_x, deterministic=deterministic,
withNoise=withNoise)
else:
logpxz = sum(log_likelihood(input_var.flatten(2), mu, ls)
for mu, ls in zip(x_mu, x_ls))/L
prediction = T.sum(nn.layers.get_output(l_x_mu_list, deterministic=deterministic,
withNoise=withNoise), axis=0)/L
loss = -1 * (logpxz + kl_div)
return loss, prediction
# If there are dropout layers etc these functions return masked or non-masked expressions
# depending on if they will be used for training or validation/test err calcs
loss, _ = build_loss(deterministic=False)
test_loss, test_prediction = build_loss(deterministic=True, withNoise=False)
# functions for generating images given a code (used for visualization)
# for an given code z, we deterministically take x_mu as the generated data
# (no Gaussian noise is used to either encode or decode).
z_var = T.vector()
if binary:
generated_x = nn.layers.get_output(l_x, {l_z_mu:z_var}, withNoise=False,
deterministic=True)
else:
generated_x = nn.layers.get_output(l_x_mu_list[0], {l_z_mu:z_var}, withNoise=False,
deterministic=True)
gen_fn = theano.function([z_var], generated_x)
# ADAM updates
params = nn.layers.get_all_params(l_x, trainable=True)
updates = nn.updates.adam(loss, params, learning_rate=1e-4)
train_fn = theano.function([input_var], loss, updates=updates)
val_fn = theano.function([input_var], test_loss)
print("Starting training...")
batch_size = 100
for epoch in range(num_epochs):
train_err = 0
train_batches = 0
start_time = time.time()
for batch in iterate_minibatches(X_train, batch_size, shuffle=True):
this_err = train_fn(batch)
train_err += this_err
train_batches += 1
val_err = 0
val_batches = 0
for batch in iterate_minibatches(X_val, batch_size, shuffle=False):
err = val_fn(batch)
val_err += err
val_batches += 1
print("Epoch {} of {} took {:.3f}s".format(
epoch + 1, num_epochs, time.time() - start_time))
print(" training loss:\t\t{:.6f}".format(train_err / train_batches))
print(" validation loss:\t\t{:.6f}".format(val_err / val_batches))
test_err = 0
test_batches = 0
for batch in iterate_minibatches(X_test, batch_size, shuffle=False):
err = val_fn(batch)
test_err += err
test_batches += 1
test_err /= test_batches
print("Final results:")
print(" test loss:\t\t\t{:.6f}".format(test_err))
# save some example pictures so we can see what it's done
example_batch_size = 20
X_comp = X_test[:example_batch_size]
pred_fn = theano.function([input_var], test_prediction)
X_pred = pred_fn(X_comp).reshape(-1, 1, width, height)
for i in range(20):
get_image_pair(X_comp, X_pred, idx=i, channels=1).save('output_{}.jpg'.format(i))
# save the parameters so they can be loaded for next time
print("Saving")
fn = 'params_{:.6f}'.format(test_err)
np.savez(fn + '.npz', *nn.layers.get_all_param_values(l_x))
# sample from latent space if it's 2d
if z_dim == 2:
im = Image.new('L', (width*19,height*19))
for (x,y),val in np.ndenumerate(np.zeros((19,19))):
z = np.asarray([norm.ppf(0.05*(x+1)), norm.ppf(0.05*(y+1))],
dtype=theano.config.floatX)
x_gen = gen_fn(z).reshape(-1, 1, width, height)
im.paste(Image.fromarray(get_image_array(x_gen,0)), (x*width,y*height))
im.save('gen.jpg')
if __name__ == '__main__':
# Arguments - integers, except for binary/continous. Default uses binary.
# Run with option --continuous for continuous output.
import argparse
parser = argparse.ArgumentParser(description='Command line options')
parser.add_argument('--num_epochs', type=int, dest='num_epochs')
parser.add_argument('--L', type=int, dest='L')
parser.add_argument('--z_dim', type=int, dest='z_dim')
parser.add_argument('--n_hid', type=int, dest='n_hid')
parser.add_argument('--binary', dest='binary', action='store_true')
parser.add_argument('--continuous', dest='binary', action='store_false')
parser.set_defaults(binary=True)
args = parser.parse_args(sys.argv[1:])
main(**{k:v for (k,v) in vars(args).items() if v is not None})
@tencia
Copy link
Author

tencia commented Dec 19, 2015

Hi, sorry I never got a notification of your comment. I believe a full run through the script (300 epochs) is more than enough to get an image similar to the one in the video (it is a different embedding every time). Looking at my saved params, I think I got it down to a loss of around 13,958.89 (combined log-likelihood). Hope that helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment