Skip to content

Instantly share code, notes, and snippets.

@wermarter
Last active June 22, 2017 03:08
Show Gist options
  • Save wermarter/318756a2f4cda35ebb178a932e1f8c38 to your computer and use it in GitHub Desktop.
Save wermarter/318756a2f4cda35ebb178a932e1f8c38 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tflearn
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import tflearn.datasets.mnist as mnist
trainX, trainY, testX, testY = mnist.load_data(one_hot=True)
TENSORBOARD_DIR='./tmp/tflearn/vae'
class VAE(object):
def __init__(self,
learning_rate = 0.001,
batch_size = 256,
input_dim = 784,
latent_dim = 2,
binary_data = True
):
self.learning_rate = learning_rate
self.batch_size = batch_size
self.input_dim = input_dim
self.latent_dim = latent_dim
self.binary_data = binary_data
self._build_graph()
def _encode(self, input_data, is_training):
fc1 = tflearn.fully_connected(input_data, self.input_dim//2, activation='elu', scope='enc_fc1', reuse=not is_training)
fc2 = tflearn.fully_connected(fc1, self.input_dim//4, activation='elu', scope='enc_fc2', reuse=not is_training)
z_mean = tflearn.fully_connected(fc2, self.latent_dim, scope='enc_z_mean', reuse=not is_training)
z_std = tflearn.fully_connected(fc2, self.latent_dim, scope='enc_z_std', reuse=not is_training)
return z_mean, z_std
def _sample_z(self, z_mean, z_std):
eps = tf.random_normal(tf.shape(z_mean))
return z_mean + tf.exp(z_std / 2) * eps
def _decode(self, z_sampled, is_training):
recon_activ = 'sigmoid' if self.binary_data else 'linear'
fc1 = tflearn.fully_connected(z_sampled, self.input_dim//4, activation='elu', scope='dec_fc1', reuse=not is_training)
fc2 = tflearn.fully_connected(fc1, self.input_dim//2, activation='elu', scope='dec_fc2', reuse=not is_training)
recon = tflearn.fully_connected(fc2, self.input_dim, activation=recon_activ, scope='dec_recon', reuse=not is_training)
return recon
def _compute_latent_loss(self, z_mean, z_std):
latent_loss = 1 + z_std - tf.square(z_mean) - tf.exp(z_std)
latent_loss = -0.5 * tf.reduce_sum(latent_loss, 1)
return latent_loss
def _compute_recon_loss(self, recon_data, input_data):
if self.binary_data:
loss = input_data*tf.log(1e-10+recon_data) + (1-input_data)*tf.log(1e-10+1-recon_data)
else:
loss = tflearn.objectives.mean_square(recon_data, input_data)
return loss
def _build_graph(self):
# Build training model
train_data = tflearn.input_data(shape=[None, self.input_dim], name='train_data')
z_mean, z_std = self._encode(train_data, True)
z_sampled = self._sample_z(z_mean, z_std)
recon_data = self._decode(z_sampled, True)
loss = self._compute_latent_loss(z_mean, z_std) + self._compute_recon_loss(recon_data, train_data)
optimizer = tflearn.optimizers.Adam(self.learning_rate).get_tensor()
trainop = tflearn.TrainOp(loss=loss, optimizer=optimizer, batch_size=self.batch_size, name='VAE_trainer')
self.training_model = tflearn.Trainer(train_ops=trainop, tensorboard_dir=TENSORBOARD_DIR) # <====== Error
# Build generator model
input_noise = tflearn.input_data(shape=[None, self.latent_dim], name='input_noise')
decoded_noise = self._decode(input_noise, False)
self.generator_model = tflearn.DNN(decoded_noise, session=training_model.session)
# Build recognition model
input_data = tflearn.input_data(shape=[None, self.input_dim], name='input_data')
encoded_data = self._sample_z(*self._encode(input_data, False))
self.recognition_model = tflearn.DNN(encoded_data, session=training_model.session)
def fit(self, X, testX, n_epoch=100):
self.training_model.fit({'train_data': X}, self.n_epoch, {'train_data': testX}, run_id='VAE')
def generate(self, input_noise=None):
if input_noise==None:
input_noise = np.random.normal(size=(1, self.latent_dim))
return self.generator_model.predict({'input_noise': input_noise})
def MNIST_latent_space(self, graph_shape=(30, 30)):
figure = np.ones((28*graph_shape[0], 28*graph_shape[1]))
X = norm.ppf(np.linspace(0., 1., graph_shape[0]))
Y = norm.ppf(np.linspace(0., 1., graph_shape[1]))
for i, x in enumerate(X):
for j, y in enumerate(Y):
_recon = self.generate(np.array([[x, y]]))
figure[i*28:(i+1)*28, j*28:(j+1*28)] = _recon
plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
def main():
vae = VAE()
vae.fit(X, testX)
vae.MNIST_latent_space()
if __name__=='__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment