Skip to content

Instantly share code, notes, and snippets.

Last active November 30, 2023 15:42
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gnperdue/b905a9c2dd4c08b53e0539d6aa3d3dc6 to your computer and use it in GitHub Desktop.
Save gnperdue/b905a9c2dd4c08b53e0539d6aa3d3dc6 to your computer and use it in GitHub Desktop.
TensorFlow Dataset `from_generator` reading HDF5
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
import os
import h5py
import numpy as np
import tensorflow as tf
# Get path to data
# find HDF5 here: wget
TFILE = os.path.join(
os.environ['HOME'], 'Dropbox/Data/RandomData/hdf5/fashion_test.hdf5'
class FashionHDF5Reader(object):
def __init__(self, hdf5_file):
self._file = hdf5_file
self._f = None
self._nlabels = 10
self.class_names = [
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
def openf(self):
self._f = h5py.File(self._file, 'r')
self._nevents = self._f['fashion/labels'].shape[0]
return self._nevents
def closef(self):
except AttributeError:
print('hdf5 file is not open yet.')
def get_examples(self, start_idx, stop_idx):
image = self._f['fashion/images'][start_idx: stop_idx]
image = np.moveaxis(image, 1, -1)
label = self._f['fashion/labels'][start_idx: stop_idx].reshape([-1])
oh_label = np.zeros((label.size, self._nlabels), dtype=np.uint8)
oh_label[np.arange(label.size), label] = 1
return image, oh_label
def _make_fashion_generator_fn(file_name, batch_size):
make a generator function that we can query for batches
reader = FashionHDF5Reader(file_name)
nevents = reader.openf()
def example_generator_fn():
start_idx, stop_idx = 0, batch_size
while True:
if start_idx >= nevents:
yield reader.get_examples(start_idx, stop_idx)
start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size
return example_generator_fn
def make_fashion_dset(file_name, batch_size, shuffle=False):
dgen = _make_fashion_generator_fn(file_name, batch_size)
features_shape = [batch_size, 28, 28, 1]
labels_shape = [batch_size, 10]
ds =
dgen, (tf.float32, tf.uint8),
(tf.TensorShape(features_shape), tf.TensorShape(labels_shape))
# we are grabbing an entire "batch", so don't call `batch()`, etc.
ds = ds.prefetch(10)
if shuffle:
ds = ds.shuffle(10)
return ds
def make_fashion_iterators(file_name, batch_size, shuffle=False):
ds = make_fashion_dset(file_name, batch_size, shuffle)
itrtr = ds.make_one_shot_iterator()
images, labels = itrtr.get_next()
return images, labels
images, labels = make_fashion_iterators(TFILE, 11)
with tf.Session() as sess:
total_batches = 0
total_examples = 0
while True:
im, ls =[images, labels])
print('{}, {}, {}, {}'.format(
im.shape, im.dtype, ls.shape, ls.dtype
total_batches += 1
total_examples += ls.shape[0]
except tf.errors.OutOfRangeError:
print('end of dataset at total_batches={}'.format(
except Exception as e:
Copy link

okurman commented Nov 22, 2023

This is a great sample code that I am stealing. Thanks!
I have a small question, did you get a chance to check the speed when you are indexing like this as a batch from the matrix vs. yielding each element and use "dataset.batch()" instead?
I wonder how it will affect the cache, interleave combinations.

Copy link

No, I didn't really do any performance optimization or checking here.

Copy link

This is pretty old... TF1.x code by the look of it (tf.Session()). I have some code (also very old, LOL) doing this in a more TF2 way here, I think - although it is not a nice, self-contained example (mixed in with a "real" project).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment