Last active
August 29, 2015 14:20
-
-
Save eickenberg/ec7860e03498144fc176 to your computer and use it in GitHub Desktop.
Variational autoencoder (sketch)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Michael Eickenberg | |
# License: BSD 3-clause | |
# This is a rapidly written proof of concept. There may remain big bugs. I find it overfits quite quickly atm | |
import numpy as np | |
import theano | |
theano.config.floatX = 'float32' | |
import theano.tensor as T | |
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams | |
def gaussian_vae_gaussian_loss(decoder_input, encoder_input, | |
mu_decode, log_sigma_decode, | |
mu_encode, log_sigma_encode, | |
num_samples_per_example=1, random_seed=42): | |
"""Implements the regularized loss function for the variational | |
autoencoder. Assumes latent variables and output variables are | |
Gaussian and assumes that functions computing their parameters | |
are stored as theano expressions in mu_encode, mu_decode, | |
log_sigma_*. | |
Per data example, a number of samples can be drawn and these | |
will be indicated using a second axis, i.e. (batch, sample, feature)""" | |
n_batch = mu_decode.shape[0] | |
n_latent = mu_decode.shape[1] | |
n_output = decoder_input.shape[1] | |
log_sigma_decode_squared = 2 * log_sigma_decode | |
sigma_decode_squared = T.exp(log_sigma_decode_squared) | |
sigma_decode = T.exp(log_sigma_decode) | |
log_sigma_encode_squared = 2 * log_sigma_encode | |
sigma_encode_squared = T.exp(log_sigma_encode_squared) | |
# regularize posterior towards unit Gaussian prior: | |
negDKL = (log_sigma_decode_squared + 1 - | |
(sigma_decode_squared + mu_decode ** 2)).sum(axis=-1) / 2. | |
# reconstruction error. First write expression for samples, then take | |
# the mean over them | |
rng = RandomStreams(random_seed) | |
gaussian_samples = rng.normal( | |
(n_batch, num_samples_per_example, n_latent)) | |
posterior_samples = mu_decode.dimshuffle((0, 'x', 1)) + ( | |
sigma_decode.dimshuffle((0, 'x', 1)) * gaussian_samples) | |
posterior_samples_reshaped = posterior_samples.reshape( | |
(n_batch * num_samples_per_example, n_latent)) | |
output_shape = (n_batch, num_samples_per_example, n_output) | |
sigma_encode_squared_of_sample = theano.clone( | |
sigma_encode_squared, | |
replace={encoder_input: posterior_samples_reshaped} | |
).reshape(output_shape) | |
log_sigma_encode_squared_of_sample = theano.clone( | |
log_sigma_encode_squared, | |
replace={encoder_input: posterior_samples_reshaped} | |
).reshape(output_shape) | |
mu_encode_of_sample = theano.clone( | |
mu_encode, | |
replace={encoder_input: posterior_samples_reshaped} | |
).reshape(output_shape) | |
gaussian_reconstruction_loss = ((decoder_input.dimshuffle((0, 'x', 1)) - | |
mu_encode_of_sample) / | |
sigma_encode_squared_of_sample) ** 2 | |
log_blur_penalty = np.log(2 * np.pi) + log_sigma_encode_squared_of_sample | |
reconstruction_error_estimate = -.5 * (log_blur_penalty + | |
gaussian_reconstruction_loss | |
).mean(axis=1) | |
return negDKL + reconstruction_error_estimate.sum(axis=-1) | |
def gaussian_vae_binary_loss(decoder_input, encoder_input, | |
mu_decode, log_sigma_decode, | |
logit_encode, | |
num_samples_per_example=2, | |
random_state=42): | |
n_batch = mu_decode.shape[0] | |
n_latent = mu_decode.shape[1] | |
n_output = decoder_input.shape[1] | |
log_sigma_decode_squared = 2 * log_sigma_decode | |
sigma_decode_squared = T.exp(log_sigma_decode_squared) | |
sigma_decode = T.exp(log_sigma_decode) | |
# regularize posterior towards unit Gaussian prior: | |
negDKL = (log_sigma_decode_squared + 1 - | |
(sigma_decode_squared + mu_decode ** 2)).sum(axis=-1) / 2. | |
# reconstruction error. First write expression for samples, then take | |
# the mean over them | |
rng = RandomStreams(random_state) | |
gaussian_samples = rng.normal( | |
(n_batch, num_samples_per_example, n_latent)) | |
posterior_samples = mu_decode.dimshuffle((0, 'x', 1)) + ( | |
sigma_decode.dimshuffle((0, 'x', 1)) * gaussian_samples) | |
posterior_samples_reshaped = posterior_samples.reshape( | |
(n_batch * num_samples_per_example, n_latent)) | |
output_shape = (n_batch, num_samples_per_example, n_output) | |
logit_of_sample = theano.clone( | |
logit_encode, replace={encoder_input: posterior_samples_reshaped} | |
).reshape(output_shape) | |
symmetric_input = 2 * decoder_input - 1 | |
log_loss = -T.log(1 + T.exp( | |
-symmetric_input.dimshuffle((0, 'x', 1)) * logit_of_sample) | |
).mean(axis=1) | |
return negDKL + log_loss.sum(axis=-1) | |
class MLPerceptron(object): | |
"""Very simple class to flexibly create perceptrons""" | |
def __init__(self, | |
shape=(784, 1000, 100), | |
activation=('ReLU', None)): | |
self.shape = shape | |
self.activation = activation | |
def build(self, input_expression=None): | |
if input_expression is None: | |
input_expression = T.matrix(dtype=theano.config.floatX) | |
self.input_expression = input_expression | |
self.shared_variables = [] | |
self.weights = [] | |
self.biases = [] | |
network_expression = input_expression | |
for i, (s1, s2, a) in enumerate( | |
zip(self.shape[:-1], self.shape[1:], self.activation)): | |
W = theano.shared( | |
np.random.rand( | |
s1, s2).astype(theano.config.floatX) * .2 - .1, | |
name='W%d' % i) | |
self.weights.append(W) | |
self.shared_variables.append(W) | |
b = theano.shared( | |
np.zeros(s2).astype(theano.config.floatX), | |
name='b%d' % i) | |
self.biases.append(b) | |
self.shared_variables.append(b) | |
network_expression = network_expression.dot(W) + b | |
if a == 'ReLU': | |
network_expression = T.maximum(network_expression, 0) | |
elif a == 'sigmoid': | |
network_expression = T.nnet.sigmoid(network_expression) | |
elif a is None: | |
pass | |
else: | |
raise Exception('Activation function not understood') | |
self.expression = network_expression | |
return self | |
if __name__ == "__main__": | |
n_latent_variables = 5 | |
n_input = 784 | |
decoder_perceptron = MLPerceptron( | |
(n_input, 500, n_latent_variables * 2), | |
('ReLU', None)).build() | |
encoder_perceptron = MLPerceptron((n_latent_variables, 500, n_input), | |
('ReLU', None)).build() | |
decoder_mu = decoder_perceptron.expression[:, 0:n_latent_variables] | |
decoder_log_sigma = decoder_perceptron.expression[:, | |
n_latent_variables:2 * n_latent_variables] | |
lower_bound = gaussian_vae_binary_loss( | |
decoder_perceptron.input_expression, | |
encoder_perceptron.input_expression, | |
decoder_mu, | |
decoder_log_sigma, | |
encoder_perceptron.expression, | |
num_samples_per_example=1) | |
shared_variables = (decoder_perceptron.shared_variables + | |
encoder_perceptron.shared_variables) | |
lower_bound_grad = T.grad(-lower_bound.mean(), wrt=shared_variables) | |
f_lb = theano.function([decoder_perceptron.input_expression], | |
lower_bound) | |
f_lb_grad = theano.function([decoder_perceptron.input_expression], | |
lower_bound_grad) | |
from optimizers.optimizers import rmsprop # this is from Kyle Kastner's gist on optimizers. May need to add a __init__.py | |
optimizer = rmsprop(shared_variables) | |
updates = optimizer.updates(shared_variables, lower_bound_grad, | |
.001, .9) | |
mnist_pkl = "/home/me/data/mnist/mnist.pkl" | |
import pickle | |
((train_data, train_y), | |
(val_data, val_y), | |
(test_data, test_y)) = pickle.load(open(mnist_pkl)) | |
tr_d, tr_y, v_d, v_y, te_d, te_y = map(theano.shared, | |
[train_data, train_y, | |
val_data, val_y, | |
test_data, test_y]) | |
batch_size = 100 | |
batch_index = T.iscalar() | |
givens = [(decoder_perceptron.input_expression, | |
tr_d[batch_index * batch_size: | |
(batch_index + 1) * batch_size])] | |
train_function = theano.function([batch_index], lower_bound.mean(), | |
updates=updates, | |
givens=givens) | |
val_function = theano.function([], lower_bound.mean(), | |
givens=[ | |
(decoder_perceptron.input_expression, | |
v_d)]) | |
import sys | |
n_epochs = 1000 | |
iteration_train_values = [] | |
all_epoch_train_values = [] | |
epoch_train_values = [] | |
epoch_val_values = [] | |
import time | |
t0 = time.time() | |
epoch_times = [] | |
for e in range(n_epochs): | |
iteration_train_values = [] | |
all_epoch_train_values.append(iteration_train_values) | |
for i in range(len(train_data) // batch_size): | |
l = train_function(i) | |
# sys.stdout.write('%1.2f ' % l) | |
# sys.stdout.flush() | |
iteration_train_values.append(l) | |
v = val_function() | |
epoch_val_values.append(v) | |
epoch_train_values.append(np.mean(iteration_train_values)) | |
epoch_times.append(time.time()) | |
print "Done epoch %d" % e | |
print epoch_train_values[-1], v | |
generate = theano.function( | |
[encoder_perceptron.input_expression], | |
1. / (1 + T.exp(-encoder_perceptron.expression))) | |
autoencoder = theano.clone( | |
encoder_perceptron.expression, | |
replace={encoder_perceptron.input_expression: | |
decoder_mu}) | |
f_autoencoder = theano.function([decoder_perceptron.input_expression], | |
1. / (1 + T.exp(-autoencoder))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment