Skip to content

Instantly share code, notes, and snippets.

@kervel
Created May 4, 2021 15:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kervel/c16d57c0c6eadf0b4497f545da5b43d2 to your computer and use it in GitHub Desktop.
Save kervel/c16d57c0c6eadf0b4497f545da5b43d2 to your computer and use it in GitHub Desktop.
preprocess VOC data
"""
Script that executes the first stage of the pipeline for VOC data (filtering)
"""
import argparse
import xml.etree.ElementTree as ET
import os
import random
DEFAULT_CONFIG = {
'validation': [],
'training': [],
'training_split': 0.8
}
def get_attributes(obj: ET.Element):
d = {}
for a in obj.find("attributes").findall("attribute"):
a_key = a.find("name").text
a_val = a.find("value").text
d[a_key] = a_val
return d
def process_annotation(ann_file, labels, img_dir, out_dir, dumponly):
""" Process an annotation, filtering out unwanted labels, modifying/adding values where
needed, and storying it in the suitable output directory
Args:
ann_file: path to input annotation file
labels: Set of labels you want to keep in the annotation
img_dir: path to directory where corresponding image is stored (needs to be saved in the
annotation itself)
out_dir: path to directory where the processed annotation needs to be saved
Returns: n/a
"""
try:
tree = ET.parse(ann_file)
except:
print(f"0 failed parsing {ann_file}")
return
root = tree.getroot()
# find all 'objects' that need to be filtered -- if all objects need to be filtered,
# return and skip this annotation
objects = root.findall("object")
for o in objects:
nm = o.find("name")
attrs = get_attributes(o)
if 'Type' in attrs:
nm.text = nm.text + attrs['Type']
if(dumponly):
print(nm.text)
if dumponly:
return
objects = root.findall("object")
if len(labels) > 0:
filtered = list(
filter(lambda x: x.find("name").text not in labels, objects)
) # objects that need to be removed from the tree
if len(objects) == len(filtered):
return
for ann_object in filtered:
root.remove(ann_object)
# img file relative to root data dir
img_file = os.path.join(img_dir, root.find("filename").text)
# if the image doesn't exist, return (stop processing)
if not os.path.isfile(os.path.join(data_root, img_file)):
print(f"0 image does not exist {ann_file} --> try to find {img_file}")
return
# otherwise the image does exist -- modify the annotation so it contains a path to this image
# the path to the image has to be relative to the root directory
ann_id = os.path.basename(ann_file)
# change annotation image directory
folder = root.find("folder")
folder.text = ""
root.find("filename").text = img_file
print(f"1 writing file {ann_file}")
# save annotation to output directory
tree.write(os.path.join(out_dir, ann_id))
return
def preprocess_voc_data(root_data_dir, labels, out_dir, dumponly):
""" preprocess a VOC data set
Args:
labels: labels to filter (or no labels if no filtering necessary)
root_data_dir: path to the base data directory
(e.g. raw/dataset)
out_dir: path to output directory, relative to root data directory
(e.g. processed/dataset)
"""
voc_images = os.path.join(root_data_dir, "JPEGImages")
annotation_dir = os.path.join(root_data_dir, "Annotations")
# create output directories
train_anns_dir = os.path.join(out_dir, "train_annotations")
val_anns_dir = os.path.join(out_dir, "val_annotations")
test_anns_dir = os.path.join(out_dir, "test_annotations")
if not dumponly:
os.makedirs(train_anns_dir, exist_ok=True)
os.makedirs(val_anns_dir, exist_ok=True)
os.makedirs(test_anns_dir, exist_ok=True)
# iterate over all annotation files, process them
# annotation dir can contain both subdirectories and annotation files
for file in os.listdir(annotation_dir):
file = os.path.join(annotation_dir, file)
if os.path.isdir(file):
for ann_file in os.listdir(file):
ann_file = os.path.join(file, ann_file)
rand = random.random()
ann_dst = train_anns_dir if rand < DEFAULT_CONFIG["training_split"] else val_anns_dir
process_annotation(ann_file, labels, voc_images, ann_dst, dumponly)
else:
ann_file = os.path.join(annotation_dir, file)
rand = random.random()
ann_dst = train_anns_dir if rand < DEFAULT_CONFIG["training_split"] else val_anns_dir
process_annotation(ann_file, labels, voc_images, ann_dst, dumponly)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_data_dir",
help="root data directy -- any specified file/directory path is relative to this root "
"data directory",
required=True,
)
parser.add_argument("--output_dir", required=True)
parser.add_argument("--target_labels", required=True, nargs="*")
parser.add_argument("--dumponly", help="just dump all labels you found")
args = parser.parse_args()
data_root = os.path.abspath(args.root_data_dir)
target_labels = args.target_labels
preprocess_voc_data(
root_data_dir=data_root,
labels=target_labels,
out_dir=os.path.join(data_root, args.output_dir),
dumponly = args.dumponly
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment