Skip to content

Instantly share code, notes, and snippets.

@Denisolt
Last active August 27, 2020 02:49
Show Gist options
  • Save Denisolt/6de779cd8f7d8a97d624cf4cf025603e to your computer and use it in GitHub Desktop.
Save Denisolt/6de779cd8f7d8a97d624cf4cf025603e to your computer and use it in GitHub Desktop.
IBM PowerAI Vision Data Export to TFRecord Converter
"""
tensorflow 2.0.2
Usage: python convert.py --input_path=data/ --image_path=data/ --train_split=0.70 --shuffle_seed=1
"""
import xml.etree.ElementTree as ET
import glob
from tqdm import tqdm
from math import floor
import os
import io
import tensorflow as tf
import dataset_util
import numpy as np
from PIL import Image
flags = tf.compat.v1.flags
flags.DEFINE_string('input_path', '', 'Path to input XML files')
flags.DEFINE_string('image_path', '', 'Path to images')
flags.DEFINE_float('train_split', 0.70, 'Train/Test Data Split')
flags.DEFINE_integer('shuffle_seed', None, 'Specify Random Seed')
label_map = {}
def train_test_split(files, test_size, seed=None):
np.random.seed(seed) if seed else None
np.random.shuffle(files)
split_index = floor(len(files) * test_size)
return files[:split_index], files[split_index:]
def map_class(obj_class):
if obj_class not in label_map:
label_map[obj_class] = len(label_map) + 1
return label_map[obj_class]
def print_label_map():
output = ''
for label, identifier in label_map.items():
output += f"item {{\n id: {identifier}\n name: '{label}'\n}}\n\n"
return output
def converter(file):
objects = []
tree = ET.parse(file)
root = tree.getroot()
filename: str = root.find('object').find('file_id').text
w: int = root.find('size').find('width').text
h: int = root.find('size').find('height').text
for member in root.findall('object'):
bndbox = member.find('bndbox')
obj = {
'filename': filename,
'width': int(w),
'height': int(h),
'class': member.find('name').text,
'xmin': int(bndbox.find('xmin').text),
'ymin': int(bndbox.find('ymin').text),
'xmax': int(bndbox.find('xmax').text),
'ymax': int(bndbox.find('ymax').text)
}
objects.append(obj)
return objects
def int64_feature(value):
return tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=value))
def bytes_feature(value):
return tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=value))
def float_list_feature(value):
return tf.compat.v1.train.Feature(float_list=tf.compat.v1.train.FloatList(value=value))
def
(objects, image_path):
filename = objects[0]['filename']
image_format = [file.split('.')[-1] for file in glob.glob(f'{image_path}{filename}.*') if file.split('.')[-1]!='xml'][0]
filename += f'.{image_format}'
w = objects[0]['filename']
h = objects[0]['height']
with tf.io.gfile.GFile(os.path.join(image_path, filename), 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
width, height = image.size
if w != width and h != height:
raise ValueError('Width and Height on XML File are not consistent with the Actual Image')
image_format = image_format.encode('utf8')
filename = filename.encode('utf8')
xmins = [obj['xmin'] / width for obj in objects]
xmaxs = [obj['xmax'] / width for obj in objects]
ymins = [obj['ymin'] / height for obj in objects]
ymaxs = [obj['ymax'] / height for obj in objects]
classes_text = [obj['class'].encode('utf8') for obj in objects]
classes = [map_class(obj['class']) for obj in objects]
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/filename': bytes_feature(filename),
'image/source_id': bytes_feature(filename),
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature(image_format),
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes),
}))
return tf_example
def main(_):
image_path = os.path.join(flags.FLAGS.image_path)
xml_files = glob.glob(f'{flags.FLAGS.input_path}/*.xml')
train, test = train_test_split(xml_files, flags.FLAGS.train_split, flags.FLAGS.shuffle_seed)
for file_type, file_type_flag in zip([train, test], ['train', 'test']):
writer = tf.io.TFRecordWriter(f'{file_type_flag}.record')
for file in tqdm(file_type):
objects = converter(file)
tf_record = tf_record_gen(objects, image_path)
writer.write(tf_record.SerializeToString())
writer.close()
print(print_label_map(), file=open(f'labelmap.pbtxt', 'w'))
print(f'Successfully created the Train and Test TFRecords and Labelmap')
if __name__ == '__main__':
tf.compat.v1.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment