Created
December 29, 2017 07:30
-
-
Save markdtw/04ff687ab2e48e916bb24a83ff87b62a to your computer and use it in GitHub Desktop.
python interface to inference flownet 2.0 (CVPR'17)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import, division, print_function | |
import pdb | |
import tempfile | |
import argparse | |
import numpy as np | |
import PIL.Image as PILI | |
import caffe | |
def visualize(image, flow, im2W=None): | |
import cv2 | |
hsv = np.zeros(image.shape, dtype=np.uint8) | |
hsv[:, :, 0] = 255 | |
hsv[:, :, 1] = 255 | |
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) | |
hsv[..., 0] = ang * 180 / np.pi / 2 | |
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) | |
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) | |
PILI.show(rgb) | |
def forward_pass(net, input_dict): | |
# There is some non-deterministic nan-bug in caffe | |
# it seems to be a race-condition | |
i = 0 | |
while i < 3: | |
net.forward(**input_dict) | |
containsNaN = False | |
for name in net.blobs: | |
blob = net.blobs[name] | |
has_nan = np.isnan(blob.data[...]).any() | |
if has_nan: | |
containsNaN = True | |
if not containsNaN: | |
break | |
else: | |
print('FOUND NANs, RETRYING...') | |
i += 1 | |
blob = np.squeeze(net.blobs['predict_flow_final'].data).transpose(1, 2, 0) | |
return blob # (960, 1920, 2) | |
def read_input(img0_path, img1_path, num_blobs, net): | |
input_data = [] | |
img0 = PILI.open(img0_path) | |
img0 = np.asarray(img0) | |
input_data.append(img0[np.newaxis, :, :, :].transpose(0, 3, 1, 2)[:, [2, 1, 0], :, :]) | |
img1 = PILI.open(img1_path) | |
img1 = np.asarray(img1) | |
input_data.append(img1[np.newaxis, :, :, :].transpose(0, 3, 1, 2)[:, [2, 1, 0], :, :]) | |
input_dict = {} | |
for blob_idx in range(num_blobs): | |
input_dict[net.inputs[blob_idx]] = input_data[blob_idx] | |
return input_dict | |
def load_caffe_net(width, height, prototxt, caffemodel): | |
# This function load the pre-trained caffe model for inferencing | |
num_blobs = 2 | |
divisor = 64. | |
vars = {} | |
vars['TARGET_WIDTH'] = width | |
vars['TARGET_HEIGHT'] = height | |
vars['ADAPTED_WIDTH'] = int(np.ceil(width/divisor) * divisor) | |
vars['ADAPTED_HEIGHT'] = int(np.ceil(height/divisor) * divisor) | |
vars['SCALE_WIDTH'] = width / float(vars['ADAPTED_WIDTH']); | |
vars['SCALE_HEIGHT'] = height / float(vars['ADAPTED_HEIGHT']); | |
tmp = tempfile.NamedTemporaryFile(mode='w', delete=True) | |
proto = open(prototxt).readlines() | |
for line in proto: | |
for key, value in vars.items(): | |
tag = "$%s$" % key | |
line = line.replace(tag, str(value)) | |
tmp.write(line) | |
tmp.flush() | |
caffe.set_logging_disabled() | |
caffe.set_device(0) | |
caffe.set_mode_gpu() | |
net = caffe.Net(tmp.name, caffemodel, caffe.TEST) | |
return num_blobs, net | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--caffeh5', help='path to model') | |
parser.add_argument('--proto', help='path to deploy prototxt template') | |
parser.add_argument('--img0', help='image 0 path') | |
parser.add_argument('--img1', help='image 1 path') | |
args = parser.parse_args() | |
eximg = PILI.open(args.img0) | |
eximg = np.asarray(eximg) | |
num_blobs, net = load_caffe_net(eximg.shape[1], eximg.shape[0], args.proto, args.caffeh5) | |
input_dict = read_input(args.img0, args.img1, num_blobs, net) | |
flow = forward_pass(net, input_dict) | |
visualize(eximg, flow) | |
print ('Done') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment