-
-
Save ismaeIfm/eeb24fad2623dfb69ca81bb0f254543f to your computer and use it in GitHub Desktop.
"""Manual test client for tensorflow_model_server.""" | |
from __future__ import print_function | |
import os | |
import sys | |
import numpy as np | |
import tensorflow as tf | |
from grpc.beta import implementations | |
from tensorflow.python.platform import flags | |
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2 | |
FLAGS = tf.app.flags.FLAGS | |
def main(_): | |
# Prepare request | |
request = predict_pb2.PredictRequest() | |
request.model_spec.name = 'iris' | |
request.model_spec.signature_name = 'predict_iris' | |
X = np.asarray([np.asarray([1, 2, 3, 4])]).astype(np.float32) | |
request.inputs['input'].CopyFrom( | |
tf.contrib.util.make_tensor_proto(X, shape=X.shape)) | |
request.output_filter.append('output') | |
# Send request | |
channel = implementations.insecure_channel('localhost', int(8500)) | |
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) | |
print(stub.Predict(request, 5.0)) | |
if __name__ == '__main__': | |
tf.app.run() |
""" | |
Usage: iris.py [--training_iteration=x] [--export_version=y] export_dir | |
""" | |
from __future__ import print_function | |
import os | |
import sys | |
import tensorflow as tf | |
from keras import backend as K | |
from keras.layers import Dense, Input | |
from keras.models import Model | |
from keras.utils.np_utils import to_categorical | |
from sklearn.datasets import load_iris | |
from tensorflow.python.saved_model import builder as saved_model_builder | |
from tensorflow.python.saved_model import (signature_constants, | |
signature_def_utils, tag_constants, | |
utils) | |
from tensorflow.python.util import compat | |
tf.app.flags.DEFINE_integer('training_iteration', 1000, | |
'number of training iterations.') | |
tf.app.flags.DEFINE_integer('export_version', 1, | |
'version number of the model.') | |
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory.') | |
FLAGS = tf.app.flags.FLAGS | |
def main(_): | |
if len(sys.argv) < 2 or sys.argv[-1].startswith('-'): | |
print('Usage: iris.py [--training_iteration=x] ' | |
'[--export_version=y] export_dir') | |
sys.exit(-1) | |
if FLAGS.training_iteration <= 0: | |
print('Please specify a positive value for training iteration.') | |
sys.exit(-1) | |
if FLAGS.export_version <= 0: | |
print('Please specify a positive value for version number.') | |
sys.exit(-1) | |
# Train model | |
print('Training model...') | |
sess = tf.Session() | |
K.set_session(sess) | |
K.set_learning_phase(1) | |
iris = load_iris() | |
_in = Input(shape=(4, ), name='input') | |
x = Dense(10, activation='relu')(_in) | |
out = Dense(3, activation='softmax', name='output')(x) | |
model = Model(input=_in, output=out) | |
model.compile( | |
optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) | |
labels = to_categorical(iris.target, 3) | |
model.fit(iris.data, labels, nb_epoch=FLAGS.training_iteration) | |
#score = model.evaluate(iris.data, labels) | |
#print(score) | |
#print(model.metrics_names) | |
print('Done training!') | |
K.set_learning_phase( | |
0) # all new operations will be in test mode from now on | |
# serialize the model and get its weights, for quick re-building | |
config = model.to_json() | |
weights = model.get_weights() | |
del model | |
# re-build a model where the learning phase is now hard-coded to 0 | |
from keras.models import model_from_json | |
new_model = model_from_json(config) | |
new_model.set_weights(weights) | |
export_path = os.path.join( | |
compat.as_bytes(sys.argv[-1]), | |
compat.as_bytes(str(FLAGS.export_version))) | |
#export_version =... # version number (integer) | |
builder = saved_model_builder.SavedModelBuilder(export_path) | |
inputs = { | |
in_.name: utils.build_tensor_info(in_.output) | |
for in_ in new_model.input_layers | |
} | |
outputs = { | |
out.name: utils.build_tensor_info(out.output) | |
for out in new_model.output_layers | |
} | |
prediction_signature = signature_def_utils.build_signature_def( | |
inputs=inputs, | |
outputs=outputs, | |
method_name=signature_constants.PREDICT_METHOD_NAME) | |
legacy_init_op = tf.group( | |
tf.initialize_all_tables(), name='legacy_init_op') | |
builder.add_meta_graph_and_variables( | |
sess, [tag_constants.SERVING], | |
signature_def_map={ | |
'predict_iris': prediction_signature, | |
}, | |
legacy_init_op=legacy_init_op) | |
builder.save() | |
print('Done exporting!') | |
if __name__ == '__main__': | |
tf.app.run() | |
# tensorflow_model_server --model_name=iris --model_base_path=/path/to/exported/model |
Sorry my mistake, the rebuild of the model is done to use a different learning phase as told in https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html, in this case this is not necessary as it doesn't use Dropout, BatchNormalization, etc. but for general purposes I put it.
Hey man, I have the problem of root@d843340a2a0b:/serving/tf_serving# /serving/bazel-bin/tf_serving/client --server=localhost:8888 --image=/serving/image/car_snow.jpg Traceback (most recent call last): File "/serving/bazel-bin/tf_serving/client.runfiles/tf_serving/tf_serving/client.py", line 59, in <module> tf.app.run() File "/serving/bazel-bin/tf_serving/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "/serving/bazel-bin/tf_serving/client.runfiles/tf_serving/tf_serving/client.py", line 54, in main result = stub.Predict(request, 10.0) # 10 secs timeout File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__ self._request_serializer, self._response_deserializer) File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary raise _abortion_error(rpc_error_call) grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.FAILED_PRECONDITION, details="Default serving signature key not found.")
when exporting model this way, do you have any ideas why?
Try to export your keras model as described on this page tensorflow/serving#310 (comment).
Thanks for the example. Why did you create a new_model? It's not used anywhere.