Created
October 18, 2018 18:20
-
-
Save hcl14/b21ecb4abed75587587d49397062ce50 to your computer and use it in GitHub Desktop.
Windows camera fix for https://github.com/victordibia/skyfall
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Author: Victor Dibia | |
## Load hand tracking model, spin up web socket and web application. | |
from utils import detector_utils as detector_utils | |
from utils import object_id_utils as id_utils | |
import cv2 | |
import tensorflow as tf | |
import multiprocessing | |
from multiprocessing import Queue, Pool | |
from utils.detector_utils import WebcamVideoStream | |
import time | |
import datetime | |
import argparse | |
from threading import Thread | |
from flask import Flask, render_template, request, jsonify | |
from utils import web_socket_server | |
from utils import web_socket_client | |
frame_processed = 0 | |
score_thresh = 0.7 | |
num_hands_detect = 10 | |
num_classes = 1 | |
web_socket_port = 5006 | |
web_socket_client_url = "ws://localhost:5006" | |
web_socket_server.init(web_socket_port) | |
# Initialize websocket client for sending messages | |
web_socket_client.socket_init(web_socket_client_url) | |
# Create a worker thread that loads graph and | |
# does detection on images in an input queue and puts it on an output queue | |
label_path = "hand_inference_graph/hand_label_map.pbtxt" | |
frozen_graph_path = "hand_inference_graph/frozen_inference_graph.pb" | |
object_refresh_timeout = 3 | |
seen_object_list = {} | |
# Set up web application serving | |
app = Flask(__name__, ) | |
@app.route("/") | |
def hello(): | |
return render_template('mousecontrol.html') | |
@app.route("/hand") | |
def test(): | |
return render_template('handcontrol.html') | |
# Worker threads that process video frame | |
def worker(input_q, output_q, cap_params, frame_processed): | |
print(">> loading frozen model for worker") | |
detection_graph, sess, category_index = detector_utils.load_inference_graph(num_classes, frozen_graph_path, label_path) | |
sess = tf.Session(graph=detection_graph) | |
while True: | |
#print("> ===== in worker loop, frame ", frame_processed) | |
frame = input_q.get() | |
if (frame is not None): | |
# actual detection | |
boxes, scores, classes = detector_utils.detect_objects( | |
frame, detection_graph, sess) | |
tags = detector_utils.get_tags(classes, category_index, num_hands_detect, score_thresh, scores, boxes, frame) | |
if (len(tags) > 0): | |
id_utils.get_id(tags, seen_object_list) | |
web_socket_client.send_message(tags,"hand") | |
id_utils.refresh_seen_object_list(seen_object_list, object_refresh_timeout) | |
detector_utils.draw_box_on_image_id(tags, frame) | |
output_q.put(frame) | |
frame_processed += 1 | |
else: | |
output_q.put(frame) | |
sess.close() | |
def launch_webserver(): | |
app.config['APPLICATION_ROOT'] = "/static" | |
app.run(host='0.0.0.0', port=5005) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-src', '--source', dest='video_source', type=int, | |
default=0, help='Device index of the camera.') | |
parser.add_argument('-nhands', '--num_hands', dest='num_hands', type=int, | |
default=2, help='Max number of hands to detect.') | |
parser.add_argument('-fps', '--fps', dest='fps', type=int, | |
default=0, help='Show FPS on detection/display visualization') | |
parser.add_argument('-wd', '--width', dest='width', type=int, | |
default=200, help='Width of the frames in the video stream.') | |
parser.add_argument('-ht', '--height', dest='height', type=int, | |
default=180, help='Height of the frames in the video stream.') | |
parser.add_argument('-ds', '--display', dest='display', type=int, | |
default=1, help='Display the detected images using OpenCV. This reduces FPS') | |
parser.add_argument('-num-w', '--num-workers', dest='num_workers', type=int, | |
default=2, help='Number of workers.') | |
parser.add_argument('-q-size', '--queue-size', dest='queue_size', type=int, | |
default=5, help='Size of the queue.') | |
args = parser.parse_args() | |
input_q = Queue(maxsize=args.queue_size) | |
output_q = Queue(maxsize=args.queue_size) | |
video_device_id = 0 | |
''' | |
video_capture = WebcamVideoStream(src=video_device_id, | |
width=args.width, | |
height=args.height).start() | |
''' | |
video_capture = cv2.VideoCapture(0)#(cv2.CAP_DSHOW) | |
video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, args.width) | |
video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height) | |
cap_params = {} | |
frame_processed = 0 | |
width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # float | |
height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # float | |
print(width) | |
print(height) | |
cap_params['im_width'], cap_params['im_height'] = width, height #video_capture.size() | |
cap_params['score_thresh'] = score_thresh | |
# max number of hands we want to detect/track | |
cap_params['num_hands_detect'] = args.num_hands | |
print(cap_params, args) | |
# spin up workers to paralleize detection. | |
pool = Pool(args.num_workers, worker, | |
(input_q, output_q, cap_params, frame_processed)) | |
start_time = datetime.datetime.now() | |
num_frames = 0 | |
fps = 0 | |
index = 0 | |
# run web application | |
Thread(target=launch_webserver).start() | |
while True: | |
ret, frame = video_capture.read() | |
frame = cv2.flip(frame, 1) | |
index += 1 | |
input_q.put(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
output_frame = output_q.get() | |
output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR) | |
elapsed_time = (datetime.datetime.now() - | |
start_time).total_seconds() | |
num_frames += 1 | |
fps = num_frames / elapsed_time | |
# print("frame ", index, num_frames, elapsed_time, fps) | |
cv2.namedWindow("Hand Tracking",cv2.WINDOW_NORMAL) | |
if (output_frame is not None): | |
if (args.display > 0): | |
if (args.fps > 0): | |
detector_utils.draw_fps_on_image( | |
"FPS : " + str(int(fps)), output_frame) | |
cv2.imshow('Hand Tracking', output_frame) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
else: | |
if (num_frames == 400): | |
num_frames = 0 | |
start_time = datetime.datetime.now() | |
else: | |
print("frames processed: ", index, | |
"elapsed time: ", elapsed_time, "fps: ", str(int(fps))) | |
else: | |
print("video end") | |
break | |
elapsed_time = (datetime.datetime.now() - | |
start_time).total_seconds() | |
fps = num_frames / elapsed_time | |
print("fps", fps) | |
pool.terminate() | |
video_capture.stop() | |
cv2.destroyAllWindows() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment