Skip to content

Instantly share code, notes, and snippets.

@gowrishankarin
Last active December 29, 2021 05:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gowrishankarin/d96350f60093868b3ab41ba396bf35b6 to your computer and use it in GitHub Desktop.
Save gowrishankarin/d96350f60093868b3ab41ba396bf35b6 to your computer and use it in GitHub Desktop.
Loss for VAE of 1d array
def build_encoder_with_sampling_layer(shape=(6)):
encoder_inputs = keras.layers.Input(shape=shape, name="input_layer")
x = keras.layers.Dense(5, activation="relu", name="h1")(encoder_inputs)
x = keras.layers.Dense(5, activation="relu", name="h2")(x)
x = keras.layers.Dense(4, activation="relu", name="h3")(x)
z_mean = keras.layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = keras.layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
return encoder
def build_decoder():
decoder_inputs = keras.layers.Input(shape=(latent_dim))
x = keras.layers.Dense(4, activation="relu", name="d1")(decoder_inputs)
x = keras.layers.Dense(5, activation="relu", name="d2")(x)
x = keras.layers.Dense(5, activation="relu", name="d3")(x)
# No activation in the the output layer
x = keras.layers.Dense(6, activation=None, name="d4")(x)
decoder = keras.Model(decoder_inputs, x, name="d1")
return decoder
# Skeleton for compute loss to match with Kapil's ELBO/KLD video
def compute_loss(encoder, decoder, x):
mean, logvar = encoder(x)
z = reparameterize(mean, logvar)
x_logit = decoder(z)
mse = tf.keras.metrics.mean_squared_error(x, x_logit)
logpx_z = "???"
logpz = "???"
logqz_x = "???"
return -tf.reduce_mean(logpx_z + logpz - logqz_x)
batch_size = 50
latent_dim = 2
encoder = build_encoder_with_sampling_layer()
decoder = build_decoder()
vae = VAE(encoder, decoder)
inputs_data = tf.random.normal(shape=(batch_size, 6))
train_inputs = inputs_data[0:batch_size//2]
test_inputs = inputs_data[batch_size//2:]
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(train_inputs, epochs=250, batch_size=8)
# My VAE, that lead into too many confusion
def compute_loss(encoder, decoder, x):
mean, logvar = encoder(x)
z = reparameterize(mean, logvar)
x_logit = decoder(z)
mse = tf.keras.metrics.mean_squared_error(x, x_logit)
# print(f"MSE: {mse}")
# logpx_z = -tf.reduce_sum(cross_ent)
logpx_z = -mse
# print(f"Shape of logpx: {z}")
logpz = log_normal_pdf(z, 0., 0.)
logqz_x = log_normal_pdf(z, mean, logvar)
return -tf.reduce_mean(logpx_z + logpz - logqz_x)
@tf.function
def train_step(encoder, decoder, x, optimizer_enc, optimizer_dec):
with tf.GradientTape() as tape_enc:
with tf.GradientTape() as tape_dec:
loss = compute_loss(encoder, decoder, x)
gradients_enc = tape_enc.gradient(loss, encoder.trainable_variables)
gradients_dec = tape_dec.gradient(loss, decoder.trainable_variables)
optimizer_enc.apply_gradients(zip(gradients_enc, encoder.trainable_variables))
optimizer_dec.apply_gradients(zip(gradients_dec, decoder.trainable_variables))
def reparameterize(mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return eps * tf.exp(logvar * .5) + mean
# Francois Chollet's VAE with modified loss fuction having MSE as the reconstruction loss
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker
]
def reparameterize(self, mean, logvar):
print(mean.shape, logvar.shape)
# eps = tf.random.normal(shape=(2,1))
eps = tf.random.normal((), mean=mean, stddev=logvar)
return eps * tf.exp(logvar * .5) + mean
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.keras.metrics.mean_squared_error(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result()
}
@gowrishankarin
Copy link
Author

gowrishankarin commented Dec 29, 2021

Please ignore total_mess_up.py for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment