Skip to content

Instantly share code, notes, and snippets.

@Franklin-Yao
Created February 23, 2020 16:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Franklin-Yao/bf449cafb32334b300ee6b4ed1459019 to your computer and use it in GitHub Desktop.
Save Franklin-Yao/bf449cafb32334b300ee6b4ed1459019 to your computer and use it in GitHub Desktop.
save and read images with hdf5
import os.path as osp
dataset_dir = '/home/frankllin/Downloads/DomainNet'
split_dir = osp.join(dataset_dir, 'splits_mini')
image_size = 84
from PIL import Image
import h5py
import numpy as np
from tqdm import tqdm
def store_many_hdf5(h5file, images, labels):
""" Stores an array of images to HDF5.
Parameters:
---------------
images images array, (N, 32, 32, 3) to be stored
labels labels array, (N, 1) to be stored
"""
num_images = len(images)
# Create a dataset in the file
dataset = h5file.create_dataset(
"images", np.shape(images), h5py.h5t.STD_U8BE, data=images
)
meta_set = h5file.create_dataset(
"labels", np.shape(labels), h5py.h5t.STD_U8BE, data=labels
)
def read_many_hdf5(h5file):
""" Reads image from HDF5.
Parameters:
---------------
num_images number of images to read
Returns:
----------
images images array, (N, 32, 32, 3) to be stored
labels associated meta data, int label (N, 1)
"""
images = np.array(h5file["images"]).astype("uint8")
labels = np.array(h5file["labels"]).astype("uint8")
return images, labels
def main():
# domains = ['clipart', 'painting', 'real', 'sketch']
domains = ['clipart']
h5file = h5py.File(osp.join(dataset_dir, 'miniDomainNet.h5'), 'w')
for domain in domains:
print('processing '+domain)
h5group = h5file.create_group(domain)
train_file_name = osp.join(split_dir,domain+'_train.txt')
test_file = osp.join(split_dir, domain + '_test.txt')
with open(train_file_name) as train_file:
lines = train_file.read().splitlines()
images = np.zeros(shape=(len(lines), image_size, image_size, 3))
labels = np.zeros(shape=(len(lines), 1))
for i, line in tqdm(enumerate(lines)):
list = line.split(' ')
image_path = osp.join(dataset_dir,list[0])
image = Image.open(image_path)
new_image = image.resize((image_size, image_size))
new_image = np.array(new_image)
label = int(list[1])
images[i] = new_image
labels[i]= label
store_many_hdf5(h5group, images, labels)
h5file.close()
if __name__ == '__main__':
# main()
h5file = h5py.File(osp.join(dataset_dir, 'miniDomainNet.h5'), 'r')
h5_group = h5file['/clipart']
images,_ = read_many_hdf5(h5_group)
img = Image.fromarray(images[0], 'RGB')
Image._show(img)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment