-
-
Save Geoyi/0b4f304143b7480b2897f94cf3587a67 to your computer and use it in GitHub Desktop.
From numpy ndarray to tfrecords
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
__author__ = "Sangwoong Yoon" | |
def np_to_tfrecords(X, Y, file_path_prefix, verbose=True): | |
""" | |
Converts a Numpy array (or two Numpy arrays) into a tfrecord file. | |
For supervised learning, feed training inputs to X and training labels to Y. | |
For unsupervised learning, only feed training inputs to X, and feed None to Y. | |
The length of the first dimensions of X and Y should be the number of samples. | |
Parameters | |
---------- | |
X : numpy.ndarray of rank 2 | |
Numpy array for training inputs. Its dtype should be float32, float64, or int64. | |
If X has a higher rank, it should be rshape before fed to this function. | |
Y : numpy.ndarray of rank 2 or None | |
Numpy array for training labels. Its dtype should be float32, float64, or int64. | |
None if there is no label array. | |
file_path_prefix : str | |
The path and name of the resulting tfrecord file to be generated, without '.tfrecords' | |
verbose : bool | |
If true, progress is reported. | |
Raises | |
------ | |
ValueError | |
If input type is not float (64 or 32) or int. | |
""" | |
def _dtype_feature(ndarray): | |
"""match appropriate tf.train.Feature class with dtype of ndarray. """ | |
assert isinstance(ndarray, np.ndarray) | |
dtype_ = ndarray.dtype | |
if dtype_ == np.float64 or dtype_ == np.float32: | |
return lambda array: tf.train.Feature(float_list=tf.train.FloatList(value=array)) | |
elif dtype_ == np.int64: | |
return lambda array: tf.train.Feature(int64_list=tf.train.Int64List(value=array)) | |
else: | |
raise ValueError("The input should be numpy ndarray. \ | |
Instaed got {}".format(ndarray.dtype)) | |
assert isinstance(X, np.ndarray) | |
assert len(X.shape) == 2 # If X has a higher rank, | |
# it should be rshape before fed to this function. | |
assert isinstance(Y, np.ndarray) or Y is None | |
# load appropriate tf.train.Feature class depending on dtype | |
dtype_feature_x = _dtype_feature(X) | |
if Y is not None: | |
assert X.shape[0] == Y.shape[0] | |
assert len(Y.shape) == 2 | |
dtype_feature_y = _dtype_feature(Y) | |
# Generate tfrecord writer | |
result_tf_file = file_path_prefix + '.tfrecords' | |
writer = tf.python_io.TFRecordWriter(result_tf_file) | |
if verbose: | |
print "Serializing {:d} examples into {}".format(X.shape[0], result_tf_file) | |
# iterate over each sample, | |
# and serialize it as ProtoBuf. | |
for idx in range(X.shape[0]): | |
x = X[idx] | |
if Y is not None: | |
y = Y[idx] | |
d_feature = {} | |
d_feature['X'] = dtype_feature_x(x) | |
if Y is not None: | |
d_feature['Y'] = dtype_feature_y(y) | |
features = tf.train.Features(feature=d_feature) | |
example = tf.train.Example(features=features) | |
serialized = example.SerializeToString() | |
writer.write(serialized) | |
if verbose: | |
print "Writing {} done!".format(result_tf_file) | |
################################# | |
## Test and Use Cases ## | |
################################# | |
# 1-1. Saving a dataset with input and label (supervised learning) | |
xx = np.random.randn(10,5) | |
yy = np.random.randn(10,1) | |
np_to_tfrecords(xx, yy, 'test1', verbose=True) | |
# 1-2. Check if the data is stored correctly | |
# open the saved file and check the first entries | |
for serialized_example in tf.python_io.tf_record_iterator('test1.tfrecords'): | |
example = tf.train.Example() | |
example.ParseFromString(serialized_example) | |
x_1 = np.array(example.features.feature['X'].float_list.value) | |
y_1 = np.array(example.features.feature['Y'].float_list.value) | |
break | |
# the numbers may be slightly different because of the floating point error. | |
print xx[0] | |
print x_1 | |
print yy[0] | |
print y_1 | |
# 2. Saving a dataset with only inputs (unsupervised learning) | |
xx = np.random.randn(100,100) | |
np_to_tfrecords(xx, None, 'test2', verbose=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment