Skip to content

Instantly share code, notes, and snippets.

@MSWon
Last active April 29, 2021 15:50
Show Gist options
  • Save MSWon/124e0d3aa22e38b17c347afdcbf5bcc5 to your computer and use it in GitHub Desktop.
Save MSWon/124e0d3aa22e38b17c347afdcbf5bcc5 to your computer and use it in GitHub Desktop.
tf record example
import tensorflow as tf
import numpy as np
import sys
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
## data
corpus = np.array([[[1,2],[0,1]],[[2,4],[5,6]],[[4,2],[0,0]]])
seq_len = np.array([2,2,1], dtype=np.int64)
## write tfrecord file
filename = 'test2.tfrecords' # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(filename)
for i in range(len(corpus)):
feature = {"corpus" : _bytes_feature(tf.compat.as_bytes(corpus[i].tostring())),
"seq_len" : _int64_feature(seq_len[i])}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
## read tfrecord file
corpus_list = []
seq_list = []
with tf.Session() as sess:
feature = {'corpus': tf.FixedLenFeature([], tf.string),
'seq_len': tf.FixedLenFeature([], tf.int64)}
# Create a list of filenames and pass it to a queue
filename_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=1)
# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)
# Convert the image data from string back to the numbers
corpus = tf.decode_raw(features['corpus'], tf.int32)
# Cast label data into int32
seq_len = tf.cast(features['seq_len'], tf.int32)
corpus = tf.reshape(corpus, [2,2])
# Creates batches by randomly shuffling tensors
corpus_, seq_len_ = tf.train.shuffle_batch([corpus, seq_len], batch_size=1, capacity=30, num_threads=1, min_after_dequeue=10)
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)
# Create a coordinator and run all QueueRunner objects
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(2):
test_corpus,test_seq_len = sess.run([corpus_, seq_len_])
corpus_list.append(test_corpus)
seq_list.append(test_seq_len)
###### using tf.data iterator ######
## read tfrecord file
def from_tfrecord(serialized):
features = \
tf.parse_single_example(
serialized=serialized,
features={
"corpus": tf.FixedLenFeature([], tf.string),
"seq_len": tf.FixedLenFeature([], tf.int64)
}
)
corpus = tf.reshape(tf.decode_raw(features['corpus'], tf.int32) ,[2,2])
seq_len = tf.cast(features['seq_len'], tf.int32)
return corpus, seq_len
dataset = tf.data.TFRecordDataset(
filenames=['test.tfrecords','test2.tfrecords']).map(from_tfrecord)
generated_corpus, generated_seq_len = \
dataset.\
batch(1).\
repeat().\
make_one_shot_iterator().\
get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
c_list, s_list = [], []
for i in range(6):
c,s = sess.run([generated_corpus, generated_seq_len])
c_list.append(c)
s_list.append(s)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment