Skip to content

Instantly share code, notes, and snippets.

@sergeyprokudin
Created June 7, 2019 14:49
Show Gist options
  • Save sergeyprokudin/6c573c48bf105769354de967cfc24a5d to your computer and use it in GitHub Desktop.
Save sergeyprokudin/6c573c48bf105769354de967cfc24a5d to your computer and use it in GitHub Desktop.
Basic variational autoencoder in Keras
import tensorflow as tf
from keras.layers import Input, Dense, Flatten, Reshape, Dropout
from keras.models import Model, Sequential
from keras.optimizers import Adam
from keras.objectives import binary_crossentropy
from keras.layers.merge import concatenate
from keras.losses import mean_squared_error
def variational_autoencoder(n_input_features, latent_space_size=64, hlayer_size=256,
lr=1.0e-3, kl_weight=0.1):
encoder_input = Input(shape=[n_input_features])
encoder_seq = Sequential()
encoder_seq.add(Dense(hlayer_size, activation='relu', input_shape=[n_input_features, ]))
encoder_seq.add(Dense(hlayer_size, activation='relu'))
encoder_mu = Dense(latent_space_size, activation='linear')(encoder_seq(encoder_input))
encoder_log_sigma = Dense(latent_space_size, activation='linear')(encoder_seq(encoder_input))
def _sample_z(args):
mu, log_sigma = args
eps = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
return mu + K.exp(log_sigma / 2) * eps
encoder_output = Lambda(_sample_z)([encoder_mu, encoder_log_sigma])
decoder_input = Input(shape=[latent_space_size])
decoder_seq = Sequential()
decoder_seq.add(Dense(hlayer_size, activation='relu', input_shape=[latent_space_size]))
decoder_seq.add(Dense(hlayer_size, activation='relu'))
decoder_seq.add(Dense(n_input_features, activation='linear'))
encoder_model = Model(inputs=encoder_input, outputs=encoder_output)
decoder_model = Model(inputs=decoder_input, outputs=decoder_seq(decoder_input))
full_model = Model(inputs=encoder_input,
outputs=concatenate([encoder_mu, encoder_log_sigma, decoder_seq(encoder_output)]))
adam_opt = Adam(lr=lr)
def _vae_loss(y_true, model_output):
""" Calculate loss = reconstruction loss + KL loss for each data in minibatch """
encoder_mu = model_output[:, 0:latent_space_size]
encoder_log_sigma = model_output[:, latent_space_size:latent_space_size*2]
y_pred = model_output[:, latent_space_size*2:]
# E[log P(X|z)] - this is because we model our P(X_i|z) as a normal distribution
#recon = K.sum(K.binary_crossentropy(y_truey_true), axis=1)
recon = mean_squared_error(y_true, y_pred)
# D_KL(Q(z|X) || P(z|X)); calculate in closed form as both dist. are Gaussian
kl = 0.5 * K.sum(K.exp(encoder_log_sigma) + K.square(encoder_mu) - 1. - encoder_log_sigma, axis=1)
return recon + kl_weight*kl
opt = Adam(lr=lr)
full_model.compile(optimizer=opt, loss=_vae_loss)
return encoder_model, decoder_model, full_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment