Created
June 13, 2018 12:47
-
-
Save wolterlw/34f6542bdfcaedb080afab069b73df61 to your computer and use it in GitHub Desktop.
PyTorch dataset template
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomDataset(torch.utils.data.Dataset): | |
def __init__(self): | |
# TODO | |
# 1. Initialize file paths or a list of file names. | |
pass | |
def __getitem__(self, index): | |
# TODO | |
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). | |
# 2. Preprocess the data (e.g. torchvision.Transform). | |
# 3. Return a data pair (e.g. image and label). | |
pass | |
def __len__(self): | |
# You should change 0 to the total size of your dataset. | |
return 0 | |
# You can then use the prebuilt data loader. | |
custom_dataset = CustomDataset() | |
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset, | |
batch_size=64, | |
shuffle=True) | |
for images, labels in train_loader: | |
# Training code should be written here. | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment