Skip to content

Instantly share code, notes, and snippets.

@MITsVision
Created June 18, 2020 07:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MITsVision/3d99ed1c4426c4cca392706e7bae98fb to your computer and use it in GitHub Desktop.
Save MITsVision/3d99ed1c4426c4cca392706e7bae98fb to your computer and use it in GitHub Desktop.
import tensorflow as tf
train_record = 'tf-records/coco_train.record-00000-of-00001'
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/filename': tf.FixedLenFeature([], tf.string),
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
'image/object/class/text': tf.VarLenFeature(tf.string)
})
image = tf.decode_raw(features['image/encoded'], tf.uint8)
xmin = tf.cast(features['image/object/bbox/xmin'], tf.float32)
ymin = tf.cast(features['image/object/bbox/ymin'], tf.float32)
xmax = tf.cast(features['image/object/bbox/xmax'], tf.float32)
ymax = tf.cast(features['image/object/bbox/ymax'], tf.float32)
label = tf.cast(features['image/object/class/text'], tf.string)
Iheight = tf.cast(features['image/height'], tf.int32)
Iwidth = tf.cast(features['image/width'], tf.int32)
return [image,xmin,ymin,xmax,ymax,label,Iheight,Iwidth]
def get_all_records(FILE):
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer([ FILE ])
image = read_and_decode(filename_queue)
#print(data[0])
#image = tf.reshape(image, tf.stack([data[6], data[7], 3]))#height,width
#image.set_shape([640,480,3])
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1):
example = sess.run([image])
#img = Image.fromarray(example, 'RGB')
#img.save( "output/" + str(i) + '-train.png')
print (len(example[0]))
coord.request_stop()
coord.join(threads)
get_all_records(train_record)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment