Created
December 21, 2021 19:31
-
-
Save ckrapu/eee1fa5279490b47aa5d3e03511c4f08 to your computer and use it in GitHub Desktop.
vae-skip-connection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Sampling(layers.Layer): | |
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" | |
def call(self, inputs): | |
z_mean, z_log_var = inputs | |
batch = tf.shape(z_mean)[0] | |
dim = tf.shape(z_mean)[1] | |
epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) | |
return z_mean + tf.exp(0.5 * z_log_var) * epsilon | |
def basic_encoder(latent_dim, p, width=256, n_layers=5, use_skip=True, dropout_p=0.0): | |
encoder_inputs = keras.Input(shape=(p,)) | |
x = encoder_inputs | |
for i in range(n_layers-1): | |
eps = layers.Dense(width, activation="linear")(x) | |
eps = keras.layers.BatchNormalization()(eps) | |
eps = layers.Dropout(dropout_p)(eps) | |
eps = layers.Activation('relu')(eps) | |
if i > 0 and use_skip: | |
x = layers.Add()([x, eps]) | |
else: | |
x = eps | |
z_mean = layers.Dense(latent_dim, name="z_mean")(x) | |
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x) | |
z = Sampling()([z_mean, z_log_var]) | |
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder") | |
return encoder | |
def basic_decoder(latent_dim, p, width=256, n_layers=5, use_skip=True, dropout_p=0.0): | |
latent_inputs = keras.Input(shape=(latent_dim,)) | |
x = latent_inputs | |
for i in range(n_layers-1): | |
eps = layers.Dense(width, activation="linear")(x) | |
eps = keras.layers.BatchNormalization()(eps) | |
eps = layers.Dropout(dropout_p)(eps) | |
eps = layers.Activation('relu')(eps) | |
if i > 0 and use_skip: | |
x = layers.Add()([x, eps]) | |
else: | |
x = eps | |
decoder_outputs = layers.Dense(p, activation="linear")(x) | |
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder") | |
return decoder |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment