Skip to content

Instantly share code, notes, and snippets.

@dominiek
Created February 8, 2017 00:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dominiek/f3aed0ef279b7de78d9aa0af78c6607d to your computer and use it in GitHub Desktop.
Save dominiek/f3aed0ef279b7de78d9aa0af78c6607d to your computer and use it in GitHub Desktop.
import tensorflow as tf
import os
import json
import subprocess
from scipy.misc import imread, imresize
from scipy import misc
from train import build_forward
from utils.annolist import AnnotationLib as al
from utils.train_utils import add_rectangles, rescale_boxes
from matplotlib import pyplot as plt
import cv2
import argparse
def get_results(args, H):
tf.reset_default_graph()
x_in = tf.placeholder(tf.float32, name='x_in', shape=[H['image_height'], H['image_width'], 3])
if H['use_rezoom']:
pred_boxes, pred_logits, pred_confidences, pred_confs_deltas, pred_boxes_deltas = build_forward(H, tf.expand_dims(x_in, 0), 'test', reuse=None)
grid_area = H['grid_height'] * H['grid_width']
pred_confidences = tf.reshape(tf.nn.softmax(tf.reshape(pred_confs_deltas, [grid_area * H['rnn_len'], 2])), [grid_area, H['rnn_len'], 2])
if H['reregress']:
pred_boxes = pred_boxes + pred_boxes_deltas
else:
pred_boxes, pred_logits, pred_confidences = build_forward(H, tf.expand_dims(x_in, 0), 'test', reuse=None)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.restore(sess, args.weights)
orig_img = imread(args.image)[:,:,:3]
img = imresize(orig_img, (H["image_height"], H["image_width"]), interp='cubic')
feed = {x_in: img}
(np_pred_boxes, np_pred_confidences) = sess.run([pred_boxes, pred_confidences], feed_dict=feed)
new_img, rects = add_rectangles(H, [img], np_pred_confidences, np_pred_boxes,
use_stitching=True, rnn_len=H['rnn_len'], min_conf=args.min_conf, tau=args.tau, show_suppressed=args.show_suppressed)
plt.imshow(new_img)
plt.show()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', required=True)
parser.add_argument('--expname', default='')
parser.add_argument('--gpu', default=0)
parser.add_argument('--image', required=True)
parser.add_argument('--iou_threshold', default=0.5, type=float)
parser.add_argument('--tau', default=0.25, type=float)
parser.add_argument('--min_conf', default=0.2, type=float)
parser.add_argument('--show_suppressed', default=True, type=bool)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
hypes_file = '%s/hypes.json' % os.path.dirname(args.weights)
with open(hypes_file, 'r') as f:
H = json.load(f)
expname = args.expname + '_' if args.expname else ''
get_results(args, H)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment