Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Created February 24, 2019 22:29
Show Gist options
  • Save mtreviso/045cac05d38f2caf8127a20a902e4fb9 to your computer and use it in GitHub Desktop.
Save mtreviso/045cac05d38f2caf8127a20a902e4fb9 to your computer and use it in GitHub Desktop.
def _compute_scores(self, emissions, tags, mask):
"""Compute the scores for a given batch of emissions with their tags.
Args:
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
tags (Torch.LongTensor): (batch_size, seq_len)
mask (Torch.FloatTensor): (batch_size, seq_len)
Returns:
torch.Tensor: Scores for each batch.
Shape of (batch_size,)
"""
batch_size, seq_length = tags.shape
scores = torch.zeros(batch_size)
# save first and last tags to be used later
first_tags = tags[:, 0]
last_valid_idx = mask.int().sum(1) - 1
last_tags = tags.gather(1, last_valid_idx.unsqueeze(1)).squeeze()
# add the transition from BOS to the first tags for each batch
t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
# add the [unary] emission scores for the first tags for each batch
# for all batches, the first word, see the correspondent emissions
# for the first tags (which is a list of ids):
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
# the scores for a word is just the sum of both scores
scores += e_scores + t_scores
# now lets do this for each remaining word
for i in range(1, seq_length):
# we could: iterate over batches, check if we reached a mask symbol
# and stop the iteration, but vecotrizing is faster due to gpu,
# so instead we perform an element-wise multiplication
is_valid = mask[:, i]
previous_tags = tags[:, i - 1]
current_tags = tags[:, i]
# calculate emission and transition scores as we did before
e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
t_scores = self.transitions[previous_tags, current_tags]
# apply the mask
e_scores = e_scores * is_valid
t_scores = t_scores * is_valid
scores += e_scores + t_scores
# add the transition from the end tag to the EOS tag for each batch
scores += self.transitions[last_tags, self.EOS_TAG_ID]
return scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment