Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Created February 24, 2019 23:01
Show Gist options
  • Save mtreviso/e14eeb29e8e1748185f46cda6dbcf645 to your computer and use it in GitHub Desktop.
Save mtreviso/e14eeb29e8e1748185f46cda6dbcf645 to your computer and use it in GitHub Desktop.
def _compute_log_partition(self, emissions, mask):
"""Compute the partition function in log-space using the forward-algorithm.
Args:
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
mask (Torch.FloatTensor): (batch_size, seq_len)
Returns:
torch.Tensor: the partition scores for each batch.
Shape of (batch_size,)
"""
batch_size, seq_length, nb_labels = emissions.shape
# in the first iteration, BOS will have all the scores
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0]
for i in range(1, seq_length):
alpha_t = []
for tag in range(nb_labels):
# get the emission for the current tag
e_scores = emissions[:, i, tag]
# broadcast emission to all labels
# since it will be the same for all previous tags
# (bs, nb_labels)
e_scores = e_scores.unsqueeze(1)
# transitions from something to our tag
t_scores = self.transitions[:, tag]
# broadcast the transition scores to all batches
# (bs, nb_labels)
t_scores = t_scores.unsqueeze(0)
# combine current scores with previous alphas
# since alphas are in log space (see logsumexp below),
# we add them instead of multiplying
scores = e_scores + t_scores + alphas
# add the new alphas for the current tag
alpha_t.append(torch.logsumexp(scores, dim=1))
# create a torch matrix from alpha_t
# (bs, nb_labels)
new_alphas = torch.stack(alpha_t).t()
# set alphas if the mask is valid, otherwise keep the current values
is_valid = mask[:, i].unsqueeze(-1)
alphas = is_valid * new_alphas + (1 - is_valid) * alphas
# add the scores for the final transition
last_transition = self.transitions[:, self.EOS_TAG_ID]
end_scores = alphas + last_transition.unsqueeze(0)
# return a *log* of sums of exps
return torch.logsumexp(end_scores, dim=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment