Skip to content

Instantly share code, notes, and snippets.

@rreece
Created May 12, 2021 18:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rreece/cfeda8ddb1dce3a2c1fdb7096d59adb9 to your computer and use it in GitHub Desktop.
Save rreece/cfeda8ddb1dce3a2c1fdb7096d59adb9 to your computer and use it in GitHub Desktop.
Writes float16 data to a tfrecord as raw bytes and reads it back.
"""
Writes float16 data to a tfrecord as raw bytes and reads it back.
Based on:
https://stackoverflow.com/questions/40184812/tensorflow-is-it-possible-to-store-tf-record-sequence-examples-as-float16
"""
import argparse
import numpy as np
import tensorflow as tf
def _write_tfrecord(data_np):
with tf.io.TFRecordWriter('data.tfrecord') as writer:
# encode the data in a dictionary of features
data = {'x': tf.train.Feature(
# the feature has a type ByteList
bytes_list=tf.train.BytesList(
# encode the data into bytes
value=[data_np.tobytes()]))}
# create a example from the features
example = tf.train.Example(features=tf.train.Features(feature=data))
# write the example to a TFRecord file
writer.write(example.SerializeToString())
def _parse_tfrecord(example_proto):
# describe how the TFRecord example will be interpreted
features_format = {
'x': tf.io.FixedLenFeature((), tf.string)
}
# parse the example (dict of features) from the TFRecord
parsed_features = tf.io.parse_single_example(example_proto, features_format)
# decode the bytes as float16 array
features = { k : tf.io.decode_raw(parsed_features[k], tf.float16) for k in features_format.keys() }
return features
def input_fn():
# read the dataset
dataset = tf.data.TFRecordDataset(['data.tfrecord'])
# parse each example of the dataset
dataset = dataset.map(_parse_tfrecord)
return dataset
def main():
# generate the data
x_np = np.array(np.random.rand(10), dtype=np.float16)
_write_tfrecord(x_np)
ds = input_fn()
for i_batch, batch in enumerate(ds):
print('DEBUG: i_batch = %i' % (i_batch), flush=True)
x = batch['x'].numpy()
print('DEBUG: x = ', x, flush=True)
print('DEBUG: allclose = ', np.allclose(x, x_np), flush=True)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment