Skip to content

Instantly share code, notes, and snippets.

@piojanu
Last active April 7, 2019 09:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save piojanu/7f847c63366f3eaba1d792139f02dec1 to your computer and use it in GitHub Desktop.
Save piojanu/7f847c63366f3eaba1d792139f02dec1 to your computer and use it in GitHub Desktop.
TensorFlow Probability MNIST VAE implementation using tf_utils (https://github.com/piojanu/tf_utils/blob/master/tf_utils/utils.py)
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tf_utils import AttrDict, lazy_property_with_scope
tfd = tfp.distributions
tfl = tf.layers
class Model(object):
def __init__(self, data, config):
self.data = data
self.data_shape = list(self.data.shape[1:])
self.config = config
self.prior
self.posterior
self.code
self.likelihood
self.samples
self.loss
self.optimise
@lazy_property_with_scope
def prior(self):
"""Standard normal distribution prior."""
return tfd.MultivariateNormalDiag(
loc=tf.zeros(self.config.code_size),
scale_diag=tf.ones(self.config.code_size))
@lazy_property_with_scope(scope_name="encoder")
def posterior(self):
"""a.k.a the encoder"""
x = tfl.Flatten()(self.data)
x = tfl.Dense(self.config.hidden_size, activation='relu')(x)
x = tfl.Dense(self.config.hidden_size, activation='relu')(x)
loc = tfl.Dense(self.config.code_size)(x)
scale = tfl.Dense(self.config.code_size, activation='softplus')(x)
return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
@lazy_property_with_scope
def code(self):
"""Sample code from the posterior."""
return self.posterior.sample()
@lazy_property_with_scope(scope_name="decoder", reuse=tf.AUTO_REUSE)
def likelihood(self):
"""a.k.a the decoder."""
return self._make_decoder(self.code)
@lazy_property_with_scope(scope_name="decoder", reuse=tf.AUTO_REUSE)
def samples(self):
"""Generate examples."""
return self._make_decoder(self.prior.sample(self.config.n_samples)).mean()
@lazy_property_with_scope
def loss(self):
"""Negative evidence lower bound reduced over the whole batch and every pixel."""
elbo = self.likelihood.log_prob(self.data) - tfd.kl_divergence(self.posterior, self.prior)
return -tf.reduce_mean(elbo)
@lazy_property_with_scope
def optimise(self):
"""ADAM optimiser for the loss (negative ELBO)."""
return tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.loss)
def _make_decoder(self, code):
x = tfl.Dense(self.config.hidden_size, activation='relu')(code)
x = tfl.Dense(self.config.hidden_size, activation='relu')(x)
logits = tfl.Dense(np.product(self.data_shape))(x)
logits = tf.reshape(logits, [-1] + self.data_shape)
return tfd.Independent(tfd.Bernoulli(logits), 2)
def plot_codes(ax, codes, labels):
ax.scatter(codes[:, 0], codes[:, 1], s=2, c=labels, alpha=0.1)
ax.set_aspect('equal')
ax.set_xlim(codes.min() - .1, codes.max() + .1)
ax.set_ylim(codes.min() - .1, codes.max() + .1)
ax.tick_params(
axis='both', which='both', left=False, bottom=False,
labelleft=False, labelbottom=False)
def plot_samples(ax, samples):
for index, sample in enumerate(samples):
ax[index].imshow(sample, cmap='gray')
ax[index].axis('off')
def create_datasets(train_set, test_set):
train_dataset = tf.data.Dataset.from_tensor_slices(
tf.convert_to_tensor(train_set, dtype=tf.float32)) \
.map(lambda x: x / 255) \
.shuffle(train_set.shape[0]) \
.batch(config.batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices(
tf.convert_to_tensor(test_set, dtype=tf.float32)) \
.map(lambda x: x / 255) \
.batch(test_set.shape[0])
iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
train_dataset.output_shapes)
next_batch = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)
return next_batch, train_init_op, test_init_op
def train(model, train_init_op, test_init_op, test_labels, config):
_, ax = plt.subplots(nrows=config.epochs, ncols=config.n_samples + 1, figsize=(10, 20))
with tf.train.MonitoredSession() as sess:
for epoch in range(config.epochs):
# Test
sess.run(test_init_op)
test_loss, test_codes, test_samples = sess.run([model.loss, model.code, model.samples])
# Plot
ax[epoch, 0].set_ylabel('Epoch {}'.format(epoch))
plot_codes(ax[epoch, 0], test_codes, test_labels)
plot_samples(ax[epoch, 1:], test_samples)
# Train
train_losses = []
sess.run(train_init_op)
while True:
try:
_, train_loss = sess.run([model.optimise, model.loss])
train_losses.append(train_loss)
except tf.errors.OutOfRangeError:
break
# Log
print('Epoch: {:2d}/{:2d}, train loss: {:.3f}, test loss: {:.3f}'.format(
epoch + 1, config.epochs, np.mean(train_losses), test_loss))
plt.savefig('vae-mnist.png', dpi=300, transparent=True, bbox_inches='tight')
if __name__ == "__main__":
config = AttrDict({
"batch_size": 100,
"epochs": 20,
"n_samples": 10,
"code_size": 2,
"hidden_size": 200,
"learning_rate": 0.001
})
(train_set, _), (test_set, test_labels) = tf.keras.datasets.mnist.load_data()
train_set, test_set, test_labels = train_set[:], test_set[:2000], test_labels[:2000] # DEBUG
next_batch, train_init_op, test_init_op = create_datasets(train_set, test_set)
model = Model(next_batch, config)
train(model, train_init_op, test_init_op, test_labels, config)
@piojanu
Copy link
Author

piojanu commented Apr 5, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment