Last active July 5, 2017 14:50
import tensorflow as tf
import tflearn
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import cv2
import tflearn.datasets.mnist as mnist
trainX, trainY, testX, testY = mnist.load_data(one_hot=True)
class VAE(object):
def __init__(self,
learning_rate = 0.001,
batch_size = 256,
latent_dim = 2,
binary_data = True,
img_shape = [28, 28, 1] # W, H, D
self.learning_rate = learning_rate
self.batch_size = batch_size
self.latent_dim = latent_dim
self.binary_data = binary_data
self.img_shape = img_shape
self.full_graph = False
def print_shape(self, tensor):
def conv_encode(self, input_data):
conv1 = tflearn.conv_2d(input_data, 32, 5, activation='elu')
pool1 = tflearn.max_pool_2d(conv1, 2)
norm1 = tflearn.local_response_normalization(pool1)
return tf.reshape(norm1, [self.curr_batch_size, 6272])
def conv_decode(self, z_sampled):
recon_activ = 'sigmoid' if self.binary_data else 'linear'
fc1 = tflearn.fully_connected(z_sampled, 196, activation='elu')
fc1 = tf.reshape(fc1, [-1, 14, 14, 1])
deconv2 = tflearn.conv_2d_transpose(fc1, 1, 3, self.img_shape, 2, activation=recon_activ)
return deconv2
def _encode(self, input_data, is_training):
with tf.variable_scope('Encoder', reuse=not is_training):
net = self.conv_encode(input_data)
z_mean = tflearn.fully_connected(net, self.latent_dim)
z_std = tflearn.fully_connected(net, self.latent_dim)
return z_mean, z_std
def _decode(self, z_sampled, is_training):
with tf.variable_scope('Decoder', reuse=not is_training):
net = self.conv_decode(z_sampled)
return net
def _sample_z(self, z_mean, z_std):
eps = tf.random_normal((self.curr_batch_size, self.latent_dim))
return z_mean + tf.exp(z_std / 2) * eps
def _compute_latent_loss(self, z_mean, z_std):
latent_loss = 1 + z_std - tf.square(z_mean) - tf.exp(z_std)
latent_loss = -0.5 * tf.reduce_sum(latent_loss)
return latent_loss
def _compute_recon_loss(self, recon_data, input_data):
recon_data = tf.reshape(recon_data, [self.curr_batch_size, -1])
input_data = tf.reshape(input_data, [self.curr_batch_size, -1])
if self.binary_data:
loss = input_data*tf.log(1e-10+recon_data) + (1-input_data)*tf.log(1e-10+1-recon_data)
loss = tflearn.objectives.mean_square(recon_data, input_data)
return -tf.reduce_sum(loss)
def _build_training_model(self):
self.train_data = tflearn.input_data(shape=[None, *self.img_shape], name='train_data') # <== Why do I need to input this when encoding 'self.input_data'
self.curr_batch_size = tf.shape(self.train_data)[0]
z_mean, z_std = self._encode(self.train_data, True)
z_sampled = self._sample_z(z_mean, z_std)
recon_data = self._decode(z_sampled, True)
loss = self._compute_latent_loss(z_mean, z_std) + self._compute_recon_loss(recon_data, self.train_data)
optimizer = tflearn.optimizers.Adam(self.learning_rate).get_tensor()
trainop = tflearn.TrainOp(loss=loss, optimizer=optimizer, batch_size=self.batch_size, name='VAE_trainer')
self.training_model = tflearn.Trainer(train_ops=trainop, tensorboard_dir=TENSORBOARD_DIR)
def _build_full_graph(self):
# Build generator model
self.input_noise = tflearn.input_data(shape=[None, self.latent_dim], name='input_noise')
decoded_noise = self._decode(self.input_noise, False)
self.generator_model = tflearn.DNN(decoded_noise, session=self.training_model.session)
# Build recognition model
self.input_data = tflearn.input_data(shape=[None, *self.img_shape], name='input_data')
encoded_data = self._sample_z(*self._encode(self.input_data, False))
self.recognition_model = tflearn.DNN(encoded_data, session=self.training_model.session)
self.full_graph = True
def fit(self, trainX, testX, n_epoch=100):
n_train, n_test = trainX.shape[0], testX.shape[0]
trainX = trainX.reshape((n_train, *self.img_shape))
testX = testX.reshape((n_test, *self.img_shape)){self.train_data: trainX}, n_epoch, {self.train_data: testX}, run_id='VAE')
def generate(self, input_noise=None):
if not self.full_graph:
if input_noise is None:
input_noise = np.random.normal(size=(1, self.latent_dim))
input_noise = input_noise.reshape((-1, self.latent_dim))
return self.generator_model.predict({self.input_noise: input_noise})
def encode(self, input_data):
if not self.full_graph:
input_data = input_data.reshape((-1, *self.img_shape))
return self.recognition_model.predict({self.input_data: input_data}) # <== Error: Need to give input value to 'self.train_data'
def format_img(self, img):
img = np.array(img, np.float32)
if self.img_shape[-1] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def img_transition(self, A, B, step=10):
enc_A = self.encode(A)[0] # <===== ERROR above
enc_B = self.encode(B)[0]
trans_lst = list()
for a, b in zip(enc_A, enc_B):
trans_step = np.linspace(a, b, step)
trans_lst = np.array(trans_lst).T
img_W, img_H, img_D = self.img_shape
figure = np.ones((img_H, img_W*step, img_D))
for i, trans_vec in enumerate(trans_lst):
figure[:, i*img_W:(i+1)*img_W, :] = self.generate(trans_vec)
output_img = self.format_img(figure)
def show_2D_latent_space(self, graph_shape=(30, 30)):
img_W, img_H, img_D = self.img_shape
figure = np.ones((img_H*graph_shape[0], img_W*graph_shape[1], img_D))
X = norm.ppf(np.linspace(0., 1., graph_shape[0]))
Y = norm.ppf(np.linspace(0., 1., graph_shape[1]))
for i, x in enumerate(X):
for j, y in enumerate(Y):
_recon = self.generate(np.array([[x, y]]))
figure[i*img_H:(i+1)*img_H, j*img_W:(j+1)*img_W, :] = _recon
output_img = self.format_img(figure)
def save(self, model_path='training_model.sav'):
def load(self, model_path='training_model.sav'):
def main():
vae = VAE(), testX, 1)
vae.img_transition(trainX[4], trainX[100])
vae.img_transition(testX[4], testX[100])
if __name__=='__main__':
