NOTE: This is a question I found on StackOverflow which I’ve archived here, because the answer is so effing phenomenal.
If you are not into long explanations, see [Paolo Bergantino’s answer][2].
NOTE: This is a question I found on StackOverflow which I’ve archived here, because the answer is so effing phenomenal.
If you are not into long explanations, see [Paolo Bergantino’s answer][2].
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def lengths_to_mask(lengths, max_len=None, dtype=None): | |
""" | |
Converts a "lengths" tensor to its binary mask representation. | |
Based on: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397 |