Skip to content

Instantly share code, notes, and snippets.

@swyoon
Last active November 29, 2022 06:39
Show Gist options
  • Save swyoon/8185b3dcf08ec728fb22b99016dd533f to your computer and use it in GitHub Desktop.
Save swyoon/8185b3dcf08ec728fb22b99016dd533f to your computer and use it in GitHub Desktop.
From numpy ndarray to tfrecords
import numpy as np
import tensorflow as tf
__author__ = "Sangwoong Yoon"
def np_to_tfrecords(X, Y, file_path_prefix, verbose=True):
"""
Converts a Numpy array (or two Numpy arrays) into a tfrecord file.
For supervised learning, feed training inputs to X and training labels to Y.
For unsupervised learning, only feed training inputs to X, and feed None to Y.
The length of the first dimensions of X and Y should be the number of samples.
Parameters
----------
X : numpy.ndarray of rank 2
Numpy array for training inputs. Its dtype should be float32, float64, or int64.
If X has a higher rank, it should be rshape before fed to this function.
Y : numpy.ndarray of rank 2 or None
Numpy array for training labels. Its dtype should be float32, float64, or int64.
None if there is no label array.
file_path_prefix : str
The path and name of the resulting tfrecord file to be generated, without '.tfrecords'
verbose : bool
If true, progress is reported.
Raises
------
ValueError
If input type is not float (64 or 32) or int.
"""
def _dtype_feature(ndarray):
"""match appropriate tf.train.Feature class with dtype of ndarray. """
assert isinstance(ndarray, np.ndarray)
dtype_ = ndarray.dtype
if dtype_ == np.float64 or dtype_ == np.float32:
return lambda array: tf.train.Feature(float_list=tf.train.FloatList(value=array))
elif dtype_ == np.int64:
return lambda array: tf.train.Feature(int64_list=tf.train.Int64List(value=array))
else:
raise ValueError("The input should be numpy ndarray. \
Instaed got {}".format(ndarray.dtype))
assert isinstance(X, np.ndarray)
assert len(X.shape) == 2 # If X has a higher rank,
# it should be rshape before fed to this function.
assert isinstance(Y, np.ndarray) or Y is None
# load appropriate tf.train.Feature class depending on dtype
dtype_feature_x = _dtype_feature(X)
if Y is not None:
assert X.shape[0] == Y.shape[0]
assert len(Y.shape) == 2
dtype_feature_y = _dtype_feature(Y)
# Generate tfrecord writer
result_tf_file = file_path_prefix + '.tfrecords'
writer = tf.python_io.TFRecordWriter(result_tf_file)
if verbose:
print "Serializing {:d} examples into {}".format(X.shape[0], result_tf_file)
# iterate over each sample,
# and serialize it as ProtoBuf.
for idx in range(X.shape[0]):
x = X[idx]
if Y is not None:
y = Y[idx]
d_feature = {}
d_feature['X'] = dtype_feature_x(x)
if Y is not None:
d_feature['Y'] = dtype_feature_y(y)
features = tf.train.Features(feature=d_feature)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
writer.write(serialized)
if verbose:
print "Writing {} done!".format(result_tf_file)
#################################
## Test and Use Cases ##
#################################
# 1-1. Saving a dataset with input and label (supervised learning)
xx = np.random.randn(10,5)
yy = np.random.randn(10,1)
np_to_tfrecords(xx, yy, 'test1', verbose=True)
# 1-2. Check if the data is stored correctly
# open the saved file and check the first entries
for serialized_example in tf.python_io.tf_record_iterator('test1.tfrecords'):
example = tf.train.Example()
example.ParseFromString(serialized_example)
x_1 = np.array(example.features.feature['X'].float_list.value)
y_1 = np.array(example.features.feature['Y'].float_list.value)
break
# the numbers may be slightly different because of the floating point error.
print xx[0]
print x_1
print yy[0]
print y_1
# 2. Saving a dataset with only inputs (unsupervised learning)
xx = np.random.randn(100,100)
np_to_tfrecords(xx, None, 'test2', verbose=True)
@yuan8421
Copy link

Neat! Thank you. Thx from a beginner.

@nairouz
Copy link

nairouz commented May 23, 2018

Thank you.

@filmo
Copy link

filmo commented Jun 13, 2018

Might want to extend _dtype_feature to recognize unit8 which is a common image datatype. Then pass this as a bytes_list.

tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

storing images in as uint8 will be 4x smaller than float32.

@chisnova
Copy link

chisnova commented Jul 5, 2018

Very good code for tfrecord tutorial thank you

@TyJK
Copy link

TyJK commented Aug 26, 2018

Thank you. I've been tearing my hair out trying to confirm that my data/labels were written to the record properly, and this made it so much easier.

@seismann
Copy link

Very useful code. Thank you!

@ypxie
Copy link

ypxie commented Dec 23, 2019

Very useful code. not sure why returning a anonymous function from _dtype_feature works?

@tchaye59
Copy link

tchaye59 commented May 4, 2020

Thank you very much. But your code is very slow for huge datasets I think this is because you write the records one by one. I don't know if it is possible to write the whole data using tf.train.SequenceExample?

@tchaye59
Copy link

If someone is looking for a faster solution check my Kaggle kernel

@TheCrazyT
Copy link

If someone is looking for a faster solution check my Kaggle kernel

@tchaye56
Well that notebook(version 17) has an error.
But I'm glad that someone is working on a faster version :-)

@tchaye59
Copy link

Wait for 10 min I just sent a new commit

@AnisIdowu1
Copy link

Thanks for this.

Please, after saving test1, how can I recover the whole array, i.e xx and not just x_1? Thanks in advance.

@apicquot
Copy link

What about just using
tensor = tf.convert_to_tensor(array) result = tf.io.serialize_tensor(tensor)

@patmorli
Copy link

patmorli commented Feb 4, 2021

How would 1-2 look like with the tf updates? tf_record_iterator does not work anymore, and I can't find out how to recover the file.

@patchy631
Copy link

why is it important for len(X.shape) == 2?

@songssssss
Copy link

thanks for your sharing! may i know what if i have X with higher dimenions? is that still possible to covert them to tfrecords? thanks!

@swyoon
Copy link
Author

swyoon commented Jan 22, 2022

It's been so long since I wrote this code and personally, I have moved to pytorch.
I have no idea if this would work for tensors with a higher rank.
My personal guess is it should probably work, if you comment out the assertions (L45, L53).
Could you try it? @songssssss

@songssssss
Copy link

thanks for your reply! i tried to use tf.data.Dataset instead and it solved my memory issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment