Created
June 7, 2019 14:49
-
-
Save sergeyprokudin/6c573c48bf105769354de967cfc24a5d to your computer and use it in GitHub Desktop.
Basic variational autoencoder in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from keras.layers import Input, Dense, Flatten, Reshape, Dropout | |
from keras.models import Model, Sequential | |
from keras.optimizers import Adam | |
from keras.objectives import binary_crossentropy | |
from keras.layers.merge import concatenate | |
from keras.losses import mean_squared_error | |
def variational_autoencoder(n_input_features, latent_space_size=64, hlayer_size=256, | |
lr=1.0e-3, kl_weight=0.1): | |
encoder_input = Input(shape=[n_input_features]) | |
encoder_seq = Sequential() | |
encoder_seq.add(Dense(hlayer_size, activation='relu', input_shape=[n_input_features, ])) | |
encoder_seq.add(Dense(hlayer_size, activation='relu')) | |
encoder_mu = Dense(latent_space_size, activation='linear')(encoder_seq(encoder_input)) | |
encoder_log_sigma = Dense(latent_space_size, activation='linear')(encoder_seq(encoder_input)) | |
def _sample_z(args): | |
mu, log_sigma = args | |
eps = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.) | |
return mu + K.exp(log_sigma / 2) * eps | |
encoder_output = Lambda(_sample_z)([encoder_mu, encoder_log_sigma]) | |
decoder_input = Input(shape=[latent_space_size]) | |
decoder_seq = Sequential() | |
decoder_seq.add(Dense(hlayer_size, activation='relu', input_shape=[latent_space_size])) | |
decoder_seq.add(Dense(hlayer_size, activation='relu')) | |
decoder_seq.add(Dense(n_input_features, activation='linear')) | |
encoder_model = Model(inputs=encoder_input, outputs=encoder_output) | |
decoder_model = Model(inputs=decoder_input, outputs=decoder_seq(decoder_input)) | |
full_model = Model(inputs=encoder_input, | |
outputs=concatenate([encoder_mu, encoder_log_sigma, decoder_seq(encoder_output)])) | |
adam_opt = Adam(lr=lr) | |
def _vae_loss(y_true, model_output): | |
""" Calculate loss = reconstruction loss + KL loss for each data in minibatch """ | |
encoder_mu = model_output[:, 0:latent_space_size] | |
encoder_log_sigma = model_output[:, latent_space_size:latent_space_size*2] | |
y_pred = model_output[:, latent_space_size*2:] | |
# E[log P(X|z)] - this is because we model our P(X_i|z) as a normal distribution | |
#recon = K.sum(K.binary_crossentropy(y_truey_true), axis=1) | |
recon = mean_squared_error(y_true, y_pred) | |
# D_KL(Q(z|X) || P(z|X)); calculate in closed form as both dist. are Gaussian | |
kl = 0.5 * K.sum(K.exp(encoder_log_sigma) + K.square(encoder_mu) - 1. - encoder_log_sigma, axis=1) | |
return recon + kl_weight*kl | |
opt = Adam(lr=lr) | |
full_model.compile(optimizer=opt, loss=_vae_loss) | |
return encoder_model, decoder_model, full_model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment