Created
November 25, 2017 09:01
-
-
Save messefor/2b6b0dbc2da984a175e952a8a25d0ab0 to your computer and use it in GitHub Desktop.
Detect face in image and save cropped.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
""" | |
Detect face and save cropped images. | |
http://vaaaaaanquish.hatenablog.com/entry/2016/08/15/193636 | |
http://dlib.net/face_detector.py.html | |
""" | |
import cv2 | |
import os | |
import glob | |
import math | |
import matplotlib.pyplot as plt | |
from matplotlib import patches | |
import dlib | |
# % matplotlib inline | |
def add_bboxes(ax, bbox, rect_opt=None, figsize=(10, 10)): | |
# Add bounding boxes | |
rect_opt = rect_opt or {} | |
xy = bbox[:2] | |
width, height = bbox[2:] | |
opt = {'linewidth': 2, 'edgecolor': 'r', | |
'facecolor': 'none'} | |
opt.update(rect_opt) | |
bbox = patches.Rectangle(xy, width, height, **opt) | |
ax.add_patch(bbox) | |
return ax | |
def get_text_pos(img_w, img_h, bbox, loc='in'): | |
left, top, width, height = bbox | |
if loc == 'right': | |
hrz_offset = img_w / 10 | |
vtc_offset = height / 2 | |
x = left + width + hrz_offset | |
y = top + vtc_offset | |
opt = {} | |
elif loc == 'in': | |
hrz_offset = width / 2 | |
vtc_offset = height / 2 | |
x = left + hrz_offset | |
y = top + vtc_offset | |
opt = {'horizontalalignment': 'center'} | |
return x, y, opt | |
def adjust_bbox(bbox): | |
bbox_new = list(bbox) | |
w_ratio = 0.1 | |
h_ratio = 0.2 | |
bbox_new[0] -= bbox_new[2] * w_ratio | |
bbox_new[2] += bbox_new[2] * w_ratio * 2 | |
bbox_new[1] -= bbox_new[3] * h_ratio | |
bbox_new[3] += bbox_new[3] * h_ratio * 2 | |
return bbox_new | |
def add_facedetected_rect(ax, img_size, | |
dets, scores, indices, show_info=True, text_opt_add=None): | |
img_h, img_w = img_size | |
for rect, score, idx in zip(dets, scores, indices): | |
bbox = rect2bbox(rect) | |
ax = add_bboxes(ax, bbox) | |
if show_info: | |
text = 'idx: {}\nscore: {:.3f}'.format(idx, score) | |
x, y, opt = get_text_pos(img_h, img_w, bbox) | |
text_opt = {'color': 'b', 'fontsize': 10} | |
text_opt.update(opt) | |
text_opt_add = {} if text_opt_add is None else text_opt_add | |
text_opt.update(text_opt_add) | |
ax.text(x, y, text, **text_opt) | |
return ax | |
def check_img(img): | |
assert img is not None | |
assert img.ndim == 3 | |
def imread_rgb(path_img): | |
img = cv2.imread(path_img) | |
if img is None: | |
raise IOError | |
else: | |
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
def plot_detected_face(path_images, detector, save_fig=None, show_info=True): | |
n = len(path_images) | |
ncols = 5 | |
nrows = math.ceil(n / ncols) | |
figsize = (3 * ncols, 3 * nrows) | |
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) | |
for ax, path_img in zip(axes.ravel(), path_images): | |
try: | |
img = imread_rgb(path_img) | |
check_img(img) | |
opt = (1, 0) # thresh | |
dets, scores, indices = detector.run(img, *opt) | |
ax.imshow(img) | |
add_facedetected_rect(ax, img.shape[:2], dets, scores, indices, | |
show_info) | |
ax.axis('off') | |
except (IOError, AssertionError): | |
errmsg = 'Broken File: {}'.format(path_img) | |
print(errmsg) | |
pass | |
plt.subplots_adjust(wspace=0, hspace=0) | |
if save_fig: | |
plt.savefig(save_fig) | |
plt.show() | |
def bbox2start_end(bbox): | |
left, top, width, height = bbox | |
x_start = left | |
x_end = left + width | |
y_start = top | |
y_end = top + height | |
return x_start, x_end, y_start, y_end | |
def get_img_cropped(img, bbox): | |
x_start, x_end, y_start, y_end =\ | |
bbox2start_end(bbox) | |
return img[y_start:y_end, x_start:x_end] | |
def rect2bbox(rect): | |
width = rect.right() - rect.left() | |
height = rect.bottom() - rect.top() | |
return rect.left(), rect.top(), width, height | |
def save_detected_face(path_images, detector, out_dir, | |
min_h=256, min_w=256, verbose=True): | |
n_imgs_saved = 0 | |
for path_img in path_images: | |
try: | |
img = imread_rgb(path_img) | |
check_img(img) | |
opt = (1, 0) # thresh | |
dets, scores, indices = detector.run(img, *opt) | |
if dets: | |
filename = os.path.basename(path_img) | |
file_parts = os.path.splitext(filename) | |
for i, rect in enumerate(dets): | |
bbox = rect2bbox(rect) | |
if bbox[2] >= min_w and bbox[3] >= min_h: | |
img_cropped = get_img_cropped(img, bbox) | |
save_path = os.path.join(out_dir, | |
'{}_{}{}'.format(file_parts[0], i, file_parts[1])) | |
cv2.imwrite(save_path, | |
cv2.cvtColor(img_cropped, cv2.COLOR_RGB2BGR)) | |
n_imgs_saved += 1 | |
if verbose: | |
print('Images #{} saved: {}'.format(n_imgs_saved, save_path)) | |
except (IOError, AssertionError): | |
errmsg = 'Broken File: {}'.format(path_img) | |
print(errmsg) | |
pass | |
def main(): | |
IMG_EXT = ('.JPG', '.jpg', '.jpeg', '.png', '.PNG') | |
keys = ('glasses', 'man', 'woman', 'no_glass', 'sunglass', 'man_itoh', | |
'woman_itoh', 'glasses_itoh') | |
for key in keys: | |
dir_in = 'data/{}'.format(key) | |
dir_out = 'data/{}_cropped'.format(key) | |
print("-" * 40) | |
print('key: {} START'.format(key)) | |
print("-" * 40) | |
os.makedirs(dir_out, exist_ok=True) | |
# Load images | |
path_images = [os.path.join(dir_in, file_path) | |
for file_path in os.listdir(dir_in) | |
if os.path.isfile(os.path.join(dir_in, file_path)) and \ | |
os.path.splitext(file_path)[-1] in IMG_EXT] | |
# Build detector object | |
detector = dlib.get_frontal_face_detector() | |
# Check how detected | |
# plot_detected_face(path_images[:20], detector) | |
# Save cropped images | |
save_detected_face(path_images, detector, dir_out) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment