Last active
April 29, 2021 15:50
-
-
Save MSWon/124e0d3aa22e38b17c347afdcbf5bcc5 to your computer and use it in GitHub Desktop.
tf record example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import sys | |
def _bytes_feature(value): | |
"""Returns a bytes_list from a string / byte.""" | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def _float_feature(value): | |
"""Returns a float_list from a float / double.""" | |
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) | |
def _int64_feature(value): | |
"""Returns an int64_list from a bool / enum / int / uint.""" | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
## data | |
corpus = np.array([[[1,2],[0,1]],[[2,4],[5,6]],[[4,2],[0,0]]]) | |
seq_len = np.array([2,2,1], dtype=np.int64) | |
## write tfrecord file | |
filename = 'test2.tfrecords' # address to save the TFRecords file | |
writer = tf.python_io.TFRecordWriter(filename) | |
for i in range(len(corpus)): | |
feature = {"corpus" : _bytes_feature(tf.compat.as_bytes(corpus[i].tostring())), | |
"seq_len" : _int64_feature(seq_len[i])} | |
example = tf.train.Example(features=tf.train.Features(feature=feature)) | |
writer.write(example.SerializeToString()) | |
writer.close() | |
sys.stdout.flush() | |
## read tfrecord file | |
corpus_list = [] | |
seq_list = [] | |
with tf.Session() as sess: | |
feature = {'corpus': tf.FixedLenFeature([], tf.string), | |
'seq_len': tf.FixedLenFeature([], tf.int64)} | |
# Create a list of filenames and pass it to a queue | |
filename_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=1) | |
# Define a reader and read the next record | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
# Decode the record read by the reader | |
features = tf.parse_single_example(serialized_example, features=feature) | |
# Convert the image data from string back to the numbers | |
corpus = tf.decode_raw(features['corpus'], tf.int32) | |
# Cast label data into int32 | |
seq_len = tf.cast(features['seq_len'], tf.int32) | |
corpus = tf.reshape(corpus, [2,2]) | |
# Creates batches by randomly shuffling tensors | |
corpus_, seq_len_ = tf.train.shuffle_batch([corpus, seq_len], batch_size=1, capacity=30, num_threads=1, min_after_dequeue=10) | |
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) | |
sess.run(init) | |
# Create a coordinator and run all QueueRunner objects | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(coord=coord) | |
for i in range(2): | |
test_corpus,test_seq_len = sess.run([corpus_, seq_len_]) | |
corpus_list.append(test_corpus) | |
seq_list.append(test_seq_len) | |
###### using tf.data iterator ###### | |
## read tfrecord file | |
def from_tfrecord(serialized): | |
features = \ | |
tf.parse_single_example( | |
serialized=serialized, | |
features={ | |
"corpus": tf.FixedLenFeature([], tf.string), | |
"seq_len": tf.FixedLenFeature([], tf.int64) | |
} | |
) | |
corpus = tf.reshape(tf.decode_raw(features['corpus'], tf.int32) ,[2,2]) | |
seq_len = tf.cast(features['seq_len'], tf.int32) | |
return corpus, seq_len | |
dataset = tf.data.TFRecordDataset( | |
filenames=['test.tfrecords','test2.tfrecords']).map(from_tfrecord) | |
generated_corpus, generated_seq_len = \ | |
dataset.\ | |
batch(1).\ | |
repeat().\ | |
make_one_shot_iterator().\ | |
get_next() | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
c_list, s_list = [], [] | |
for i in range(6): | |
c,s = sess.run([generated_corpus, generated_seq_len]) | |
c_list.append(c) | |
s_list.append(s) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment