Created
May 17, 2019 08:58
-
-
Save qmaruf/cec27bed38b352e90d9faf9820bf059c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib | |
matplotlib.use('TkAgg') | |
import ntpath | |
import numpy as np | |
import os | |
import six.moves.urllib as urllib | |
import sys | |
import tarfile | |
import tensorflow as tf | |
import zipfile | |
import cv2 | |
from distutils.version import StrictVersion | |
from collections import defaultdict | |
from io import StringIO | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
from glob import glob | |
sys.path.append("/media/quazi/DATADRIVE1/tensorflow/models/research/") | |
sys.path.append("/media/quazi/DATADRIVE1/tensorflow/models/research/object_detection/") | |
from object_detection.utils import ops as utils_ops | |
if StrictVersion(tf.__version__) < StrictVersion('1.9.0'): | |
raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!') | |
from utils import label_map_util | |
from utils import visualization_utils as vis_util | |
class ObjectDetection: | |
def __init__(self, path_to_frozen_graph = None, path_to_labels = None): | |
self.path_to_frozen_graph = path_to_frozen_graph #'/media/quazi/DATADRIVE1/data/tf_models/faster_rcnn_resnet101_kitti_2018_01_28/frozen_inference_graph.pb' | |
self.path_to_labels = path_to_labels #= '/media/quazi/DATADRIVE1/data/tf_models/faster_rcnn_resnet101_kitti_2018_01_28/kitti_label_map.pbtxt' | |
self.category_index = label_map_util.create_category_index_from_labelmap(self.path_to_labels, use_display_name=True) | |
self.detection_graph = tf.Graph() | |
with self.detection_graph.as_default(): | |
od_graph_def = tf.GraphDef() | |
with tf.gfile.GFile(self.path_to_frozen_graph, 'rb') as fid: | |
serialized_graph = fid.read() | |
od_graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(od_graph_def, name='') | |
def read_images(self, paths, resize=True): | |
images = [] | |
for path in paths: | |
img = cv2.imread(path) | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
if resize: | |
img = cv2.resize(img, (300, 300)) | |
images.append(img) | |
images = np.array(images) | |
return images | |
def run_inference_for_multiple_images(self, image_paths, batch_size=16): | |
detections = list() | |
with self.detection_graph.as_default(): | |
with tf.Session() as sess: | |
ops = tf.get_default_graph().get_operations() | |
#all_tensor_names = {output.name for op in ops for output in op.outputs} | |
tensor_dict = {} | |
for key in ['num_detections', 'detection_boxes', 'detection_scores','detection_classes', 'detection_masks']: | |
tensor_name = key + ':0' | |
#if tensor_name in all_tensor_names: | |
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name) | |
for i in range(0, len(image_paths), batch_size): | |
image_paths_batch = image_paths[i:i+batch_size] | |
images = self.read_images(image_paths_batch) | |
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') | |
output_dict = sess.run(tensor_dict, feed_dict={image_tensor: images}) | |
print (images.shape) | |
for j in range(images.shape[0]): | |
dd = dict() | |
dd['num_detections'] = int(output_dict['num_detections'][j]) | |
dd['detection_classes'] = output_dict['detection_classes'][j].astype(np.int64) | |
dd['detection_boxes'] = output_dict['detection_boxes'][j] | |
dd['detection_scores'] = output_dict['detection_scores'][j] | |
detections.append(dd) | |
return detections | |
object_detection = ObjectDetection( | |
path_to_frozen_graph = '/media/quazi/DATADRIVE1/data/ssd_inception_v2_coco/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb', | |
path_to_labels = '/media/quazi/DATADRIVE1/data/ssd_inception_v2_coco/mscoco_label_map.pbtxt') | |
image_paths = glob('/media/quazi/DATADRIVE1/data/coco/val2017/*.jpg')[:10] | |
detections = object_detection.run_inference_for_multiple_images(image_paths, batch_size=16) | |
for id, (image_path, detection) in enumerate(zip(image_paths, detections)): | |
img = object_detection.read_images([image_path], resize=False)[0] | |
img_height, img_width, _ = img.shape | |
for obj_class, obj_bbox, obj_score in zip(detection['detection_classes'], detection['detection_boxes'], detection['detection_scores']): | |
if obj_score >= 0.5: | |
ymin, xmin, ymax, xmax = obj_bbox | |
ymin = int(ymin * img_height) | |
ymax = int(ymax * img_height) | |
xmin = int(xmin * img_width) | |
xmax = int(xmax * img_width) | |
img = cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 0), 5) | |
cv2.imwrite('./img_%d.jpg'%id, img) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment