Skip to content

Instantly share code, notes, and snippets.

@RomanSteinberg
Created December 20, 2017 11:40
Show Gist options
  • Save RomanSteinberg/54e516cc20ebfaed1a01cfc8b0b5765c to your computer and use it in GitHub Desktop.
Save RomanSteinberg/54e516cc20ebfaed1a01cfc8b0b5765c to your computer and use it in GitHub Desktop.
VAE (MNIST)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# import tensorflow as tf
import tensorflow as tf, numpy as np
from tensorflow import nn
from tensorflow import keras as ke
from tensorflow.examples.tutorials.mnist import input_data
class VAELoss(ke.layers.Layer):
def __init__(self, **kwargs):
self.is_placeholder = True
super(VAELoss, self).__init__(**kwargs)
def __call__(self, *args, **kwargs):
return super(VAELoss, self).__call__(*args, **kwargs)
def vae_loss(self, input_var, mu, logstd, reconstruction):
log_likelihood = tf.reduce_sum(input_var * tf.log(reconstruction + 1e-9) +
(1 - input_var) * tf.log(1 - reconstruction + 1e-9),
reduction_indices=1)
KL_term = -.5 * tf.reduce_sum(1 + 2 * logstd - tf.pow(mu, 2) - tf.exp(2 * logstd), reduction_indices=1)
variational_lower_bound = tf.reduce_mean(log_likelihood - KL_term)
return -variational_lower_bound
def call(self, inputs, **kwargs):
loss = self.vae_loss(*inputs)
self.add_loss(loss, inputs=inputs)
return inputs[0]
class VAE:
def __init__(self):
self.latent_dim = 20
self.h_dim = 500
self.image_shape = 784 # (1280, 5000)
self.decoder_layers = [ke.layers.Dense(self.h_dim, activation='tanh'),
ke.layers.Dense(self.image_shape, activation='sigmoid')]
def fc_encoder(self, previous):
out = ke.layers.Dense(self.h_dim, activation='tanh')(previous)
return out
def fc_decoder(self, previous):
out = previous
for l in self.decoder_layers:
out = l(out)
return out
def distribution_layers(self, previous):
mu = ke.layers.Dense(self.latent_dim, activation='tanh')(previous)
logstd = ke.layers.Dense(self.latent_dim, activation='tanh')(previous)
la = lambda args: args[0] + tf.random_normal([1, self.latent_dim]) * tf.exp(.5 * args[1])
sample = ke.layers.Lambda(la)([mu, logstd])
return sample, mu, logstd
def loss(self, input_var, mu, logstd, reconstruction):
out = VAELoss()([input_var, mu, logstd, reconstruction])
return out
def create_train_pass(self):
inp = ke.layers.Input(shape=(self.image_shape,))
h = self.fc_encoder(inp)
z, mu, logstd = self.distribution_layers(h)
decoded = self.fc_decoder(z)
out = self.loss(inp, mu, logstd, decoded)
model = ke.models.Model(inputs=inp, outputs=out)
gen_inp = ke.layers.Input(shape=(self.latent_dim,))
reconstruction = self.fc_decoder(gen_inp)
generator = ke.models.Model(inputs=gen_inp, outputs=reconstruction)
return model, generator
def create_gen_path(self):
pass
def train(train_data):
vae = VAE()
model, generator = vae.create_train_pass()
model.compile(optimizer='adam', loss=None)
model.fit(train_data, shuffle=True, batch_size=128, epochs=1, verbose=1)
generator.save('gen.h5')
def main():
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
train(mnist.train.images)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment