Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Converts LabelImg files to TFRecord
import tensorflow as tf
import os
from lxml import etree
from tf.models.research.object_detection.utils import dataset_util
from tf.models.research.object_detection.utils import label_map_util
flags = tf.app.flags
flags.DEFINE_string('annotations', '', 'Path to LabelImg XMLs')
flags.DEFINE_string('pbtxt', '', 'pbtxt mapping id to name')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS
def create_tf_example(example, label_map_dict):
# TODO(user): Populate the following variables from your example.
xml_path = os.path.join(FLAGS.annotations, example)
xml_str = tf.io.gfile.GFile(xml_path, "r").read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
height = int(data['size']['height']) # Image height
width = int(data['size']['width']) # Image width
filename = data['filename'].encode('utf8') # Filename of the image. Empty if image is not from file
encoded_image_data = tf.io.gfile.GFile(data['path'], 'rb').read() # Encoded image bytes
image_format = b'png' if data['filename'].endswith('.png') else b'jpeg' # b'jpeg' or b'png'
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
xmaxs = [] # List of normalized right x coordinates in bounding box
# (1 per box)
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
ymaxs = [] # List of normalized bottom y coordinates in bounding box
# (1 per box)
classes_text = [] # List of string class name of bounding box (1 per box)
classes = [] # List of integer class id of bounding box (1 per box)
if 'object' in data:
for obj in data['object']:
xmins.append(float(obj['bndbox']['xmin']) / width)
ymins.append(float(obj['bndbox']['ymin']) / height)
xmaxs.append(float(obj['bndbox']['xmax']) / width)
ymaxs.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_image_data),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def main(_):
writer = tf.io.TFRecordWriter(FLAGS.output_path)
examples = []
for fi in os.listdir(FLAGS.annotations):
if (fi.endswith(".xml")):
examples.append(fi)
label_map_dict = label_map_util.get_label_map_dict(FLAGS.pbtxt)
for example in examples:
tf_example = create_tf_example(example, label_map_dict)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
tf.compat.v1.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment