Skip to content

Instantly share code, notes, and snippets.

@yufengg
Created October 24, 2017 15:16
Show Gist options
  • Save yufengg/2fdd853d5528fd0a8c9e563ac4cf33aa to your computer and use it in GitHub Desktop.
Save yufengg/2fdd853d5528fd0a8c9e563ac4cf33aa to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wide and Deep on TensorFlow (notebook style)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
" http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction\n",
"\n",
"This notebook uses the `tf.Estimator` API in TensorFlow to answer a yes/no question. This is called a binary classification problem: Given census data about a person such as age, gender, education and occupation (the features), we will try to predict whether or not the person earns more than 50,000 dollars a year (the target label). \n",
"\n",
"Given an individual's information our model will output a number between 0 and 1, which can be interpreted as the model's certainty that the individual has an annual income of over 50,000 dollars, (1=True, 0=False)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports and constants\n",
"First we'll import our libraries and set up some strings for column names. We also print out the version of TensorFlow we are running."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using TensorFlow version 1.3.0\n"
]
}
],
"source": [
"from __future__ import absolute_import\n",
"from __future__ import division\n",
"from __future__ import print_function\n",
"\n",
"import time\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"tf.logging.set_verbosity(tf.logging.INFO) \n",
"# Set to INFO for tracking training, default is WARN \n",
"\n",
"print(\"Using TensorFlow version %s\" % (tf.__version__)) \n",
"# This notebook is intended for tested for TF 1.3\n",
"\n",
"CATEGORICAL_COLUMNS = [\"workclass\", \"education\", \n",
" \"marital_status\", \"occupation\", \n",
" \"relationship\", \"race\", \n",
" \"gender\", \"native_country\"]\n",
"\n",
"# Columns of the input csv file\n",
"COLUMNS = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \n",
" \"education_num\", \"marital_status\",\n",
" \"occupation\", \"relationship\", \"race\", \n",
" \"gender\", \"capital_gain\", \"capital_loss\",\n",
" \"hours_per_week\", \"native_country\", \"income_bracket\"]\n",
"\n",
"# Feature columns for input into the model\n",
"FEATURE_COLUMNS = [\"age\", \"workclass\", \"education\", \n",
" \"education_num\", \"marital_status\",\n",
" \"occupation\", \"relationship\", \"race\", \n",
" \"gender\", \"capital_gain\", \"capital_loss\",\n",
" \"hours_per_week\", \"native_country\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pandas data exploration\n",
"We load the data into pandas because it is small enough to manage in memory, and look at some properties."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"adult.test.csv\", header=None, names=COLUMNS)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education_num</th>\n",
" <th>marital_status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital_gain</th>\n",
" <th>capital_loss</th>\n",
" <th>hours_per_week</th>\n",
" <th>native_country</th>\n",
" <th>income_bracket</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>25</td>\n",
" <td>Private</td>\n",
" <td>226802</td>\n",
" <td>11th</td>\n",
" <td>7</td>\n",
" <td>Never-married</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Own-child</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>89814</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Farming-fishing</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>28</td>\n",
" <td>Local-gov</td>\n",
" <td>336951</td>\n",
" <td>Assoc-acdm</td>\n",
" <td>12</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Protective-serv</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>44</td>\n",
" <td>Private</td>\n",
" <td>160323</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>7688</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>18</td>\n",
" <td>?</td>\n",
" <td>103497</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Never-married</td>\n",
" <td>?</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>34</td>\n",
" <td>Private</td>\n",
" <td>198693</td>\n",
" <td>10th</td>\n",
" <td>6</td>\n",
" <td>Never-married</td>\n",
" <td>Other-service</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>29</td>\n",
" <td>?</td>\n",
" <td>227026</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Never-married</td>\n",
" <td>?</td>\n",
" <td>Unmarried</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>63</td>\n",
" <td>Self-emp-not-inc</td>\n",
" <td>104626</td>\n",
" <td>Prof-school</td>\n",
" <td>15</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>3103</td>\n",
" <td>0</td>\n",
" <td>32</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>24</td>\n",
" <td>Private</td>\n",
" <td>369667</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Never-married</td>\n",
" <td>Other-service</td>\n",
" <td>Unmarried</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>55</td>\n",
" <td>Private</td>\n",
" <td>104996</td>\n",
" <td>7th-8th</td>\n",
" <td>4</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Craft-repair</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>10</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education_num \\\n",
"0 25 Private 226802 11th 7 \n",
"1 38 Private 89814 HS-grad 9 \n",
"2 28 Local-gov 336951 Assoc-acdm 12 \n",
"3 44 Private 160323 Some-college 10 \n",
"4 18 ? 103497 Some-college 10 \n",
"5 34 Private 198693 10th 6 \n",
"6 29 ? 227026 HS-grad 9 \n",
"7 63 Self-emp-not-inc 104626 Prof-school 15 \n",
"8 24 Private 369667 Some-college 10 \n",
"9 55 Private 104996 7th-8th 4 \n",
"\n",
" marital_status occupation relationship race gender \\\n",
"0 Never-married Machine-op-inspct Own-child Black Male \n",
"1 Married-civ-spouse Farming-fishing Husband White Male \n",
"2 Married-civ-spouse Protective-serv Husband White Male \n",
"3 Married-civ-spouse Machine-op-inspct Husband Black Male \n",
"4 Never-married ? Own-child White Female \n",
"5 Never-married Other-service Not-in-family White Male \n",
"6 Never-married ? Unmarried Black Male \n",
"7 Married-civ-spouse Prof-specialty Husband White Male \n",
"8 Never-married Other-service Unmarried White Female \n",
"9 Married-civ-spouse Craft-repair Husband White Male \n",
"\n",
" capital_gain capital_loss hours_per_week native_country income_bracket \n",
"0 0 0 40 United-States <=50K \n",
"1 0 0 50 United-States <=50K \n",
"2 0 0 40 United-States >50K \n",
"3 7688 0 40 United-States >50K \n",
"4 0 0 30 United-States <=50K \n",
"5 0 0 30 United-States <=50K \n",
"6 0 0 40 United-States <=50K \n",
"7 3103 0 32 United-States >50K \n",
"8 0 0 40 United-States <=50K \n",
"9 0 0 10 United-States <=50K "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head(10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>fnlwgt</th>\n",
" <th>education_num</th>\n",
" <th>capital_gain</th>\n",
" <th>capital_loss</th>\n",
" <th>hours_per_week</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>16278.000000</td>\n",
" <td>1.627800e+04</td>\n",
" <td>16278.000000</td>\n",
" <td>16278.000000</td>\n",
" <td>16278.000000</td>\n",
" <td>16278.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>38.767416</td>\n",
" <td>1.894312e+05</td>\n",
" <td>10.072368</td>\n",
" <td>1081.769382</td>\n",
" <td>87.915469</td>\n",
" <td>40.390466</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>13.850370</td>\n",
" <td>1.057114e+05</td>\n",
" <td>2.567474</td>\n",
" <td>7584.547894</td>\n",
" <td>403.140665</td>\n",
" <td>12.479308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>17.000000</td>\n",
" <td>1.349200e+04</td>\n",
" <td>1.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>28.000000</td>\n",
" <td>1.167385e+05</td>\n",
" <td>9.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>40.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>37.000000</td>\n",
" <td>1.778295e+05</td>\n",
" <td>10.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>40.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>48.000000</td>\n",
" <td>2.383840e+05</td>\n",
" <td>12.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>45.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>90.000000</td>\n",
" <td>1.490400e+06</td>\n",
" <td>16.000000</td>\n",
" <td>99999.000000</td>\n",
" <td>3770.000000</td>\n",
" <td>99.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age fnlwgt education_num capital_gain capital_loss \\\n",
"count 16278.000000 1.627800e+04 16278.000000 16278.000000 16278.000000 \n",
"mean 38.767416 1.894312e+05 10.072368 1081.769382 87.915469 \n",
"std 13.850370 1.057114e+05 2.567474 7584.547894 403.140665 \n",
"min 17.000000 1.349200e+04 1.000000 0.000000 0.000000 \n",
"25% 28.000000 1.167385e+05 9.000000 0.000000 0.000000 \n",
"50% 37.000000 1.778295e+05 10.000000 0.000000 0.000000 \n",
"75% 48.000000 2.383840e+05 12.000000 0.000000 0.000000 \n",
"max 90.000000 1.490400e+06 16.000000 99999.000000 3770.000000 \n",
"\n",
" hours_per_week \n",
"count 16278.000000 \n",
"mean 40.390466 \n",
"std 12.479308 \n",
"min 1.000000 \n",
"25% 40.000000 \n",
"50% 40.000000 \n",
"75% 45.000000 \n",
"max 99.000000 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.describe(include=[np.number])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>workclass</th>\n",
" <th>education</th>\n",
" <th>marital_status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>native_country</th>\n",
" <th>income_bracket</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" <td>16278</td>\n",
" </tr>\n",
" <tr>\n",
" <th>unique</th>\n",
" <td>9</td>\n",
" <td>16</td>\n",
" <td>7</td>\n",
" <td>15</td>\n",
" <td>6</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>41</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>top</th>\n",
" <td>Private</td>\n",
" <td>HS-grad</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>freq</th>\n",
" <td>11208</td>\n",
" <td>5283</td>\n",
" <td>7401</td>\n",
" <td>2031</td>\n",
" <td>6521</td>\n",
" <td>13944</td>\n",
" <td>10857</td>\n",
" <td>14659</td>\n",
" <td>12433</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" workclass education marital_status occupation relationship \\\n",
"count 16278 16278 16278 16278 16278 \n",
"unique 9 16 7 15 6 \n",
"top Private HS-grad Married-civ-spouse Prof-specialty Husband \n",
"freq 11208 5283 7401 2031 6521 \n",
"\n",
" race gender native_country income_bracket \n",
"count 16278 16278 16278 16278 \n",
"unique 5 2 41 2 \n",
"top White Male United-States <=50K \n",
"freq 13944 10857 14659 12433 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.describe(include=[np.object])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>fnlwgt</th>\n",
" <th>education_num</th>\n",
" <th>capital_gain</th>\n",
" <th>capital_loss</th>\n",
" <th>hours_per_week</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>age</th>\n",
" <td>1.000000</td>\n",
" <td>-0.076556</td>\n",
" <td>0.019944</td>\n",
" <td>0.076362</td>\n",
" <td>0.055304</td>\n",
" <td>0.077096</td>\n",
" </tr>\n",
" <tr>\n",
" <th>fnlwgt</th>\n",
" <td>-0.076556</td>\n",
" <td>1.000000</td>\n",
" <td>-0.029951</td>\n",
" <td>-0.011705</td>\n",
" <td>0.007396</td>\n",
" <td>-0.003234</td>\n",
" </tr>\n",
" <tr>\n",
" <th>education_num</th>\n",
" <td>0.019944</td>\n",
" <td>-0.029951</td>\n",
" <td>1.000000</td>\n",
" <td>0.130089</td>\n",
" <td>0.083133</td>\n",
" <td>0.134766</td>\n",
" </tr>\n",
" <tr>\n",
" <th>capital_gain</th>\n",
" <td>0.076362</td>\n",
" <td>-0.011705</td>\n",
" <td>0.130089</td>\n",
" <td>1.000000</td>\n",
" <td>-0.031106</td>\n",
" <td>0.089421</td>\n",
" </tr>\n",
" <tr>\n",
" <th>capital_loss</th>\n",
" <td>0.055304</td>\n",
" <td>0.007396</td>\n",
" <td>0.083133</td>\n",
" <td>-0.031106</td>\n",
" <td>1.000000</td>\n",
" <td>0.054926</td>\n",
" </tr>\n",
" <tr>\n",
" <th>hours_per_week</th>\n",
" <td>0.077096</td>\n",
" <td>-0.003234</td>\n",
" <td>0.134766</td>\n",
" <td>0.089421</td>\n",
" <td>0.054926</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age fnlwgt education_num capital_gain capital_loss \\\n",
"age 1.000000 -0.076556 0.019944 0.076362 0.055304 \n",
"fnlwgt -0.076556 1.000000 -0.029951 -0.011705 0.007396 \n",
"education_num 0.019944 -0.029951 1.000000 0.130089 0.083133 \n",
"capital_gain 0.076362 -0.011705 0.130089 1.000000 -0.031106 \n",
"capital_loss 0.055304 0.007396 0.083133 -0.031106 1.000000 \n",
"hours_per_week 0.077096 -0.003234 0.134766 0.089421 0.054926 \n",
"\n",
" hours_per_week \n",
"age 0.077096 \n",
"fnlwgt -0.003234 \n",
"education_num 0.134766 \n",
"capital_gain 0.089421 \n",
"capital_loss 0.054926 \n",
"hours_per_week 1.000000 "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.corr()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Input file parsing\n",
"\n",
"Here we extract the file into a pandas dataframe and use a built-in utility function to generate an input function for us. TensorFlow also has a similar input function for NumPy arrays."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## More about input functions\n",
"\n",
"The input function is how we will feed the input data into the model during training and evaluation. \n",
"The structure that must be returned is a pair, where the first element is a dict of the column names (features) mapped to a tensor of values, and the 2nd element is a tensor of values representing the answers (labels). Recall that a tensor is just a general term for an n-dimensional array.\n",
"\n",
"This could be represented as: `map(column_name => [Tensor of values]) , [Tensor of labels])`\n",
"\n",
"More concretely, for this particular dataset, something like this:\n",
"\n",
" { \n",
" 'age': [ 39, 50, 38, 53, 28, … ], \n",
" 'marital_status': [ 'Married-civ-spouse', 'Never-married', 'Widowed', 'Widowed' … ],\n",
" ...\n",
" 'gender': ['Male', 'Female', 'Male', 'Male', 'Female',, … ], \n",
" } , \n",
" [ 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1]\n",
" \n",
"Additionally, we define which columns of the input data we will treat as categorical vs continuous, using the global `CATEGORICAL_COLUMNS`.\n",
"\n",
"You can try different values for `BATCH_SIZE` to see how they impact your results"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input function configured\n"
]
}
],
"source": [
"BATCH_SIZE = 40\n",
"\n",
"def generate_input_fn(filename, num_epochs=None, shuffle=True, batch_size=BATCH_SIZE):\n",
" df = pd.read_csv(filename, header=None, names=COLUMNS)\n",
" labels = df[\"income_bracket\"].apply(lambda x: \">50K\" in x).astype(int)\n",
" del df[\"fnlwgt\"] # Unused column\n",
" del df[\"income_bracket\"] # Labels column, already saved to labels variable\n",
" \n",
" return tf.estimator.inputs.pandas_input_fn(\n",
" x=df,\n",
" y=labels,\n",
" batch_size=batch_size,\n",
" num_epochs=num_epochs,\n",
" shuffle=shuffle)\n",
"\n",
"print('input function configured')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create Feature Columns\n",
"This section configures the model with the information about the model. There are many parameters here to experiment with to see how they affect the accuracy.\n",
"\n",
"This is the bulk of the time and energy that is often spent on making a machine learning model work, called *feature selection* or *feature engineering*. We choose the features (columns) we will use for training, and apply any additional transformations to them as needed. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sparse Columns\n",
"First we build the sparse columns.\n",
"\n",
"Use `sparse_column_with_keys()` for columns that we know all possible values for.\n",
"\n",
"Use `sparse_column_with_hash_bucket()` for columns that we want the the library to automatically map values for us."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Categorical columns configured\n"
]
}
],
"source": [
"# The layers module contains many utilities for creating feature columns.\n",
"\n",
"# Categorical base columns.\n",
"gender = tf.feature_column.categorical_column_with_vocabulary_list(key=\"gender\", \n",
" vocabulary_list=[\"female\", \"male\"])\n",
"race = tf.feature_column.categorical_column_with_vocabulary_list(key=\"race\",\n",
" vocabulary_list=[\"Amer-Indian-Eskimo\",\n",
" \"Asian-Pac-Islander\",\n",
" \"Black\", \"Other\",\n",
" \"White\"])\n",
"\n",
"education = tf.feature_column.categorical_column_with_hash_bucket(\n",
" \"education\", hash_bucket_size=1000)\n",
"marital_status = tf.feature_column.categorical_column_with_hash_bucket(\n",
" \"marital_status\", hash_bucket_size=100)\n",
"relationship = tf.feature_column.categorical_column_with_hash_bucket(\n",
" \"relationship\", hash_bucket_size=100)\n",
"workclass = tf.feature_column.categorical_column_with_hash_bucket(\n",
" \"workclass\", hash_bucket_size=100)\n",
"occupation = tf.feature_column.categorical_column_with_hash_bucket(\n",
" \"occupation\", hash_bucket_size=1000)\n",
"native_country = tf.feature_column.categorical_column_with_hash_bucket(\n",
" \"native_country\", hash_bucket_size=1000)\n",
"\n",
"print('Categorical columns configured')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Continuous columns\n",
"Second, configure the real-valued columns using `real_valued_column()`. "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Continuous columns configured\n"
]
}
],
"source": [
"# Continuous base columns.\n",
"age = tf.feature_column.numeric_column(\"age\")\n",
"education_num = tf.feature_column.numeric_column(\"education_num\")\n",
"capital_gain = tf.feature_column.numeric_column(\"capital_gain\")\n",
"capital_loss = tf.feature_column.numeric_column(\"capital_loss\")\n",
"hours_per_week = tf.feature_column.numeric_column(\"hours_per_week\")\n",
"\n",
"print('Continuous columns configured')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Transformations\n",
"Now for the interesting stuff. We will employ a couple of techniques to get even more out of the data.\n",
" \n",
"* **bucketizing** turns what would have otherwise been a continuous feature into a categorical one. \n",
"* **feature crossing** allows us to compute a model weight for specific pairings across columns, rather than learning them as independently. This essentially encodes related columns together, for situations where having 2 (or more) columns being certain values is meaningful. \n",
"\n",
"Only categorical features can be crossed. This is one reason why age has been bucketized.\n",
"\n",
"For example, crossing education and occupation would enable the model to learn about: \n",
"\n",
" education=\"Bachelors\" AND occupation=\"Exec-managerial\"\n",
"\n",
"or perhaps \n",
"\n",
" education=\"Bachelors\" AND occupation=\"Craft-repair\"\n",
"\n",
"We do a few combined features (feature crosses) here. \n",
"\n",
"Add your own, based on your intuitions about the dataset, to try to improve on the model!"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Transformations complete\n"
]
}
],
"source": [
"# Transformations.\n",
"age_buckets = tf.feature_column.bucketized_column(\n",
" age, boundaries=[ 18, 25, 30, 35, 40, 45, 50, 55, 60, 65 ])\n",
"\n",
"education_occupation = tf.feature_column.crossed_column(\n",
" [\"education\", \"occupation\"], hash_bucket_size=int(1e4))\n",
"\n",
"age_race_occupation = tf.feature_column.crossed_column(\n",
" [age_buckets, \"race\", \"occupation\"], hash_bucket_size=int(1e6))\n",
"\n",
"country_occupation = tf.feature_column.crossed_column(\n",
" [\"native_country\", \"occupation\"], hash_bucket_size=int(1e4))\n",
"\n",
"print('Transformations complete')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Group feature columns into 2 objects\n",
"\n",
"The wide columns are the sparse, categorical columns that we specified, as well as our hashed, bucket, and feature crossed columns. \n",
"\n",
"The deep columns are composed of embedded categorical columns along with the continuous real-valued columns. **Column embeddings** transform a sparse, categorical tensor into a low-dimensional and dense real-valued vector. The embedding values are also trained along with the rest of the model. For more information about embeddings, see the TensorFlow tutorial on [Vector Representations Words](https://www.tensorflow.org/tutorials/word2vec/), or [Word Embedding](https://en.wikipedia.org/wiki/Word_embedding) on Wikipedia.\n",
"\n",
"The higher the dimension of the embedding is, the more degrees of freedom the model will have to learn the representations of the features. We are starting with an 8-dimension embedding for simplicity, but later you can come back and increase the dimensionality if you wish.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"wide and deep columns configured\n"
]
}
],
"source": [
"# Wide columns and deep columns.\n",
"wide_columns = [gender, race, native_country,\n",
" education, occupation, workclass,\n",
" marital_status, relationship,\n",
" age_buckets, education_occupation,\n",
" age_race_occupation, country_occupation]\n",
"\n",
"deep_columns = [\n",
" # Multi-hot indicator columns for columns with fewer possibilities\n",
" tf.feature_column.indicator_column(workclass),\n",
" tf.feature_column.indicator_column(marital_status),\n",
" tf.feature_column.indicator_column(gender),\n",
" tf.feature_column.indicator_column(relationship),\n",
" tf.feature_column.indicator_column(race),\n",
" # Embeddings for categories with more possibilities\n",
" tf.feature_column.embedding_column(education, dimension=8),\n",
" tf.feature_column.embedding_column(native_country, dimension=8),\n",
" tf.feature_column.embedding_column(occupation, dimension=8),\n",
" # Numerical columns\n",
" age,\n",
" education_num,\n",
" capital_gain,\n",
" capital_loss,\n",
" hours_per_week,\n",
"]\n",
"\n",
"print('wide and deep columns configured')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create the model\n",
"\n",
"You can train either a \"wide\" model, a \"deep\" model, or a \"wide and deep\" model, using the classifiers below. Try each one and see what kind of results you get.\n",
"\n",
"* **Wide**: Linear Classifier\n",
"* **Deep**: Deep Neural Net Classifier\n",
"* **Wide & Deep**: Combined Linear and Deep Classifier\n",
"\n",
"The `hidden_units` or `dnn_hidden_units` argument is to specify the size of each layer of the deep portion of the network. For example, `[12, 20, 15]` would create a network with the first layer of size 12, the second layer of size 20, and a third layer of size 15."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model directory = models/model_WIDE_AND_DEEP_1508857607\n",
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_tf_random_seed': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_save_checkpoints_steps': None, '_model_dir': 'models/model_WIDE_AND_DEEP_1508857607', '_save_summary_steps': 100}\n",
"estimator built\n"
]
}
],
"source": [
"def create_model_dir(model_type):\n",
" return 'models/model_' + model_type + '_' + str(int(time.time()))\n",
"\n",
"# If new_model=False, pass in the desired model_dir \n",
"def get_model(model_type, new_model=False, model_dir=None):\n",
" if new_model or model_dir is None:\n",
" model_dir = create_model_dir(model_type) # Comment out this line to continue training a existing model\n",
" print(\"Model directory = %s\" % model_dir)\n",
" \n",
" m = None\n",
" \n",
" # Linear Classifier\n",
" if model_type == 'WIDE':\n",
" m = tf.estimator.LinearClassifier(\n",
" model_dir=model_dir, \n",
" feature_columns=wide_columns)\n",
"\n",
" # Deep Neural Net Classifier\n",
" if model_type == 'DEEP':\n",
" m = tf.estimator.DNNClassifier(\n",
" model_dir=model_dir,\n",
" feature_columns=deep_columns,\n",
" hidden_units=[100, 50])\n",
"\n",
" # Combined Linear and Deep Classifier\n",
" if model_type == 'WIDE_AND_DEEP':\n",
" m = tf.estimator.DNNLinearCombinedClassifier(\n",
" model_dir=model_dir,\n",
" linear_feature_columns=wide_columns,\n",
" dnn_feature_columns=deep_columns,\n",
" dnn_hidden_units=[100, 70, 50, 25])\n",
" \n",
" print('estimator built')\n",
" \n",
" return m, model_dir\n",
" \n",
"MODEL_TYPE = 'WIDE_AND_DEEP'\n",
"model_dir = create_model_dir(model_type=MODEL_TYPE)\n",
"m, model_dir = get_model(model_type = MODEL_TYPE, model_dir=model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fit the model (train it)\n",
"\n",
"Run `train()` to train the model. You can experiment with the `train_steps` and `BATCH_SIZE` parameters.\n",
"\n",
"This can take some time, depending on the values chosen for `train_steps` and `BATCH_SIZE`.\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Saving checkpoints for 1 into models/model_WIDE_AND_DEEP_1508857607/model.ckpt.\n",
"INFO:tensorflow:loss = 136.401, step = 1\n",
"INFO:tensorflow:global_step/sec: 120.726\n",
"INFO:tensorflow:loss = 17.487, step = 101 (0.835 sec)\n",
"INFO:tensorflow:global_step/sec: 84.6235\n",
"INFO:tensorflow:loss = 44.0411, step = 201 (1.180 sec)\n",
"INFO:tensorflow:global_step/sec: 84.1638\n",
"INFO:tensorflow:loss = 17.4602, step = 301 (1.185 sec)\n",
"INFO:tensorflow:global_step/sec: 84.1106\n",
"INFO:tensorflow:loss = 20.2271, step = 401 (1.194 sec)\n",
"INFO:tensorflow:global_step/sec: 95.5922\n",
"INFO:tensorflow:loss = 16.7382, step = 501 (1.041 sec)\n",
"INFO:tensorflow:global_step/sec: 112.899\n",
"INFO:tensorflow:loss = 16.1211, step = 601 (0.891 sec)\n",
"INFO:tensorflow:global_step/sec: 82.642\n",
"INFO:tensorflow:loss = 17.3099, step = 701 (1.207 sec)\n",
"INFO:tensorflow:global_step/sec: 83.0798\n",
"INFO:tensorflow:loss = 14.5654, step = 801 (1.204 sec)\n",
"INFO:tensorflow:global_step/sec: 84.2676\n",
"INFO:tensorflow:loss = 20.3173, step = 901 (1.188 sec)\n",
"INFO:tensorflow:Saving checkpoints for 1000 into models/model_WIDE_AND_DEEP_1508857607/model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 13.9233.\n",
"training done\n",
"CPU times: user 46.9 s, sys: 9.18 s, total: 56 s\n",
"Wall time: 43.6 s\n"
]
}
],
"source": [
"%%time \n",
"\n",
"train_file = str(\"adult.data.csv\") \n",
"# \"gs://cloudml-public/census/data/adult.data.csv\"\n",
"# storage.googleapis.com/cloudml-public/census/data/adult.data.csv\n",
"\n",
"m.train(input_fn=generate_input_fn(train_file), \n",
" steps=1000)\n",
"\n",
"print('training done')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate the accuracy of the model\n",
"Let's see how the model did. We will evaluate all the test data."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.\n",
"WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.\n",
"INFO:tensorflow:Starting evaluation at 2017-10-24-15:07:36\n",
"INFO:tensorflow:Restoring parameters from models/model_WIDE_AND_DEEP_1508857607/model.ckpt-1000\n",
"INFO:tensorflow:Finished evaluation at 2017-10-24-15:07:43\n",
"INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.835668, accuracy_baseline = 0.763792, auc = 0.880099, auc_precision_recall = 0.728618, average_loss = 0.3701, global_step = 1000, label/mean = 0.236208, loss = 14.8022, prediction/mean = 0.257514\n",
"evaluate done\n",
"\n",
"Accuracy: 0.835668\n"
]
}
],
"source": [
"test_file = str(\"adult.test.csv\") \n",
"# \"gs://cloudml-public/census/data/adult.test.csv\"\n",
"# storage.googleapis.com/cloudml-public/census/data/adult.test.csv\n",
"\n",
"results = m.evaluate(input_fn=generate_input_fn(test_file, num_epochs=1, shuffle=False), \n",
" steps=None)\n",
"print('evaluate done')\n",
"print('\\nAccuracy: %s' % results['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Make a prediction"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education_num</th>\n",
" <th>marital_status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital_gain</th>\n",
" <th>capital_loss</th>\n",
" <th>hours_per_week</th>\n",
" <th>native_country</th>\n",
" <th>income_bracket</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8000</th>\n",
" <td>35</td>\n",
" <td>Private</td>\n",
" <td>399455</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-spouse-absent</td>\n",
" <td>Other-service</td>\n",
" <td>Unmarried</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>52</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8001</th>\n",
" <td>37</td>\n",
" <td>Private</td>\n",
" <td>52630</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Craft-repair</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8002</th>\n",
" <td>42</td>\n",
" <td>Private</td>\n",
" <td>124692</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>45</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8003</th>\n",
" <td>21</td>\n",
" <td>Private</td>\n",
" <td>278254</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Never-married</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Own-child</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8004</th>\n",
" <td>40</td>\n",
" <td>Private</td>\n",
" <td>162098</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Divorced</td>\n",
" <td>Adm-clerical</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education_num \\\n",
"8000 35 Private 399455 HS-grad 9 \n",
"8001 37 Private 52630 Some-college 10 \n",
"8002 42 Private 124692 Bachelors 13 \n",
"8003 21 Private 278254 Some-college 10 \n",
"8004 40 Private 162098 HS-grad 9 \n",
"\n",
" marital_status occupation relationship race \\\n",
"8000 Married-spouse-absent Other-service Unmarried White \n",
"8001 Married-civ-spouse Craft-repair Husband White \n",
"8002 Married-civ-spouse Exec-managerial Husband White \n",
"8003 Never-married Handlers-cleaners Own-child Black \n",
"8004 Divorced Adm-clerical Not-in-family White \n",
"\n",
" gender capital_gain capital_loss hours_per_week native_country \\\n",
"8000 Female 0 0 52 United-States \n",
"8001 Male 0 0 40 United-States \n",
"8002 Male 0 0 45 United-States \n",
"8003 Male 0 0 40 United-States \n",
"8004 Female 0 0 40 United-States \n",
"\n",
" income_bracket \n",
"8000 <=50K \n",
"8001 <=50K \n",
"8002 >50K \n",
"8003 <=50K \n",
"8004 <=50K "
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a dataframe to wrap in an input function\n",
"df = pd.read_csv(\"adult.test.csv\", header=None, names=COLUMNS)\n",
"data_predict = df.iloc[8000:8005]\n",
"data_predict.head() # show this before deleting, so we know what the labels are"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# If you run this cell twice, it will give an error since you'd be deleting something that was already gone\n",
"del data_predict[\"fnlwgt\"] # Unused column\n",
"del data_predict[\"income_bracket\"] # remove the label column"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from models/model_WIDE_AND_DEEP_1508857607/model.ckpt-1000\n",
"Predictions: ['0'] with probabilities [ 0.94370598 0.05629405]\n",
"\n",
"Predictions: ['0'] with probabilities [ 0.68004459 0.31995538]\n",
"\n",
"Predictions: ['1'] with probabilities [ 0.46427986 0.53572011]\n",
"\n",
"Predictions: ['0'] with probabilities [ 0.94195259 0.05804734]\n",
"\n",
"Predictions: ['0'] with probabilities [ 0.90183949 0.09816054]\n",
"\n"
]
}
],
"source": [
"predict_input_fn = tf.estimator.inputs.pandas_input_fn(\n",
" x=data_predict,\n",
" batch_size=1,\n",
" num_epochs=1,\n",
" shuffle=False)\n",
"\n",
"predictions = m.predict(input_fn=predict_input_fn)\n",
"\n",
"for prediction in predictions:\n",
" print(\"Predictions: {} with probabilities {}\\n\".format(prediction[\"classes\"], prediction[\"probabilities\"]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Export a model optimized for inference\n",
"We can upload our trained model to the Cloud Machine Learning Engine's Prediction Service, which will take care of serving our model and scaling it. The code below exports our trained model to a `saved_model.pb` file and a `variables` folder where the trained weights are stored. This format is also compatible with TensorFlow Serving.\n",
"\n",
"The `export_savedmodel()` function expects a `serving_input_receiver_fn()`, which returns the mapping from the data that the Prediction Service passes in to the data that should be fed into the trained TensorFlow prediction graph."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from models/model_WIDE_AND_DEEP_1508857607/model.ckpt-1000\n",
"INFO:tensorflow:Assets added to graph.\n",
"INFO:tensorflow:No assets to write.\n",
"INFO:tensorflow:SavedModel written to: models/model_WIDE_AND_DEEP_1508857607/export/1508857859/saved_model.pb\n"
]
},
{
"data": {
"text/plain": [
"'models/model_WIDE_AND_DEEP_1508857607/export/1508857859'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def column_to_dtype(column):\n",
" if column in CATEGORICAL_COLUMNS:\n",
" return tf.string\n",
" else:\n",
" return tf.float32\n",
" \n",
"feature_spec = {\n",
" column: tf.FixedLenFeature(shape=[1], dtype=column_to_dtype(column))\n",
" for column in FEATURE_COLUMNS\n",
"}\n",
"serving_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)\n",
"m.export_savedmodel(export_dir_base=model_dir + '/export', \n",
" serving_input_receiver_fn=serving_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Conclusions\n",
"\n",
"In this Juypter notebook, we have configured, created, and evaluated a Wide & Deep machine learning model, that combines the powers of a Linear Classifier with a Deep Neural Network, using TensorFlow's tf.Estimator module.\n",
"\n",
"With this working example in your toolbelt, you are ready to explore the wide (and deep) world of machine learning with TensorFlow! Some ideas to help you get going:\n",
"* Change the features we used today. Which columns do you think are correlated and should be crossed? Which ones do you think are just adding noise and could be removed to clean up the model?\n",
"* Swap in an entirely new dataset! There are many dataset available on the web, or use a dataset you possess! Check out https://archive.ics.uci.edu/ml to find your own dataset. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@jeaper1986
Copy link

I got an error in last step:
NotFoundError: Failed to create a directory: models/model_WIDE_AND_DEEP_1528875542/export\temp-b'1528876228'; No such file or directory
The tf version was 1.5, all others were same.
Could you help to check? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment