Skip to content

Instantly share code, notes, and snippets.

@dusenberrymw
Created June 28, 2018 05:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dusenberrymw/683a14212b665b4018a73561e698b7a5 to your computer and use it in GitHub Desktop.
Save dusenberrymw/683a14212b665b4018a73561e698b7a5 to your computer and use it in GitHub Desktop.
Example ImageNet-style resnet training scenario with synthetic data and using the tf.Estimator API
"""Example ImageNet-style resnet training scenario with synthetic data.
Author: Mike Dusenberry
"""
import argparse
import sys
import tensorflow as tf
#class Model(object):
# """Model"""
#
# def __init__(self, data_format, params):
# """Create a model.
#
# Args:
# data_format: Either 'channels_first' or 'channels_last'.
# """
# self.conv1 = tf.layers.Conv2D(params.c, 3, padding='same', data_format=data_format)
# self.fc = tf.layers.Dense
def model_fn(features, labels, mode, params):
"""The model function."""
# tf model
# NOTE: it would be better to have a model class where the `call` function accepts a `training`
# boolean parameter that could be used for dropout, batch norm, etc. Then, this model code, as
# well as the loss and metrics, would go inside each of the `if` blocks.
x = features
x = tf.layers.conv2d(x, 8, 3, padding='same', data_format=params.data_format)
x = tf.layers.flatten(x)
logits = tf.layers.dense(x, FLAGS.c)
preds = tf.argmax(logits, axis=1)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'classes': preds,
'probs': tf.nn.softmax(logits),
}
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
# tf loss
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
# metrics
acc = tf.metrics.accuracy(labels=labels, predictions=preds)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
eval_metric_ops={
'accuracy': acc
})
if mode == tf.estimator.ModeKeys.TRAIN:
# tf optimizer
opt = tf.train.AdamOptimizer(FLAGS.lr)
opt = tf.contrib.estimator.TowerOptimizer(opt)
grads_and_vars = opt.compute_gradients(loss)
train_op = opt.apply_gradients(grads_and_vars, tf.train.get_or_create_global_step())
# we will log these to the console
#tf.identity(loss, 'train_loss') # estimators automatically track loss
tf.identity(acc[1], 'acc')
# tensorboard summaries
#tf.summary.scalar('train_loss', loss) # estimators automatically create a loss summary
tf.summary.scalar('accuracy_train', acc[1])
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op)
def input_fn(params, training=False, seed=42):
"""Create a Dataset from x, y data."""
# synthetic data
x = tf.random_normal([params.n, params.h, params.w, 3], seed=seed)
y = tf.random_uniform([params.n], maxval=params.c, dtype=tf.int32, seed=seed)
# tf data
dataset = tf.data.Dataset.from_tensor_slices((x, y))
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.batch(params.batch_size)
if training:
dataset = dataset.repeat()
dataset = dataset.prefetch(params.buffer)
#iterator = dataset.make_one_shot_iterator() # not needed -- can return a Dataset directly
#x_batch, y_batch = iterator.get_next()
return dataset
def main(argv=None):
# synth data
#x, y = gen_synth_data(FLAGS, 42)
# create estimator
model_function = tf.contrib.estimator.replicate_model_fn( # multi-gpu
model_fn, loss_reduction=tf.losses.Reduction.MEAN)
estimator = tf.estimator.Estimator(
model_fn=model_function,
model_dir=FLAGS.model_dir,
params=FLAGS)
# train
tensors_to_log = {
#'train_loss': 'train_loss', # estimators automatically track loss
'accuracy': 'acc'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=FLAGS.log_interval)
hooks = [logging_hook]
# NOTE: we could also periodically run the evaluation by having an eval interval param and
# training for `FLAGS.steps/FLAGS.eval_interval` steps at a time
#estimator.train(
# input_fn=lambda: input_fn(FLAGS, training=True, seed=FLAGS.seed),
# steps=FLAGS.steps,
# hooks=hooks)
## eval
## NOTE: normally, we could structure the `input_fn` such that it properly returned a different
## dataset for training and validation, where the latter would not repeat. However, for this
## quick demo, we can just set the steps here.
#eval_results = estimator.evaluate(
# input_fn=lambda: input_fn(FLAGS, training=False, seed=FLAGS.seed))
#print(eval_results)
# alternative train + eval setup where the eval is run based on time since the previous
# evaluation and the availability of new checkpoints.
# NOTE: this setup allows for distributed training with no code changes
train_spec = tf.estimator.TrainSpec(
input_fn=lambda: input_fn(FLAGS, training=True, seed=FLAGS.seed),
max_steps=FLAGS.steps,
hooks=hooks)
eval_spec = tf.estimator.EvalSpec(
input_fn=lambda: input_fn(FLAGS, training=False, seed=FLAGS.seed),
steps=None,
start_delay_secs=FLAGS.start_delay_secs,
throttle_secs=FLAGS.throttle_secs)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# TODO: predict
if __name__ == '__main__':
# args
parser = argparse.ArgumentParser(
add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-n", type=int, default=10000, help="num examples to generate")
parser.add_argument("-h", type=int, default=28, help="example height")
parser.add_argument("-w", type=int, default=28, help="example width")
parser.add_argument("-c", type=int, default=10, help="num classes")
parser.add_argument(
"--data_format", default="channels_last", choices=["channels_first", "channels_last"],
help="data format to use")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--steps", type=int, default=10000, help="training steps")
parser.add_argument("--log_interval", type=int, default=100, help="how often to print the loss")
parser.add_argument("--buffer", type=int, default=100, help="size of prefetch buffer in batches")
parser.add_argument("--start_delay_secs", type=int, default=1, #120, # TODO: change this back
help="start evaluating after waiting for this many seconds")
parser.add_argument("--throttle_secs", type=int, default=10, #600, # TODO: change this back
help="do not re-evaluate unless the last evaluation was started at least this many seconds "\
"ago. Of course, evaluation does not occur if no new checkpoints are available, hence, "\
"this is the minimum")
parser.add_argument("--model_dir", default=None,
help="model directory; defaults to tmp dir; also, tensorboard can be started for this dir")
parser.add_argument("--seed", type=int, default=42, help="random seed")
parser.add_argument("--help", action='help', help="show this help message and exit")
FLAGS = parser.parse_args() # TODO: should `tf.flags` be used instead?
tf.logging.set_verbosity(tf.logging.INFO)
#tf.app.run(main=main, argv=[sys.argv[0]])
main() # tf.app.run is basically useless...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment