Skip to content

Instantly share code, notes, and snippets.

@tawnkramer
Last active June 26, 2017 20:38
Show Gist options
  • Save tawnkramer/d1c69c1d26c011220b7fbe36645feb01 to your computer and use it in GitHub Desktop.
Save tawnkramer/d1c69c1d26c011220b7fbe36645feb01 to your computer and use it in GitHub Desktop.
'''
Author: Tawn Kramer
Date: June 22, 2017
Brief:
This will crop the set of images in a LISA sign data set, resize to a constant dimension,
and save out a pickle file. You can download the data here:
http://cvrr.ucsd.edu/LISA/lisa-traffic-sign-dataset.html
The output of this data set is a pickle file. Consideering the output is 15Mb while the input is 7GB,
I will host the product here if someone is interested in trying this:
https://s3.amazonaws.com/tawn-train/udacity_carnd/traffic_signs/lisa_train.tar.gz
You may want to refer to my training script:
https://gist.github.com/tawnkramer/5d72fc62d800fc96504db09c6408cb5e
And check notes there.
'''
from __future__ import print_function
import os
import sys
import numpy as np
from PIL import Image
import csv
import pickle
import math
import random
import matplotlib.pyplot as plt
annotations_filename = 'allAnnotations.csv'
classes = {}
class_names = []
iClass = 0
jobs = []
with open(annotations_filename, mode="rt") as infile:
reader = csv.DictReader(infile, delimiter=';')
for row in reader:
if row['Occluded,On another road'] != '0,0' and \
row['Occluded,On another road'] != '1,0' :
continue
label = row['Annotation tag']
if not label in classes.keys():
classes[label] = iClass
class_names.append(label)
iClass += 1
i = classes[label]
jobs.append( { 'file' : row['Filename'], 'class' : i,\
'bbox' : ( int(row['Upper left corner X']),\
int(row['Upper left corner Y']),\
int(row['Lower right corner X']),\
int(row['Lower right corner Y']))\
} )
print("Images:", len(jobs))
print("Classes:", len(classes.keys()))
with open('signnames.csv', "wt") as outfile:
outfile.write('ClassId,SignName\n')
for i in range(len(class_names)):
outfile.write("%d,%s\n" % (i, class_names[i]))
print("wrote signnames.csv")
num_samples = len(jobs)
num_classes = len(classes.keys())
img_width = 32
img_height = 32
img_depth = 3
pad_small_images = True
#when pad_with_black is False, a larger crop area is made, centered on the image
#when pad_with_black. the edges of the image are black
pad_with_black = False
#enable verbose to get more details to console
verbose = False
#enable show_image to see image after cropping and resize operations
show_images = False
labels = np.empty((num_samples), dtype=np.uint8)
features = np.empty((num_samples, img_width, img_height, img_depth), dtype=np.uint8)
def largest(a, b):
if a > b:
return a
return b
def expand_bbox_to_square(bbox):
'''
take a bounding box and modify it to make it a square aspect ratio
check the flags pad_small_images and expand bounding box
returns modified bounding box
'''
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
#the largest edge
dim = largest(w, h)
#expand the bounding box to include more of the image
#only if pad_small_images is True and pad_with_black
#is False
if pad_small_images and dim < img_width and not pad_with_black:
dim = img_width
d0 = 0
d1 = 0
d2 = 0
d3 = 0
if dim > w:
d0 = -int(math.floor((dim - w) / 2.0))
d2 = int(math.ceil((dim - w) / 2.0))
if dim > h:
d1 = -int(math.floor((dim - h) / 2.0))
d3 = int(math.ceil((dim - h) / 2.0))
return ( bbox[0] + d0,\
bbox[1] + d1,\
bbox[2] + d2,\
bbox[3] + d3)
def get_image(job):
img = Image.open(job['file'])
if img_depth == 1:
img = img.convert('L')
elif img_depth == 3:
img = img.convert('RGB')
bbox = expand_bbox_to_square(job['bbox'])
img = img.crop(bbox)
if show_images:
imgplot = plt.imshow(np.array(img))
plt.show()
if verbose:
print('cropped', job['bbox'], bbox, img.size)
if img.size[0] < img_width:
if pad_small_images and pad_with_black:
new_size = (img_width, img_height)
old_size = img.size
if img_depth == 3:
new_im = Image.new("RGB", new_size)
elif img_depth == 1:
new_im = Image.new("L", new_size)
new_im.paste(img, ((new_size[0]-old_size[0])/2,\
(new_size[1]-old_size[1])/2))
img = new_im
else:
img = img.resize((img_width, img_height))
elif img.size[0] > img_width:
img.thumbnail((img_width, img_height), Image.BICUBIC)
arr = np.array(img)
return arr.reshape(img_width, img_height, img_depth)
for i in range(len(jobs)):
job = jobs[i]
labels[i] = job['class']
img = get_image(job)
if show_images:
print('class', class_names[job['class']])
imgplot = plt.imshow(img)
plt.show()
features[i] = img
if i % 100 == 0:
print('.', end="")
sys.stdout.flush()
print('done.')
data = { 'labels' : labels, 'features': features }
outfilename = 'train.p'
with open(outfilename, 'wb') as outfile:
pickle.dump(data, outfile)
print('saved', outfilename)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment