Skip to content

Instantly share code, notes, and snippets.

@wermarter
Last active July 5, 2017 14:50
Show Gist options
  • Save wermarter/466e9585579ef65927fa934fe4e0ffd4 to your computer and use it in GitHub Desktop.
Save wermarter/466e9585579ef65927fa934fe4e0ffd4 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tflearn
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import cv2
import tflearn.datasets.mnist as mnist
trainX, trainY, testX, testY = mnist.load_data(one_hot=True)
TENSORBOARD_DIR='./vae_log'
class VAE(object):
def __init__(self,
learning_rate = 0.001,
batch_size = 256,
latent_dim = 2,
binary_data = True,
img_shape = [28, 28, 1] # W, H, D
):
self.learning_rate = learning_rate
self.batch_size = batch_size
self.latent_dim = latent_dim
self.binary_data = binary_data
self.img_shape = img_shape
self.full_graph = False
self._build_training_model()
def print_shape(self, tensor):
print(tensor.get_shape())
def conv_encode(self, input_data):
conv1 = tflearn.conv_2d(input_data, 32, 5, activation='elu')
pool1 = tflearn.max_pool_2d(conv1, 2)
norm1 = tflearn.local_response_normalization(pool1)
return tf.reshape(norm1, [self.curr_batch_size, 6272])
def conv_decode(self, z_sampled):
recon_activ = 'sigmoid' if self.binary_data else 'linear'
fc1 = tflearn.fully_connected(z_sampled, 196, activation='elu')
fc1 = tf.reshape(fc1, [-1, 14, 14, 1])
deconv2 = tflearn.conv_2d_transpose(fc1, 1, 3, self.img_shape, 2, activation=recon_activ)
return deconv2
def _encode(self, input_data, is_training):
with tf.variable_scope('Encoder', reuse=not is_training):
net = self.conv_encode(input_data)
self.print_shape(net)
z_mean = tflearn.fully_connected(net, self.latent_dim)
z_std = tflearn.fully_connected(net, self.latent_dim)
return z_mean, z_std
def _decode(self, z_sampled, is_training):
with tf.variable_scope('Decoder', reuse=not is_training):
net = self.conv_decode(z_sampled)
return net
def _sample_z(self, z_mean, z_std):
eps = tf.random_normal((self.curr_batch_size, self.latent_dim))
return z_mean + tf.exp(z_std / 2) * eps
def _compute_latent_loss(self, z_mean, z_std):
latent_loss = 1 + z_std - tf.square(z_mean) - tf.exp(z_std)
latent_loss = -0.5 * tf.reduce_sum(latent_loss)
return latent_loss
def _compute_recon_loss(self, recon_data, input_data):
recon_data = tf.reshape(recon_data, [self.curr_batch_size, -1])
input_data = tf.reshape(input_data, [self.curr_batch_size, -1])
if self.binary_data:
loss = input_data*tf.log(1e-10+recon_data) + (1-input_data)*tf.log(1e-10+1-recon_data)
else:
loss = tflearn.objectives.mean_square(recon_data, input_data)
return -tf.reduce_sum(loss)
def _build_training_model(self):
self.train_data = tflearn.input_data(shape=[None, *self.img_shape], name='train_data') # <== Why do I need to input this when encoding 'self.input_data'
self.curr_batch_size = tf.shape(self.train_data)[0]
z_mean, z_std = self._encode(self.train_data, True)
z_sampled = self._sample_z(z_mean, z_std)
recon_data = self._decode(z_sampled, True)
loss = self._compute_latent_loss(z_mean, z_std) + self._compute_recon_loss(recon_data, self.train_data)
optimizer = tflearn.optimizers.Adam(self.learning_rate).get_tensor()
trainop = tflearn.TrainOp(loss=loss, optimizer=optimizer, batch_size=self.batch_size, name='VAE_trainer')
self.training_model = tflearn.Trainer(train_ops=trainop, tensorboard_dir=TENSORBOARD_DIR)
def _build_full_graph(self):
# Build generator model
self.input_noise = tflearn.input_data(shape=[None, self.latent_dim], name='input_noise')
decoded_noise = self._decode(self.input_noise, False)
self.generator_model = tflearn.DNN(decoded_noise, session=self.training_model.session)
# Build recognition model
self.input_data = tflearn.input_data(shape=[None, *self.img_shape], name='input_data')
encoded_data = self._sample_z(*self._encode(self.input_data, False))
self.recognition_model = tflearn.DNN(encoded_data, session=self.training_model.session)
self.full_graph = True
def fit(self, trainX, testX, n_epoch=100):
n_train, n_test = trainX.shape[0], testX.shape[0]
trainX = trainX.reshape((n_train, *self.img_shape))
testX = testX.reshape((n_test, *self.img_shape))
self.training_model.fit({self.train_data: trainX}, n_epoch, {self.train_data: testX}, run_id='VAE')
def generate(self, input_noise=None):
if not self.full_graph:
self._build_full_graph()
if input_noise is None:
input_noise = np.random.normal(size=(1, self.latent_dim))
else:
input_noise = input_noise.reshape((-1, self.latent_dim))
return self.generator_model.predict({self.input_noise: input_noise})
def encode(self, input_data):
if not self.full_graph:
self._build_full_graph()
input_data = input_data.reshape((-1, *self.img_shape))
return self.recognition_model.predict({self.input_data: input_data}) # <== Error: Need to give input value to 'self.train_data'
def format_img(self, img):
img = np.array(img, np.float32)
if self.img_shape[-1] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def img_transition(self, A, B, step=10):
enc_A = self.encode(A)[0] # <===== ERROR above
enc_B = self.encode(B)[0]
trans_lst = list()
for a, b in zip(enc_A, enc_B):
trans_step = np.linspace(a, b, step)
trans_lst.append(trans_step)
trans_lst = np.array(trans_lst).T
img_W, img_H, img_D = self.img_shape
figure = np.ones((img_H, img_W*step, img_D))
for i, trans_vec in enumerate(trans_lst):
figure[:, i*img_W:(i+1)*img_W, :] = self.generate(trans_vec)
output_img = self.format_img(figure)
plt.figure()
plt.imshow(output_img)
plt.show()
def show_2D_latent_space(self, graph_shape=(30, 30)):
img_W, img_H, img_D = self.img_shape
figure = np.ones((img_H*graph_shape[0], img_W*graph_shape[1], img_D))
X = norm.ppf(np.linspace(0., 1., graph_shape[0]))
Y = norm.ppf(np.linspace(0., 1., graph_shape[1]))
for i, x in enumerate(X):
for j, y in enumerate(Y):
_recon = self.generate(np.array([[x, y]]))
figure[i*img_H:(i+1)*img_H, j*img_W:(j+1)*img_W, :] = _recon
output_img = self.format_img(figure)
plt.figure()
plt.imshow(output_img)
plt.show()
def save(self, model_path='training_model.sav'):
self.training_model.save(model_path)
def load(self, model_path='training_model.sav'):
self.training_model.restore(model_path)
self._build_full_graph()
def main():
vae = VAE()
vae.fit(trainX, testX, 1)
vae.show_2D_latent_space()
vae.img_transition(trainX[4], trainX[100])
vae.img_transition(testX[4], testX[100])
vae.save()
if __name__=='__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment