Last active
October 25, 2021 12:31
-
-
Save maurapintor/25a6d80f9f86d36f72a4b2cc8540008f to your computer and use it in GitHub Desktop.
Loads the ImageNet dataset with torchvision, and stores a subset of the samples and labels on a target hdf5 file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import h5py as h5py | |
import numpy as np | |
import torch.utils.data | |
from torchvision import datasets, transforms | |
source_path = 'path-to-the-extracted-imagenet-dataset' | |
destination_path = 'destination-hdf5-file' | |
# place here your own transforms | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
]) | |
batch_size = 5 | |
num_batches = 10 | |
dataset = datasets.ImageFolder(source_path, transform=transform) | |
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
with h5py.File(destination_path, 'w') as hf: | |
for i, (samples, labels) in enumerate(loader): | |
if i == num_batches: | |
break | |
if i == 0: | |
# create dataset | |
hf.create_dataset('samples', dtype=np.float64, data=samples, chunks=True, maxshape=(None, 3, 224, 224)) | |
hf.create_dataset('labels', dtype=np.float64, data=labels, chunks=True, maxshape=(None,)) | |
else: | |
# append new data | |
hf['samples'].resize((hf['samples'].shape[0] + samples.shape[0]), axis=0) | |
hf['samples'][-samples.shape[0]:] = samples | |
hf['labels'].resize((hf['labels'].shape[0] + labels.shape[0]), axis=0) | |
hf['labels'][-labels.shape[0]:] = labels | |
hf.flush() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment