Skip to content

Instantly share code, notes, and snippets.

@ldcastell
Created July 18, 2018 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ldcastell/b2797cfc536e02e25450d28bd9e349e6 to your computer and use it in GitHub Desktop.
Save ldcastell/b2797cfc536e02e25450d28bd9e349e6 to your computer and use it in GitHub Desktop.
Python TensorFlow Object Detection client
from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function
import time
import argparse
from argparse import RawTextHelpFormatter
from grpc.beta import implementations
import numpy as np
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
from utils import load_image_into_numpy_array
from utils import visualize_serving_bounding_boxes
tf.logging.set_verbosity(tf.logging.INFO)
def load_input_tensor(input_image, input_type):
if input_type == 'image_tensor':
image_np = load_image_into_numpy_array(input_image)
image_np_expanded = np.expand_dims(image_np, axis=0).astype(np.float32)
tensor = tf.contrib.util.make_tensor_proto(image_np_expanded)
elif input_type == 'encoded_image_string_tensor':
with open(input_image, 'rb') as f:
data = f.read()
tensor = tf.contrib.util.make_tensor_proto(data, shape=[1])
else:
raise ValueError("Unsupported input type: %s" % input_type)
return tensor
def main(args):
host, port = args.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
# Create prediction request object
start_ts = time.time()
request = predict_pb2.PredictRequest()
# Specify model name (must be the same as when the TensorFlow serving
# was started)
request.model_spec.name = args.model_name
input_tensor = load_input_tensor(args.input_image, args.input_type)
request.inputs['inputs'].CopyFrom(input_tensor)
tf.logging.info("Image load time: %s sec" % (time.time() - start_ts))
# Call the prediction server
start_ts = time.time()
result = stub.Predict(request, 60.0) # 60 secs timeout
tf.logging.info("Inference time: %s sec" % (time.time() - start_ts))
visualize_serving_bounding_boxes(args.output_directory, args.input_image, args.label_map,
args.max_classes, result)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Object detection client.",
formatter_class=RawTextHelpFormatter)
parser.add_argument('--server',
type=str,
required=True,
help='PredictionService host:port')
parser.add_argument('--model_name',
type=str,
required=True,
help='Name of the model')
parser.add_argument('--input_image',
type=str,
required=True,
help='Path to input image')
parser.add_argument('--output_directory',
type=str,
required=True,
help='Path to output directory')
parser.add_argument('--label_map',
type=str,
required=True,
help='Path to label map file')
parser.add_argument('--max_classes',
type=int,
default=100,
help='Maximum number of classes')
parser.add_argument('--input_type',
choices=['image_tensor',
'encoded_image_string_tensor',
'tf_example'],
default='image_tensor',
help='Type of input node. Can be '
'one of [`image_tensor`, '
'`encoded_image_string_tensor`, '
'`tf_example`]')
args = parser.parse_args()
main(args)
from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function
import os
from PIL import Image
import numpy as np
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
from object_detection.core.standard_fields import \
DetectionResultFields as dt_fields
def load_image_into_numpy_array(input_image):
image = Image.open(input_image)
(im_width, im_height) = image.size
image_arr = np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
image.close()
return image_arr
def visualize_serving_bounding_boxes(output_dir,
input_image,
label_map,
max_classes,
serving_result,
line_thickness=4
):
boxes = serving_result.outputs[dt_fields.detection_boxes].float_val
classes = serving_result.outputs[dt_fields.detection_classes].float_val
scores = serving_result.outputs[dt_fields.detection_scores].float_val
visualize_bounding_boxes(output_dir, input_image, label_map, max_classes,
boxes, classes, scores, line_thickness)
def visualize_bounding_boxes(output_dir,
input_image,
label_map,
max_classes,
boxes,
classes,
scores,
min_confidence=0.5,
line_thickness=4
):
image_np = load_image_into_numpy_array(input_image)
label_map = label_map_util.load_labelmap(label_map)
categories = label_map_util.\
convert_label_map_to_categories(label_map,
max_num_classes=max_classes,
use_display_name=True)
category_index = label_map_util.create_category_index(categories)
classes = np.squeeze(classes)
scores = np.squeeze(scores)
boxes = np.reshape(boxes, (-1, 4))
idx = [i for i, x in enumerate(classes.tolist()) if x in category_index]
boxes = np.take(boxes, idx, axis=0)
scores = np.take(scores, idx, axis=0)
classes = np.take(classes, idx, axis=0)
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.reshape(boxes, (-1, 4)),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=line_thickness,
max_boxes_to_draw=300,
min_score_thresh=min_confidence)
# Save labeled image
base_filename = os.path.splitext(os.path.basename(input_image))[0]
labeled_image_path = os.path.join(output_dir,
base_filename + "_labeled.jpg")
tf.logging.info('Saving labeled image: %s' % labeled_image_path)
output_img = Image.fromarray(image_np.astype(np.uint8))
output_img.save(labeled_image_path)
def draw_bounding_boxes(input_image,
category_index,
boxes,
classes,
scores,
masks=None,
line_thickness=2,
min_confidence=0.5,
max_classes=100):
image_np = load_image_into_numpy_array(input_image)
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.array(boxes),
np.array(classes).astype(np.int32),
np.array(scores),
category_index,
instance_masks=np.array(masks).astype(
np.uint8) if masks is not None else None,
use_normalized_coordinates=True,
line_thickness=line_thickness,
max_boxes_to_draw=300,
min_score_thresh=min_confidence)
return image_np
def load_detection_graph(inference_graph_file):
"""Load object detection graph from model checkpoint.
Args:
inference_graph_file (str): Path to frozen detection graph in
model checkpoint.
Returns:
(Graph): TensorFlow graph with object detection model.
"""
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(inference_graph_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return detection_graph
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment