-
-
Save lgutzwil/da732d25d14c917ddb6626b4a5fa8ed0 to your computer and use it in GitHub Desktop.
Client for DeepLab Export Blog
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import imageio | |
import numpy as np | |
import tensorflow as tf | |
from grpc.beta import implementations | |
from tensorflow_serving.apis import predict_pb2 | |
from tensorflow_serving.apis import prediction_service_pb2 | |
def call(input_image_path, | |
model_name="my_deeplab_model", | |
host="0.0.0.0", | |
port=8500): | |
image_data = imageio.imread(input_image_path) | |
channel = implementations.insecure_channel(host, port) | |
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) | |
# Create prediction request object | |
request = predict_pb2.PredictRequest() | |
# Specify model name | |
request.model_spec.name = model_name | |
# Specify detection signature name | |
request.model_spec.signature_name = "detection_signature" | |
request.inputs["inputs"].CopyFrom( | |
tf.contrib.util.make_tensor_proto( | |
image_data, | |
shape=[1]+list(image_data.shape) | |
) | |
) | |
# Call the prediction server: time this request out after 600 seconds | |
result = stub.Predict(request, 600.0) | |
# Extract output segmentation map | |
output = np.array(result.outputs["segmentation_map"].int64_val) | |
height = result.outputs["segmentation_map"].tensor_shape.dim[1].size | |
width = result.outputs["segmentation_map"].tensor_shape.dim[2].size | |
image_mask = np.reshape(output, (height, width)).astype(np.uint8) | |
# Save as a PNG file alongside the original image | |
output_path = input_image_path + ".seg.png" | |
imageio.imwrite(output_path, image_mask) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input_image_path") | |
parser.add_argument("--model_name", default="my_deeplab_model") | |
args = parser.parse_args() | |
call(args.input_image_path, args.model_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment