Created
December 21, 2018 16:03
-
-
Save Delaunay/c0ac3322e3ec5ec3760a0f9edd97b6b5 to your computer and use it in GitHub Desktop.
HDF5 POC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
import torchvision.transforms as transforms | |
import time | |
import torch | |
import os | |
import h5py | |
import numpy as np | |
from MixedPrecision.tools.stats import StatStream | |
def preprocess_to_hdf5(transform, input_folder: str, output_file: str): | |
train_dataset = torchvision.datasets.ImageFolder( | |
input_folder, | |
transform) | |
output = h5py.File(output_file, 'w', libver='latest') | |
# >>>>>> | |
# Stores an Array of String representing Index -> class | |
classes = output.create_dataset('classes', (1000,), dtype='S9') | |
cls = list(train_dataset.class_to_idx.items()) | |
cls.sort(key=lambda x: x[1]) | |
for (key, index) in cls: | |
classes[index] = np.string_(key) | |
# <<<<<< | |
n = len(train_dataset) | |
hdy = output.create_dataset('label', (n,), dtype=np.uint8) | |
hdx = output.create_dataset( | |
'data', | |
(n, 3, 256, 256), | |
dtype=np.uint8, | |
chunks=(1, 3, 256, 256), # Chunk Per sample for fast retrieval | |
# compression prevent parallel readers somehow | |
# compression='lzf' | |
) | |
load_time = StatStream(10) | |
save_time = StatStream(10) | |
start = time.time() | |
print('Converting...') | |
for index, (x, y) in enumerate(train_dataset): | |
end = time.time() | |
load_time += end - start | |
s = time.time() | |
# convert to uint8 | |
x = np.array(x, dtype=np.uint8) | |
hdy[index] = y | |
hdx[index] = np.moveaxis(x, -1, 0) | |
e = time.time() | |
save_time += e - s | |
if index % 100 == 0 and load_time.avg > 0: | |
print('{:.4f} % Load[avg: {:.4f} img/s sd: {:.4f}] Save[avg: {:.4f} img/s sd: {:.4f}]'.format( | |
index * 100 / n, 1 / load_time.avg, load_time.sd, 1 / save_time.avg, save_time.sd)) | |
start = time.time() | |
output.close() | |
print('{:.4f} img/s'.format(1 / load_time.avg)) | |
class HDF5Dataset(torch.utils.data.Dataset): | |
def __init__(self, file_name: str, transform=None, target_transform=None): | |
self.file = h5py.File(file_name, 'r', libver='latest', swmr=True) | |
self.transform = transform | |
self.target_transform = target_transform | |
self.labels = self.file['label'] | |
self.samples = self.file['data'] | |
self.size = len(self.file['label']) | |
def __getitem__(self, index): | |
self.samples.refresh() | |
sample = self.samples[index] | |
sample = sample.astype(np.uint8) | |
if self.transform is not None: | |
sample = self.transform(sample) | |
target = self.labels[index] | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return sample, target | |
def __len__(self): | |
return self.size | |
def __del__(self): | |
self.file.close() | |
def hdf5_loader(args, train=True): | |
from MixedPrecision.tools.hdf5 import HDF5Dataset | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
data_transforms = transforms.Compose([ | |
# data is stored as uint8 | |
transforms.ToPILImage(), | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize | |
]) | |
train_dataset = HDF5Dataset( | |
args.data, | |
data_transforms) | |
return torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
num_workers=args.workers, | |
pin_memory=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment