Created
June 28, 2018 05:27
-
-
Save dusenberrymw/683a14212b665b4018a73561e698b7a5 to your computer and use it in GitHub Desktop.
Example ImageNet-style resnet training scenario with synthetic data and using the tf.Estimator API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Example ImageNet-style resnet training scenario with synthetic data. | |
Author: Mike Dusenberry | |
""" | |
import argparse | |
import sys | |
import tensorflow as tf | |
#class Model(object): | |
# """Model""" | |
# | |
# def __init__(self, data_format, params): | |
# """Create a model. | |
# | |
# Args: | |
# data_format: Either 'channels_first' or 'channels_last'. | |
# """ | |
# self.conv1 = tf.layers.Conv2D(params.c, 3, padding='same', data_format=data_format) | |
# self.fc = tf.layers.Dense | |
def model_fn(features, labels, mode, params): | |
"""The model function.""" | |
# tf model | |
# NOTE: it would be better to have a model class where the `call` function accepts a `training` | |
# boolean parameter that could be used for dropout, batch norm, etc. Then, this model code, as | |
# well as the loss and metrics, would go inside each of the `if` blocks. | |
x = features | |
x = tf.layers.conv2d(x, 8, 3, padding='same', data_format=params.data_format) | |
x = tf.layers.flatten(x) | |
logits = tf.layers.dense(x, FLAGS.c) | |
preds = tf.argmax(logits, axis=1) | |
if mode == tf.estimator.ModeKeys.PREDICT: | |
predictions = { | |
'classes': preds, | |
'probs': tf.nn.softmax(logits), | |
} | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
predictions=predictions, | |
export_outputs={ | |
'classify': tf.estimator.export.PredictOutput(predictions) | |
}) | |
# tf loss | |
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=labels, logits=logits)) | |
# metrics | |
acc = tf.metrics.accuracy(labels=labels, predictions=preds) | |
if mode == tf.estimator.ModeKeys.EVAL: | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
loss=loss, | |
eval_metric_ops={ | |
'accuracy': acc | |
}) | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
# tf optimizer | |
opt = tf.train.AdamOptimizer(FLAGS.lr) | |
opt = tf.contrib.estimator.TowerOptimizer(opt) | |
grads_and_vars = opt.compute_gradients(loss) | |
train_op = opt.apply_gradients(grads_and_vars, tf.train.get_or_create_global_step()) | |
# we will log these to the console | |
#tf.identity(loss, 'train_loss') # estimators automatically track loss | |
tf.identity(acc[1], 'acc') | |
# tensorboard summaries | |
#tf.summary.scalar('train_loss', loss) # estimators automatically create a loss summary | |
tf.summary.scalar('accuracy_train', acc[1]) | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
loss=loss, | |
train_op=train_op) | |
def input_fn(params, training=False, seed=42): | |
"""Create a Dataset from x, y data.""" | |
# synthetic data | |
x = tf.random_normal([params.n, params.h, params.w, 3], seed=seed) | |
y = tf.random_uniform([params.n], maxval=params.c, dtype=tf.int32, seed=seed) | |
# tf data | |
dataset = tf.data.Dataset.from_tensor_slices((x, y)) | |
if training: | |
dataset = dataset.shuffle(10000) | |
dataset = dataset.batch(params.batch_size) | |
if training: | |
dataset = dataset.repeat() | |
dataset = dataset.prefetch(params.buffer) | |
#iterator = dataset.make_one_shot_iterator() # not needed -- can return a Dataset directly | |
#x_batch, y_batch = iterator.get_next() | |
return dataset | |
def main(argv=None): | |
# synth data | |
#x, y = gen_synth_data(FLAGS, 42) | |
# create estimator | |
model_function = tf.contrib.estimator.replicate_model_fn( # multi-gpu | |
model_fn, loss_reduction=tf.losses.Reduction.MEAN) | |
estimator = tf.estimator.Estimator( | |
model_fn=model_function, | |
model_dir=FLAGS.model_dir, | |
params=FLAGS) | |
# train | |
tensors_to_log = { | |
#'train_loss': 'train_loss', # estimators automatically track loss | |
'accuracy': 'acc' | |
} | |
logging_hook = tf.train.LoggingTensorHook( | |
tensors=tensors_to_log, every_n_iter=FLAGS.log_interval) | |
hooks = [logging_hook] | |
# NOTE: we could also periodically run the evaluation by having an eval interval param and | |
# training for `FLAGS.steps/FLAGS.eval_interval` steps at a time | |
#estimator.train( | |
# input_fn=lambda: input_fn(FLAGS, training=True, seed=FLAGS.seed), | |
# steps=FLAGS.steps, | |
# hooks=hooks) | |
## eval | |
## NOTE: normally, we could structure the `input_fn` such that it properly returned a different | |
## dataset for training and validation, where the latter would not repeat. However, for this | |
## quick demo, we can just set the steps here. | |
#eval_results = estimator.evaluate( | |
# input_fn=lambda: input_fn(FLAGS, training=False, seed=FLAGS.seed)) | |
#print(eval_results) | |
# alternative train + eval setup where the eval is run based on time since the previous | |
# evaluation and the availability of new checkpoints. | |
# NOTE: this setup allows for distributed training with no code changes | |
train_spec = tf.estimator.TrainSpec( | |
input_fn=lambda: input_fn(FLAGS, training=True, seed=FLAGS.seed), | |
max_steps=FLAGS.steps, | |
hooks=hooks) | |
eval_spec = tf.estimator.EvalSpec( | |
input_fn=lambda: input_fn(FLAGS, training=False, seed=FLAGS.seed), | |
steps=None, | |
start_delay_secs=FLAGS.start_delay_secs, | |
throttle_secs=FLAGS.throttle_secs) | |
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) | |
# TODO: predict | |
if __name__ == '__main__': | |
# args | |
parser = argparse.ArgumentParser( | |
add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("-n", type=int, default=10000, help="num examples to generate") | |
parser.add_argument("-h", type=int, default=28, help="example height") | |
parser.add_argument("-w", type=int, default=28, help="example width") | |
parser.add_argument("-c", type=int, default=10, help="num classes") | |
parser.add_argument( | |
"--data_format", default="channels_last", choices=["channels_first", "channels_last"], | |
help="data format to use") | |
parser.add_argument("--batch_size", type=int, default=32, help="batch size") | |
parser.add_argument("--lr", type=float, default=0.001, help="learning rate") | |
parser.add_argument("--steps", type=int, default=10000, help="training steps") | |
parser.add_argument("--log_interval", type=int, default=100, help="how often to print the loss") | |
parser.add_argument("--buffer", type=int, default=100, help="size of prefetch buffer in batches") | |
parser.add_argument("--start_delay_secs", type=int, default=1, #120, # TODO: change this back | |
help="start evaluating after waiting for this many seconds") | |
parser.add_argument("--throttle_secs", type=int, default=10, #600, # TODO: change this back | |
help="do not re-evaluate unless the last evaluation was started at least this many seconds "\ | |
"ago. Of course, evaluation does not occur if no new checkpoints are available, hence, "\ | |
"this is the minimum") | |
parser.add_argument("--model_dir", default=None, | |
help="model directory; defaults to tmp dir; also, tensorboard can be started for this dir") | |
parser.add_argument("--seed", type=int, default=42, help="random seed") | |
parser.add_argument("--help", action='help', help="show this help message and exit") | |
FLAGS = parser.parse_args() # TODO: should `tf.flags` be used instead? | |
tf.logging.set_verbosity(tf.logging.INFO) | |
#tf.app.run(main=main, argv=[sys.argv[0]]) | |
main() # tf.app.run is basically useless... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment