Last active
April 17, 2020 02:27
-
-
Save gokart23/a46201960f0eaabed4df6857519342bc to your computer and use it in GitHub Desktop.
Mask each token out in a batch of sentences of variable length (PyTorch)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def mask_all_out(self, x, lengths): | |
""" | |
Returns tensor with all tokens masked out one at a time, with respect for seq lengths | |
Inputs: | |
x : (slen, bsz) LongTensor | |
lengths : (bsz) LongTensor | |
Returns: | |
flat_x: (slen, bsz*slen) LongTensor with `slen` tokens inserted into each sentence and packed | |
_x_real: (num_tokens_in_batch) LongTensor with token IDs of masked tokens (non-pad) | |
lengths: (bsz*slen) LongTensor with (bsz) repetitions of each slen | |
pred_mask: (slen, bsz*slen) ByteTensor with True in each position that corresponds to a token and a pred value | |
""" | |
params = self.params | |
slen, bsz = x.size() | |
# seq_mask : bsz x word_token | |
seq_mask = (torch.arange(slen).repeat(bsz, 1).to(lengths.device) < lengths.unsqueeze(1)) | |
# x : bsz x slen | |
x = x.transpose(0, 1) | |
# repeated_x : mask_token x bsz x word_token | |
repeated_x = x.repeat(slen, 1, 1) | |
# mask_token x bsz x word_token -> mask_token x word_token x bsz | |
repeated_x = repeated_x.permute(0, 2, 1) | |
# mask_selection : mask_token x word_token | |
mask_selection = torch.eye(slen, dtype=bool).to(repeated_x.device) | |
# _x_real : bsz x word_token | |
_x_real = repeated_x[mask_selection].transpose(0, 1) | |
# _x_real : num_non_pad_tokens | |
_x_real = _x_real[seq_mask] | |
# overwrite with mask | |
repeated_x[mask_selection] = params.mask_index | |
# mask_token x word_token x bsz -> bsz x mask_token x word_token | |
repeated_x = repeated_x.permute(2, 0, 1) | |
# bsz x mask_token x word_token -> (bsz * mask_token) x word_token | |
flat_x = repeated_x.reshape(-1, slen) | |
# non_padded_mask : (bsz * mask_token) x word_token | |
non_padded_mask = mask_selection.repeat(bsz, 1, 1).reshape(-1, slen) | |
# pad_mask : (bsz*mask_token) x word_token | |
pad_mask = seq_mask.repeat(slen, 1, 1).permute(1, 0, 2).reshape(-1, slen) | |
# pred_mask : (bsz * mask_token) x word_token | |
pred_mask = (non_padded_mask & pad_mask) | |
# word_token x (bsz * mask_token) | |
flat_x, pred_mask = flat_x.T, pred_mask.transpose(0, 1).byte() | |
lengths = lengths.unsqueeze(dim=1).repeat(1, slen).reshape(-1) | |
return flat_x, _x_real, lengths, pred_mask |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment