Skip to content

Instantly share code, notes, and snippets.

@yufengg
Last active April 5, 2020 00:42
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 10 You must be signed in to fork a gist
  • Save yufengg/a6dff912ab48f7a273f5704ad9ab1311 to your computer and use it in GitHub Desktop.
Save yufengg/a6dff912ab48f7a273f5704ad9ab1311 to your computer and use it in GitHub Desktop.
Jupyter notebook for AI Adventures episode 3
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Copyright 2017 Google Inc.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overview\n",
"All the code we'll look at is in the next cell. We will step through each step after."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"print(tf.__version__)\n",
"\n",
"from tensorflow.contrib.learn.python.learn.datasets import base\n",
"\n",
"# Data files\n",
"IRIS_TRAINING = \"iris_training.csv\"\n",
"IRIS_TEST = \"iris_test.csv\"\n",
"\n",
"# Load datasets.\n",
"training_set = base.load_csv_with_header(filename=IRIS_TRAINING,\n",
" features_dtype=np.float32,\n",
" target_dtype=np.int)\n",
"test_set = base.load_csv_with_header(filename=IRIS_TEST,\n",
" features_dtype=np.float32,\n",
" target_dtype=np.int)\n",
"\n",
"# Specify that all features have real-value data\n",
"feature_name = \"flower_features\"\n",
"feature_columns = [tf.feature_column.numeric_column(feature_name, \n",
" shape=[4])]\n",
"classifier = tf.estimator.LinearClassifier(\n",
" feature_columns=feature_columns,\n",
" n_classes=3,\n",
" model_dir=\"/tmp/iris_model\")\n",
"\n",
"def input_fn(dataset):\n",
" def _fn():\n",
" features = {feature_name: tf.constant(dataset.data)}\n",
" label = tf.constant(dataset.target)\n",
" return features, label\n",
" return _fn\n",
"\n",
"# Fit model.\n",
"classifier.train(input_fn=input_fn(training_set),\n",
" steps=1000)\n",
"print('fit done')\n",
"\n",
"# Evaluate accuracy.\n",
"accuracy_score = classifier.evaluate(input_fn=input_fn(test_set), \n",
" steps=100)[\"accuracy\"]\n",
"print('\\nAccuracy: {0:f}'.format(accuracy_score))\n",
"\n",
"# Export the model for serving\n",
"feature_spec = {'flower_features': tf.FixedLenFeature(shape=[4], dtype=np.float32)}\n",
"\n",
"serving_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)\n",
"\n",
"classifier.export_savedmodel(export_dir_base='/tmp/iris_model' + '/export', \n",
" serving_input_receiver_fn=serving_fn)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.3.0\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"print(tf.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data set\n",
"From https://en.wikipedia.org/wiki/Iris_flower_data_set\n",
"\n",
"3 types of Iris Flowers: \n",
"\n",
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg/450px-Kosaciec_szczecinkowaty_Iris_setosa.jpg\" style=\"width: 100px; display:inline\"/>\n",
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/4/41/Iris_versicolor_3.jpg/800px-Iris_versicolor_3.jpg\" style=\"width: 150px;display:inline\"/>\n",
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/9/9f/Iris_virginica.jpg/736px-Iris_virginica.jpg\" style=\"width: 150px;display:inline\"/>\n",
"* Iris Setosa\n",
"* Iris Versicolour\n",
"* Iris Virginica\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Columns:\n",
" 1. sepal length in cm \n",
" 2. sepal width in cm \n",
" 3. petal length in cm \n",
" 4. petal width in cm\n",
"\n",
"<img src=\"petal_sepal.png\" style=\"width: 200px;\"/>\n",
"<img src=\"https://storage.googleapis.com/image-uploader/AIA_images/data_table.png\" style=\"width: 450px\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data in"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 6.4000001 2.79999995 5.5999999 2.20000005]\n",
" [ 5. 2.29999995 3.29999995 1. ]\n",
" [ 4.9000001 2.5 4.5 1.70000005]\n",
" [ 4.9000001 3.0999999 1.5 0.1 ]\n",
" [ 5.69999981 3.79999995 1.70000005 0.30000001]\n",
" [ 4.4000001 3.20000005 1.29999995 0.2 ]\n",
" [ 5.4000001 3.4000001 1.5 0.40000001]\n",
" [ 6.9000001 3.0999999 5.0999999 2.29999995]\n",
" [ 6.69999981 3.0999999 4.4000001 1.39999998]\n",
" [ 5.0999999 3.70000005 1.5 0.40000001]\n",
" [ 5.19999981 2.70000005 3.9000001 1.39999998]\n",
" [ 6.9000001 3.0999999 4.9000001 1.5 ]\n",
" [ 5.80000019 4. 1.20000005 0.2 ]\n",
" [ 5.4000001 3.9000001 1.70000005 0.40000001]\n",
" [ 7.69999981 3.79999995 6.69999981 2.20000005]\n",
" [ 6.30000019 3.29999995 4.69999981 1.60000002]\n",
" [ 6.80000019 3.20000005 5.9000001 2.29999995]\n",
" [ 7.5999999 3. 6.5999999 2.0999999 ]\n",
" [ 6.4000001 3.20000005 5.30000019 2.29999995]\n",
" [ 5.69999981 4.4000001 1.5 0.40000001]\n",
" [ 6.69999981 3.29999995 5.69999981 2.0999999 ]\n",
" [ 6.4000001 2.79999995 5.5999999 2.0999999 ]\n",
" [ 5.4000001 3.9000001 1.29999995 0.40000001]\n",
" [ 6.0999999 2.5999999 5.5999999 1.39999998]\n",
" [ 7.19999981 3. 5.80000019 1.60000002]\n",
" [ 5.19999981 3.5 1.5 0.2 ]\n",
" [ 5.80000019 2.5999999 4. 1.20000005]\n",
" [ 5.9000001 3. 5.0999999 1.79999995]\n",
" [ 5.4000001 3. 4.5 1.5 ]\n",
" [ 6.69999981 3. 5. 1.70000005]\n",
" [ 6.30000019 2.29999995 4.4000001 1.29999995]\n",
" [ 5.0999999 2.5 3. 1.10000002]\n",
" [ 6.4000001 3.20000005 4.5 1.5 ]\n",
" [ 6.80000019 3. 5.5 2.0999999 ]\n",
" [ 6.19999981 2.79999995 4.80000019 1.79999995]\n",
" [ 6.9000001 3.20000005 5.69999981 2.29999995]\n",
" [ 6.5 3.20000005 5.0999999 2. ]\n",
" [ 5.80000019 2.79999995 5.0999999 2.4000001 ]\n",
" [ 5.0999999 3.79999995 1.5 0.30000001]\n",
" [ 4.80000019 3. 1.39999998 0.30000001]\n",
" [ 7.9000001 3.79999995 6.4000001 2. ]\n",
" [ 5.80000019 2.70000005 5.0999999 1.89999998]\n",
" [ 6.69999981 3. 5.19999981 2.29999995]\n",
" [ 5.0999999 3.79999995 1.89999998 0.40000001]\n",
" [ 4.69999981 3.20000005 1.60000002 0.2 ]\n",
" [ 6. 2.20000005 5. 1.5 ]\n",
" [ 4.80000019 3.4000001 1.60000002 0.2 ]\n",
" [ 7.69999981 2.5999999 6.9000001 2.29999995]\n",
" [ 4.5999999 3.5999999 1. 0.2 ]\n",
" [ 7.19999981 3.20000005 6. 1.79999995]\n",
" [ 5. 3.29999995 1.39999998 0.2 ]\n",
" [ 6.5999999 3. 4.4000001 1.39999998]\n",
" [ 6.0999999 2.79999995 4. 1.29999995]\n",
" [ 5. 3.20000005 1.20000005 0.2 ]\n",
" [ 7. 3.20000005 4.69999981 1.39999998]\n",
" [ 6. 3. 4.80000019 1.79999995]\n",
" [ 7.4000001 2.79999995 6.0999999 1.89999998]\n",
" [ 5.80000019 2.70000005 5.0999999 1.89999998]\n",
" [ 6.19999981 3.4000001 5.4000001 2.29999995]\n",
" [ 5. 2. 3.5 1. ]\n",
" [ 5.5999999 2.5 3.9000001 1.10000002]\n",
" [ 6.69999981 3.0999999 5.5999999 2.4000001 ]\n",
" [ 6.30000019 2.5 5. 1.89999998]\n",
" [ 6.4000001 3.0999999 5.5 1.79999995]\n",
" [ 6.19999981 2.20000005 4.5 1.5 ]\n",
" [ 7.30000019 2.9000001 6.30000019 1.79999995]\n",
" [ 4.4000001 3. 1.29999995 0.2 ]\n",
" [ 7.19999981 3.5999999 6.0999999 2.5 ]\n",
" [ 6.5 3. 5.5 1.79999995]\n",
" [ 5. 3.4000001 1.5 0.2 ]\n",
" [ 4.69999981 3.20000005 1.29999995 0.2 ]\n",
" [ 6.5999999 2.9000001 4.5999999 1.29999995]\n",
" [ 5.5 3.5 1.29999995 0.2 ]\n",
" [ 7.69999981 3. 6.0999999 2.29999995]\n",
" [ 6.0999999 3. 4.9000001 1.79999995]\n",
" [ 4.9000001 3.0999999 1.5 0.1 ]\n",
" [ 5.5 2.4000001 3.79999995 1.10000002]\n",
" [ 5.69999981 2.9000001 4.19999981 1.29999995]\n",
" [ 6. 2.9000001 4.5 1.5 ]\n",
" [ 6.4000001 2.70000005 5.30000019 1.89999998]\n",
" [ 5.4000001 3.70000005 1.5 0.2 ]\n",
" [ 6.0999999 2.9000001 4.69999981 1.39999998]\n",
" [ 6.5 2.79999995 4.5999999 1.5 ]\n",
" [ 5.5999999 2.70000005 4.19999981 1.29999995]\n",
" [ 6.30000019 3.4000001 5.5999999 2.4000001 ]\n",
" [ 4.9000001 3.0999999 1.5 0.1 ]\n",
" [ 6.80000019 2.79999995 4.80000019 1.39999998]\n",
" [ 5.69999981 2.79999995 4.5 1.29999995]\n",
" [ 6. 2.70000005 5.0999999 1.60000002]\n",
" [ 5. 3.5 1.29999995 0.30000001]\n",
" [ 6.5 3. 5.19999981 2. ]\n",
" [ 6.0999999 2.79999995 4.69999981 1.20000005]\n",
" [ 5.0999999 3.5 1.39999998 0.30000001]\n",
" [ 4.5999999 3.0999999 1.5 0.2 ]\n",
" [ 6.5 3. 5.80000019 2.20000005]\n",
" [ 4.5999999 3.4000001 1.39999998 0.30000001]\n",
" [ 4.5999999 3.20000005 1.39999998 0.2 ]\n",
" [ 7.69999981 2.79999995 6.69999981 2. ]\n",
" [ 5.9000001 3.20000005 4.80000019 1.79999995]\n",
" [ 5.0999999 3.79999995 1.60000002 0.2 ]\n",
" [ 4.9000001 3. 1.39999998 0.2 ]\n",
" [ 4.9000001 2.4000001 3.29999995 1. ]\n",
" [ 4.5 2.29999995 1.29999995 0.30000001]\n",
" [ 5.80000019 2.70000005 4.0999999 1. ]\n",
" [ 5. 3.4000001 1.60000002 0.40000001]\n",
" [ 5.19999981 3.4000001 1.39999998 0.2 ]\n",
" [ 5.30000019 3.70000005 1.5 0.2 ]\n",
" [ 5. 3.5999999 1.39999998 0.2 ]\n",
" [ 5.5999999 2.9000001 3.5999999 1.29999995]\n",
" [ 4.80000019 3.0999999 1.60000002 0.2 ]\n",
" [ 6.30000019 2.70000005 4.9000001 1.79999995]\n",
" [ 5.69999981 2.79999995 4.0999999 1.29999995]\n",
" [ 5. 3. 1.60000002 0.2 ]\n",
" [ 6.30000019 3.29999995 6. 2.5 ]\n",
" [ 5. 3.5 1.60000002 0.60000002]\n",
" [ 5.5 2.5999999 4.4000001 1.20000005]\n",
" [ 5.69999981 3. 4.19999981 1.20000005]\n",
" [ 4.4000001 2.9000001 1.39999998 0.2 ]\n",
" [ 4.80000019 3. 1.39999998 0.1 ]\n",
" [ 5.5 2.4000001 3.70000005 1. ]]\n",
"[2 1 2 0 0 0 0 2 1 0 1 1 0 0 2 1 2 2 2 0 2 2 0 2 2 0 1 2 1 1 1 1 1 2 2 2 2\n",
" 2 0 0 2 2 2 0 0 2 0 2 0 2 0 1 1 0 1 2 2 2 2 1 1 2 2 2 1 2 0 2 2 0 0 1 0 2\n",
" 2 0 1 1 1 2 0 1 1 1 2 0 1 1 1 0 2 1 0 0 2 0 0 2 1 0 0 1 0 1 0 0 0 0 1 0 2\n",
" 1 0 2 0 1 1 0 0 1]\n"
]
}
],
"source": [
"from tensorflow.contrib.learn.python.learn.datasets import base\n",
"\n",
"# Data files\n",
"IRIS_TRAINING = \"iris_training.csv\"\n",
"IRIS_TEST = \"iris_test.csv\"\n",
"\n",
"# Load datasets.\n",
"training_set = base.load_csv_with_header(filename=IRIS_TRAINING,\n",
" features_dtype=np.float32,\n",
" target_dtype=np.int)\n",
"test_set = base.load_csv_with_header(filename=IRIS_TEST,\n",
" features_dtype=np.float32,\n",
" target_dtype=np.int)\n",
"\n",
"print(training_set.data)\n",
"\n",
"print(training_set.target)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature columns and model creation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_tf_random_seed': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_save_checkpoints_steps': None, '_model_dir': '/tmp/iris_model', '_save_summary_steps': 100}\n"
]
}
],
"source": [
"# Specify that all features have real-value data\n",
"feature_name = \"flower_features\"\n",
"feature_columns = [tf.feature_column.numeric_column(feature_name, \n",
" shape=[4])]\n",
"\n",
"classifier = tf.estimator.LinearClassifier(\n",
" feature_columns=feature_columns,\n",
" n_classes=3,\n",
" model_dir=\"/tmp/iris_model\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Input function"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"({'flower_features': <tf.Tensor 'Const:0' shape=(120, 4) dtype=float32>}, <tf.Tensor 'Const_1:0' shape=(120,) dtype=int64>)\n"
]
}
],
"source": [
"def input_fn(dataset):\n",
" def _fn():\n",
" features = {feature_name: tf.constant(dataset.data)}\n",
" label = tf.constant(dataset.target)\n",
" return features, label\n",
" return _fn\n",
"\n",
"print(input_fn(training_set)())\n",
"\n",
"# raw data -> input function -> feature columns -> model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-3000\n",
"INFO:tensorflow:Saving checkpoints for 3001 into /tmp/iris_model/model.ckpt.\n",
"INFO:tensorflow:loss = 8.50027, step = 3001\n",
"INFO:tensorflow:global_step/sec: 827.439\n",
"INFO:tensorflow:loss = 8.40806, step = 3101 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 868.825\n",
"INFO:tensorflow:loss = 8.32063, step = 3201 (0.116 sec)\n",
"INFO:tensorflow:global_step/sec: 959.112\n",
"INFO:tensorflow:loss = 8.23757, step = 3301 (0.104 sec)\n",
"INFO:tensorflow:global_step/sec: 844.444\n",
"INFO:tensorflow:loss = 8.15855, step = 3401 (0.118 sec)\n",
"INFO:tensorflow:global_step/sec: 847.278\n",
"INFO:tensorflow:loss = 8.08324, step = 3501 (0.118 sec)\n",
"INFO:tensorflow:global_step/sec: 825.594\n",
"INFO:tensorflow:loss = 8.01139, step = 3601 (0.120 sec)\n",
"INFO:tensorflow:global_step/sec: 882.98\n",
"INFO:tensorflow:loss = 7.94273, step = 3701 (0.114 sec)\n",
"INFO:tensorflow:global_step/sec: 941.876\n",
"INFO:tensorflow:loss = 7.87704, step = 3801 (0.106 sec)\n",
"INFO:tensorflow:global_step/sec: 889.862\n",
"INFO:tensorflow:loss = 7.81412, step = 3901 (0.112 sec)\n",
"INFO:tensorflow:Saving checkpoints for 4000 into /tmp/iris_model/model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 7.75437.\n",
"fit done\n"
]
}
],
"source": [
"# Fit model.\n",
"classifier.train(input_fn=input_fn(training_set),\n",
" steps=1000)\n",
"print('fit done')\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Starting evaluation at 2017-09-14-15:16:40\n",
"INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-4000\n",
"INFO:tensorflow:Evaluation [1/100]\n",
"INFO:tensorflow:Evaluation [2/100]\n",
"INFO:tensorflow:Evaluation [3/100]\n",
"INFO:tensorflow:Evaluation [4/100]\n",
"INFO:tensorflow:Evaluation [5/100]\n",
"INFO:tensorflow:Evaluation [6/100]\n",
"INFO:tensorflow:Evaluation [7/100]\n",
"INFO:tensorflow:Evaluation [8/100]\n",
"INFO:tensorflow:Evaluation [9/100]\n",
"INFO:tensorflow:Evaluation [10/100]\n",
"INFO:tensorflow:Evaluation [11/100]\n",
"INFO:tensorflow:Evaluation [12/100]\n",
"INFO:tensorflow:Evaluation [13/100]\n",
"INFO:tensorflow:Evaluation [14/100]\n",
"INFO:tensorflow:Evaluation [15/100]\n",
"INFO:tensorflow:Evaluation [16/100]\n",
"INFO:tensorflow:Evaluation [17/100]\n",
"INFO:tensorflow:Evaluation [18/100]\n",
"INFO:tensorflow:Evaluation [19/100]\n",
"INFO:tensorflow:Evaluation [20/100]\n",
"INFO:tensorflow:Evaluation [21/100]\n",
"INFO:tensorflow:Evaluation [22/100]\n",
"INFO:tensorflow:Evaluation [23/100]\n",
"INFO:tensorflow:Evaluation [24/100]\n",
"INFO:tensorflow:Evaluation [25/100]\n",
"INFO:tensorflow:Evaluation [26/100]\n",
"INFO:tensorflow:Evaluation [27/100]\n",
"INFO:tensorflow:Evaluation [28/100]\n",
"INFO:tensorflow:Evaluation [29/100]\n",
"INFO:tensorflow:Evaluation [30/100]\n",
"INFO:tensorflow:Evaluation [31/100]\n",
"INFO:tensorflow:Evaluation [32/100]\n",
"INFO:tensorflow:Evaluation [33/100]\n",
"INFO:tensorflow:Evaluation [34/100]\n",
"INFO:tensorflow:Evaluation [35/100]\n",
"INFO:tensorflow:Evaluation [36/100]\n",
"INFO:tensorflow:Evaluation [37/100]\n",
"INFO:tensorflow:Evaluation [38/100]\n",
"INFO:tensorflow:Evaluation [39/100]\n",
"INFO:tensorflow:Evaluation [40/100]\n",
"INFO:tensorflow:Evaluation [41/100]\n",
"INFO:tensorflow:Evaluation [42/100]\n",
"INFO:tensorflow:Evaluation [43/100]\n",
"INFO:tensorflow:Evaluation [44/100]\n",
"INFO:tensorflow:Evaluation [45/100]\n",
"INFO:tensorflow:Evaluation [46/100]\n",
"INFO:tensorflow:Evaluation [47/100]\n",
"INFO:tensorflow:Evaluation [48/100]\n",
"INFO:tensorflow:Evaluation [49/100]\n",
"INFO:tensorflow:Evaluation [50/100]\n",
"INFO:tensorflow:Evaluation [51/100]\n",
"INFO:tensorflow:Evaluation [52/100]\n",
"INFO:tensorflow:Evaluation [53/100]\n",
"INFO:tensorflow:Evaluation [54/100]\n",
"INFO:tensorflow:Evaluation [55/100]\n",
"INFO:tensorflow:Evaluation [56/100]\n",
"INFO:tensorflow:Evaluation [57/100]\n",
"INFO:tensorflow:Evaluation [58/100]\n",
"INFO:tensorflow:Evaluation [59/100]\n",
"INFO:tensorflow:Evaluation [60/100]\n",
"INFO:tensorflow:Evaluation [61/100]\n",
"INFO:tensorflow:Evaluation [62/100]\n",
"INFO:tensorflow:Evaluation [63/100]\n",
"INFO:tensorflow:Evaluation [64/100]\n",
"INFO:tensorflow:Evaluation [65/100]\n",
"INFO:tensorflow:Evaluation [66/100]\n",
"INFO:tensorflow:Evaluation [67/100]\n",
"INFO:tensorflow:Evaluation [68/100]\n",
"INFO:tensorflow:Evaluation [69/100]\n",
"INFO:tensorflow:Evaluation [70/100]\n",
"INFO:tensorflow:Evaluation [71/100]\n",
"INFO:tensorflow:Evaluation [72/100]\n",
"INFO:tensorflow:Evaluation [73/100]\n",
"INFO:tensorflow:Evaluation [74/100]\n",
"INFO:tensorflow:Evaluation [75/100]\n",
"INFO:tensorflow:Evaluation [76/100]\n",
"INFO:tensorflow:Evaluation [77/100]\n",
"INFO:tensorflow:Evaluation [78/100]\n",
"INFO:tensorflow:Evaluation [79/100]\n",
"INFO:tensorflow:Evaluation [80/100]\n",
"INFO:tensorflow:Evaluation [81/100]\n",
"INFO:tensorflow:Evaluation [82/100]\n",
"INFO:tensorflow:Evaluation [83/100]\n",
"INFO:tensorflow:Evaluation [84/100]\n",
"INFO:tensorflow:Evaluation [85/100]\n",
"INFO:tensorflow:Evaluation [86/100]\n",
"INFO:tensorflow:Evaluation [87/100]\n",
"INFO:tensorflow:Evaluation [88/100]\n",
"INFO:tensorflow:Evaluation [89/100]\n",
"INFO:tensorflow:Evaluation [90/100]\n",
"INFO:tensorflow:Evaluation [91/100]\n",
"INFO:tensorflow:Evaluation [92/100]\n",
"INFO:tensorflow:Evaluation [93/100]\n",
"INFO:tensorflow:Evaluation [94/100]\n",
"INFO:tensorflow:Evaluation [95/100]\n",
"INFO:tensorflow:Evaluation [96/100]\n",
"INFO:tensorflow:Evaluation [97/100]\n",
"INFO:tensorflow:Evaluation [98/100]\n",
"INFO:tensorflow:Evaluation [99/100]\n",
"INFO:tensorflow:Evaluation [100/100]\n",
"INFO:tensorflow:Finished evaluation at 2017-09-14-15:16:41\n",
"INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.966667, average_loss = 0.0706094, global_step = 4000, loss = 2.11828\n",
"\n",
"Accuracy: 0.966667\n"
]
}
],
"source": [
"# Evaluate accuracy.\n",
"accuracy_score = classifier.evaluate(input_fn=input_fn(test_set), \n",
" steps=100)[\"accuracy\"]\n",
"print('\\nAccuracy: {0:f}'.format(accuracy_score))"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Estimators review\n",
"\n",
"### Load datasets.\n",
"\n",
" training_data = load_csv_with_header()\n",
"\n",
"### define input functions\n",
"\n",
" def input_fn(dataset)\n",
" \n",
"### Define feature columns\n",
"\n",
" feature_columns = [tf.feature_column.numeric_column(feature_name, shape=[4])]\n",
"\n",
"### Create model\n",
"\n",
" classifier = tf.estimator.LinearClassifier()\n",
"\n",
"### Train\n",
"\n",
" classifier.train()\n",
"\n",
"### Evaluate\n",
"\n",
" classifier.evaluate()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exporting a model for serving predictions\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-4000\n",
"INFO:tensorflow:Assets added to graph.\n",
"INFO:tensorflow:No assets to write.\n",
"INFO:tensorflow:SavedModel written to: /tmp/iris_model/export/1505402201/saved_model.pb\n"
]
},
{
"data": {
"text/plain": [
"'/tmp/iris_model/export/1505402201'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feature_spec = {'flower_features': tf.FixedLenFeature(shape=[4], dtype=np.float32)}\n",
"\n",
"serving_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)\n",
"\n",
"classifier.export_savedmodel(export_dir_base='/tmp/iris_model' + '/export', \n",
" serving_input_receiver_fn=serving_fn)\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment