Last active
January 4, 2019 08:51
-
-
Save princefr/3b16c14d4af5c280f3a2892f6e50c18b to your computer and use it in GitHub Desktop.
gist detect
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
load_model() | |
# Adding some arg to a parser so its easy to manage | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-src', '--source', dest='video_source', type=int, default=0, help='Device index of the camera.') | |
parser.add_argument('-wd', '--width', dest='width', type=int, default=320, help='Width of the frames in the video stream.') | |
parser.add_argument('-ht', '--height', dest='height', type=int, default=360, help='Height of the frames in the video stream.') | |
args = parser.parse_args() | |
logger = multiprocessing.log_to_stderr() | |
logger.setLevel(multiprocessing.SUBDEBUG) | |
video_capture = cv2.VideoCapture(args.video_source) | |
# in this way it always works, because your get the right "size" | |
size = (int(video_capture.get(3)), | |
int(video_capture.get(4))) | |
fps = FPS().start() | |
while True: | |
ret, frame = video_capture.read() | |
if frame is None: | |
break | |
if ret == True: | |
boxes, scores, classes, num_detections, image = PersonDetection.GetHumans(frame) | |
# humans = poseEstimator.inference(frame) | |
# show, _, bboxes, _, _ = TfPoseEstimator.get_skeleton(frame, humans, imgcopy=False) | |
# print(bboxes) | |
t = time.time() | |
# show = TfPoseEstimator.draw_humans(frame, humans, imgcopy=False) | |
# imagetoShow = cv2.cvtColor(show, cv2.COLOR_RGB2BGR) | |
cv2.imshow('Video', image) | |
fps.update() | |
print('[INFO] elapsed time: {:.2f}'.format(time.time() - t)) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
fps.stop() | |
print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed())) | |
print('[INFO] approx. FPS: {:.2f}'.format(fps.fps())) | |
# cleaan up all | |
# pool.terminate() | |
out = None | |
video_capture.stop() | |
cv2.destroyAllWindows() | |
def detectObjects(image, sess, detection_graph): | |
# Expand dimensions since the model expects images to have shape: [1, None, None, 3] | |
image_np_expanded = np.expand_dims(image, axis=0) | |
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') | |
# Each box represents a part of the image where a particular object was detected. | |
boxes = detection_graph.get_tensor_by_name('detection_boxes:0') | |
# Each score represent how level of confidence for each of the objects. | |
# Score is shown on the result image, together with the class label. | |
scores = detection_graph.get_tensor_by_name('detection_scores:0') | |
classes = detection_graph.get_tensor_by_name('detection_classes:0') | |
num_detections = detection_graph.get_tensor_by_name('num_detections:0') | |
# Actual detection. | |
(boxes, scores, classes, num_detections) = sess.run( | |
[boxes, scores, classes, num_detections], | |
feed_dict={image_tensor: image_np_expanded}) | |
# Visualization of the results of a detection. | |
vis_util.visualize_boxes_and_labels_on_image_array( | |
image, | |
np.squeeze(boxes), | |
np.squeeze(classes).astype(np.int32), | |
np.squeeze(scores), | |
category_index, | |
use_normalized_coordinates=True, | |
line_thickness=1) | |
return boxes, scores, classes, num_detections, image | |
def GetHumans(frame): | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
od_graph_def = tf.GraphDef() | |
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: | |
tfgraph = trt.create_inference_graph( | |
input_graph_def=od_graph_def, | |
outputs=[“your_output_node_names”], | |
max_batch_size=your_batch_size, | |
max_workspace_size_bytes=max_GPU_mem_size_for_TRT, | |
precision_mode=”your_precision_mode”) | |
serialized_graph = fid.read() | |
od_graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(od_graph_def, name='') | |
sess = tf.Session(graph=detection_graph) | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
return detectObjects(frame_rgb, sess, detection_graph) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment