Skip to content

Instantly share code, notes, and snippets.

@enakai00
Created July 27, 2017 08:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save enakai00/856c39cf9b6d00dbd12f3fb7e814936f to your computer and use it in GitHub Desktop.
Save enakai00/856c39cf9b6d00dbd12f3fb7e814936f to your computer and use it in GitHub Desktop.
Local prediction with Saved Model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is an example to make local predictions using the model in Saved Model format.\n",
"\n",
"In this example, you use the model from the babyweight tutorial:\n",
"\n",
"https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/blogs/babyweight/babyweight.ipynb"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0.0\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"print tf.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Suppose that you have trained the model locally on the datalab, and the model is exported under:\n",
"\n",
"`babyweight_trained/export/Servo/1501143087336`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following code imports the model and extract signatures from it."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"export_dir = 'babyweight_trained/export/Servo/1501143087336'\n",
"\n",
"sess = tf.Session()\n",
"meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)\n",
"model_signature = meta_graph.signature_def['serving_default']\n",
"input_signature = model_signature.inputs\n",
"output_signature = model_signature.outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check input and output key names."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[u'cigarette_use', u'gestation_weeks', u'is_male', u'mother_race', u'plurality', u'alcohol_use', u'mother_married', u'mother_age']\n",
"[u'outputs']\n"
]
}
],
"source": [
"print input_signature.keys()\n",
"print output_signature.keys()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each key has a correspoding Tensor in the graph and you can extract them using the following code:\n",
"```\n",
"sess.graph.get_tensor_by_name(input_signature['cigarette_use'].name)\n",
"sess.graph.get_tensor_by_name(input_signature['gestation_weeks'].name)\n",
"...\n",
"sess.graph.get_tensor_by_name(output_signature['outputs'].name)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's an example to build a feed_dict dictionary and an output tensor which you can pass to the session."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"features = {\n",
" 'is_male': 'True',\n",
" 'mother_age': 26.0,\n",
" 'mother_race': 'Asian Indian',\n",
" 'plurality': 1.0,\n",
" 'gestation_weeks': 39,\n",
" 'mother_married': 'True',\n",
" 'cigarette_use': 'False',\n",
" 'alcohol_use': 'False'\n",
" }\n",
"\n",
"feed_dict = {\n",
" sess.graph.get_tensor_by_name(input_signature[key].name): [val]\n",
" for key, val in features.iteritems()\n",
"}\n",
"\n",
"output = sess.graph.get_tensor_by_name(output_signature['outputs'].name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now you can run the inference using them."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 6.39090729], dtype=float32)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sess.run(output, feed_dict=feed_dict)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment