Created
May 6, 2021 13:47
-
-
Save harsh-99/1e2b815687e54dcf4416fc601aefd58e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def collate_fn(data): | |
''' | |
We should build a custom collate_fn rather than using default collate_fn, | |
as the size of every sentence is different and merging sequences (including padding) | |
is not supported in default. | |
Args: | |
data: list of tuple (training sequence, label) | |
Return: | |
padded_seq - Padded Sequence, tensor of shape (batch_size, padded_length) | |
length - Original length of each sequence(without padding), tensor of shape(batch_size) | |
label - tensor of shape (batch_size) | |
''' | |
#sorting is important for usage pack padded sequence (used in model). It should be in decreasing order. | |
data.sort(key=lambda x: len(x[0]), reverse=True) | |
sequences, label = zip(*data) | |
length = [len(seq) for seq in sequences] | |
padded_seq = torch.zeros(len(sequences), max(length)).long() | |
for i, seq in enumerate(sequences): | |
end = length[i] | |
padded_seq[i,:end] = seq | |
return padded_seq, torch.from_numpy(np.array(length)), torch.from_numpy(np.array(label)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment