Last active
July 9, 2018 02:12
-
-
Save pranjalsatija/485eed15f50ad0d8ca025067f84fc3f0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import turicreate as tc | |
import xml.etree.cElementTree as XMLParser | |
# Takes the XML for a bounding box and returns a Dictionary containing the center X and Y coordinates of the box, | |
# along with its width and height. | |
def get_bounding_box(xml): | |
box = xml.find('bndbox') | |
x_min, x_max = float(box.find('xmin').text), float(box.find('xmax').text) | |
y_min, y_max = float(box.find('ymin').text), float(box.find('ymax').text) | |
return { | |
'x': (x_min + x_max) / 2.0, | |
'y': (y_min + y_max) / 2.0, | |
'width': x_max - x_min, | |
'height': y_max - y_min | |
} | |
# Gets the annotation dictionary for an image, given its path. | |
def get_annotation(path): | |
filename = os.path.split(path)[-1] | |
img_annotations = filter(lambda a: a['img_name'] == filename, annotations) | |
return map(lambda a: a['objects'], img_annotations)[0] | |
path = '../' # The path to the root of your project directory. | |
ann_path = os.path.join(path, 'ann') # The path to the XML annotations. | |
img_path = os.path.join(path, 'img') # The path to the images. | |
sf_path = os.path.join(path, 'data.sframe') # The path to save the SFrame to. | |
annotations = [] | |
sf = None | |
print 'Loading annotation files...' | |
# Loads all of the .xml files in ann_path and parses them to retrieve the name of the image they belong to, | |
# along with the names of the objects and bounding boxes they contain. | |
for annotation_file in os.listdir(ann_path): | |
if not annotation_file.endswith('.xml'): | |
continue | |
else: | |
full_path = os.path.join(ann_path, annotation_file) | |
ann_xml = XMLParser.parse(full_path) | |
img_name = ann_xml.find('filename').text | |
ann_objects = [] | |
for ann_object in ann_xml.iter('object'): | |
bounding_box = get_bounding_box(ann_object) | |
label = ann_object.find('name').text | |
# Discard objects without a label or valid bounding boxes. | |
if bounding_box['width'] > 0 and bounding_box['height'] > 0 and len(label) > 0: | |
ann_objects.append({ | |
'coordinates': bounding_box, | |
'label': label | |
}) | |
annotations.append({ | |
'img_name': img_name, | |
'objects': ann_objects | |
}) | |
print 'Loading images...' | |
# The raw SFrame containing our images. We'll modify this to fit our needs. | |
# Turi Create will take care of loading all the images inside the img_path directory for us. | |
# This SFrame, once loaded, has 2 columns: 'path' (the path to the image), and 'image' (the image). | |
# We will add a third column, called 'annotations'. It'll contain a list of all the objects we want to detect in that | |
# image, expressed as a list of bounding boxes and labels. | |
sf = tc.image_analysis.load_images(img_path) | |
print 'Assigning annotations to images...' | |
# Looks through all the images loaded into sf, finds the appropriate annotation for each one, and sets that row's | |
# 'annotations' column. This is necessary so we can pass the SFrame to Turi Create for training later. | |
sf['annotations'] = sf['path'].apply(get_annotation) | |
print 'Saving SFrame...' | |
sf.save(sf_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment