Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Created May 8, 2019 23:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhreshold/5f44dcb1a00f84bc1d981465f1c1f0e2 to your computer and use it in GitHub Desktop.
Save zhreshold/5f44dcb1a00f84bc1d981465f1c1f0e2 to your computer and use it in GitHub Desktop.
GluonCV cam demo pose
from __future__ import division
import argparse, time, logging, os, math, tqdm, cv2
import numpy as np
import mxnet as mx
from mxnet import gluon, nd, image
from mxnet.gluon.data.vision import transforms
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("TkAgg")
import gluoncv as gcv
from gluoncv import data
from gluoncv.data import mscoco
from gluoncv.model_zoo import get_model
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord
from gluoncv.utils.viz import plot_image, plot_keypoints
parser = argparse.ArgumentParser(description='Predict ImageNet classes from a given image')
parser.add_argument('--detector', type=str, default='yolo3_mobilenet1.0_coco',
help='name of the detection model to use')
parser.add_argument('--pose-model', type=str, default='simple_pose_resnet50_v1b',
help='name of the pose estimation model to use')
parser.add_argument('--num-frames', type=int, default=100,
help='Number of frames to capture')
opt = parser.parse_args()
def cv_plot_image(img, **kwargs):
if not img:
return
canvas = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imshow('demo', canvas)
cv2.waitKey(1)
def cv_plot_bbox(img, bboxes, scores=None, labels=None, thresh=0.5,
class_names=None, colors=None, ax=None,
reverse_rgb=False, absolute_coordinates=True):
from matplotlib import pyplot as plt
import random
if labels is not None and not len(bboxes) == len(labels):
raise ValueError('The length of labels and bboxes mismatch, {} vs {}'
.format(len(labels), len(bboxes)))
if scores is not None and not len(bboxes) == len(scores):
raise ValueError('The length of scores and bboxes mismatch, {} vs {}'
.format(len(scores), len(bboxes)))
if len(bboxes) < 1:
return img
if isinstance(bboxes, mx.nd.NDArray):
bboxes = bboxes.asnumpy()
if isinstance(labels, mx.nd.NDArray):
labels = labels.asnumpy()
if isinstance(scores, mx.nd.NDArray):
scores = scores.asnumpy()
if not absolute_coordinates:
# convert to absolute coordinates using image shape
height = img.shape[0]
width = img.shape[1]
bboxes[:, (0, 2)] *= width
bboxes[:, (1, 3)] *= height
# use random colors if None is provided
if colors is None:
colors = dict()
for i, bbox in enumerate(bboxes):
if scores is not None and scores.flat[i] < thresh:
continue
if labels is not None and labels.flat[i] < 0:
continue
cls_id = int(labels.flat[i]) if labels is not None else -1
if cls_id not in colors:
if class_names is not None:
colors[cls_id] = plt.get_cmap('hsv')(cls_id / len(class_names))
else:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin, ymin, xmax, ymax = [int(x) for x in bbox]
bcolor = [x * 255 for x in colors[cls_id]]
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), bcolor, 3)
if class_names is not None and cls_id < len(class_names):
class_name = class_names[cls_id]
else:
class_name = str(cls_id) if cls_id >= 0 else ''
score = '{:.3f}'.format(scores.flat[i]) if scores is not None else ''
if class_name or score:
pass
# ax.text(xmin, ymin - 2,
# '{:s} {:s}'.format(class_name, score),
# bbox=dict(facecolor=colors[cls_id], alpha=0.5),
# fontsize=12, color='white')
# cv2.putText(img, '{:s} {:s}'.format(class_name, score),
# (xmin, ymin-2), cv2.FONT_HERSHEY_TRIPLEX, 12, (255, 255, 255))
return img
def cv_plot_keypoints(img, coords, confidence, class_ids, bboxes, scores,
box_thresh=0.5, keypoint_thresh=0.2, **kwargs):
def to_int(float_arr):
return tuple([int(x) for x in float_arr])
if isinstance(coords, mx.nd.NDArray):
coords = coords.asnumpy()
if isinstance(class_ids, mx.nd.NDArray):
class_ids = class_ids.asnumpy()
if isinstance(bboxes, mx.nd.NDArray):
bboxes = bboxes.asnumpy()
if isinstance(scores, mx.nd.NDArray):
scores = scores.asnumpy()
if isinstance(confidence, mx.nd.NDArray):
confidence = confidence.asnumpy()
joint_visible = confidence[:, :, 0] > keypoint_thresh
joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4],
[5, 6], [5, 7], [7, 9], [6, 8], [8, 10],
[5, 11], [6, 12], [11, 12],
[11, 13], [12, 14], [13, 15], [14, 16]]
person_ind = class_ids[0] == 0
img = cv_plot_bbox(img, bboxes[0][person_ind[:, 0]],
scores[0][person_ind[:, 0]], thresh=box_thresh, **kwargs)
colormap_index = np.linspace(0, 1, len(joint_pairs))
for i in range(coords.shape[0]):
pts = coords[i]
for cm_ind, jp in zip(colormap_index, joint_pairs):
if joint_visible[i, jp[0]] and joint_visible[i, jp[1]]:
cm_color = tuple([int(x * 255) for x in plt.cm.cool(cm_ind)[1:]])
pt1 = (int(pts[jp, 0][0]), int(pts[jp, 1][0]))
pt2 = (int(pts[jp, 0][1]), int(pts[jp, 1][1]))
cv2.line(img, pt1, pt2, cm_color, 3)
# cv2.circle(img, pt1, 1, cm_color)
# cv2.circle(img, pt2, 1, cm_color)
canvas = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imshow('demo', canvas)
cv2.waitKey(1)
def keypoint_detection(img, detector, pose_net, ctx=mx.cpu(), axes=None):
x, img = gcv.data.transforms.presets.yolo.transform_test(img, short=512, max_size=350)
x = x.as_in_context(ctx)
class_IDs, scores, bounding_boxs = detector(x)
# plt.cla()
pose_input, upscale_bbox = detector_to_simple_pose(img, class_IDs, scores, bounding_boxs,
output_shape=(256, 192), ctx=ctx)
if len(upscale_bbox) > 0:
predicted_heatmap = pose_net(pose_input)
pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)
axes = cv_plot_keypoints(img, pred_coords, confidence, class_IDs, bounding_boxs, scores,
box_thresh=0.5, keypoint_thresh=0.2, ax=axes)
# plt.draw()
# plt.pause(0.001)
else:
axes = cv_plot_image(frame, ax=axes)
# plt.draw()
# plt.pause(0.001)
return axes
if __name__ == '__main__':
ctx = mx.cpu()
detector_name = "ssd_512_mobilenet1.0_coco"
detector = get_model(detector_name, pretrained=True, ctx=ctx)
detector.reset_class(classes=['person'], reuse_weights={'person':'person'})
net = get_model('simple_pose_resnet50_v1b', pretrained=True, ctx=ctx)
cap = cv2.VideoCapture(0)
time.sleep(1) ### letting the camera autofocus
axes = None
for i in range(opt.num_frames):
ret, frame = cap.read()
frame = mx.nd.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).astype('uint8')
axes = keypoint_detection(frame, detector, net, ctx, axes=axes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment