Last active
November 12, 2018 20:14
-
-
Save yoel-zeldes/e4f3253756de8397fec7f5cd0837b4ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# the loss is composed of how well we can reconstruct the image | |
loss_reconstruction = -tf.reduce_sum( | |
tf.contrib.distributions.Normal( | |
decoded_images, | |
params['decoder_std'] | |
).log_prob(images), | |
axis=1 | |
) | |
# and how off the distribution over the latent space is from the prior. | |
# Given the prior is a standard Gaussian and the inferred distribution | |
# is a Gaussian with a diagonal covariance matrix, the KL-divergence | |
# becomes analytically solvable, and we get | |
loss_prior = -0.5 * tf.reduce_sum( | |
1 + tf.log(encoder_var) - encoder_mu ** 2 - encoder_var, | |
axis=1 | |
) | |
loss_auto_encode = tf.reduce_mean( | |
loss_reconstruction + loss_prior, | |
axis=0 | |
) | |
# digit_classification_weight is used to weight between the two losses, | |
# since there's a tension between them | |
loss_digit_classifier = params['digit_classification_weight'] * tf.reduce_mean( | |
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=digits, | |
logits=digit_logits), | |
axis=0 | |
) | |
loss = loss_auto_encode + loss_digit_classifier | |
train_op = tf.train.AdamOptimizer(params['learning_rate']).minimize(loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment