Skip to content

Instantly share code, notes, and snippets.

@yoel-zeldes
Last active April 1, 2019 20:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoel-zeldes/57310216750a81500cb873588e036420 to your computer and use it in GitHub Desktop.
Save yoel-zeldes/57310216750a81500cb873588e036420 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
np.random.seed(42)
tf.set_random_seed(42)
mnist = input_data.read_data_sets('MNIST_data')
INPUT_SIZE = 28 * 28
NUM_DIGITS = 10
params = {
'manager_layers': [128], # the manager will be implemented using a simple feed forward network
'encoder_layers': [128], # ... and so will be the encoder
'decoder_layers': [128], # ... and the decoder as well (CNN will be better, but let's keep it concise)
'activation': tf.nn.sigmoid, # the activation function used by all subnetworks
'decoder_std': 0.5, # the standard deviation of P(x|z) discussed in the first post of the series
'z_dim': 10, # the dimension of the latent space
'balancing_weight': 0.1, # how much the balancing term will contribute to the loss
'epochs': 100,
'batch_size': 100,
'learning_rate': 0.001
}
class VAE(object):
_ID = 0
def __init__(self, params, images):
self._id = VAE._ID
VAE._ID += 1
self._params = params
encoder_mu, encoder_var = self.encode(images)
eps = tf.random_normal(shape=[tf.shape(images)[0],
self._params['z_dim']],
mean=0.0,
stddev=1.0)
z = encoder_mu + tf.sqrt(encoder_var) * eps
self.decoded_images = self.decode(z)
self.loss = self._calculate_loss(images,
self.decoded_images,
encoder_mu,
encoder_var)
def encode(self, images):
with tf.variable_scope('encode_{}'.format(self._id), reuse=tf.AUTO_REUSE):
x = images
for layer in self._params['encoder_layers']:
x = tf.layers.dense(x,
layer,
activation=self._params['activation'])
mu = tf.layers.dense(x, self._params['z_dim'])
var = 1e-5 + tf.exp(tf.layers.dense(x, self._params['z_dim']))
return mu, var
def decode(self, z):
with tf.variable_scope('decode_{}'.format(self._id), reuse=tf.AUTO_REUSE):
for layer in self._params['decoder_layers']:
z = tf.layers.dense(z,
layer,
activation=self._params['activation'])
mu = tf.layers.dense(z, INPUT_SIZE)
return tf.nn.sigmoid(mu)
def _calculate_loss(self, images, decoded_images, encoder_mu, encoder_var):
loss_reconstruction = -tf.reduce_sum(
tf.contrib.distributions.Normal(
decoded_images,
self._params['decoder_std']
).log_prob(images),
axis=1
)
loss_prior = -0.5 * tf.reduce_sum(
1 + tf.log(encoder_var) - encoder_mu ** 2 - encoder_var,
axis=1
)
return loss_reconstruction + loss_prior
class Manager(object):
def __init__(self, params, experts, images):
self._params = params
self._experts = experts
probs = self.calc_probs(images)
self.expected_expert_loss, self.balancing_loss, self.loss = self._calculate_loss(probs)
def calc_probs(self, images):
with tf.variable_scope('prob', reuse=tf.AUTO_REUSE):
x = images
for layer in self._params['manager_layers']:
x = tf.layers.dense(x,
layer,
activation=self._params['activation'])
logits = tf.layers.dense(x, len(self._experts))
probs = tf.nn.softmax(logits)
return probs
def _calculate_loss(self, probs):
losses = tf.concat([tf.reshape(expert.loss, [-1, 1])
for expert in self._experts], axis=1)
expected_expert_loss = tf.reduce_mean(tf.reduce_sum(losses * probs, axis=1), axis=0)
experts_importance = tf.reduce_sum(probs, axis=0)
_, experts_importance_var = tf.nn.moments(experts_importance, axes=[0])
balancing_loss = experts_importance_var
loss = expected_expert_loss + self._params['balancing_weight'] * balancing_loss
return expected_expert_loss, balancing_loss, loss
images = tf.placeholder(tf.float32, [None, INPUT_SIZE])
experts = [VAE(params, images) for _ in range(NUM_DIGITS)]
manager = Manager(params, experts, images)
train_op = tf.train.AdamOptimizer(params['learning_rate']).minimize(manager.loss)
samples = []
expected_expert_losses = []
balancing_losses = []
losses = []
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(params['epochs']):
# train over the batches
for _ in range(mnist.train.num_examples / params['batch_size']):
batch_images, batch_digits = mnist.train.next_batch(params['batch_size'])
sess.run(train_op, feed_dict={images: batch_images})
# keep track of the loss
expected_expert_loss, balancing_loss, loss = sess.run(
[manager.expected_expert_loss, manager.balancing_loss, manager.loss],
{images: mnist.train.images}
)
expected_expert_losses.append(expected_expert_loss)
balancing_losses.append(balancing_loss)
losses.append(loss)
# generate random samples so we can have a look later on
sample_z = np.random.randn(1, params['z_dim'])
gen_samples = sess.run([expert.decode(tf.constant(sample_z, dtype='float32'))
for expert in experts])
samples.append(gen_samples)
plt.subplot(131)
plt.plot(expected_expert_losses)
plt.title('expected expert loss', y=1.07)
plt.subplot(132)
plt.plot(balancing_losses)
plt.title('balancing loss', y=1.07)
plt.subplot(133)
plt.plot(losses)
plt.title('total loss', y=1.07)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment