Skip to content

Instantly share code, notes, and snippets.

@Akash-Rawat
Created July 2, 2021 10:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Akash-Rawat/8877cd993faf46eb9101e5937d1e9ebb to your computer and use it in GitHub Desktop.
Save Akash-Rawat/8877cd993faf46eb9101e5937d1e9ebb to your computer and use it in GitHub Desktop.
Defining loss_function
def calculate_loss(reconstructed, caption_prob, images, captions_transformed, mean, log_std):
size = captions_transformed.shape[0]
reconstruction_error = criterion(reconstructed, images)
likelihoods = torch.stack([
caption_prob[i, np.arange(MAX_CAPTION_LEN), captions_transformed[i]] for i in range(size)])
log_likelihoods = -torch.log(likelihoods).sum()
KL_divergence = - (1 - mean.pow(2) - torch.exp(2 * log_std) + (2 *log_std)).sum()
return reconstruction_error + (log_likelihoods) + KL_divergence, log_likelihoods
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment