Last active
December 29, 2021 05:32
-
-
Save gowrishankarin/d96350f60093868b3ab41ba396bf35b6 to your computer and use it in GitHub Desktop.
Loss for VAE of 1d array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_encoder_with_sampling_layer(shape=(6)): | |
encoder_inputs = keras.layers.Input(shape=shape, name="input_layer") | |
x = keras.layers.Dense(5, activation="relu", name="h1")(encoder_inputs) | |
x = keras.layers.Dense(5, activation="relu", name="h2")(x) | |
x = keras.layers.Dense(4, activation="relu", name="h3")(x) | |
z_mean = keras.layers.Dense(latent_dim, name="z_mean")(x) | |
z_log_var = keras.layers.Dense(latent_dim, name="z_log_var")(x) | |
z = Sampling()([z_mean, z_log_var]) | |
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder") | |
return encoder | |
def build_decoder(): | |
decoder_inputs = keras.layers.Input(shape=(latent_dim)) | |
x = keras.layers.Dense(4, activation="relu", name="d1")(decoder_inputs) | |
x = keras.layers.Dense(5, activation="relu", name="d2")(x) | |
x = keras.layers.Dense(5, activation="relu", name="d3")(x) | |
# No activation in the the output layer | |
x = keras.layers.Dense(6, activation=None, name="d4")(x) | |
decoder = keras.Model(decoder_inputs, x, name="d1") | |
return decoder | |
# Skeleton for compute loss to match with Kapil's ELBO/KLD video | |
def compute_loss(encoder, decoder, x): | |
mean, logvar = encoder(x) | |
z = reparameterize(mean, logvar) | |
x_logit = decoder(z) | |
mse = tf.keras.metrics.mean_squared_error(x, x_logit) | |
logpx_z = "???" | |
logpz = "???" | |
logqz_x = "???" | |
return -tf.reduce_mean(logpx_z + logpz - logqz_x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
batch_size = 50 | |
latent_dim = 2 | |
encoder = build_encoder_with_sampling_layer() | |
decoder = build_decoder() | |
vae = VAE(encoder, decoder) | |
inputs_data = tf.random.normal(shape=(batch_size, 6)) | |
train_inputs = inputs_data[0:batch_size//2] | |
test_inputs = inputs_data[batch_size//2:] | |
vae.compile(optimizer=keras.optimizers.Adam()) | |
vae.fit(train_inputs, epochs=250, batch_size=8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# My VAE, that lead into too many confusion | |
def compute_loss(encoder, decoder, x): | |
mean, logvar = encoder(x) | |
z = reparameterize(mean, logvar) | |
x_logit = decoder(z) | |
mse = tf.keras.metrics.mean_squared_error(x, x_logit) | |
# print(f"MSE: {mse}") | |
# logpx_z = -tf.reduce_sum(cross_ent) | |
logpx_z = -mse | |
# print(f"Shape of logpx: {z}") | |
logpz = log_normal_pdf(z, 0., 0.) | |
logqz_x = log_normal_pdf(z, mean, logvar) | |
return -tf.reduce_mean(logpx_z + logpz - logqz_x) | |
@tf.function | |
def train_step(encoder, decoder, x, optimizer_enc, optimizer_dec): | |
with tf.GradientTape() as tape_enc: | |
with tf.GradientTape() as tape_dec: | |
loss = compute_loss(encoder, decoder, x) | |
gradients_enc = tape_enc.gradient(loss, encoder.trainable_variables) | |
gradients_dec = tape_dec.gradient(loss, decoder.trainable_variables) | |
optimizer_enc.apply_gradients(zip(gradients_enc, encoder.trainable_variables)) | |
optimizer_dec.apply_gradients(zip(gradients_dec, decoder.trainable_variables)) | |
def reparameterize(mean, logvar): | |
eps = tf.random.normal(shape=mean.shape) | |
return eps * tf.exp(logvar * .5) + mean |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Francois Chollet's VAE with modified loss fuction having MSE as the reconstruction loss | |
class VAE(keras.Model): | |
def __init__(self, encoder, decoder, **kwargs): | |
super(VAE, self).__init__(**kwargs) | |
self.encoder = encoder | |
self.decoder = decoder | |
self.total_loss_tracker = keras.metrics.Mean(name="total_loss") | |
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss") | |
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss") | |
@property | |
def metrics(self): | |
return [ | |
self.total_loss_tracker, | |
self.reconstruction_loss_tracker, | |
self.kl_loss_tracker | |
] | |
def reparameterize(self, mean, logvar): | |
print(mean.shape, logvar.shape) | |
# eps = tf.random.normal(shape=(2,1)) | |
eps = tf.random.normal((), mean=mean, stddev=logvar) | |
return eps * tf.exp(logvar * .5) + mean | |
def train_step(self, data): | |
with tf.GradientTape() as tape: | |
z_mean, z_log_var, z = self.encoder(data) | |
reconstruction = self.decoder(z) | |
reconstruction_loss = tf.keras.metrics.mean_squared_error(data, reconstruction) | |
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) | |
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) | |
total_loss = reconstruction_loss + kl_loss | |
grads = tape.gradient(total_loss, self.trainable_weights) | |
self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) | |
self.total_loss_tracker.update_state(total_loss) | |
self.reconstruction_loss_tracker.update_state(reconstruction_loss) | |
self.kl_loss_tracker.update_state(kl_loss) | |
return { | |
"loss": self.total_loss_tracker.result(), | |
"reconstruction_loss": self.reconstruction_loss_tracker.result(), | |
"kl_loss": self.kl_loss_tracker.result() | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Please ignore total_mess_up.py for now.