Created
February 24, 2019 23:01
-
-
Save mtreviso/e14eeb29e8e1748185f46cda6dbcf645 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _compute_log_partition(self, emissions, mask): | |
"""Compute the partition function in log-space using the forward-algorithm. | |
Args: | |
emissions (torch.Tensor): (batch_size, seq_len, nb_labels) | |
mask (Torch.FloatTensor): (batch_size, seq_len) | |
Returns: | |
torch.Tensor: the partition scores for each batch. | |
Shape of (batch_size,) | |
""" | |
batch_size, seq_length, nb_labels = emissions.shape | |
# in the first iteration, BOS will have all the scores | |
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0] | |
for i in range(1, seq_length): | |
alpha_t = [] | |
for tag in range(nb_labels): | |
# get the emission for the current tag | |
e_scores = emissions[:, i, tag] | |
# broadcast emission to all labels | |
# since it will be the same for all previous tags | |
# (bs, nb_labels) | |
e_scores = e_scores.unsqueeze(1) | |
# transitions from something to our tag | |
t_scores = self.transitions[:, tag] | |
# broadcast the transition scores to all batches | |
# (bs, nb_labels) | |
t_scores = t_scores.unsqueeze(0) | |
# combine current scores with previous alphas | |
# since alphas are in log space (see logsumexp below), | |
# we add them instead of multiplying | |
scores = e_scores + t_scores + alphas | |
# add the new alphas for the current tag | |
alpha_t.append(torch.logsumexp(scores, dim=1)) | |
# create a torch matrix from alpha_t | |
# (bs, nb_labels) | |
new_alphas = torch.stack(alpha_t).t() | |
# set alphas if the mask is valid, otherwise keep the current values | |
is_valid = mask[:, i].unsqueeze(-1) | |
alphas = is_valid * new_alphas + (1 - is_valid) * alphas | |
# add the scores for the final transition | |
last_transition = self.transitions[:, self.EOS_TAG_ID] | |
end_scores = alphas + last_transition.unsqueeze(0) | |
# return a *log* of sums of exps | |
return torch.logsumexp(end_scores, dim=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment