Skip to content

Instantly share code, notes, and snippets.

@ckrapu
Created December 21, 2021 19:31
Show Gist options
  • Save ckrapu/eee1fa5279490b47aa5d3e03511c4f08 to your computer and use it in GitHub Desktop.
Save ckrapu/eee1fa5279490b47aa5d3e03511c4f08 to your computer and use it in GitHub Desktop.
vae-skip-connection
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
def basic_encoder(latent_dim, p, width=256, n_layers=5, use_skip=True, dropout_p=0.0):
encoder_inputs = keras.Input(shape=(p,))
x = encoder_inputs
for i in range(n_layers-1):
eps = layers.Dense(width, activation="linear")(x)
eps = keras.layers.BatchNormalization()(eps)
eps = layers.Dropout(dropout_p)(eps)
eps = layers.Activation('relu')(eps)
if i > 0 and use_skip:
x = layers.Add()([x, eps])
else:
x = eps
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
return encoder
def basic_decoder(latent_dim, p, width=256, n_layers=5, use_skip=True, dropout_p=0.0):
latent_inputs = keras.Input(shape=(latent_dim,))
x = latent_inputs
for i in range(n_layers-1):
eps = layers.Dense(width, activation="linear")(x)
eps = keras.layers.BatchNormalization()(eps)
eps = layers.Dropout(dropout_p)(eps)
eps = layers.Activation('relu')(eps)
if i > 0 and use_skip:
x = layers.Add()([x, eps])
else:
x = eps
decoder_outputs = layers.Dense(p, activation="linear")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
return decoder
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment