Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active March 29, 2019 00:39
Show Gist options
  • Save mtreviso/29c6336983b2d55b7061d7049e69ada6 to your computer and use it in GitHub Desktop.
Save mtreviso/29c6336983b2d55b7061d7049e69ada6 to your computer and use it in GitHub Desktop.
def decode(self, emissions, mask=None):
"""Find the most probable sequence of labels given the emissions using
the Viterbi algorithm.
Args:
emissions (torch.Tensor): Sequence of emissions for each label.
Shape (batch_size, seq_len, nb_labels) if batch_first is True,
(seq_len, batch_size, nb_labels) otherwise.
mask (torch.FloatTensor, optional): Tensor representing valid positions.
If None, all positions are considered valid.
Shape (batch_size, seq_len) if batch_first is True,
(seq_len, batch_size) otherwise.
Returns:
torch.Tensor: the viterbi score for the for each batch.
Shape of (batch_size,)
list of lists: the best viterbi sequence of labels for each batch.
"""
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.float)
scores, sequences = self._viterbi_decode(emissions, mask)
return scores, sequences
def _viterbi_decode(self, emissions, mask):
"""Compute the viterbi algorithm to find the most probable sequence of labels
given a sequence of emissions.
Args:
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
mask (Torch.FloatTensor): (batch_size, seq_len)
Returns:
torch.Tensor: the viterbi score for the for each batch.
Shape of (batch_size,)
list of lists of ints: the best viterbi sequence of labels for each batch
"""
batch_size, seq_length, nb_labels = emissions.shape
# in the first iteration, BOS will have all the scores and then, the max
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0]
backpointers = []
for i in range(1, seq_length):
alpha_t = []
backpointers_t = []
for tag in range(nb_labels):
# get the emission for the current tag and broadcast to all labels
e_scores = emissions[:, i, tag]
e_scores = e_scores.unsqueeze(1)
# transitions from something to our tag and broadcast to all batches
t_scores = self.transitions[:, tag]
t_scores = t_scores.unsqueeze(0)
# combine current scores with previous alphas
scores = e_scores + t_scores + alphas
# so far is exactly like the forward algorithm,
# but now, instead of calculating the logsumexp,
# we will find the highest score and the tag associated with it
max_score, max_score_tag = torch.max(scores, dim=-1)
# add the max score for the current tag
alpha_t.append(max_score)
# add the max_score_tag for our list of backpointers
backpointers_t.append(max_score_tag)
# create a torch matrix from alpha_t
# (bs, nb_labels)
new_alphas = torch.stack(alpha_t).t()
# set alphas if the mask is valid, otherwise keep the current values
is_valid = mask[:, i].unsqueeze(-1)
alphas = is_valid * new_alphas + (1 - is_valid) * alphas
# append the new backpointers
backpointers.append(backpointers_t)
# add the scores for the final transition
last_transition = self.transitions[:, self.EOS_TAG_ID]
end_scores = alphas + last_transition.unsqueeze(0)
# get the final most probable score and the final most probable tag
max_final_scores, max_final_tags = torch.max(end_scores, dim=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment