Last active
March 29, 2019 00:39
-
-
Save mtreviso/29c6336983b2d55b7061d7049e69ada6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decode(self, emissions, mask=None): | |
"""Find the most probable sequence of labels given the emissions using | |
the Viterbi algorithm. | |
Args: | |
emissions (torch.Tensor): Sequence of emissions for each label. | |
Shape (batch_size, seq_len, nb_labels) if batch_first is True, | |
(seq_len, batch_size, nb_labels) otherwise. | |
mask (torch.FloatTensor, optional): Tensor representing valid positions. | |
If None, all positions are considered valid. | |
Shape (batch_size, seq_len) if batch_first is True, | |
(seq_len, batch_size) otherwise. | |
Returns: | |
torch.Tensor: the viterbi score for the for each batch. | |
Shape of (batch_size,) | |
list of lists: the best viterbi sequence of labels for each batch. | |
""" | |
if mask is None: | |
mask = torch.ones(emissions.shape[:2], dtype=torch.float) | |
scores, sequences = self._viterbi_decode(emissions, mask) | |
return scores, sequences | |
def _viterbi_decode(self, emissions, mask): | |
"""Compute the viterbi algorithm to find the most probable sequence of labels | |
given a sequence of emissions. | |
Args: | |
emissions (torch.Tensor): (batch_size, seq_len, nb_labels) | |
mask (Torch.FloatTensor): (batch_size, seq_len) | |
Returns: | |
torch.Tensor: the viterbi score for the for each batch. | |
Shape of (batch_size,) | |
list of lists of ints: the best viterbi sequence of labels for each batch | |
""" | |
batch_size, seq_length, nb_labels = emissions.shape | |
# in the first iteration, BOS will have all the scores and then, the max | |
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0] | |
backpointers = [] | |
for i in range(1, seq_length): | |
alpha_t = [] | |
backpointers_t = [] | |
for tag in range(nb_labels): | |
# get the emission for the current tag and broadcast to all labels | |
e_scores = emissions[:, i, tag] | |
e_scores = e_scores.unsqueeze(1) | |
# transitions from something to our tag and broadcast to all batches | |
t_scores = self.transitions[:, tag] | |
t_scores = t_scores.unsqueeze(0) | |
# combine current scores with previous alphas | |
scores = e_scores + t_scores + alphas | |
# so far is exactly like the forward algorithm, | |
# but now, instead of calculating the logsumexp, | |
# we will find the highest score and the tag associated with it | |
max_score, max_score_tag = torch.max(scores, dim=-1) | |
# add the max score for the current tag | |
alpha_t.append(max_score) | |
# add the max_score_tag for our list of backpointers | |
backpointers_t.append(max_score_tag) | |
# create a torch matrix from alpha_t | |
# (bs, nb_labels) | |
new_alphas = torch.stack(alpha_t).t() | |
# set alphas if the mask is valid, otherwise keep the current values | |
is_valid = mask[:, i].unsqueeze(-1) | |
alphas = is_valid * new_alphas + (1 - is_valid) * alphas | |
# append the new backpointers | |
backpointers.append(backpointers_t) | |
# add the scores for the final transition | |
last_transition = self.transitions[:, self.EOS_TAG_ID] | |
end_scores = alphas + last_transition.unsqueeze(0) | |
# get the final most probable score and the final most probable tag | |
max_final_scores, max_final_tags = torch.max(end_scores, dim=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment