Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active November 7, 2019 21:07
Show Gist options
  • Save williamFalcon/42da07d5cea5d00151f9cfde30f092b6 to your computer and use it in GitHub Desktop.
Save williamFalcon/42da07d5cea5d00151f9cfde30f092b6 to your computer and use it in GitHub Desktop.
"""
Blog post:
Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your health:
https://medium.com/@_willfalcon/taming-lstms-variable-sized-mini-batches-and-why-pytorch-is-good-for-your-health-61d35642972e
"""
def loss(self, Y_hat, Y, X_lengths):
# TRICK 3 ********************************
# before we calculate the negative log likelihood, we need to mask out the activations
# this means we don't want to take into account padded items in the output vector
# simplest way to think about this is to flatten ALL sequences into a REALLY long sequence
# and calculate the loss on that.
# flatten all the labels
Y = Y.view(-1)
# flatten all predictions
Y_hat = Y_hat.view(-1, self.nb_tags)
# create a mask by filtering out all tokens that ARE NOT the padding token
tag_pad_token = self.tags['<PAD>']
mask = (Y > tag_pad_token).float()
# count how many tokens we have
nb_tokens = int(torch.sum(mask).data[0])
# pick the values for the label and zero out the rest with the mask
Y_hat = Y_hat[range(Y_hat.shape[0]), Y] * mask
# compute cross entropy loss which ignores all <PAD> tokens
ce_loss = -torch.sum(Y_hat) / nb_tokens
return ce_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment