Skip to content

Instantly share code, notes, and snippets.

@EniasCailliau
Created August 23, 2018 10:18
Show Gist options
  • Save EniasCailliau/8e5d89ae13eb1491234f948a3387630b to your computer and use it in GitHub Desktop.
Save EniasCailliau/8e5d89ae13eb1491234f948a3387630b to your computer and use it in GitHub Desktop.
"""
Converts MNIST data to TFRecords file format
"""
import os
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def convert_mnist_fashion_dataset(images, labels, name, directory):
_, height, width = images.shape
filename = os.path.join(directory, name + '.tfrecords')
print(f'Writing {filename}')
with tf.python_io.TFRecordWriter(filename) as writer:
for index in range(len(images)):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(1),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment