Skip to content

Instantly share code, notes, and snippets.

@ck196
Created February 22, 2016 11:02
Show Gist options
  • Save ck196/198885b45118e2963cbf to your computer and use it in GitHub Desktop.
Save ck196/198885b45118e2963cbf to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
CLASSES = ('__background__', 'face')
def vis_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
#im = im[:, :, (2, 1, 0)]
#fig, ax = plt.subplots(figsize=(12, 12))
#ax.imshow(im, aspect='equal')
#plt.pause(5)
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2)
cv2.putText(im, class_name, (int(bbox[0]), int(bbox[1]) + 15), cv2.FONT_HERSHEY_DUPLEX, 0.6, (255, 255, 0), 1)
cv2.imshow('image',im)
def demo(net, image_name):
"""Detect object classes in an image using pre-computed object proposals."""
# Load the demo image
im_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo', image_name)
im = cv2.imread(im_file)
# Detect all object classes and regress object bounds
timer = Timer()
timer.tic()
scores, boxes = im_detect(net, im)
#print scores
#print boxes
timer.toc()
print ('Detection took {:.3f}s for '
'{:d} object proposals').format(timer.total_time, boxes.shape[0])
# Visualize detections for each class
CONF_THRESH = 0.8
NMS_THRESH = 0.3
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
vis_detections(im, cls, dets, thresh=CONF_THRESH)
if __name__ == '__main__':
cfg.TEST.HAS_RPN = True # Use RPN for proposals
#args = parse_args()
prototxt = "models/VGG16/face/test.prototxt"
caffemodel = "/data/kju/facedetect/face.caffemodel"
caffe.set_mode_gpu()
caffe.set_device(0)
cfg.GPU_ID = 0
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
# Warmup on a dummy image
im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
for i in xrange(2):
_, _= im_detect(net, im)
im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
'001763.jpg', '004545.jpg', 'test.jpg']
for im_name in im_names:
print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
print 'Demo for data/demo/{}'.format(im_name)
demo(net, im_name)
#plt.show()
cv2.waitKey(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment