Last active
June 17, 2018 02:12
-
-
Save ismaeIfm/eeb24fad2623dfb69ca81bb0f254543f to your computer and use it in GitHub Desktop.
Keras - tensorflow serving - Iris example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Manual test client for tensorflow_model_server.""" | |
from __future__ import print_function | |
import os | |
import sys | |
import numpy as np | |
import tensorflow as tf | |
from grpc.beta import implementations | |
from tensorflow.python.platform import flags | |
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2 | |
FLAGS = tf.app.flags.FLAGS | |
def main(_): | |
# Prepare request | |
request = predict_pb2.PredictRequest() | |
request.model_spec.name = 'iris' | |
request.model_spec.signature_name = 'predict_iris' | |
X = np.asarray([np.asarray([1, 2, 3, 4])]).astype(np.float32) | |
request.inputs['input'].CopyFrom( | |
tf.contrib.util.make_tensor_proto(X, shape=X.shape)) | |
request.output_filter.append('output') | |
# Send request | |
channel = implementations.insecure_channel('localhost', int(8500)) | |
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) | |
print(stub.Predict(request, 5.0)) | |
if __name__ == '__main__': | |
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Usage: iris.py [--training_iteration=x] [--export_version=y] export_dir | |
""" | |
from __future__ import print_function | |
import os | |
import sys | |
import tensorflow as tf | |
from keras import backend as K | |
from keras.layers import Dense, Input | |
from keras.models import Model | |
from keras.utils.np_utils import to_categorical | |
from sklearn.datasets import load_iris | |
from tensorflow.python.saved_model import builder as saved_model_builder | |
from tensorflow.python.saved_model import (signature_constants, | |
signature_def_utils, tag_constants, | |
utils) | |
from tensorflow.python.util import compat | |
tf.app.flags.DEFINE_integer('training_iteration', 1000, | |
'number of training iterations.') | |
tf.app.flags.DEFINE_integer('export_version', 1, | |
'version number of the model.') | |
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory.') | |
FLAGS = tf.app.flags.FLAGS | |
def main(_): | |
if len(sys.argv) < 2 or sys.argv[-1].startswith('-'): | |
print('Usage: iris.py [--training_iteration=x] ' | |
'[--export_version=y] export_dir') | |
sys.exit(-1) | |
if FLAGS.training_iteration <= 0: | |
print('Please specify a positive value for training iteration.') | |
sys.exit(-1) | |
if FLAGS.export_version <= 0: | |
print('Please specify a positive value for version number.') | |
sys.exit(-1) | |
# Train model | |
print('Training model...') | |
sess = tf.Session() | |
K.set_session(sess) | |
K.set_learning_phase(1) | |
iris = load_iris() | |
_in = Input(shape=(4, ), name='input') | |
x = Dense(10, activation='relu')(_in) | |
out = Dense(3, activation='softmax', name='output')(x) | |
model = Model(input=_in, output=out) | |
model.compile( | |
optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) | |
labels = to_categorical(iris.target, 3) | |
model.fit(iris.data, labels, nb_epoch=FLAGS.training_iteration) | |
#score = model.evaluate(iris.data, labels) | |
#print(score) | |
#print(model.metrics_names) | |
print('Done training!') | |
K.set_learning_phase( | |
0) # all new operations will be in test mode from now on | |
# serialize the model and get its weights, for quick re-building | |
config = model.to_json() | |
weights = model.get_weights() | |
del model | |
# re-build a model where the learning phase is now hard-coded to 0 | |
from keras.models import model_from_json | |
new_model = model_from_json(config) | |
new_model.set_weights(weights) | |
export_path = os.path.join( | |
compat.as_bytes(sys.argv[-1]), | |
compat.as_bytes(str(FLAGS.export_version))) | |
#export_version =... # version number (integer) | |
builder = saved_model_builder.SavedModelBuilder(export_path) | |
inputs = { | |
in_.name: utils.build_tensor_info(in_.output) | |
for in_ in new_model.input_layers | |
} | |
outputs = { | |
out.name: utils.build_tensor_info(out.output) | |
for out in new_model.output_layers | |
} | |
prediction_signature = signature_def_utils.build_signature_def( | |
inputs=inputs, | |
outputs=outputs, | |
method_name=signature_constants.PREDICT_METHOD_NAME) | |
legacy_init_op = tf.group( | |
tf.initialize_all_tables(), name='legacy_init_op') | |
builder.add_meta_graph_and_variables( | |
sess, [tag_constants.SERVING], | |
signature_def_map={ | |
'predict_iris': prediction_signature, | |
}, | |
legacy_init_op=legacy_init_op) | |
builder.save() | |
print('Done exporting!') | |
if __name__ == '__main__': | |
tf.app.run() | |
# tensorflow_model_server --model_name=iris --model_base_path=/path/to/exported/model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey man, I have the problem of
root@d843340a2a0b:/serving/tf_serving# /serving/bazel-bin/tf_serving/client --server=localhost:8888 --image=/serving/image/car_snow.jpg Traceback (most recent call last): File "/serving/bazel-bin/tf_serving/client.runfiles/tf_serving/tf_serving/client.py", line 59, in <module> tf.app.run() File "/serving/bazel-bin/tf_serving/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "/serving/bazel-bin/tf_serving/client.runfiles/tf_serving/tf_serving/client.py", line 54, in main result = stub.Predict(request, 10.0) # 10 secs timeout File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__ self._request_serializer, self._response_deserializer) File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary raise _abortion_error(rpc_error_call) grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.FAILED_PRECONDITION, details="Default serving signature key not found.")
when exporting model this way, do you have any ideas why?