Skip to content

Instantly share code, notes, and snippets.

@dragonlock2
Created August 19, 2019 09:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dragonlock2/f2486feb42fa211708a53ec85e45a74b to your computer and use it in GitHub Desktop.
Save dragonlock2/f2486feb42fa211708a53ec85e45a74b to your computer and use it in GitHub Desktop.
Converts LabelImg files to TFRecord
import tensorflow as tf
import os
from lxml import etree
from tf.models.research.object_detection.utils import dataset_util
from tf.models.research.object_detection.utils import label_map_util
flags = tf.app.flags
flags.DEFINE_string('annotations', '', 'Path to LabelImg XMLs')
flags.DEFINE_string('pbtxt', '', 'pbtxt mapping id to name')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS
def create_tf_example(example, label_map_dict):
# TODO(user): Populate the following variables from your example.
xml_path = os.path.join(FLAGS.annotations, example)
xml_str = tf.io.gfile.GFile(xml_path, "r").read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
height = int(data['size']['height']) # Image height
width = int(data['size']['width']) # Image width
filename = data['filename'].encode('utf8') # Filename of the image. Empty if image is not from file
encoded_image_data = tf.io.gfile.GFile(data['path'], 'rb').read() # Encoded image bytes
image_format = b'png' if data['filename'].endswith('.png') else b'jpeg' # b'jpeg' or b'png'
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
xmaxs = [] # List of normalized right x coordinates in bounding box
# (1 per box)
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
ymaxs = [] # List of normalized bottom y coordinates in bounding box
# (1 per box)
classes_text = [] # List of string class name of bounding box (1 per box)
classes = [] # List of integer class id of bounding box (1 per box)
if 'object' in data:
for obj in data['object']:
xmins.append(float(obj['bndbox']['xmin']) / width)
ymins.append(float(obj['bndbox']['ymin']) / height)
xmaxs.append(float(obj['bndbox']['xmax']) / width)
ymaxs.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_image_data),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def main(_):
writer = tf.io.TFRecordWriter(FLAGS.output_path)
examples = []
for fi in os.listdir(FLAGS.annotations):
if (fi.endswith(".xml")):
examples.append(fi)
label_map_dict = label_map_util.get_label_map_dict(FLAGS.pbtxt)
for example in examples:
tf_example = create_tf_example(example, label_map_dict)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
tf.compat.v1.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment