Skip to content

Instantly share code, notes, and snippets.

@ponta256
Created Apr 7, 2020
Embed
What would you like to do?
prediction code for pose estimation
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
from core.function import validate
from utils.utils import create_logger
from utils.utils import get_model_summary
import pprint
import dataset
import models
import cv2
from ptflops import get_model_complexity_info
from core.inference import get_max_preds
def get_keypoints(input_image, model):
'''
Calculates keypoints based on resnet
Input: Image
Output: List of 19 Keypoints locations and probabilities
'''
H = 64
W = 48
img = cv2.imread(input_image, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
height, width, channels = img.shape
img = cv2.resize(img, (192,256))
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
toTensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
x = toTensor(img).unsqueeze(0)
model.eval()
with torch.no_grad():
res = model(x)
preds, maxvals = get_max_preds(res.clone().cpu().numpy())
probs = []
for prob in maxvals[0]:
probs.append(prob[0])
points = []
for arr in preds[0]:
points.append((arr[0], arr[1]))
d1_x = (points[5][0] + points[6][0]) / 2
d1_y = (points[5][1] + points[6][1]) / 2
d2_x = (points[11][0] + points[12][0]) / 2
d2_y = (points[11][1] + points[12][1]) / 2
prob_17 = (probs[5] + probs[6]) / 2
prob_18 = (probs[11] + probs[12]) / 2
points.append((d1_x, d1_y))
points.append((d2_x, d2_y))
probs.append(prob_17)
probs.append(prob_18)
resize = []
for coord in points:
x = (coord[0] / W) * width
y = (coord[1] / H) * height
resize.append((x, y))
return resize, probs
def draw(image, points, probs, res, threshold):
'''
Input: Image, Keypoints, Probabilities, Resolution, Threshold
Draws keypoints on image if probability is greater than threshold
'''
rounded = []
for point in points:
x = int(round(point[0]))
y = int(round(point[1]))
rounded.append((x, y))
def draw_line(index_1, index_2):
if (rounded[index_1] > (0, 0)) and (rounded[index_2] > (0, 0)):
if (probs[index_1] > threshold) and (probs[index_2] > threshold):
cv2.line(image, rounded[index_1], rounded[index_2], (255, 255, 0), res)
draw_line(0, 1)
draw_line(0, 2)
draw_line(1, 3)
draw_line(2, 4)
draw_line(0, 17)
draw_line(17, 5)
draw_line(17, 6)
draw_line(6, 8)
draw_line(8, 10)
draw_line(5, 7)
draw_line(7, 9)
draw_line(17, 18)
draw_line(18, 12)
draw_line(18, 11)
draw_line(12, 11)
draw_line(12, 14)
draw_line(14, 16)
draw_line(11, 13)
draw_line(13, 15)
for i in range(19):
if (rounded[i] > (0, 0)) and (probs[i] > threshold):
cv2.circle(image, rounded[i], res, (255, 0, 0), thickness=-1, lineType=cv2.FILLED)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("image", nargs='?', help = "the image that you want to input")
parser.add_argument("--output", help = "the output filename", default = "output.jpg")
parser.add_argument("--threshold", help = "probability of the keypoint that should appear greater than this threshold", type = int, default = 0.1)
parser.add_argument("--thickness", help = "thickness of the line", type = int, default = 8)
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
parser.add_argument('--fileList',
help='',
type=str,
default=None)
args = parser.parse_args()
update_config(cfg, args)
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(cfg, is_train=False)
# dump_input = torch.rand((1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
# flops, params = get_model_complexity_info(model, (3, 192, 256), as_strings=True, print_per_layer_stat=True)
# print('{:<30} {:<8}'.format('Computational complexity: ', flops))
# print('{:<30} {:<8}'.format('Number of parameters: ', params))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
if args.fileList is not None:
f = open(args.fileList, 'r')
flist = f.read().rstrip('\n').split('\n')
for fname in flist:
keypoints, probs = get_keypoints(fname, model)
print('processing... {}'.format(fname))
image = cv2.imread(fname)
draw(image, keypoints, probs, args.thickness, args.threshold)
cv2.imwrite(fname.split('.jpg')[0]+'_'+args.image+'_out.jpg', image)
else:
keypoints, probs = get_keypoints(args.image, model)
image = cv2.imread(args.image)
draw(image, keypoints, probs, args.thickness, args.threshold)
cv2.imwrite(args.output, image)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment