Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active March 29, 2019 00:39
Show Gist options
  • Save mtreviso/c69d6d9c78db3f212ecc4ad58bb0ab2a to your computer and use it in GitHub Desktop.
Save mtreviso/c69d6d9c78db3f212ecc4ad58bb0ab2a to your computer and use it in GitHub Desktop.
# find the best sequence of labels for each sample in the batch
best_sequences = []
emission_lengths = mask.int().sum(dim=1)
for i in range(batch_size):
# recover the original sentence length for the i-th sample in the batch
sample_length = emission_lengths[i].item()
# recover the max tag for the last timestep
sample_final_tag = max_final_tags[i].item()
# limit the backpointers until the last but one
# since the last corresponds to the sample_final_tag
sample_backpointers = backpointers[: sample_length - 1]
# follow the backpointers to build the sequence of labels
sample_path = self._find_best_path(i, sample_final_tag, sample_backpointers)
# add this path to the list of best sequences
best_sequences.append(sample_path)
return max_final_scores, best_sequences
def _find_best_path(self, sample_id, best_tag, backpointers):
"""Auxiliary function to find the best path sequence for a specific sample.
Args:
sample_id (int): sample index in the range [0, batch_size)
best_tag (int): tag which maximizes the final score
backpointers (list of lists of tensors): list of pointers with
shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i
represents the length of the ith sample in the batch
Returns:
list of ints: a list of tag indexes representing the bast path
"""
# add the final best_tag to our best path
best_path = [best_tag]
# traverse the backpointers in backwards
for backpointers_t in reversed(backpointers):
# recover the best_tag at this timestep
best_tag = backpointers_t[best_tag][sample_id].item()
# append to the beginning of the list so we don't need to reverse it later
best_path.insert(0, best_tag)
return best_path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment