Skip to content

Instantly share code, notes, and snippets.

View tubali12345's full-sized avatar

Bálint Turi tubali12345

View GitHub Profile
import torch
import torch.nn as nn
import torch.nn.functional as F
def lengths_to_mask(lengths, max_len=None, dtype=None):
"""
Converts a "lengths" tensor to its binary mask representation.
Based on: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397