Created
February 24, 2019 22:29
-
-
Save mtreviso/045cac05d38f2caf8127a20a902e4fb9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _compute_scores(self, emissions, tags, mask): | |
"""Compute the scores for a given batch of emissions with their tags. | |
Args: | |
emissions (torch.Tensor): (batch_size, seq_len, nb_labels) | |
tags (Torch.LongTensor): (batch_size, seq_len) | |
mask (Torch.FloatTensor): (batch_size, seq_len) | |
Returns: | |
torch.Tensor: Scores for each batch. | |
Shape of (batch_size,) | |
""" | |
batch_size, seq_length = tags.shape | |
scores = torch.zeros(batch_size) | |
# save first and last tags to be used later | |
first_tags = tags[:, 0] | |
last_valid_idx = mask.int().sum(1) - 1 | |
last_tags = tags.gather(1, last_valid_idx.unsqueeze(1)).squeeze() | |
# add the transition from BOS to the first tags for each batch | |
t_scores = self.transitions[self.BOS_TAG_ID, first_tags] | |
# add the [unary] emission scores for the first tags for each batch | |
# for all batches, the first word, see the correspondent emissions | |
# for the first tags (which is a list of ids): | |
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]] | |
e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze() | |
# the scores for a word is just the sum of both scores | |
scores += e_scores + t_scores | |
# now lets do this for each remaining word | |
for i in range(1, seq_length): | |
# we could: iterate over batches, check if we reached a mask symbol | |
# and stop the iteration, but vecotrizing is faster due to gpu, | |
# so instead we perform an element-wise multiplication | |
is_valid = mask[:, i] | |
previous_tags = tags[:, i - 1] | |
current_tags = tags[:, i] | |
# calculate emission and transition scores as we did before | |
e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze() | |
t_scores = self.transitions[previous_tags, current_tags] | |
# apply the mask | |
e_scores = e_scores * is_valid | |
t_scores = t_scores * is_valid | |
scores += e_scores + t_scores | |
# add the transition from the end tag to the EOS tag for each batch | |
scores += self.transitions[last_tags, self.EOS_TAG_ID] | |
return scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment