Skip to content

Instantly share code, notes, and snippets.

@eickenberg
Last active August 29, 2015 14:20
Show Gist options
  • Save eickenberg/ec7860e03498144fc176 to your computer and use it in GitHub Desktop.
Save eickenberg/ec7860e03498144fc176 to your computer and use it in GitHub Desktop.
Variational autoencoder (sketch)
# Author: Michael Eickenberg
# License: BSD 3-clause
# This is a rapidly written proof of concept. There may remain big bugs. I find it overfits quite quickly atm
import numpy as np
import theano
theano.config.floatX = 'float32'
import theano.tensor as T
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
def gaussian_vae_gaussian_loss(decoder_input, encoder_input,
mu_decode, log_sigma_decode,
mu_encode, log_sigma_encode,
num_samples_per_example=1, random_seed=42):
"""Implements the regularized loss function for the variational
autoencoder. Assumes latent variables and output variables are
Gaussian and assumes that functions computing their parameters
are stored as theano expressions in mu_encode, mu_decode,
log_sigma_*.
Per data example, a number of samples can be drawn and these
will be indicated using a second axis, i.e. (batch, sample, feature)"""
n_batch = mu_decode.shape[0]
n_latent = mu_decode.shape[1]
n_output = decoder_input.shape[1]
log_sigma_decode_squared = 2 * log_sigma_decode
sigma_decode_squared = T.exp(log_sigma_decode_squared)
sigma_decode = T.exp(log_sigma_decode)
log_sigma_encode_squared = 2 * log_sigma_encode
sigma_encode_squared = T.exp(log_sigma_encode_squared)
# regularize posterior towards unit Gaussian prior:
negDKL = (log_sigma_decode_squared + 1 -
(sigma_decode_squared + mu_decode ** 2)).sum(axis=-1) / 2.
# reconstruction error. First write expression for samples, then take
# the mean over them
rng = RandomStreams(random_seed)
gaussian_samples = rng.normal(
(n_batch, num_samples_per_example, n_latent))
posterior_samples = mu_decode.dimshuffle((0, 'x', 1)) + (
sigma_decode.dimshuffle((0, 'x', 1)) * gaussian_samples)
posterior_samples_reshaped = posterior_samples.reshape(
(n_batch * num_samples_per_example, n_latent))
output_shape = (n_batch, num_samples_per_example, n_output)
sigma_encode_squared_of_sample = theano.clone(
sigma_encode_squared,
replace={encoder_input: posterior_samples_reshaped}
).reshape(output_shape)
log_sigma_encode_squared_of_sample = theano.clone(
log_sigma_encode_squared,
replace={encoder_input: posterior_samples_reshaped}
).reshape(output_shape)
mu_encode_of_sample = theano.clone(
mu_encode,
replace={encoder_input: posterior_samples_reshaped}
).reshape(output_shape)
gaussian_reconstruction_loss = ((decoder_input.dimshuffle((0, 'x', 1)) -
mu_encode_of_sample) /
sigma_encode_squared_of_sample) ** 2
log_blur_penalty = np.log(2 * np.pi) + log_sigma_encode_squared_of_sample
reconstruction_error_estimate = -.5 * (log_blur_penalty +
gaussian_reconstruction_loss
).mean(axis=1)
return negDKL + reconstruction_error_estimate.sum(axis=-1)
def gaussian_vae_binary_loss(decoder_input, encoder_input,
mu_decode, log_sigma_decode,
logit_encode,
num_samples_per_example=2,
random_state=42):
n_batch = mu_decode.shape[0]
n_latent = mu_decode.shape[1]
n_output = decoder_input.shape[1]
log_sigma_decode_squared = 2 * log_sigma_decode
sigma_decode_squared = T.exp(log_sigma_decode_squared)
sigma_decode = T.exp(log_sigma_decode)
# regularize posterior towards unit Gaussian prior:
negDKL = (log_sigma_decode_squared + 1 -
(sigma_decode_squared + mu_decode ** 2)).sum(axis=-1) / 2.
# reconstruction error. First write expression for samples, then take
# the mean over them
rng = RandomStreams(random_state)
gaussian_samples = rng.normal(
(n_batch, num_samples_per_example, n_latent))
posterior_samples = mu_decode.dimshuffle((0, 'x', 1)) + (
sigma_decode.dimshuffle((0, 'x', 1)) * gaussian_samples)
posterior_samples_reshaped = posterior_samples.reshape(
(n_batch * num_samples_per_example, n_latent))
output_shape = (n_batch, num_samples_per_example, n_output)
logit_of_sample = theano.clone(
logit_encode, replace={encoder_input: posterior_samples_reshaped}
).reshape(output_shape)
symmetric_input = 2 * decoder_input - 1
log_loss = -T.log(1 + T.exp(
-symmetric_input.dimshuffle((0, 'x', 1)) * logit_of_sample)
).mean(axis=1)
return negDKL + log_loss.sum(axis=-1)
class MLPerceptron(object):
"""Very simple class to flexibly create perceptrons"""
def __init__(self,
shape=(784, 1000, 100),
activation=('ReLU', None)):
self.shape = shape
self.activation = activation
def build(self, input_expression=None):
if input_expression is None:
input_expression = T.matrix(dtype=theano.config.floatX)
self.input_expression = input_expression
self.shared_variables = []
self.weights = []
self.biases = []
network_expression = input_expression
for i, (s1, s2, a) in enumerate(
zip(self.shape[:-1], self.shape[1:], self.activation)):
W = theano.shared(
np.random.rand(
s1, s2).astype(theano.config.floatX) * .2 - .1,
name='W%d' % i)
self.weights.append(W)
self.shared_variables.append(W)
b = theano.shared(
np.zeros(s2).astype(theano.config.floatX),
name='b%d' % i)
self.biases.append(b)
self.shared_variables.append(b)
network_expression = network_expression.dot(W) + b
if a == 'ReLU':
network_expression = T.maximum(network_expression, 0)
elif a == 'sigmoid':
network_expression = T.nnet.sigmoid(network_expression)
elif a is None:
pass
else:
raise Exception('Activation function not understood')
self.expression = network_expression
return self
if __name__ == "__main__":
n_latent_variables = 5
n_input = 784
decoder_perceptron = MLPerceptron(
(n_input, 500, n_latent_variables * 2),
('ReLU', None)).build()
encoder_perceptron = MLPerceptron((n_latent_variables, 500, n_input),
('ReLU', None)).build()
decoder_mu = decoder_perceptron.expression[:, 0:n_latent_variables]
decoder_log_sigma = decoder_perceptron.expression[:,
n_latent_variables:2 * n_latent_variables]
lower_bound = gaussian_vae_binary_loss(
decoder_perceptron.input_expression,
encoder_perceptron.input_expression,
decoder_mu,
decoder_log_sigma,
encoder_perceptron.expression,
num_samples_per_example=1)
shared_variables = (decoder_perceptron.shared_variables +
encoder_perceptron.shared_variables)
lower_bound_grad = T.grad(-lower_bound.mean(), wrt=shared_variables)
f_lb = theano.function([decoder_perceptron.input_expression],
lower_bound)
f_lb_grad = theano.function([decoder_perceptron.input_expression],
lower_bound_grad)
from optimizers.optimizers import rmsprop # this is from Kyle Kastner's gist on optimizers. May need to add a __init__.py
optimizer = rmsprop(shared_variables)
updates = optimizer.updates(shared_variables, lower_bound_grad,
.001, .9)
mnist_pkl = "/home/me/data/mnist/mnist.pkl"
import pickle
((train_data, train_y),
(val_data, val_y),
(test_data, test_y)) = pickle.load(open(mnist_pkl))
tr_d, tr_y, v_d, v_y, te_d, te_y = map(theano.shared,
[train_data, train_y,
val_data, val_y,
test_data, test_y])
batch_size = 100
batch_index = T.iscalar()
givens = [(decoder_perceptron.input_expression,
tr_d[batch_index * batch_size:
(batch_index + 1) * batch_size])]
train_function = theano.function([batch_index], lower_bound.mean(),
updates=updates,
givens=givens)
val_function = theano.function([], lower_bound.mean(),
givens=[
(decoder_perceptron.input_expression,
v_d)])
import sys
n_epochs = 1000
iteration_train_values = []
all_epoch_train_values = []
epoch_train_values = []
epoch_val_values = []
import time
t0 = time.time()
epoch_times = []
for e in range(n_epochs):
iteration_train_values = []
all_epoch_train_values.append(iteration_train_values)
for i in range(len(train_data) // batch_size):
l = train_function(i)
# sys.stdout.write('%1.2f ' % l)
# sys.stdout.flush()
iteration_train_values.append(l)
v = val_function()
epoch_val_values.append(v)
epoch_train_values.append(np.mean(iteration_train_values))
epoch_times.append(time.time())
print "Done epoch %d" % e
print epoch_train_values[-1], v
generate = theano.function(
[encoder_perceptron.input_expression],
1. / (1 + T.exp(-encoder_perceptron.expression)))
autoencoder = theano.clone(
encoder_perceptron.expression,
replace={encoder_perceptron.input_expression:
decoder_mu})
f_autoencoder = theano.function([decoder_perceptron.input_expression],
1. / (1 + T.exp(-autoencoder)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment