Last active
July 5, 2017 14:50
-
-
Save wermarter/466e9585579ef65927fa934fe4e0ffd4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tflearn | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.stats import norm | |
import cv2 | |
import tflearn.datasets.mnist as mnist | |
trainX, trainY, testX, testY = mnist.load_data(one_hot=True) | |
TENSORBOARD_DIR='./vae_log' | |
class VAE(object): | |
def __init__(self, | |
learning_rate = 0.001, | |
batch_size = 256, | |
latent_dim = 2, | |
binary_data = True, | |
img_shape = [28, 28, 1] # W, H, D | |
): | |
self.learning_rate = learning_rate | |
self.batch_size = batch_size | |
self.latent_dim = latent_dim | |
self.binary_data = binary_data | |
self.img_shape = img_shape | |
self.full_graph = False | |
self._build_training_model() | |
def print_shape(self, tensor): | |
print(tensor.get_shape()) | |
def conv_encode(self, input_data): | |
conv1 = tflearn.conv_2d(input_data, 32, 5, activation='elu') | |
pool1 = tflearn.max_pool_2d(conv1, 2) | |
norm1 = tflearn.local_response_normalization(pool1) | |
return tf.reshape(norm1, [self.curr_batch_size, 6272]) | |
def conv_decode(self, z_sampled): | |
recon_activ = 'sigmoid' if self.binary_data else 'linear' | |
fc1 = tflearn.fully_connected(z_sampled, 196, activation='elu') | |
fc1 = tf.reshape(fc1, [-1, 14, 14, 1]) | |
deconv2 = tflearn.conv_2d_transpose(fc1, 1, 3, self.img_shape, 2, activation=recon_activ) | |
return deconv2 | |
def _encode(self, input_data, is_training): | |
with tf.variable_scope('Encoder', reuse=not is_training): | |
net = self.conv_encode(input_data) | |
self.print_shape(net) | |
z_mean = tflearn.fully_connected(net, self.latent_dim) | |
z_std = tflearn.fully_connected(net, self.latent_dim) | |
return z_mean, z_std | |
def _decode(self, z_sampled, is_training): | |
with tf.variable_scope('Decoder', reuse=not is_training): | |
net = self.conv_decode(z_sampled) | |
return net | |
def _sample_z(self, z_mean, z_std): | |
eps = tf.random_normal((self.curr_batch_size, self.latent_dim)) | |
return z_mean + tf.exp(z_std / 2) * eps | |
def _compute_latent_loss(self, z_mean, z_std): | |
latent_loss = 1 + z_std - tf.square(z_mean) - tf.exp(z_std) | |
latent_loss = -0.5 * tf.reduce_sum(latent_loss) | |
return latent_loss | |
def _compute_recon_loss(self, recon_data, input_data): | |
recon_data = tf.reshape(recon_data, [self.curr_batch_size, -1]) | |
input_data = tf.reshape(input_data, [self.curr_batch_size, -1]) | |
if self.binary_data: | |
loss = input_data*tf.log(1e-10+recon_data) + (1-input_data)*tf.log(1e-10+1-recon_data) | |
else: | |
loss = tflearn.objectives.mean_square(recon_data, input_data) | |
return -tf.reduce_sum(loss) | |
def _build_training_model(self): | |
self.train_data = tflearn.input_data(shape=[None, *self.img_shape], name='train_data') # <== Why do I need to input this when encoding 'self.input_data' | |
self.curr_batch_size = tf.shape(self.train_data)[0] | |
z_mean, z_std = self._encode(self.train_data, True) | |
z_sampled = self._sample_z(z_mean, z_std) | |
recon_data = self._decode(z_sampled, True) | |
loss = self._compute_latent_loss(z_mean, z_std) + self._compute_recon_loss(recon_data, self.train_data) | |
optimizer = tflearn.optimizers.Adam(self.learning_rate).get_tensor() | |
trainop = tflearn.TrainOp(loss=loss, optimizer=optimizer, batch_size=self.batch_size, name='VAE_trainer') | |
self.training_model = tflearn.Trainer(train_ops=trainop, tensorboard_dir=TENSORBOARD_DIR) | |
def _build_full_graph(self): | |
# Build generator model | |
self.input_noise = tflearn.input_data(shape=[None, self.latent_dim], name='input_noise') | |
decoded_noise = self._decode(self.input_noise, False) | |
self.generator_model = tflearn.DNN(decoded_noise, session=self.training_model.session) | |
# Build recognition model | |
self.input_data = tflearn.input_data(shape=[None, *self.img_shape], name='input_data') | |
encoded_data = self._sample_z(*self._encode(self.input_data, False)) | |
self.recognition_model = tflearn.DNN(encoded_data, session=self.training_model.session) | |
self.full_graph = True | |
def fit(self, trainX, testX, n_epoch=100): | |
n_train, n_test = trainX.shape[0], testX.shape[0] | |
trainX = trainX.reshape((n_train, *self.img_shape)) | |
testX = testX.reshape((n_test, *self.img_shape)) | |
self.training_model.fit({self.train_data: trainX}, n_epoch, {self.train_data: testX}, run_id='VAE') | |
def generate(self, input_noise=None): | |
if not self.full_graph: | |
self._build_full_graph() | |
if input_noise is None: | |
input_noise = np.random.normal(size=(1, self.latent_dim)) | |
else: | |
input_noise = input_noise.reshape((-1, self.latent_dim)) | |
return self.generator_model.predict({self.input_noise: input_noise}) | |
def encode(self, input_data): | |
if not self.full_graph: | |
self._build_full_graph() | |
input_data = input_data.reshape((-1, *self.img_shape)) | |
return self.recognition_model.predict({self.input_data: input_data}) # <== Error: Need to give input value to 'self.train_data' | |
def format_img(self, img): | |
img = np.array(img, np.float32) | |
if self.img_shape[-1] == 1: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
def img_transition(self, A, B, step=10): | |
enc_A = self.encode(A)[0] # <===== ERROR above | |
enc_B = self.encode(B)[0] | |
trans_lst = list() | |
for a, b in zip(enc_A, enc_B): | |
trans_step = np.linspace(a, b, step) | |
trans_lst.append(trans_step) | |
trans_lst = np.array(trans_lst).T | |
img_W, img_H, img_D = self.img_shape | |
figure = np.ones((img_H, img_W*step, img_D)) | |
for i, trans_vec in enumerate(trans_lst): | |
figure[:, i*img_W:(i+1)*img_W, :] = self.generate(trans_vec) | |
output_img = self.format_img(figure) | |
plt.figure() | |
plt.imshow(output_img) | |
plt.show() | |
def show_2D_latent_space(self, graph_shape=(30, 30)): | |
img_W, img_H, img_D = self.img_shape | |
figure = np.ones((img_H*graph_shape[0], img_W*graph_shape[1], img_D)) | |
X = norm.ppf(np.linspace(0., 1., graph_shape[0])) | |
Y = norm.ppf(np.linspace(0., 1., graph_shape[1])) | |
for i, x in enumerate(X): | |
for j, y in enumerate(Y): | |
_recon = self.generate(np.array([[x, y]])) | |
figure[i*img_H:(i+1)*img_H, j*img_W:(j+1)*img_W, :] = _recon | |
output_img = self.format_img(figure) | |
plt.figure() | |
plt.imshow(output_img) | |
plt.show() | |
def save(self, model_path='training_model.sav'): | |
self.training_model.save(model_path) | |
def load(self, model_path='training_model.sav'): | |
self.training_model.restore(model_path) | |
self._build_full_graph() | |
def main(): | |
vae = VAE() | |
vae.fit(trainX, testX, 1) | |
vae.show_2D_latent_space() | |
vae.img_transition(trainX[4], trainX[100]) | |
vae.img_transition(testX[4], testX[100]) | |
vae.save() | |
if __name__=='__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment