Created
April 7, 2020 16:04
-
-
Save ponta256/e623a4e3cd8d61964a8b95c98039cade to your computer and use it in GitHub Desktop.
prediction code for pose estimation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import os | |
import pprint | |
import torch | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim | |
import torch.utils.data | |
import torch.utils.data.distributed | |
import torchvision.transforms as transforms | |
import _init_paths | |
from config import cfg | |
from config import update_config | |
from core.loss import JointsMSELoss | |
from core.function import validate | |
from utils.utils import create_logger | |
from utils.utils import get_model_summary | |
import pprint | |
import dataset | |
import models | |
import cv2 | |
from ptflops import get_model_complexity_info | |
from core.inference import get_max_preds | |
def get_keypoints(input_image, model): | |
''' | |
Calculates keypoints based on resnet | |
Input: Image | |
Output: List of 19 Keypoints locations and probabilities | |
''' | |
H = 64 | |
W = 48 | |
img = cv2.imread(input_image, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) | |
height, width, channels = img.shape | |
img = cv2.resize(img, (192,256)) | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
toTensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | |
x = toTensor(img).unsqueeze(0) | |
model.eval() | |
with torch.no_grad(): | |
res = model(x) | |
preds, maxvals = get_max_preds(res.clone().cpu().numpy()) | |
probs = [] | |
for prob in maxvals[0]: | |
probs.append(prob[0]) | |
points = [] | |
for arr in preds[0]: | |
points.append((arr[0], arr[1])) | |
d1_x = (points[5][0] + points[6][0]) / 2 | |
d1_y = (points[5][1] + points[6][1]) / 2 | |
d2_x = (points[11][0] + points[12][0]) / 2 | |
d2_y = (points[11][1] + points[12][1]) / 2 | |
prob_17 = (probs[5] + probs[6]) / 2 | |
prob_18 = (probs[11] + probs[12]) / 2 | |
points.append((d1_x, d1_y)) | |
points.append((d2_x, d2_y)) | |
probs.append(prob_17) | |
probs.append(prob_18) | |
resize = [] | |
for coord in points: | |
x = (coord[0] / W) * width | |
y = (coord[1] / H) * height | |
resize.append((x, y)) | |
return resize, probs | |
def draw(image, points, probs, res, threshold): | |
''' | |
Input: Image, Keypoints, Probabilities, Resolution, Threshold | |
Draws keypoints on image if probability is greater than threshold | |
''' | |
rounded = [] | |
for point in points: | |
x = int(round(point[0])) | |
y = int(round(point[1])) | |
rounded.append((x, y)) | |
def draw_line(index_1, index_2): | |
if (rounded[index_1] > (0, 0)) and (rounded[index_2] > (0, 0)): | |
if (probs[index_1] > threshold) and (probs[index_2] > threshold): | |
cv2.line(image, rounded[index_1], rounded[index_2], (255, 255, 0), res) | |
draw_line(0, 1) | |
draw_line(0, 2) | |
draw_line(1, 3) | |
draw_line(2, 4) | |
draw_line(0, 17) | |
draw_line(17, 5) | |
draw_line(17, 6) | |
draw_line(6, 8) | |
draw_line(8, 10) | |
draw_line(5, 7) | |
draw_line(7, 9) | |
draw_line(17, 18) | |
draw_line(18, 12) | |
draw_line(18, 11) | |
draw_line(12, 11) | |
draw_line(12, 14) | |
draw_line(14, 16) | |
draw_line(11, 13) | |
draw_line(13, 15) | |
for i in range(19): | |
if (rounded[i] > (0, 0)) and (probs[i] > threshold): | |
cv2.circle(image, rounded[i], res, (255, 0, 0), thickness=-1, lineType=cv2.FILLED) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("image", nargs='?', help = "the image that you want to input") | |
parser.add_argument("--output", help = "the output filename", default = "output.jpg") | |
parser.add_argument("--threshold", help = "probability of the keypoint that should appear greater than this threshold", type = int, default = 0.1) | |
parser.add_argument("--thickness", help = "thickness of the line", type = int, default = 8) | |
parser.add_argument('--cfg', | |
help='experiment configure file name', | |
required=True, | |
type=str) | |
parser.add_argument('opts', | |
help="Modify config options using the command-line", | |
default=None, | |
nargs=argparse.REMAINDER) | |
parser.add_argument('--modelDir', | |
help='model directory', | |
type=str, | |
default='') | |
parser.add_argument('--logDir', | |
help='log directory', | |
type=str, | |
default='') | |
parser.add_argument('--dataDir', | |
help='data directory', | |
type=str, | |
default='') | |
parser.add_argument('--prevModelDir', | |
help='prev Model directory', | |
type=str, | |
default='') | |
parser.add_argument('--fileList', | |
help='', | |
type=str, | |
default=None) | |
args = parser.parse_args() | |
update_config(cfg, args) | |
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(cfg, is_train=False) | |
# dump_input = torch.rand((1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])) | |
# flops, params = get_model_complexity_info(model, (3, 192, 256), as_strings=True, print_per_layer_stat=True) | |
# print('{:<30} {:<8}'.format('Computational complexity: ', flops)) | |
# print('{:<30} {:<8}'.format('Number of parameters: ', params)) | |
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) | |
if args.fileList is not None: | |
f = open(args.fileList, 'r') | |
flist = f.read().rstrip('\n').split('\n') | |
for fname in flist: | |
keypoints, probs = get_keypoints(fname, model) | |
print('processing... {}'.format(fname)) | |
image = cv2.imread(fname) | |
draw(image, keypoints, probs, args.thickness, args.threshold) | |
cv2.imwrite(fname.split('.jpg')[0]+'_'+args.image+'_out.jpg', image) | |
else: | |
keypoints, probs = get_keypoints(args.image, model) | |
image = cv2.imread(args.image) | |
draw(image, keypoints, probs, args.thickness, args.threshold) | |
cv2.imwrite(args.output, image) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment