Last active
August 27, 2020 02:49
-
-
Save Denisolt/6de779cd8f7d8a97d624cf4cf025603e to your computer and use it in GitHub Desktop.
IBM PowerAI Vision Data Export to TFRecord Converter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
tensorflow 2.0.2 | |
Usage: python convert.py --input_path=data/ --image_path=data/ --train_split=0.70 --shuffle_seed=1 | |
""" | |
import xml.etree.ElementTree as ET | |
import glob | |
from tqdm import tqdm | |
from math import floor | |
import os | |
import io | |
import tensorflow as tf | |
import dataset_util | |
import numpy as np | |
from PIL import Image | |
flags = tf.compat.v1.flags | |
flags.DEFINE_string('input_path', '', 'Path to input XML files') | |
flags.DEFINE_string('image_path', '', 'Path to images') | |
flags.DEFINE_float('train_split', 0.70, 'Train/Test Data Split') | |
flags.DEFINE_integer('shuffle_seed', None, 'Specify Random Seed') | |
label_map = {} | |
def train_test_split(files, test_size, seed=None): | |
np.random.seed(seed) if seed else None | |
np.random.shuffle(files) | |
split_index = floor(len(files) * test_size) | |
return files[:split_index], files[split_index:] | |
def map_class(obj_class): | |
if obj_class not in label_map: | |
label_map[obj_class] = len(label_map) + 1 | |
return label_map[obj_class] | |
def print_label_map(): | |
output = '' | |
for label, identifier in label_map.items(): | |
output += f"item {{\n id: {identifier}\n name: '{label}'\n}}\n\n" | |
return output | |
def converter(file): | |
objects = [] | |
tree = ET.parse(file) | |
root = tree.getroot() | |
filename: str = root.find('object').find('file_id').text | |
w: int = root.find('size').find('width').text | |
h: int = root.find('size').find('height').text | |
for member in root.findall('object'): | |
bndbox = member.find('bndbox') | |
obj = { | |
'filename': filename, | |
'width': int(w), | |
'height': int(h), | |
'class': member.find('name').text, | |
'xmin': int(bndbox.find('xmin').text), | |
'ymin': int(bndbox.find('ymin').text), | |
'xmax': int(bndbox.find('xmax').text), | |
'ymax': int(bndbox.find('ymax').text) | |
} | |
objects.append(obj) | |
return objects | |
def int64_feature(value): | |
return tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=[value])) | |
def int64_list_feature(value): | |
return tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=value)) | |
def bytes_feature(value): | |
return tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=[value])) | |
def bytes_list_feature(value): | |
return tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=value)) | |
def float_list_feature(value): | |
return tf.compat.v1.train.Feature(float_list=tf.compat.v1.train.FloatList(value=value)) | |
def | |
(objects, image_path): | |
filename = objects[0]['filename'] | |
image_format = [file.split('.')[-1] for file in glob.glob(f'{image_path}{filename}.*') if file.split('.')[-1]!='xml'][0] | |
filename += f'.{image_format}' | |
w = objects[0]['filename'] | |
h = objects[0]['height'] | |
with tf.io.gfile.GFile(os.path.join(image_path, filename), 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = Image.open(encoded_jpg_io) | |
width, height = image.size | |
if w != width and h != height: | |
raise ValueError('Width and Height on XML File are not consistent with the Actual Image') | |
image_format = image_format.encode('utf8') | |
filename = filename.encode('utf8') | |
xmins = [obj['xmin'] / width for obj in objects] | |
xmaxs = [obj['xmax'] / width for obj in objects] | |
ymins = [obj['ymin'] / height for obj in objects] | |
ymaxs = [obj['ymax'] / height for obj in objects] | |
classes_text = [obj['class'].encode('utf8') for obj in objects] | |
classes = [map_class(obj['class']) for obj in objects] | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': int64_feature(height), | |
'image/width': int64_feature(width), | |
'image/filename': bytes_feature(filename), | |
'image/source_id': bytes_feature(filename), | |
'image/encoded': bytes_feature(encoded_jpg), | |
'image/format': bytes_feature(image_format), | |
'image/object/bbox/xmin': float_list_feature(xmins), | |
'image/object/bbox/xmax': float_list_feature(xmaxs), | |
'image/object/bbox/ymin': float_list_feature(ymins), | |
'image/object/bbox/ymax': float_list_feature(ymaxs), | |
'image/object/class/text': bytes_list_feature(classes_text), | |
'image/object/class/label': int64_list_feature(classes), | |
})) | |
return tf_example | |
def main(_): | |
image_path = os.path.join(flags.FLAGS.image_path) | |
xml_files = glob.glob(f'{flags.FLAGS.input_path}/*.xml') | |
train, test = train_test_split(xml_files, flags.FLAGS.train_split, flags.FLAGS.shuffle_seed) | |
for file_type, file_type_flag in zip([train, test], ['train', 'test']): | |
writer = tf.io.TFRecordWriter(f'{file_type_flag}.record') | |
for file in tqdm(file_type): | |
objects = converter(file) | |
tf_record = tf_record_gen(objects, image_path) | |
writer.write(tf_record.SerializeToString()) | |
writer.close() | |
print(print_label_map(), file=open(f'labelmap.pbtxt', 'w')) | |
print(f'Successfully created the Train and Test TFRecords and Labelmap') | |
if __name__ == '__main__': | |
tf.compat.v1.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment