Created
December 16, 2018 17:08
-
-
Save foowaa/6bc8f600e1f60e7f5e04815bda6c7d56 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def collate_fn(data): | |
""" | |
default_collate比较好地实现了对图片等的操作,但是并不支持对文字等不等长序列的操作 | |
dataloader的一个参数 | |
输入data是list of (x, y), list的长度是batch_size | |
返回xs, ys, lens | |
(batch_size, x或y的size)的tensor | |
""" | |
# Sort a data list | |
data.sort(key=lambda x: len(x[1]), reverse=True) | |
xs, ys = zip(*data) | |
# Merge xs. | |
xs = torch.stack(xs, 0) | |
# Merge ys. | |
lens = [len(e) for e in ys] | |
ys = torch.zeros(len(ys), max(lens)).long() | |
# padding | |
for i, e in enumerate(ys): | |
end = lens[i] | |
ys[i, :end] = ys[:end] | |
return xs, ys, lens |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment