Skip to content

Instantly share code, notes, and snippets.

@gokart23
Last active April 17, 2020 02:27
Show Gist options
  • Save gokart23/a46201960f0eaabed4df6857519342bc to your computer and use it in GitHub Desktop.
Save gokart23/a46201960f0eaabed4df6857519342bc to your computer and use it in GitHub Desktop.
Mask each token out in a batch of sentences of variable length (PyTorch)
def mask_all_out(self, x, lengths):
"""
Returns tensor with all tokens masked out one at a time, with respect for seq lengths
Inputs:
x : (slen, bsz) LongTensor
lengths : (bsz) LongTensor
Returns:
flat_x: (slen, bsz*slen) LongTensor with `slen` tokens inserted into each sentence and packed
_x_real: (num_tokens_in_batch) LongTensor with token IDs of masked tokens (non-pad)
lengths: (bsz*slen) LongTensor with (bsz) repetitions of each slen
pred_mask: (slen, bsz*slen) ByteTensor with True in each position that corresponds to a token and a pred value
"""
params = self.params
slen, bsz = x.size()
# seq_mask : bsz x word_token
seq_mask = (torch.arange(slen).repeat(bsz, 1).to(lengths.device) < lengths.unsqueeze(1))
# x : bsz x slen
x = x.transpose(0, 1)
# repeated_x : mask_token x bsz x word_token
repeated_x = x.repeat(slen, 1, 1)
# mask_token x bsz x word_token -> mask_token x word_token x bsz
repeated_x = repeated_x.permute(0, 2, 1)
# mask_selection : mask_token x word_token
mask_selection = torch.eye(slen, dtype=bool).to(repeated_x.device)
# _x_real : bsz x word_token
_x_real = repeated_x[mask_selection].transpose(0, 1)
# _x_real : num_non_pad_tokens
_x_real = _x_real[seq_mask]
# overwrite with mask
repeated_x[mask_selection] = params.mask_index
# mask_token x word_token x bsz -> bsz x mask_token x word_token
repeated_x = repeated_x.permute(2, 0, 1)
# bsz x mask_token x word_token -> (bsz * mask_token) x word_token
flat_x = repeated_x.reshape(-1, slen)
# non_padded_mask : (bsz * mask_token) x word_token
non_padded_mask = mask_selection.repeat(bsz, 1, 1).reshape(-1, slen)
# pad_mask : (bsz*mask_token) x word_token
pad_mask = seq_mask.repeat(slen, 1, 1).permute(1, 0, 2).reshape(-1, slen)
# pred_mask : (bsz * mask_token) x word_token
pred_mask = (non_padded_mask & pad_mask)
# word_token x (bsz * mask_token)
flat_x, pred_mask = flat_x.T, pred_mask.transpose(0, 1).byte()
lengths = lengths.unsqueeze(dim=1).repeat(1, slen).reshape(-1)
return flat_x, _x_real, lengths, pred_mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment