Skip to content

Instantly share code, notes, and snippets.

@fede1608
Last active May 4, 2018 14:26
Show Gist options
  • Save fede1608/6a4ead84189ee6f1bc3cf6864ebcf29d to your computer and use it in GitHub Desktop.
Save fede1608/6a4ead84189ee6f1bc3cf6864ebcf29d to your computer and use it in GitHub Desktop.
ML model
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers
_CSV_COLUMNS = [
'identitylevel', 'newuser', 'hasphone', 'userage', 'client', 'customLocation',
'gps', 'locale', 'timeZone', 'textPaste', 'hasPhoto', 'category', 'textLarge',
'isduplicate', 'result'
]
_CSV_COLUMN_DEFAULTS = [[1], [0], [0], [0], [''], [0],
[0], [''], [''], [0], [0], [1], [0],
[0], [0]]
_NUM_EXAMPLES = {
'train': 1500,
'validation': 500,
}
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_enum(
name="model_type", short_name="mt", default="wide_deep",
enum_values=['wide', 'deep', 'wide_deep'],
help="Select model topology.")
flags_core.set_defaults(data_dir='/Users/selio01/Documents/TensorFlow/virenv/',
model_dir='/Users/selio01/Documents/TensorFlow/virenv/modelResult',
train_epochs=40,
epochs_between_evals=2,
batch_size=40)
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous columns
identitylevel = tf.feature_column.numeric_column('identitylevel')
newuser = tf.feature_column.numeric_column('newuser')
hasphone = tf.feature_column.numeric_column('hasphone')
userage = tf.feature_column.numeric_column('userage')
customLocation = tf.feature_column.numeric_column('customLocation')
gps = tf.feature_column.numeric_column('gps')
textPaste = tf.feature_column.numeric_column('textPaste')
hasPhoto = tf.feature_column.numeric_column('hasPhoto')
textLarge = tf.feature_column.numeric_column('textLarge')
isduplicate = tf.feature_column.numeric_column('isduplicate')
locale = tf.feature_column.categorical_column_with_vocabulary_list(
'locale', [
"pt_PT",
"pt_BR",
"pt-PT",
"uk_UA",
"es",
"fr_FR",
"es-AR",
"pt-BR",
"es-ES",
"en-US",
"en_GB",
"nodata",
"es_ES",
"ro_RO",
"it_IT",
"en_US",
"pt_AO",
"bg_BG",
"tr_TR",
"es_AR",
"es_VE",
"ru_RU",
"en-EN"
])
client = tf.feature_column.categorical_column_with_vocabulary_list(
'client', [
"AND",
"PWA",
"WEB",
"iOS",
"nodata"
])
timeZone = tf.feature_column.categorical_column_with_vocabulary_list(
'timeZone', [
"nodata",
"WET Atlantic\/Madeira",
"WET Europe\/Lisbon",
"GMT+00:00 Europe\/Lisbon",
"GMT+00:00 Europe\/London",
"AZOT Atlantic\/Azores",
"CET Europe\/Luxembourg",
"CET Europe\/Belgrade",
"GMT-11:00 Pacific\/Midway",
"CET Europe\/Madrid",
"CET Europe\/Amsterdam",
"CET Europe\/Lisbon",
"GMT-08:00 America\/Tijuana",
"GMT+01:00 Europe\/Amsterdam",
"GMT+07:00 Asia\/Ho_Chi_Minh",
"GMT+00:00 Atlantic\/Madeira",
"GMT+07:00 Asia\/Saigon",
"GMT+01:00 Europe\/Madrid",
"CET Europe\/Brussels",
"WET Africa\/Casablanca",
"GMT+00:00 GMT"
])
category = tf.feature_column.categorical_column_with_vocabulary_list(
'category', [
"1",
"31",
"20",
"304",
"40",
"292",
"29",
"288",
"295",
"360",
"998",
"15",
"168",
"119",
"124",
"16",
"316",
"238",
"3",
"194",
"138",
"779",
"167",
"1003",
"199",
"17",
"196",
"993",
"343",
"191",
"34",
"165",
"193",
"689",
"795",
"997",
"394",
"264",
"163",
"791",
"22",
"202",
"164",
"170",
"55",
"214",
"18",
"732",
"216",
"197",
"188",
"178",
"38",
"133",
"198",
"355",
"182",
"159",
"158",
"291",
"10",
"36",
"179",
"1002",
"740",
"357",
"729",
"161",
"792",
"784",
"176",
"672",
"162",
"1001",
"37",
"317",
"187",
"673",
"46",
"788",
"358",
"157",
"33",
"175",
"789",
"166",
"262",
"289",
"282",
"781",
"28",
"670",
"201",
"803",
"200",
"263",
"353",
"56",
"192",
"42",
"290",
"195",
"681",
"812",
"691",
"183",
"892",
"356",
"51",
"319",
"685",
"395",
"215",
"354",
"189",
"741",
"277",
"785",
"796",
"794",
"27",
"171",
"26",
"177",
"180",
"14",
"787",
"783",
"184",
"23",
"693",
"35",
"391",
"999",
"172",
"790",
"283",
"671",
"266",
"169"
])
# Wide columns and deep columns.
base_columns = [
identitylevel, newuser, hasphone, userage, customLocation,
gps,textPaste,hasPhoto,textLarge,isduplicate
]
crossed_columns = [
tf.feature_column.crossed_column(
['locale', 'timeZone'], hash_bucket_size=1000)
]
wide_columns = base_columns + crossed_columns
deep_columns = [
identitylevel, newuser, hasphone, userage, customLocation,
gps,textPaste,hasPhoto,textLarge,isduplicate,
tf.feature_column.indicator_column(client),
tf.feature_column.indicator_column(locale),
tf.feature_column.indicator_column(timeZone),
tf.feature_column.indicator_column(category)
]
return wide_columns, deep_columns
def build_estimator(model_dir, model_type):
"""Build an estimator appropriate for the given model type."""
wide_columns, deep_columns = build_model_columns()
hidden_units = [100, 75, 50, 25]
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0}))
if model_type == 'wide':
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep':
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config)
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have run data_download.py and '
'set the --data_dir argument to the correct path.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('result')
return features, tf.equal(labels, 1)
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
return dataset
def export_model(model, model_type, export_dir):
"""Export to SavedModel format.
Args:
model: Estimator object
model_type: string indicating model type. "wide", "deep" or "wide_deep"
export_dir: directory to export the model.
"""
wide_columns, deep_columns = build_model_columns()
if model_type == 'wide':
columns = wide_columns
elif model_type == 'deep':
columns = deep_columns
else:
columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
model.export_savedmodel(export_dir, example_input_fn)
def run_wide_deep(flags_obj):
"""Run Wide-Deep training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present
shutil.rmtree(flags_obj.model_dir, ignore_errors=True)
model = build_estimator(flags_obj.model_dir, flags_obj.model_type)
train_file = os.path.join(flags_obj.data_dir, 'data_train1500.csv')
test_file = os.path.join(flags_obj.data_dir, 'data_test500.csv')
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def train_input_fn():
return input_fn(
train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
def eval_input_fn():
return input_fn(test_file, 1, False, flags_obj.batch_size)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, batch_size=flags_obj.batch_size,
tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
'loss': loss_prefix + 'head/weighted_loss/Sum'})
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
model.train(input_fn=train_input_fn, hooks=train_hooks)
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
print('Results at epoch', (n + 1) * flags_obj.epochs_between_evals)
print('-' * 60)
for key in sorted(results):
print('%s: %s' % (key, results[key]))
if model_helpers.past_stop_threshold(
flags_obj.stop_threshold, results['accuracy']):
break
# Export the model
if flags_obj.export_dir is not None:
export_model(model, flags_obj.model_type, flags_obj.export_dir)
def main(_):
run_wide_deep(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_wide_deep_flags()
absl_app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment