Skip to content

Instantly share code, notes, and snippets.

@harusametime
Last active December 18, 2018 08:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save harusametime/d0a92d715fe0610b4b8fbd11f6d42359 to your computer and use it in GitHub Desktop.
Save harusametime/d0a92d715fe0610b4b8fbd11f6d42359 to your computer and use it in GitHub Desktop.
import os
import argparse
import tensorflow as tf
# global variables
width = 28
height = 28
n_class = 10
# tfrecord parser
def _parse_function(example_proto):
features = {
'height': tf.FixedLenFeature((), tf.int64, default_value=height),
'width': tf.FixedLenFeature((), tf.int64, default_value=width),
'depth': tf.FixedLenFeature((), tf.int64, default_value=1),
'label': tf.FixedLenFeature((), tf.int64, default_value=0),
'image_raw': tf.FixedLenFeature((), tf.string, default_value="")}
parsed_features = tf.parse_single_example(example_proto, features)
images = parsed_features["image_raw"]
labels = parsed_features["label"]
images = tf.decode_raw(images, tf.uint8)
images.set_shape([784])
images = tf.cast(images, tf.float32) * (1. / 255)
labels = tf.cast(labels, tf.int32)
labels = tf.one_hot(labels, n_class)
return images, labels
if __name__ =='__main__':
parser = argparse.ArgumentParser()
# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--optimizer', type=str, default='sgd', metavar='O',
help='optimizer (default: sgd, alternative: [sgd, adam])')
parser.add_argument('--adam-lr', type=float, default=0.001, metavar='ALR',
help='learning rate for adam (default: 0.001)')
parser.add_argument('--sgd-lr', type=float, default=0.01, metavar='SLR',
help='learning rate for SGD (default: 0.01)')
# input data and model directories
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--data-dir', type=str, default=os.environ.get('SM_INPUT_DIR'))
args, _ = parser.parse_known_args()
data_dir = os.path.join(args.data_dir, "data", "training")
model_dir = args.model_dir
training_filename = [os.path.join(data_dir, "train.tfrecords")]
validation_filename = [os.path.join(data_dir, "validation.tfrecords")]
n_epoch = args.epochs
batch_size = args.batch_size
# training data
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function) # Parse the record into tensors.
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()
# Data fed into network
inputs = tf.placeholder(tf.float32, [None, width*height])
labels = tf.placeholder(tf.int32, shape=[None])
# model definition
hidden1 = tf.layers.dense(inputs=inputs, units=1024, activation=tf.nn.relu)
hidden2 = tf.layers.dense(inputs=hidden1, units=512, activation=tf.nn.relu)
logits = tf.layers.dense(inputs=hidden2, units=n_class)
prob = tf.nn.softmax(logits)
# optimization & loss
labels = tf.one_hot(labels, n_class)
loss = tf.losses.softmax_cross_entropy(labels, logits, reduction=tf.losses.Reduction.MEAN)
global_step = tf.Variable(0, trainable=False, name='global_step')
if args.optimizer == "sgd":
optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.sgd_lr)
elif args.optimizer == "adam":
optimizer = tf.train.AdamOptimizer(learning_rate=args.adam_lr)
training_op = optimizer.minimize(loss, global_step=global_step)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(labels, axis=1), tf.argmax(logits, axis=1)), tf.float32))
# Initialization
init = tf.global_variables_initializer()
saver = tf.train.Saver()
# Training loop
with tf.Session() as sess:
init.run()
# Initialize the variable
sess.run(global_step.initializer)
for i in range(n_epoch):
# Training
sess.run(iterator.initializer, feed_dict={filenames: training_filename})
train_loss = 0
train_accuracy = 0
train_iter = 0
while True:
try:
batch = sess.run(next_batch)
_, _loss, _accuracy = sess.run([training_op, loss, accuracy],
feed_dict={inputs: batch[0],labels: batch[1]})
train_loss += _loss
train_accuracy += _accuracy
train_iter += 1
except tf.errors.OutOfRangeError:
break
# Validation
sess.run(iterator.initializer, feed_dict={filenames: validation_filename})
val_loss = 0
val_accuracy = 0
val_iter = 0
while True:
try:
batch = sess.run(next_batch)
_, _loss, _accuracy = sess.run([training_op, loss, accuracy],
feed_dict={inputs: batch[0],labels: batch[1]})
val_loss += _loss
val_accuracy += _accuracy
val_iter += 1
except tf.errors.OutOfRangeError:
break
# Result for each epoch
avg_train_loss = train_loss/train_iter
avg_train_accuracy = train_accuracy/train_iter
avg_val_loss = val_loss/val_iter
avg_val_accuracy = val_accuracy/val_iter
print("Epoch: {}, Training loss: {}, Training accuracy: {}, Validation loss: {}, Validation accuracy: {}".format(i,
avg_train_loss, avg_train_accuracy, avg_val_loss, avg_val_accuracy))
# Save the model in a predifined directory, which are uploaded to S3.
# The saved model must be the files that tensorflow serving can read.
# A simple way is to use "simple save".
tf.saved_model.simple_save(
sess,
os.path.join(args.model_dir, 'model/1'),
inputs={'input_image': inputs},
outputs={'predictions': prob})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment