Last active
March 29, 2019 00:39
-
-
Save mtreviso/c69d6d9c78db3f212ecc4ad58bb0ab2a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# find the best sequence of labels for each sample in the batch | |
best_sequences = [] | |
emission_lengths = mask.int().sum(dim=1) | |
for i in range(batch_size): | |
# recover the original sentence length for the i-th sample in the batch | |
sample_length = emission_lengths[i].item() | |
# recover the max tag for the last timestep | |
sample_final_tag = max_final_tags[i].item() | |
# limit the backpointers until the last but one | |
# since the last corresponds to the sample_final_tag | |
sample_backpointers = backpointers[: sample_length - 1] | |
# follow the backpointers to build the sequence of labels | |
sample_path = self._find_best_path(i, sample_final_tag, sample_backpointers) | |
# add this path to the list of best sequences | |
best_sequences.append(sample_path) | |
return max_final_scores, best_sequences | |
def _find_best_path(self, sample_id, best_tag, backpointers): | |
"""Auxiliary function to find the best path sequence for a specific sample. | |
Args: | |
sample_id (int): sample index in the range [0, batch_size) | |
best_tag (int): tag which maximizes the final score | |
backpointers (list of lists of tensors): list of pointers with | |
shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i | |
represents the length of the ith sample in the batch | |
Returns: | |
list of ints: a list of tag indexes representing the bast path | |
""" | |
# add the final best_tag to our best path | |
best_path = [best_tag] | |
# traverse the backpointers in backwards | |
for backpointers_t in reversed(backpointers): | |
# recover the best_tag at this timestep | |
best_tag = backpointers_t[best_tag][sample_id].item() | |
# append to the beginning of the list so we don't need to reverse it later | |
best_path.insert(0, best_tag) | |
return best_path |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment