Last active
May 21, 2019 08:41
-
-
Save thomwolf/cd35d5238cd7617c97d005091be591b1 to your computer and use it in GitHub Desktop.
Multi-task losses computation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Let's add a distractor to our previously defined persona, history and reply | |
distractor = ["sorry", "to", "hear", "that"] | |
# Build & tokenize inputs ending with our distractor like we did with the gold reply | |
words_distractor, segments_distractor, _, _ = build_inputs(persona, history, distractor) | |
words_distractor = tokenizer.convert_tokens_to_ids(words_distractor) | |
segments_distractor = tokenizer.convert_tokens_to_ids(segments_distractor) | |
# Prepare our language modeling targets: keep only the reply segment, -1 on the rest | |
lm_targets = ([-1] * sum(len(s) for s in sequence[:-1])) \ | |
+ [-1] + tokenizer.convert_tokens_to_ids(sequence[-1][1:]) | |
lm_distractor = [-1] * len(words_distractor) | |
# Store the position of the last tokens for the next-sentence prediction loss | |
last_token = len(words) - 1 | |
last_token_distractor = len(words_distractor) - 1 | |
# Now we can pad reply and distractor inputs and targets to the same length | |
padding_length = max(len(words), len(words_distractor)) | |
def pad(x, padding): | |
return x + [padding] * (padding_length - len(x)) | |
(words, words_distractor, | |
segments, segments_distractor) = [pad(x, tokenizer.convert_tokens_to_ids('<pad>')) | |
for x in (words, words_distractor, | |
segments, segments_distractor)] | |
(lm_targets, lm_distractor) = [pad(x, -1) for x in (lm_targets, lm_distractor)] | |
# And gather reply and distractor inputs to build the input tensors: | |
# words tokens | |
input_ids = torch.tensor([[words, words_distractor]], dtype=torch.long) | |
# segment tokens | |
token_type_ids = torch.tensor([[segments, segments_distractor]], dtype=torch.long) | |
# Positions tokens can be automatically created by the model as (0, 1, ..., N) | |
# Last tokens location | |
mc_token_ids = torch.tensor([[last_token, last_token_distractor]], dtype=torch.long) | |
# Language modeling labels | |
lm_labels = torch.tensor([[lm_targets, lm_distractor]], dtype=torch.long) | |
# Next-sentence prediction labels | |
mc_labels = torch.tensor([0], dtype=torch.long) # Gold reply is 1st (index 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment