Skip to content

Instantly share code, notes, and snippets.

@Delaunay
Created December 21, 2018 16:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Delaunay/c0ac3322e3ec5ec3760a0f9edd97b6b5 to your computer and use it in GitHub Desktop.
Save Delaunay/c0ac3322e3ec5ec3760a0f9edd97b6b5 to your computer and use it in GitHub Desktop.
HDF5 POC
import torchvision
import torchvision.transforms as transforms
import time
import torch
import os
import h5py
import numpy as np
from MixedPrecision.tools.stats import StatStream
def preprocess_to_hdf5(transform, input_folder: str, output_file: str):
train_dataset = torchvision.datasets.ImageFolder(
input_folder,
transform)
output = h5py.File(output_file, 'w', libver='latest')
# >>>>>>
# Stores an Array of String representing Index -> class
classes = output.create_dataset('classes', (1000,), dtype='S9')
cls = list(train_dataset.class_to_idx.items())
cls.sort(key=lambda x: x[1])
for (key, index) in cls:
classes[index] = np.string_(key)
# <<<<<<
n = len(train_dataset)
hdy = output.create_dataset('label', (n,), dtype=np.uint8)
hdx = output.create_dataset(
'data',
(n, 3, 256, 256),
dtype=np.uint8,
chunks=(1, 3, 256, 256), # Chunk Per sample for fast retrieval
# compression prevent parallel readers somehow
# compression='lzf'
)
load_time = StatStream(10)
save_time = StatStream(10)
start = time.time()
print('Converting...')
for index, (x, y) in enumerate(train_dataset):
end = time.time()
load_time += end - start
s = time.time()
# convert to uint8
x = np.array(x, dtype=np.uint8)
hdy[index] = y
hdx[index] = np.moveaxis(x, -1, 0)
e = time.time()
save_time += e - s
if index % 100 == 0 and load_time.avg > 0:
print('{:.4f} % Load[avg: {:.4f} img/s sd: {:.4f}] Save[avg: {:.4f} img/s sd: {:.4f}]'.format(
index * 100 / n, 1 / load_time.avg, load_time.sd, 1 / save_time.avg, save_time.sd))
start = time.time()
output.close()
print('{:.4f} img/s'.format(1 / load_time.avg))
class HDF5Dataset(torch.utils.data.Dataset):
def __init__(self, file_name: str, transform=None, target_transform=None):
self.file = h5py.File(file_name, 'r', libver='latest', swmr=True)
self.transform = transform
self.target_transform = target_transform
self.labels = self.file['label']
self.samples = self.file['data']
self.size = len(self.file['label'])
def __getitem__(self, index):
self.samples.refresh()
sample = self.samples[index]
sample = sample.astype(np.uint8)
if self.transform is not None:
sample = self.transform(sample)
target = self.labels[index]
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return self.size
def __del__(self):
self.file.close()
def hdf5_loader(args, train=True):
from MixedPrecision.tools.hdf5 import HDF5Dataset
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
data_transforms = transforms.Compose([
# data is stored as uint8
transforms.ToPILImage(),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
train_dataset = HDF5Dataset(
args.data,
data_transforms)
return torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment