This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
images = tf.placeholder(tf.float32, [None, input_size]) | |
digits = tf.placeholder(tf.int32, [None]) | |
# encode an image into a distribution over the latent space | |
encoder_mu, encoder_var = encoder(images, | |
params['encoder_layers']) | |
# sample a latent vector from the latent space - using the reparameterization trick | |
eps = tf.random_normal(shape=[tf.shape(images)[0], | |
params['z_dim']], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def encoder(x, layers): | |
for layer in layers: | |
x = tf.layers.dense(x, | |
layer, | |
activation=params['activation']) | |
mu = tf.layers.dense(x, params['z_dim']) | |
var = 1e-5 + tf.exp(tf.layers.dense(x, params['z_dim'])) | |
return mu, var | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
params = { | |
'encoder_layers': [128], # the encoder will be implemented using a simple feed forward network | |
'decoder_layers': [128], # and so will the decoder (CNN will be better, but I want to keep the code simple) | |
'digit_classification_layers': [128], # this is for the conditioning. I'll explain it later on | |
'activation': tf.nn.sigmoid, # the activation function used by all sub-networks | |
'decoder_std': 0.5, # the standard deviation of P(x|z) discussed in the first post | |
'z_dim': 10, # the dimension of the latent space | |
'digit_classification_weight': 10.0, # this is for the conditioning. I'll explain it later on | |
'epochs': 20, | |
'batch_size': 100, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mnist = input_data.read_data_sets('MNIST_data') | |
input_size = 28 * 28 | |
num_digits = 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
import matplotlib.pyplot as plt | |
np.random.seed(42) | |
tf.set_random_seed(42) | |
%matplotlib inline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.figure(figsize=(10, 2)) | |
prob_errors = [np.array(learned_prob) - np.array(number_to_prob.values()) | |
for learned_prob in learned_probs] | |
plt.imshow(np.transpose(prob_errors), | |
cmap='bwr', | |
aspect='auto', | |
vmin=-2, | |
vmax=2) | |
plt.xlabel('epoch') | |
plt.ylabel('number') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
learned_probs = [] | |
for _ in range(EPOCHS): | |
for _ in range(BATCHS_IN_EPOCH): | |
sess.run(d_train_opt) | |
for _ in range(GENERATOR_TRAINING_FACTOR): | |
sess.run(g_train_opt) | |
learned_probs.append(sess.run(generated_probs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
generated_outputs, generated_probs = generator() | |
discriminated_real = discriminator(value) | |
discriminated_generated = discriminator(generated_outputs) | |
d_loss_real = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminated_real, | |
labels=tf.ones_like(discriminated_real))) | |
d_loss_fake = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminated_generated, | |
labels=tf.zeros_like(discriminated_generated))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def discriminator(x): | |
with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE): | |
return tf.contrib.layers.fully_connected(x, | |
num_outputs=1, | |
activation_fn=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generator(): | |
with tf.variable_scope('generator'): | |
logits = tf.get_variable('logits', initializer=tf.ones([len(number_to_prob)])) | |
gumbel_dist = tf.contrib.distributions.RelaxedOneHotCategorical(TEMPERATURE, logits=logits) | |
probs = tf.nn.softmax(logits) | |
generated = gumbel_dist.sample(BATCH_SIZE) | |
return generated, probs |