Last active
May 4, 2018 14:26
-
-
Save fede1608/6a4ead84189ee6f1bc3cf6864ebcf29d to your computer and use it in GitHub Desktop.
ML model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import shutil | |
from absl import app as absl_app | |
from absl import flags | |
import tensorflow as tf # pylint: disable=g-bad-import-order | |
from official.utils.flags import core as flags_core | |
from official.utils.logs import hooks_helper | |
from official.utils.misc import model_helpers | |
_CSV_COLUMNS = [ | |
'identitylevel', 'newuser', 'hasphone', 'userage', 'client', 'customLocation', | |
'gps', 'locale', 'timeZone', 'textPaste', 'hasPhoto', 'category', 'textLarge', | |
'isduplicate', 'result' | |
] | |
_CSV_COLUMN_DEFAULTS = [[1], [0], [0], [0], [''], [0], | |
[0], [''], [''], [0], [0], [1], [0], | |
[0], [0]] | |
_NUM_EXAMPLES = { | |
'train': 1500, | |
'validation': 500, | |
} | |
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'} | |
def define_wide_deep_flags(): | |
"""Add supervised learning flags, as well as wide-deep model type.""" | |
flags_core.define_base() | |
flags.adopt_module_key_flags(flags_core) | |
flags.DEFINE_enum( | |
name="model_type", short_name="mt", default="wide_deep", | |
enum_values=['wide', 'deep', 'wide_deep'], | |
help="Select model topology.") | |
flags_core.set_defaults(data_dir='/Users/selio01/Documents/TensorFlow/virenv/', | |
model_dir='/Users/selio01/Documents/TensorFlow/virenv/modelResult', | |
train_epochs=40, | |
epochs_between_evals=2, | |
batch_size=40) | |
def build_model_columns(): | |
"""Builds a set of wide and deep feature columns.""" | |
# Continuous columns | |
identitylevel = tf.feature_column.numeric_column('identitylevel') | |
newuser = tf.feature_column.numeric_column('newuser') | |
hasphone = tf.feature_column.numeric_column('hasphone') | |
userage = tf.feature_column.numeric_column('userage') | |
customLocation = tf.feature_column.numeric_column('customLocation') | |
gps = tf.feature_column.numeric_column('gps') | |
textPaste = tf.feature_column.numeric_column('textPaste') | |
hasPhoto = tf.feature_column.numeric_column('hasPhoto') | |
textLarge = tf.feature_column.numeric_column('textLarge') | |
isduplicate = tf.feature_column.numeric_column('isduplicate') | |
locale = tf.feature_column.categorical_column_with_vocabulary_list( | |
'locale', [ | |
"pt_PT", | |
"pt_BR", | |
"pt-PT", | |
"uk_UA", | |
"es", | |
"fr_FR", | |
"es-AR", | |
"pt-BR", | |
"es-ES", | |
"en-US", | |
"en_GB", | |
"nodata", | |
"es_ES", | |
"ro_RO", | |
"it_IT", | |
"en_US", | |
"pt_AO", | |
"bg_BG", | |
"tr_TR", | |
"es_AR", | |
"es_VE", | |
"ru_RU", | |
"en-EN" | |
]) | |
client = tf.feature_column.categorical_column_with_vocabulary_list( | |
'client', [ | |
"AND", | |
"PWA", | |
"WEB", | |
"iOS", | |
"nodata" | |
]) | |
timeZone = tf.feature_column.categorical_column_with_vocabulary_list( | |
'timeZone', [ | |
"nodata", | |
"WET Atlantic\/Madeira", | |
"WET Europe\/Lisbon", | |
"GMT+00:00 Europe\/Lisbon", | |
"GMT+00:00 Europe\/London", | |
"AZOT Atlantic\/Azores", | |
"CET Europe\/Luxembourg", | |
"CET Europe\/Belgrade", | |
"GMT-11:00 Pacific\/Midway", | |
"CET Europe\/Madrid", | |
"CET Europe\/Amsterdam", | |
"CET Europe\/Lisbon", | |
"GMT-08:00 America\/Tijuana", | |
"GMT+01:00 Europe\/Amsterdam", | |
"GMT+07:00 Asia\/Ho_Chi_Minh", | |
"GMT+00:00 Atlantic\/Madeira", | |
"GMT+07:00 Asia\/Saigon", | |
"GMT+01:00 Europe\/Madrid", | |
"CET Europe\/Brussels", | |
"WET Africa\/Casablanca", | |
"GMT+00:00 GMT" | |
]) | |
category = tf.feature_column.categorical_column_with_vocabulary_list( | |
'category', [ | |
"1", | |
"31", | |
"20", | |
"304", | |
"40", | |
"292", | |
"29", | |
"288", | |
"295", | |
"360", | |
"998", | |
"15", | |
"168", | |
"119", | |
"124", | |
"16", | |
"316", | |
"238", | |
"3", | |
"194", | |
"138", | |
"779", | |
"167", | |
"1003", | |
"199", | |
"17", | |
"196", | |
"993", | |
"343", | |
"191", | |
"34", | |
"165", | |
"193", | |
"689", | |
"795", | |
"997", | |
"394", | |
"264", | |
"163", | |
"791", | |
"22", | |
"202", | |
"164", | |
"170", | |
"55", | |
"214", | |
"18", | |
"732", | |
"216", | |
"197", | |
"188", | |
"178", | |
"38", | |
"133", | |
"198", | |
"355", | |
"182", | |
"159", | |
"158", | |
"291", | |
"10", | |
"36", | |
"179", | |
"1002", | |
"740", | |
"357", | |
"729", | |
"161", | |
"792", | |
"784", | |
"176", | |
"672", | |
"162", | |
"1001", | |
"37", | |
"317", | |
"187", | |
"673", | |
"46", | |
"788", | |
"358", | |
"157", | |
"33", | |
"175", | |
"789", | |
"166", | |
"262", | |
"289", | |
"282", | |
"781", | |
"28", | |
"670", | |
"201", | |
"803", | |
"200", | |
"263", | |
"353", | |
"56", | |
"192", | |
"42", | |
"290", | |
"195", | |
"681", | |
"812", | |
"691", | |
"183", | |
"892", | |
"356", | |
"51", | |
"319", | |
"685", | |
"395", | |
"215", | |
"354", | |
"189", | |
"741", | |
"277", | |
"785", | |
"796", | |
"794", | |
"27", | |
"171", | |
"26", | |
"177", | |
"180", | |
"14", | |
"787", | |
"783", | |
"184", | |
"23", | |
"693", | |
"35", | |
"391", | |
"999", | |
"172", | |
"790", | |
"283", | |
"671", | |
"266", | |
"169" | |
]) | |
# Wide columns and deep columns. | |
base_columns = [ | |
identitylevel, newuser, hasphone, userage, customLocation, | |
gps,textPaste,hasPhoto,textLarge,isduplicate | |
] | |
crossed_columns = [ | |
tf.feature_column.crossed_column( | |
['locale', 'timeZone'], hash_bucket_size=1000) | |
] | |
wide_columns = base_columns + crossed_columns | |
deep_columns = [ | |
identitylevel, newuser, hasphone, userage, customLocation, | |
gps,textPaste,hasPhoto,textLarge,isduplicate, | |
tf.feature_column.indicator_column(client), | |
tf.feature_column.indicator_column(locale), | |
tf.feature_column.indicator_column(timeZone), | |
tf.feature_column.indicator_column(category) | |
] | |
return wide_columns, deep_columns | |
def build_estimator(model_dir, model_type): | |
"""Build an estimator appropriate for the given model type.""" | |
wide_columns, deep_columns = build_model_columns() | |
hidden_units = [100, 75, 50, 25] | |
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which | |
# trains faster than GPU for this model. | |
run_config = tf.estimator.RunConfig().replace( | |
session_config=tf.ConfigProto(device_count={'GPU': 0})) | |
if model_type == 'wide': | |
return tf.estimator.LinearClassifier( | |
model_dir=model_dir, | |
feature_columns=wide_columns, | |
config=run_config) | |
elif model_type == 'deep': | |
return tf.estimator.DNNClassifier( | |
model_dir=model_dir, | |
feature_columns=deep_columns, | |
hidden_units=hidden_units, | |
config=run_config) | |
else: | |
return tf.estimator.DNNLinearCombinedClassifier( | |
model_dir=model_dir, | |
linear_feature_columns=wide_columns, | |
dnn_feature_columns=deep_columns, | |
dnn_hidden_units=hidden_units, | |
config=run_config) | |
def input_fn(data_file, num_epochs, shuffle, batch_size): | |
"""Generate an input function for the Estimator.""" | |
assert tf.gfile.Exists(data_file), ( | |
'%s not found. Please make sure you have run data_download.py and ' | |
'set the --data_dir argument to the correct path.' % data_file) | |
def parse_csv(value): | |
print('Parsing', data_file) | |
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) | |
features = dict(zip(_CSV_COLUMNS, columns)) | |
labels = features.pop('result') | |
return features, tf.equal(labels, 1) | |
# Extract lines from input files using the Dataset API. | |
dataset = tf.data.TextLineDataset(data_file) | |
if shuffle: | |
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train']) | |
dataset = dataset.map(parse_csv, num_parallel_calls=5) | |
# We call repeat after shuffling, rather than before, to prevent separate | |
# epochs from blending together. | |
dataset = dataset.repeat(num_epochs) | |
dataset = dataset.batch(batch_size) | |
return dataset | |
def export_model(model, model_type, export_dir): | |
"""Export to SavedModel format. | |
Args: | |
model: Estimator object | |
model_type: string indicating model type. "wide", "deep" or "wide_deep" | |
export_dir: directory to export the model. | |
""" | |
wide_columns, deep_columns = build_model_columns() | |
if model_type == 'wide': | |
columns = wide_columns | |
elif model_type == 'deep': | |
columns = deep_columns | |
else: | |
columns = wide_columns + deep_columns | |
feature_spec = tf.feature_column.make_parse_example_spec(columns) | |
example_input_fn = ( | |
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)) | |
model.export_savedmodel(export_dir, example_input_fn) | |
def run_wide_deep(flags_obj): | |
"""Run Wide-Deep training and eval loop. | |
Args: | |
flags_obj: An object containing parsed flag values. | |
""" | |
# Clean up the model directory if present | |
shutil.rmtree(flags_obj.model_dir, ignore_errors=True) | |
model = build_estimator(flags_obj.model_dir, flags_obj.model_type) | |
train_file = os.path.join(flags_obj.data_dir, 'data_train1500.csv') | |
test_file = os.path.join(flags_obj.data_dir, 'data_test500.csv') | |
# Train and evaluate the model every `flags.epochs_between_evals` epochs. | |
def train_input_fn(): | |
return input_fn( | |
train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size) | |
def eval_input_fn(): | |
return input_fn(test_file, 1, False, flags_obj.batch_size) | |
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '') | |
train_hooks = hooks_helper.get_train_hooks( | |
flags_obj.hooks, batch_size=flags_obj.batch_size, | |
tensors_to_log={'average_loss': loss_prefix + 'head/truediv', | |
'loss': loss_prefix + 'head/weighted_loss/Sum'}) | |
# Train and evaluate the model every `flags.epochs_between_evals` epochs. | |
for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals): | |
model.train(input_fn=train_input_fn, hooks=train_hooks) | |
results = model.evaluate(input_fn=eval_input_fn) | |
# Display evaluation metrics | |
print('Results at epoch', (n + 1) * flags_obj.epochs_between_evals) | |
print('-' * 60) | |
for key in sorted(results): | |
print('%s: %s' % (key, results[key])) | |
if model_helpers.past_stop_threshold( | |
flags_obj.stop_threshold, results['accuracy']): | |
break | |
# Export the model | |
if flags_obj.export_dir is not None: | |
export_model(model, flags_obj.model_type, flags_obj.export_dir) | |
def main(_): | |
run_wide_deep(flags.FLAGS) | |
if __name__ == '__main__': | |
tf.logging.set_verbosity(tf.logging.INFO) | |
define_wide_deep_flags() | |
absl_app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment