Skip to content

Instantly share code, notes, and snippets.

@charlee
Created October 24, 2017 05:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save charlee/fb19922c4f986be5947b54242e7ab09a to your computer and use it in GitHub Desktop.
Save charlee/fb19922c4f986be5947b54242e7ab09a to your computer and use it in GitHub Desktop.
Read example from TFRecord
def read_and_decode(filename_queue):
"""Read from tfrecords file and decode and normalize the image data."""
reader = tf.TFRecordReader()
_, serialized_exmaple = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_exmaple,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
},
)
image = tf.decode_raw(features['image'], tf.uint8)
image.set_shape([IMAGE_SIZE * IMAGE_SIZE])
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features['label'], tf.int32)
return image, label
def train_input_fn(train_dir, batch_size=100, max_tfrecords_count=-1):
"""Feed max_tfrecords_count TFRecords for each class to estimator."""
filename_list = []
for root, dirs, files in os.walk(train_dir):
tfrecords = [os.path.join(root, f) for f in files if f.endswith('.tfrecords') and f.startswith('training-')]
if len(tfrecords) > 0:
if max_tfrecords_count == -1:
filename_list += tfrecords
else:
filename_list += tfrecords[0:max_tfrecords_count]
filename_list = filename_list[0:1]
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(filename_list)
image, label = read_and_decode(filename_queue)
images, labels = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=2,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000,
)
return images, labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment