Skip to content

Instantly share code, notes, and snippets.

@formigone
Created December 21, 2017 06:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save formigone/439a05d878222e104a17856b46dfbdf5 to your computer and use it in GitHub Desktop.
Save formigone/439a05d878222e104a17856b46dfbdf5 to your computer and use it in GitHub Desktop.
Creating and using TensorFlow TFRecords
def list_to_tfrecord(list, tfrecord_filename):
"""
Convert a list of (features, labels) to a TFRecord file.
param list: a list of tuples with (feature, label)
"""
with python_io.TFRecordWriter(tfrecord_filename) as writer:
for feature, label in list:
example = tf.train.Example()
example.features.feature['x'].float_list.value.extend(features)
example.features.feature['y'].int64_list.value.append(label)
writer.write(example.SerializeToString())
def gen_input_fn(tfrecord, epochs=1, batch_size=16, buffer_size=64, feature_shape=(299 * 299,), label_shape=()):
"""
Return an input_fn that uses TFRecords for use with TensorFlow's estimator API
"""
def parse(example):
features = {
'x': tf.FixedLenFeature(feature_shape, tf.float32),
'y': tf.FixedLenFeature(label_shape, tf.int64),
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features['x'], parsed_features['y']
def input_fn():
dataset = tf.contrib.data.TFRecordDataset(['train_aug_12.tfrecords'])
dataset = dataset.map(parse)
if buffer_size > 0:
dataset = dataset.shuffle(buffer_size)
dataset = dataset.repeat(epochs)
dataset = dataset.batch(batch_size)
features, label = dataset.make_one_shot_iterator().get_next()
return features, label
return input_fn
@Auth0rM0rgan
Copy link

There are some errors in your code like python_io,example_proto,...
Also, please add a readme file to understand how to use your python file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment