Skip to content

Instantly share code, notes, and snippets.

@skaae
Created August 29, 2019 15:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skaae/b5c8d63eb031aa14a6daf5ae6cd6c9b0 to your computer and use it in GitHub Desktop.
Save skaae/b5c8d63eb031aa14a6daf5ae6cd6c9b0 to your computer and use it in GitHub Desktop.
import tensorflow as tf
data_arr = [
{
"img": np.random.randn(10, 30)
},
{
"img": np.random.randn(10, 30)
}
]
def get_example_object(data_record):
img_str = data_record["img"].flatten().astype("float32").tostring()
feature_key_value_pair = {
"img": tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_str])),
"img_shp": tf.train.Feature(int64_list = tf.train.Int64List(value = data_record["img"].shape)),
}
features = tf.train.Features(feature = feature_key_value_pair)
return tf.train.Example(features = features)
with tf.python_io.TFRecordWriter('example.tfrecord') as tfwriter:
# Iterate through all records
for data_record in data_arr:
example = get_example_object(data_record)
tfwriter.write(example.SerializeToString())
def extract_fn(data_record):
features = {
"img": tf.FixedLenFeature([], tf.string),
"img_shp": tf.FixedLenFeature([2], tf.int64)
}
sample = tf.parse_single_example(data_record, features)
sample["img"] = tf.reshape(tf.decode_raw(sample["img"], tf.float32), sample["img_shp"])
return sample
# Initialize all tfrecord paths
dataset = tf.data.TFRecordDataset(['example.tfrecord'])
dataset = dataset.map(extract_fn)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
data_record = sess.run(next_element)
print(data_record)
except:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment