Skip to content

Instantly share code, notes, and snippets.

@asimshankar
Created June 20, 2017 00:17
Show Gist options
  • Star 26 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save asimshankar/000b8d276f211f972168afa138eb3cc7 to your computer and use it in GitHub Desktop.
Save asimshankar/000b8d276f211f972168afa138eb3cc7 to your computer and use it in GitHub Desktop.
Keras Models --> TensorFlow SavedModel format
# Mostly copied from https://keras.io/applications/#usage-examples-for-image-classification-models
# Changing it to use InceptionV3 instead of ResNet50
from keras.applications.inception_v3 import InceptionV3, preprocess_input, decode_predictions
from keras.preprocessing import image
import numpy as np
model = InceptionV3()
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(299, 299))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
print('Predicted:', decode_predictions(preds, top=3)[0])
# And now exporting to the TensorFlow SavedModel format.
# Documentation for the SavedModel format can be found here:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
from keras import backend as K
import tensorflow as tf
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'image': model.input}, outputs={'scores': model.output})
builder = tf.saved_model.builder.SavedModelBuilder('/tmp/my_saved_model')
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature
})
builder.save()
# This model can be loaded in other langauges using the C API:
# TF_SessionOptions* opts = TF_NewSessionOptions();
# const char* tags[] = {"serve"}; // tf.saved_model.tag_constants.SERVING
# TF_LoadSessionFromSavedModel(opts, NULL, "/tmp/my_saved_model", tags, 1, graph, NULL, status);
#
# This is what is used by the:
# - Java API: https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/SavedModelBundle
# - Go API: https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#LoadSavedModel
# etc.
@Moeletji17
Copy link

Thanks for sharing. This helped me save my tensorflow.keras models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment