Last active
April 1, 2019 20:50
-
-
Save yoel-zeldes/57310216750a81500cb873588e036420 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
import matplotlib.pyplot as plt | |
np.random.seed(42) | |
tf.set_random_seed(42) | |
mnist = input_data.read_data_sets('MNIST_data') | |
INPUT_SIZE = 28 * 28 | |
NUM_DIGITS = 10 | |
params = { | |
'manager_layers': [128], # the manager will be implemented using a simple feed forward network | |
'encoder_layers': [128], # ... and so will be the encoder | |
'decoder_layers': [128], # ... and the decoder as well (CNN will be better, but let's keep it concise) | |
'activation': tf.nn.sigmoid, # the activation function used by all subnetworks | |
'decoder_std': 0.5, # the standard deviation of P(x|z) discussed in the first post of the series | |
'z_dim': 10, # the dimension of the latent space | |
'balancing_weight': 0.1, # how much the balancing term will contribute to the loss | |
'epochs': 100, | |
'batch_size': 100, | |
'learning_rate': 0.001 | |
} | |
class VAE(object): | |
_ID = 0 | |
def __init__(self, params, images): | |
self._id = VAE._ID | |
VAE._ID += 1 | |
self._params = params | |
encoder_mu, encoder_var = self.encode(images) | |
eps = tf.random_normal(shape=[tf.shape(images)[0], | |
self._params['z_dim']], | |
mean=0.0, | |
stddev=1.0) | |
z = encoder_mu + tf.sqrt(encoder_var) * eps | |
self.decoded_images = self.decode(z) | |
self.loss = self._calculate_loss(images, | |
self.decoded_images, | |
encoder_mu, | |
encoder_var) | |
def encode(self, images): | |
with tf.variable_scope('encode_{}'.format(self._id), reuse=tf.AUTO_REUSE): | |
x = images | |
for layer in self._params['encoder_layers']: | |
x = tf.layers.dense(x, | |
layer, | |
activation=self._params['activation']) | |
mu = tf.layers.dense(x, self._params['z_dim']) | |
var = 1e-5 + tf.exp(tf.layers.dense(x, self._params['z_dim'])) | |
return mu, var | |
def decode(self, z): | |
with tf.variable_scope('decode_{}'.format(self._id), reuse=tf.AUTO_REUSE): | |
for layer in self._params['decoder_layers']: | |
z = tf.layers.dense(z, | |
layer, | |
activation=self._params['activation']) | |
mu = tf.layers.dense(z, INPUT_SIZE) | |
return tf.nn.sigmoid(mu) | |
def _calculate_loss(self, images, decoded_images, encoder_mu, encoder_var): | |
loss_reconstruction = -tf.reduce_sum( | |
tf.contrib.distributions.Normal( | |
decoded_images, | |
self._params['decoder_std'] | |
).log_prob(images), | |
axis=1 | |
) | |
loss_prior = -0.5 * tf.reduce_sum( | |
1 + tf.log(encoder_var) - encoder_mu ** 2 - encoder_var, | |
axis=1 | |
) | |
return loss_reconstruction + loss_prior | |
class Manager(object): | |
def __init__(self, params, experts, images): | |
self._params = params | |
self._experts = experts | |
probs = self.calc_probs(images) | |
self.expected_expert_loss, self.balancing_loss, self.loss = self._calculate_loss(probs) | |
def calc_probs(self, images): | |
with tf.variable_scope('prob', reuse=tf.AUTO_REUSE): | |
x = images | |
for layer in self._params['manager_layers']: | |
x = tf.layers.dense(x, | |
layer, | |
activation=self._params['activation']) | |
logits = tf.layers.dense(x, len(self._experts)) | |
probs = tf.nn.softmax(logits) | |
return probs | |
def _calculate_loss(self, probs): | |
losses = tf.concat([tf.reshape(expert.loss, [-1, 1]) | |
for expert in self._experts], axis=1) | |
expected_expert_loss = tf.reduce_mean(tf.reduce_sum(losses * probs, axis=1), axis=0) | |
experts_importance = tf.reduce_sum(probs, axis=0) | |
_, experts_importance_var = tf.nn.moments(experts_importance, axes=[0]) | |
balancing_loss = experts_importance_var | |
loss = expected_expert_loss + self._params['balancing_weight'] * balancing_loss | |
return expected_expert_loss, balancing_loss, loss | |
images = tf.placeholder(tf.float32, [None, INPUT_SIZE]) | |
experts = [VAE(params, images) for _ in range(NUM_DIGITS)] | |
manager = Manager(params, experts, images) | |
train_op = tf.train.AdamOptimizer(params['learning_rate']).minimize(manager.loss) | |
samples = [] | |
expected_expert_losses = [] | |
balancing_losses = [] | |
losses = [] | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
for epoch in range(params['epochs']): | |
# train over the batches | |
for _ in range(mnist.train.num_examples / params['batch_size']): | |
batch_images, batch_digits = mnist.train.next_batch(params['batch_size']) | |
sess.run(train_op, feed_dict={images: batch_images}) | |
# keep track of the loss | |
expected_expert_loss, balancing_loss, loss = sess.run( | |
[manager.expected_expert_loss, manager.balancing_loss, manager.loss], | |
{images: mnist.train.images} | |
) | |
expected_expert_losses.append(expected_expert_loss) | |
balancing_losses.append(balancing_loss) | |
losses.append(loss) | |
# generate random samples so we can have a look later on | |
sample_z = np.random.randn(1, params['z_dim']) | |
gen_samples = sess.run([expert.decode(tf.constant(sample_z, dtype='float32')) | |
for expert in experts]) | |
samples.append(gen_samples) | |
plt.subplot(131) | |
plt.plot(expected_expert_losses) | |
plt.title('expected expert loss', y=1.07) | |
plt.subplot(132) | |
plt.plot(balancing_losses) | |
plt.title('balancing loss', y=1.07) | |
plt.subplot(133) | |
plt.plot(losses) | |
plt.title('total loss', y=1.07) | |
plt.tight_layout() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment