Created
July 27, 2017 08:46
-
-
Save enakai00/856c39cf9b6d00dbd12f3fb7e814936f to your computer and use it in GitHub Desktop.
Local prediction with Saved Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This is an example to make local predictions using the model in Saved Model format.\n", | |
"\n", | |
"In this example, you use the model from the babyweight tutorial:\n", | |
"\n", | |
"https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/blogs/babyweight/babyweight.ipynb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.0.0\n" | |
] | |
} | |
], | |
"source": [ | |
"import tensorflow as tf\n", | |
"print tf.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Suppose that you have trained the model locally on the datalab, and the model is exported under:\n", | |
"\n", | |
"`babyweight_trained/export/Servo/1501143087336`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The following code imports the model and extract signatures from it." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"export_dir = 'babyweight_trained/export/Servo/1501143087336'\n", | |
"\n", | |
"sess = tf.Session()\n", | |
"meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)\n", | |
"model_signature = meta_graph.signature_def['serving_default']\n", | |
"input_signature = model_signature.inputs\n", | |
"output_signature = model_signature.outputs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Check input and output key names." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[u'cigarette_use', u'gestation_weeks', u'is_male', u'mother_race', u'plurality', u'alcohol_use', u'mother_married', u'mother_age']\n", | |
"[u'outputs']\n" | |
] | |
} | |
], | |
"source": [ | |
"print input_signature.keys()\n", | |
"print output_signature.keys()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Each key has a correspoding Tensor in the graph and you can extract them using the following code:\n", | |
"```\n", | |
"sess.graph.get_tensor_by_name(input_signature['cigarette_use'].name)\n", | |
"sess.graph.get_tensor_by_name(input_signature['gestation_weeks'].name)\n", | |
"...\n", | |
"sess.graph.get_tensor_by_name(output_signature['outputs'].name)\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Here's an example to build a feed_dict dictionary and an output tensor which you can pass to the session." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"features = {\n", | |
" 'is_male': 'True',\n", | |
" 'mother_age': 26.0,\n", | |
" 'mother_race': 'Asian Indian',\n", | |
" 'plurality': 1.0,\n", | |
" 'gestation_weeks': 39,\n", | |
" 'mother_married': 'True',\n", | |
" 'cigarette_use': 'False',\n", | |
" 'alcohol_use': 'False'\n", | |
" }\n", | |
"\n", | |
"feed_dict = {\n", | |
" sess.graph.get_tensor_by_name(input_signature[key].name): [val]\n", | |
" for key, val in features.iteritems()\n", | |
"}\n", | |
"\n", | |
"output = sess.graph.get_tensor_by_name(output_signature['outputs'].name)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now you can run the inference using them." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 6.39090729], dtype=float32)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sess.run(output, feed_dict=feed_dict)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment