Created
May 6, 2018 12:17
-
-
Save codescv/50b34c3c84bf328e57ad814a47cd2c4e to your computer and use it in GitHub Desktop.
Single node tensorflow model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
_CSV_COLUMNS = [ | |
'age', 'workclass', 'fnlwgt', 'education', 'education_num', | |
'marital_status', 'occupation', 'relationship', 'race', 'gender', | |
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', | |
'income_bracket' | |
] | |
_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''], | |
[0], [0], [0], [''], ['']] | |
_NUM_EXAMPLES = { | |
'train': 32561, | |
'validation': 16281, | |
} | |
def build_model_columns(): | |
"""Builds a set of wide and deep feature columns.""" | |
# Continuous columns | |
age = tf.feature_column.numeric_column('age') | |
education_num = tf.feature_column.numeric_column('education_num') | |
capital_gain = tf.feature_column.numeric_column('capital_gain') | |
capital_loss = tf.feature_column.numeric_column('capital_loss') | |
hours_per_week = tf.feature_column.numeric_column('hours_per_week') | |
education = tf.feature_column.categorical_column_with_vocabulary_list( | |
'education', [ | |
'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college', | |
'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school', | |
'5th-6th', '10th', '1st-4th', 'Preschool', '12th']) | |
marital_status = tf.feature_column.categorical_column_with_vocabulary_list( | |
'marital_status', [ | |
'Married-civ-spouse', 'Divorced', 'Married-spouse-absent', | |
'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed']) | |
relationship = tf.feature_column.categorical_column_with_vocabulary_list( | |
'relationship', [ | |
'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried', | |
'Other-relative']) | |
workclass = tf.feature_column.categorical_column_with_vocabulary_list( | |
'workclass', [ | |
'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov', | |
'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked']) | |
# To show an example of hashing: | |
occupation = tf.feature_column.categorical_column_with_hash_bucket( | |
'occupation', hash_bucket_size=1000) | |
# Transformations. | |
age_buckets = tf.feature_column.bucketized_column( | |
age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) | |
# Wide columns and deep columns. | |
base_columns = [ | |
education, marital_status, relationship, workclass, occupation, | |
age_buckets, | |
] | |
crossed_columns = [ | |
tf.feature_column.crossed_column( | |
['education', 'occupation'], hash_bucket_size=1000), | |
tf.feature_column.crossed_column( | |
[age_buckets, 'education', 'occupation'], hash_bucket_size=1000), | |
] | |
wide_columns = base_columns + crossed_columns | |
deep_columns = [ | |
age, | |
education_num, | |
capital_gain, | |
capital_loss, | |
hours_per_week, | |
tf.feature_column.indicator_column(workclass), | |
tf.feature_column.indicator_column(education), | |
tf.feature_column.indicator_column(marital_status), | |
tf.feature_column.indicator_column(relationship), | |
# To show an example of embedding | |
tf.feature_column.embedding_column(occupation, dimension=8), | |
] | |
return wide_columns, deep_columns | |
def input_fn(data_file, num_epochs=None, shuffle=True, batch_size=128): | |
"""Generate an input function for the Estimator.""" | |
assert tf.gfile.Exists(data_file), ( | |
'%s not found. Please make sure you have run data_download.py and ' | |
'set the --data_dir argument to the correct path.' % data_file) | |
def parse_csv(value): | |
print('Parsing', data_file) | |
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) | |
features = dict(zip(_CSV_COLUMNS, columns)) | |
labels = features.pop('income_bracket') | |
return features, tf.equal(labels, '>50K') | |
# Extract lines from input files using the Dataset API. | |
dataset = tf.data.TextLineDataset(data_file) | |
if shuffle: | |
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train']) | |
dataset = dataset.map(parse_csv, num_parallel_calls=5) | |
# We call repeat after shuffling, rather than before, to prevent separate | |
# epochs from blending together. | |
dataset = dataset.repeat(num_epochs) | |
dataset = dataset.batch(batch_size) | |
return dataset | |
def build_model(filename): | |
wide_columns, deep_columns = build_model_columns() | |
features, labels = input_fn(filename).make_one_shot_iterator().get_next() | |
cols_to_vars = {} | |
logits = tf.feature_column.linear_model(features=features, feature_columns=wide_columns, cols_to_vars=cols_to_vars) | |
predictions = tf.reshape(tf.nn.sigmoid(logits), (-1,)) | |
loss = tf.losses.log_loss(labels=labels, predictions=predictions) | |
optimizer = tf.train.FtrlOptimizer(learning_rate=0.1, l1_regularization_strength=0.1, l2_regularization_strength=0.1) | |
train_op = optimizer.minimize(loss) | |
return { | |
'train': { | |
'train_op': train_op, | |
'loss': loss | |
}, | |
'init': { | |
'global': [tf.global_variables_initializer()], | |
'local': [tf.local_variables_initializer(), tf.tables_initializer()] | |
}, | |
'cols_to_vars': cols_to_vars | |
} | |
def main(): | |
# build graph | |
model = build_model(filename='census_data/adult.data') | |
# inspect graph variables | |
for col, var in model['cols_to_vars'].items(): | |
print('Column: ', col) | |
print('Variable:', var) | |
print('-' * 50) | |
# create session | |
with tf.Session() as sess: | |
sess.run(model['init']) | |
for step in range(1, 1000): | |
result = sess.run(model['train']) | |
print('step =', step, 'loss =', result['loss']) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment