Skip to content

Instantly share code, notes, and snippets.

@maurapintor
Last active October 25, 2021 12:31
Show Gist options
  • Save maurapintor/25a6d80f9f86d36f72a4b2cc8540008f to your computer and use it in GitHub Desktop.
Save maurapintor/25a6d80f9f86d36f72a4b2cc8540008f to your computer and use it in GitHub Desktop.
Loads the ImageNet dataset with torchvision, and stores a subset of the samples and labels on a target hdf5 file.
import h5py as h5py
import numpy as np
import torch.utils.data
from torchvision import datasets, transforms
source_path = 'path-to-the-extracted-imagenet-dataset'
destination_path = 'destination-hdf5-file'
# place here your own transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
batch_size = 5
num_batches = 10
dataset = datasets.ImageFolder(source_path, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
with h5py.File(destination_path, 'w') as hf:
for i, (samples, labels) in enumerate(loader):
if i == num_batches:
break
if i == 0:
# create dataset
hf.create_dataset('samples', dtype=np.float64, data=samples, chunks=True, maxshape=(None, 3, 224, 224))
hf.create_dataset('labels', dtype=np.float64, data=labels, chunks=True, maxshape=(None,))
else:
# append new data
hf['samples'].resize((hf['samples'].shape[0] + samples.shape[0]), axis=0)
hf['samples'][-samples.shape[0]:] = samples
hf['labels'].resize((hf['labels'].shape[0] + labels.shape[0]), axis=0)
hf['labels'][-labels.shape[0]:] = labels
hf.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment