Skip to content

Instantly share code, notes, and snippets.

@aic25
Forked from allskyee/faster_rcnn_webcam.py
Created April 1, 2019 03:56
Show Gist options
  • Save aic25/f5b08a01a549f7baa8b2d8d4c836cbc8 to your computer and use it in GitHub Desktop.
Save aic25/f5b08a01a549f7baa8b2d8d4c836cbc8 to your computer and use it in GitHub Desktop.
Faster RCNN (ZFnet) detection and classification on image from webcam
#!/usr/bin/env python
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see https://github.com/rbgirshick/py-faster-rcnn/blob/master/LICENSE for details]
# Written by Ross Girshick
# Modified by Sky Chon for webcam use
# --------------------------------------------------------
import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
import sys
from threading import Thread, Lock
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
class WebcamVideoStream :
def __init__(self, src = 0, width = 320, height = 240) :
self.stream = cv2.VideoCapture(src)
self.stream.set(cv2.cv.CV_CAP_PROP_FRAME_WIDTH, width)
self.stream.set(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT, height)
(self.grabbed, self.frame) = self.stream.read()
self.started = False
self.read_lock = Lock()
def start(self) :
if self.started :
print "already started!!"
return None
self.started = True
self.thread = Thread(target=self.update, args=())
self.thread.start()
return self
def update(self) :
while self.started :
(grabbed, frame) = self.stream.read()
self.read_lock.acquire()
self.grabbed, self.frame = grabbed, frame
self.read_lock.release()
def read(self) :
self.read_lock.acquire()
frame = self.frame.copy()
self.read_lock.release()
return frame
def stop(self) :
self.started = False
def stop(self) :
self.started = False
self.thread.join()
def __exit__(self, exc_type, exc_value, traceback) :
self.stream.release()
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Faster R-CNN demo')
parser.add_argument('--src', dest='src', help='video device source [0]',
default=0, type=int)
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--cpu', dest='cpu_mode',
help='Use CPU mode (overrides --gpu)',
action='store_true')
parser.add_argument('--width', dest='width', help='webcam feed width',
default=640, type=int)
parser.add_argument('--height', dest='height', help='webcam feed height',
default=480, type=int)
return parser.parse_args()
if __name__ == "__main__" :
cfg.TEST.HAS_RPN = True # Use RPN for proposals
args = parse_args()
prototxt = os.path.join(cfg.MODELS_DIR, "ZF", "faster_rcnn_end2end", "test.prototxt")
caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models', "ZF_faster_rcnn_final.caffemodel")
if not os.path.isfile(caffemodel):
raise IOError(('{:s} not found.\nDid you run ./data/script/'
'fetch_faster_rcnn_models.sh?').format(caffemodel))
if args.cpu_mode:
caffe.set_mode_cpu()
else:
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id)
cfg.GPU_ID = args.gpu_id
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
print '\n\nLoaded network {:s}'.format(caffemodel)
print "starting capture"
vs = WebcamVideoStream(args.src, args.width, args.height).start()
while True :
frame = vs.read()
# do detection and classification
timer = Timer()
timer.tic()
scores, boxes = im_detect(net, frame)
timer.toc()
# print stats
print ('Detection took {:.3f}s for '
'{:d} object proposals').format(timer.total_time, boxes.shape[0])
CONF_THRESH = 0.8
NMS_THRESH = 0.3
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
inds = np.where(dets[:, -1] >= CONF_THRESH)[0]
if len(inds) == 0:
continue
#print CLASSES[cls_ind], "detected"
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
cv2.rectangle(frame, (bbox[0], bbox[1]),
(bbox[2], bbox[3]), (0, 255, 0), 3)
cv2.putText(frame, "%s %f" % (CLASSES[cls_ind], score),
(int(bbox[0] + 10), int(bbox[1] + 10)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
# show image
cv2.imshow('webcam', frame)
if cv2.waitKey(1) == 27 :
break
vs.stop()
cv2.destroyAllWindows()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment