Skip to content

Instantly share code, notes, and snippets.

@wolterlw
Created June 13, 2018 12:47
Show Gist options
  • Save wolterlw/34f6542bdfcaedb080afab069b73df61 to your computer and use it in GitHub Desktop.
Save wolterlw/34f6542bdfcaedb080afab069b73df61 to your computer and use it in GitHub Desktop.
PyTorch dataset template
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# TODO
# 1. Initialize file paths or a list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
# You can then use the prebuilt data loader.
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
batch_size=64,
shuffle=True)
for images, labels in train_loader:
# Training code should be written here.
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment