Skip to content

Instantly share code, notes, and snippets.

@foowaa
Created December 16, 2018 17:08
Show Gist options
  • Save foowaa/6bc8f600e1f60e7f5e04815bda6c7d56 to your computer and use it in GitHub Desktop.
Save foowaa/6bc8f600e1f60e7f5e04815bda6c7d56 to your computer and use it in GitHub Desktop.
def collate_fn(data):
"""
default_collate比较好地实现了对图片等的操作,但是并不支持对文字等不等长序列的操作
dataloader的一个参数
输入data是list of (x, y), list的长度是batch_size
返回xs, ys, lens
(batch_size, x或y的size)的tensor
"""
# Sort a data list
data.sort(key=lambda x: len(x[1]), reverse=True)
xs, ys = zip(*data)
# Merge xs.
xs = torch.stack(xs, 0)
# Merge ys.
lens = [len(e) for e in ys]
ys = torch.zeros(len(ys), max(lens)).long()
# padding
for i, e in enumerate(ys):
end = lens[i]
ys[i, :end] = ys[:end]
return xs, ys, lens
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment