Skip to content

Instantly share code, notes, and snippets.

@nlgranger
Created February 27, 2018 11:12
Show Gist options
  • Save nlgranger/fc0f5559a619641f66b0b279d4ae4121 to your computer and use it in GitHub Desktop.
Save nlgranger/fc0f5559a619641f66b0b279d4ae4121 to your computer and use it in GitHub Desktop.
tensorflow training script that does not generate graph.pbtxt
import numpy as np
import tensorflow as tf
from procnet.utils import sensible_dir
from experiments.SpatialTransformerNetwork.model import spatial_transformer
from experiments.SpatialTransformerNetwork import dataset
def build_model_layers(inputs, nclasses, is_training):
initializer = tf.contrib.layers.xavier_initializer()
inputs = tf.layers.Input((28, 28, 1), tensor=inputs)
net = spatial_transformer(inputs, grid_dims=(1, 2))
with tf.name_scope("conv1"):
net = tf.layers.conv2d(
net, 10, (5, 5),
activation=None,
kernel_initializer=initializer,
bias_initializer=tf.zeros_initializer)
tf.layers.max_pooling2d(net, 2, 2)
net = tf.nn.relu(net)
with tf.name_scope("conv2"):
net = tf.layers.conv2d(
net, 20, (5, 5),
activation=None,
kernel_initializer=initializer,
bias_initializer=tf.zeros_initializer)
if is_training:
net = tf.nn.dropout(net, 0.9)
net = tf.layers.max_pooling2d(net, 2, 2)
net = tf.nn.relu(net)
net = tf.layers.flatten(net)
with tf.name_scope("dense1"):
net = tf.layers.dense(
net, 50,
activation=None,
kernel_initializer=initializer,
bias_initializer=tf.zeros_initializer)
if is_training:
net = tf.nn.dropout(net, 0.9)
net = tf.nn.relu(net)
with tf.name_scope("dense2"):
logits = tf.layers.dense(
net, nclasses,
activation=None,
kernel_initializer=initializer,
bias_initializer=tf.zeros_initializer)
return logits
def train_input_fn(batch_size):
features, labels = dataset.train_x[:50000], dataset.train_y[:50000]
features = features.reshape((-1, 28, 28, 1)).astype(np.float32)
features = (features / 255 - 0.1307) / 0.3081
labels = labels.reshape((-1,)).astype(np.int32)
pairs = tf.data.Dataset.from_tensor_slices((
features, labels))
pairs = pairs.shuffle(len(labels)).batch(batch_size).repeat()
return pairs.make_one_shot_iterator().get_next()
def eval_input_fn(batch_size):
features, labels = dataset.train_x[50000:], dataset.train_y[50000:]
features = features.reshape((-1, 28, 28, 1)).astype(np.float32)
features = (features / 255 - 0.1307) / 0.3081
labels = labels.reshape((-1,)).astype(np.int32)
pairs = tf.data.Dataset.from_tensor_slices((
features, labels))
pairs = pairs.batch(batch_size).repeat()
return pairs.make_one_shot_iterator().get_next()
def main():
batch_size = 64
with tf.Session() as sess:
train_inputs, train_labels = train_input_fn(batch_size)
val_inputs, val_labels = eval_input_fn(batch_size)
with tf.variable_scope("inference", reuse=False):
train_logits = build_model_layers(train_inputs, 10, True)
train_loss = tf.losses.sparse_softmax_cross_entropy(
train_labels, train_logits)
tf.summary.scalar('loss', train_loss)
optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(
train_loss, global_step=tf.train.get_global_step())
with tf.variable_scope("inference", reuse=True):
val_logits = build_model_layers(val_inputs, 10, False)
val_loss = tf.losses.sparse_softmax_cross_entropy(
val_labels, val_logits)
tf.summary.scalar('loss', val_loss)
predicted_classes = tf.argmax(val_logits, 1)
accuracy_op = tf.metrics.accuracy(val_labels, predicted_classes)
model_dir = sensible_dir(
"experiments/SpatialTransformerNetwork/checkpoints", "run_")
train_writer = tf.summary.FileWriter(model_dir + "/train", sess.graph)
eval_writer = tf.summary.FileWriter(model_dir + "/eval")
merged = tf.summary.merge_all()
train_writer.flush()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
for step in range(50000 // batch_size * 1):
summary, _ = sess.run([merged, train_op])
train_writer.add_summary(summary, step)
if (step + 1) % 25 == 0:
summary, acc = sess.run(
[merged, accuracy_op])
eval_writer.add_summary(summary, step)
train_writer.close()
eval_writer.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment