Skip to content

Instantly share code, notes, and snippets.

@markdtw
Created December 29, 2017 07:30
Show Gist options
  • Save markdtw/04ff687ab2e48e916bb24a83ff87b62a to your computer and use it in GitHub Desktop.
Save markdtw/04ff687ab2e48e916bb24a83ff87b62a to your computer and use it in GitHub Desktop.
python interface to inference flownet 2.0 (CVPR'17)
from __future__ import absolute_import, division, print_function
import pdb
import tempfile
import argparse
import numpy as np
import PIL.Image as PILI
import caffe
def visualize(image, flow, im2W=None):
import cv2
hsv = np.zeros(image.shape, dtype=np.uint8)
hsv[:, :, 0] = 255
hsv[:, :, 1] = 255
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
hsv[..., 0] = ang * 180 / np.pi / 2
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
PILI.show(rgb)
def forward_pass(net, input_dict):
# There is some non-deterministic nan-bug in caffe
# it seems to be a race-condition
i = 0
while i < 3:
net.forward(**input_dict)
containsNaN = False
for name in net.blobs:
blob = net.blobs[name]
has_nan = np.isnan(blob.data[...]).any()
if has_nan:
containsNaN = True
if not containsNaN:
break
else:
print('FOUND NANs, RETRYING...')
i += 1
blob = np.squeeze(net.blobs['predict_flow_final'].data).transpose(1, 2, 0)
return blob # (960, 1920, 2)
def read_input(img0_path, img1_path, num_blobs, net):
input_data = []
img0 = PILI.open(img0_path)
img0 = np.asarray(img0)
input_data.append(img0[np.newaxis, :, :, :].transpose(0, 3, 1, 2)[:, [2, 1, 0], :, :])
img1 = PILI.open(img1_path)
img1 = np.asarray(img1)
input_data.append(img1[np.newaxis, :, :, :].transpose(0, 3, 1, 2)[:, [2, 1, 0], :, :])
input_dict = {}
for blob_idx in range(num_blobs):
input_dict[net.inputs[blob_idx]] = input_data[blob_idx]
return input_dict
def load_caffe_net(width, height, prototxt, caffemodel):
# This function load the pre-trained caffe model for inferencing
num_blobs = 2
divisor = 64.
vars = {}
vars['TARGET_WIDTH'] = width
vars['TARGET_HEIGHT'] = height
vars['ADAPTED_WIDTH'] = int(np.ceil(width/divisor) * divisor)
vars['ADAPTED_HEIGHT'] = int(np.ceil(height/divisor) * divisor)
vars['SCALE_WIDTH'] = width / float(vars['ADAPTED_WIDTH']);
vars['SCALE_HEIGHT'] = height / float(vars['ADAPTED_HEIGHT']);
tmp = tempfile.NamedTemporaryFile(mode='w', delete=True)
proto = open(prototxt).readlines()
for line in proto:
for key, value in vars.items():
tag = "$%s$" % key
line = line.replace(tag, str(value))
tmp.write(line)
tmp.flush()
caffe.set_logging_disabled()
caffe.set_device(0)
caffe.set_mode_gpu()
net = caffe.Net(tmp.name, caffemodel, caffe.TEST)
return num_blobs, net
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--caffeh5', help='path to model')
parser.add_argument('--proto', help='path to deploy prototxt template')
parser.add_argument('--img0', help='image 0 path')
parser.add_argument('--img1', help='image 1 path')
args = parser.parse_args()
eximg = PILI.open(args.img0)
eximg = np.asarray(eximg)
num_blobs, net = load_caffe_net(eximg.shape[1], eximg.shape[0], args.proto, args.caffeh5)
input_dict = read_input(args.img0, args.img1, num_blobs, net)
flow = forward_pass(net, input_dict)
visualize(eximg, flow)
print ('Done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment