Skip to content

Instantly share code, notes, and snippets.

@wxianfeng
Created March 28, 2018 07:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wxianfeng/ff03ac7ad516ede72a82f2dd8497f7f4 to your computer and use it in GitHub Desktop.
Save wxianfeng/ff03ac7ad516ede72a82f2dd8497f7f4 to your computer and use it in GitHub Desktop.
tensorflow 使用模型检测照片物体
# python demo/infer.py 原始图片路径 模型文件 结果输出文件
# python demo/infer.py /data/photos/car_trash.jpeg demo/output/frozen_inference_graph.pb demo/result.json
import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
import json
from PIL import Image
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
if len(sys.argv) < 3:
print('Usage: python {} test_image_path checkpoint_path'.format(sys.argv[0]))
exit()
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = sys.argv[2]
PATH_OUTPUT = sys.argv[3]
PATH_TO_LABELS = 'data/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
test_annos = dict()
def get_results(boxes, classes, scores, category_index, im_width, im_height,
min_score_thresh=.5):
bboxes = list()
for i, box in enumerate(boxes):
if scores[i] > min_score_thresh:
ymin, xmin, ymax, xmax = box
bbox = {
'bbox': {
'xmax': xmax * im_width,
'xmin': xmin * im_width,
'ymax': ymax * im_height,
'ymin': ymin * im_height
},
'category': category_index[classes[i]]['name'],
'score': float(scores[i])
}
bboxes.append(bbox)
return bboxes
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
im_width, im_height, _ = image_np.shape
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
# vis_util.visualize_boxes_and_labels_on_image_array(
# image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
# category_index, use_normalized_coordinates=True, line_thickness=8)
# plt.figure(figsize=IMAGE_SIZE)
# plt.imshow(image_np)
# if flag:
# total_time += end_time - start_time
# else:
# flag = True
test_annos["001"] = {'objects': get_results(
np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index,
im_width, im_height)}
# print('total time: {}, total images: {}, average time: {}'.format(
# total_time, len(test_ids), total_time / len(test_ids)))
test_annos = {'imgs': test_annos}
print(test_annos)
fd = open(PATH_OUTPUT, 'w')
json.dump(test_annos, fd)
fd.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment