Skip to content

Instantly share code, notes, and snippets.

@alexcpn
Last active September 24, 2019 04:29
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexcpn/9a2d7c1d651235025624755681bcc2b2 to your computer and use it in GitHub Desktop.
Save alexcpn/9a2d7c1d651235025624755681bcc2b2 to your computer and use it in GitHub Desktop.
# Simple TFServing example; Based on
# https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_client.py
# Added simpler mnist loading parts and removed some complexity
#!/usr/bin/env python2.7
"""A client that talks to tensorflow_model_server loaded with mnist model.
The client downloads test images of mnist data set, queries the service with
such test images to get predictions, and calculates the inference error rate.
Typical usage example:
mnist_client.py --num_tests=100 --server=localhost:9000
"""
from __future__ import print_function
import sys
import threading
# This is a placeholder for a Google-internal import.
import grpc
import numpy
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
from keras.preprocessing import image
import numpy as np
from keras.datasets import mnist
import time
tf.app.flags.DEFINE_integer('concurrency', 1,
'maximum number of concurrent inference requests')
tf.app.flags.DEFINE_integer('num_tests', 100, 'Number of test images')
tf.app.flags.DEFINE_string('server', '', 'PredictionService host:port')
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory. ')
FLAGS = tf.app.flags.FLAGS
_counter = 0
_start = 0
def _callback(result_future):
"""Callback function.
Calculates the statistics for the prediction result.
Args:
result_future: Result future of the RPC.
"""
global _counter
global _start
exception = result_future.exception()
if exception:
print(exception)
else:
#print("From Callback",result_future.result().outputs['dense_2/Softmax:0'])
if(_start == 0):
_start = time.time()
response = numpy.array(
result_future.result().outputs['dense_2/Softmax:0'].float_val)
prediction = numpy.argmax(response)
_counter += 1
if( (_counter % 100) ==0):#print every 100
print("[", _counter,"] From Callback Predicted Result is ", prediction,"confidence= ",response[prediction])
if (_counter == FLAGS.num_tests):
end = time.time()
print("Time for ",FLAGS.num_tests," is ",end -_start)
def do_inference(hostport, work_dir, concurrency, num_tests):
"""Tests PredictionService with concurrent requests.
Args:
hostport: Host:port address of the PredictionService.
work_dir: The full path of working directory for test data set.
concurrency: Maximum number of concurrent requests.
num_tests: Number of test images to use.
Returns:
The classification error rate.
Raises:
IOError: An error occurred processing test data set.
"""
channel = grpc.insecure_channel(hostport)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
request.model_spec.signature_name = 'serving_default'
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
# For loading images
# img = image.load_img('./data/mnist_png/testing/0/10.png', target_size=(28,28))
#x = image.img_to_array(img)
#request.inputs['input_image'].CopyFrom(
#tf.contrib.util.make_tensor_proto(image2, shape=[1,1,image2.size]))
x= X_train[4545][0]
print("Shape is ",x.shape," Label is ", y_train[4545])
start = time.time()
for _ in range(num_tests):
x= x.astype(np.float32)
request.inputs['input_image'].CopyFrom(tf.contrib.util.make_tensor_proto(x, shape=[1,1,28, 28]))
#result_counter.throttle()
result_future = stub.Predict.future(request, 10.25) # 5 seconds
result_future.add_done_callback(_callback)
end = time.time()
print("Time to Send ", num_tests ," is ",end - start)
time.sleep(10)
# if things are wrong the callback will not come, so uncomment below to force the result
# or get to see what is the bug
#res= result_future.result()
#response = numpy.array(res.outputs['dense_2/Softmax:0'].float_val)
#prediction = numpy.argmax(response)
#print("Predicted Result is ", prediction,"Detection Porbability= ",response[prediction])
def main(_):
if FLAGS.num_tests > 20000:
print('num_tests should not be greater than 20k')
return
if not FLAGS.server:
print('please specify server host:port')
return
error_rate = do_inference(FLAGS.server, FLAGS.work_dir,
FLAGS.concurrency, FLAGS.num_tests)
if __name__ == '__main__':
print ("hello from TFServing client slim")
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment