Skip to content

Instantly share code, notes, and snippets.

@harsh-99
Created May 6, 2021 13:47
Show Gist options
  • Save harsh-99/1e2b815687e54dcf4416fc601aefd58e to your computer and use it in GitHub Desktop.
Save harsh-99/1e2b815687e54dcf4416fc601aefd58e to your computer and use it in GitHub Desktop.
def collate_fn(data):
'''
We should build a custom collate_fn rather than using default collate_fn,
as the size of every sentence is different and merging sequences (including padding)
is not supported in default.
Args:
data: list of tuple (training sequence, label)
Return:
padded_seq - Padded Sequence, tensor of shape (batch_size, padded_length)
length - Original length of each sequence(without padding), tensor of shape(batch_size)
label - tensor of shape (batch_size)
'''
#sorting is important for usage pack padded sequence (used in model). It should be in decreasing order.
data.sort(key=lambda x: len(x[0]), reverse=True)
sequences, label = zip(*data)
length = [len(seq) for seq in sequences]
padded_seq = torch.zeros(len(sequences), max(length)).long()
for i, seq in enumerate(sequences):
end = length[i]
padded_seq[i,:end] = seq
return padded_seq, torch.from_numpy(np.array(length)), torch.from_numpy(np.array(label))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment