Last active
July 19, 2016 22:01
-
-
Save tilarids/a47fd736d9e61750072a5a83adcf1881 to your computer and use it in GitHub Desktop.
Minimized TensorFlow regression model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
import time | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.python.training import saver as saver_lib | |
from tensorflow.python.training import summary_io | |
# Basic model parameters as external flags. | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') | |
flags.DEFINE_integer('num_steps', 10000, 'Number of steps to run trainer.') | |
flags.DEFINE_integer('hidden_layers', 5, 'Number of hidden layers.') | |
flags.DEFINE_integer('hidden_layers_size', 64, 'Number of units in hidden layers.') | |
flags.DEFINE_integer('batch_size', 10, 'Batch size. ') | |
def input_pipeline(batch_size): | |
label = tf.constant([[10.1], [1.3], [12.5], [372.0], [443.0], [0.3], [23.45], [1.42], [0.0], [56.7]], dtype=tf.float32) | |
example = tf.constant([[11.0], [11.0], [13.0], [370.0], [422.0], [125.0], [25.0], [15.0], [100.0], [55.0]], dtype=tf.float32) | |
example, label = tf.train.slice_input_producer( | |
[example, label], | |
shuffle=True, | |
capacity=30+3*batch_size) | |
example_batch, label_batch = tf.train.batch( | |
[example, label], | |
batch_size=batch_size, | |
capacity=30+3*batch_size) | |
return example_batch, label_batch | |
def input_fn(): | |
return input_pipeline(FLAGS.batch_size) | |
def model_fn(features, target): | |
# features = tf.Print(features, [features], message="This is features: ") | |
# target = tf.Print(target, [target], message="This is target: ") | |
# return {}, tf.constant(0), tf.group(features, target) | |
out = tf.contrib.layers.stack(features, | |
tf.contrib.layers.fully_connected, | |
[FLAGS.hidden_layers_size] * FLAGS.hidden_layers, | |
activation_fn=tf.nn.relu) | |
prediction = tf.contrib.layers.fully_connected( | |
out, | |
num_outputs=1, | |
activation_fn=None) | |
loss = tf.nn.l2_loss(prediction - target) | |
train_op = tf.contrib.layers.optimize_loss( | |
loss, | |
tf.contrib.framework.get_global_step(), | |
optimizer='Adagrad', | |
learning_rate=FLAGS.learning_rate) | |
rmse, update_op = tf.contrib.metrics.streaming_root_mean_squared_error(prediction, target) | |
tf.scalar_summary("rmse", rmse) | |
return prediction, loss, tf.group(train_op, update_op) | |
def create_eval_metrics(): | |
return {'rmse': tf.contrib.metrics.streaming_root_mean_squared_error} | |
def main(_): | |
model_dir = time.strftime("/tmp/tf_regression/%m-%d-%H-%M-%S", time.gmtime()) | |
estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=model_dir) | |
estimator.fit( | |
input_fn=input_fn, | |
steps=FLAGS.num_steps, | |
monitors=[]) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment