Skip to content

Instantly share code, notes, and snippets.

@dusenberrymw
Last active March 22, 2018 04:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dusenberrymw/384fde2fe91a8443b423ec5967062cdb to your computer and use it in GitHub Desktop.
Save dusenberrymw/384fde2fe91a8443b423ec5967062cdb to your computer and use it in GitHub Desktop.
Example ImageNet-style resnet training scenario with synthetic data
"""Example ImageNet-style resnet training scenario with synthetic data.
Author: Mike Dusenberry
"""
import argparse
import numpy as np
import tensorflow as tf
# args
parser = argparse.ArgumentParser(add_help=False) # to allow for `-h` as a flag for height
parser.add_argument("-n", type=int, default=175, help="num examples to generate")
parser.add_argument("-h", type=int, default=224, help="example height")
parser.add_argument("-w", type=int, default=224, help="example width")
parser.add_argument("-k", type=int, default=1000, help="num classes")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--steps", type=int, default=1000, help="training steps")
parser.add_argument("--log_interval", type=int, default=100, help="how often to print the loss")
parser.add_argument("--buffer", type=int, default=100, help="size of prefetch buffer in batches")
parser.add_argument("--help", action='help', help="show this help message and exit")
FLAGS = parser.parse_args()
# synthetic data
x = np.random.randn(FLAGS.n, FLAGS.h, FLAGS.w, 3).astype(np.float32)
y = np.eye(FLAGS.k)[np.random.randint(FLAGS.k, size=FLAGS.n)].astype(np.float32)
# tf data
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(100)
dataset = dataset.batch(FLAGS.batch_size)
dataset = dataset.repeat(-1)
dataset = dataset.prefetch(FLAGS.buffer)
iterator = dataset.make_one_shot_iterator()
x_batch, y_batch = iterator.get_next()
# tf model
resnet = tf.keras.applications.ResNet50(
include_top=False, input_tensor=x_batch, input_shape=(FLAGS.h, FLAGS.w, 3))
out = tf.layers.flatten(resnet.output)
logits = tf.layers.dense(out, FLAGS.k)
# tf loss
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_batch, logits=logits))
# tf optimizer
opt = tf.train.AdamOptimizer(FLAGS.lr)
train_op = opt.minimize(loss)
# saver
saver = tf.train.Saver()
# init
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# train loop
for i in range(FLAGS.steps):
feed_dict = {tf.keras.backend.learning_phase(): True}
if i % FLAGS.log_interval == 0:
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
print("loss: {}".format(loss_value))
else:
sess.run(train_op, feed_dict=feed_dict)
#saver.save(sess, "model.ckpt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment