Skip to content

Instantly share code, notes, and snippets.

@pranjalsatija
Last active July 9, 2018 02:12
Show Gist options
  • Save pranjalsatija/485eed15f50ad0d8ca025067f84fc3f0 to your computer and use it in GitHub Desktop.
Save pranjalsatija/485eed15f50ad0d8ca025067f84fc3f0 to your computer and use it in GitHub Desktop.
import os
import turicreate as tc
import xml.etree.cElementTree as XMLParser
# Takes the XML for a bounding box and returns a Dictionary containing the center X and Y coordinates of the box,
# along with its width and height.
def get_bounding_box(xml):
box = xml.find('bndbox')
x_min, x_max = float(box.find('xmin').text), float(box.find('xmax').text)
y_min, y_max = float(box.find('ymin').text), float(box.find('ymax').text)
return {
'x': (x_min + x_max) / 2.0,
'y': (y_min + y_max) / 2.0,
'width': x_max - x_min,
'height': y_max - y_min
}
# Gets the annotation dictionary for an image, given its path.
def get_annotation(path):
filename = os.path.split(path)[-1]
img_annotations = filter(lambda a: a['img_name'] == filename, annotations)
return map(lambda a: a['objects'], img_annotations)[0]
path = '../' # The path to the root of your project directory.
ann_path = os.path.join(path, 'ann') # The path to the XML annotations.
img_path = os.path.join(path, 'img') # The path to the images.
sf_path = os.path.join(path, 'data.sframe') # The path to save the SFrame to.
annotations = []
sf = None
print 'Loading annotation files...'
# Loads all of the .xml files in ann_path and parses them to retrieve the name of the image they belong to,
# along with the names of the objects and bounding boxes they contain.
for annotation_file in os.listdir(ann_path):
if not annotation_file.endswith('.xml'):
continue
else:
full_path = os.path.join(ann_path, annotation_file)
ann_xml = XMLParser.parse(full_path)
img_name = ann_xml.find('filename').text
ann_objects = []
for ann_object in ann_xml.iter('object'):
bounding_box = get_bounding_box(ann_object)
label = ann_object.find('name').text
# Discard objects without a label or valid bounding boxes.
if bounding_box['width'] > 0 and bounding_box['height'] > 0 and len(label) > 0:
ann_objects.append({
'coordinates': bounding_box,
'label': label
})
annotations.append({
'img_name': img_name,
'objects': ann_objects
})
print 'Loading images...'
# The raw SFrame containing our images. We'll modify this to fit our needs.
# Turi Create will take care of loading all the images inside the img_path directory for us.
# This SFrame, once loaded, has 2 columns: 'path' (the path to the image), and 'image' (the image).
# We will add a third column, called 'annotations'. It'll contain a list of all the objects we want to detect in that
# image, expressed as a list of bounding boxes and labels.
sf = tc.image_analysis.load_images(img_path)
print 'Assigning annotations to images...'
# Looks through all the images loaded into sf, finds the appropriate annotation for each one, and sets that row's
# 'annotations' column. This is necessary so we can pass the SFrame to Turi Create for training later.
sf['annotations'] = sf['path'].apply(get_annotation)
print 'Saving SFrame...'
sf.save(sf_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment