Skip to content

Instantly share code, notes, and snippets.

@subhadarship
Created February 27, 2020 04:19
Show Gist options
  • Save subhadarship/e5a60bd3ef7ef845348325bfb4d9ddc1 to your computer and use it in GitHub Desktop.
Save subhadarship/e5a60bd3ef7ef845348325bfb4d9ddc1 to your computer and use it in GitHub Desktop.
collate_fn for PyTorch DataLoader
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class MyDataset(Dataset):
def __init__(self):
x = np.random.rand(1000, 3) # 1000 3-dim samples
self.x = [x[i].tolist() for i in range(1000)]
y = np.random.randint(low=0, high=2, size=(1000,))
self.y = [y[i] for i in range(1000)]
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def collate_fn(batch):
data_list, label_list = [], []
for _data, _label in batch:
data_list.append(_data)
label_list.append(_label)
return torch.Tensor(data_list), torch.LongTensor(label_list)
if __name__ == "__main__":
dataset = MyDataset()
print(len(dataset))
print(dataset[-1])
dataloader = DataLoader(dataset, batch_size=3, shuffle=False, collate_fn=collate_fn)
for data, label in dataloader:
print(type(data))
print(data)
print(type(label))
print(label)
break
@subhadarship
Copy link
Author

subhadarship commented Feb 27, 2020

explicit definition of collate_fn is not required if self.x and self.y are numpy arrays already. collate_fn=None will create two tensors for each batch.

@alexzhwqc
Copy link

Thank you very much!
I am learning Pytorch and I don't know how to Dataloader 15002828 data input the model. Your example teache me how to do that. Thank you a thousand times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment