Skip to content

Instantly share code, notes, and snippets.

@yenchenlin
Created December 26, 2016 13:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yenchenlin/22c2963412d3340ad053ad24750dd7b4 to your computer and use it in GitHub Desktop.
Save yenchenlin/22c2963412d3340ad053ad24750dd7b4 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy
from sklearn.datasets import fetch_mldata
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('seed', 1, "initial random seed")
tf.app.flags.DEFINE_string('layer_sizes', '784-1200-600-300-150-10', "layer sizes")
tf.app.flags.DEFINE_integer('batch_size', 100, "the number of examples in a batch")
tf.app.flags.DEFINE_integer('num_epochs', 100, "the number of epochs for training")
tf.app.flags.DEFINE_float('learning_rate', 0.002, "initial leanring rate")
tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.9, "learning rate decay factor")
tf.app.flags.DEFINE_float('bn_stats_decay_factor', 0.99,
"moving average decay factor for stats on batch normalization")
tf.app.flags.DEFINE_float('epsilon', 2.0, "norm length for (virtual) adversarial training ")
tf.app.flags.DEFINE_float('balance_factor', 1.0,
"balance factor between neg. log-likelihood and (virtual) adversarial loss")
tf.app.flags.DEFINE_integer('num_power_iterations', 1, "the number of power iterations")
tf.app.flags.DEFINE_float('xi', 1e-6, "small constant for finite difference")
tf.app.flags.DEFINE_bool('stop_grad_of_1st_kl_arg', True, "stop gradient of 1st kl arg")
def get_normalized_vector(d):
d /= (1e-12 + tf.reduce_max(tf.abs(d), 1, keep_dims=True))
d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), 1, keep_dims=True))
return d
def kl_divergence_with_logit(q_logit, p_logit, name="kl_divergence"):
q = tf.nn.softmax(q_logit)
p = tf.nn.softmax(p_logit)
kl = tf.reduce_mean(tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1))
return tf.identity(kl, name=name)
def generate_virtual_adversarial_perturbation(x, logit):
d = tf.random_normal(shape=tf.shape(x))
for _ in range(FLAGS.num_power_iterations):
d = FLAGS.xi * get_normalized_vector(d)
logit_d = forward(x+d, update_batch_stats=False)
kl = kl_divergence_with_logit(logit, logit_d)
grad = tf.gradients(kl, [d])[0]
d = tf.stop_gradient(grad)
return FLAGS.epsilon * get_normalized_vector(d)
def virtual_adversarial_loss(x, logit, name="vat_loss"):
r_vadv = generate_virtual_adversarial_perturbation(x, logit)
if FLAGS.stop_grad_of_1st_kl_arg:
logit = tf.stop_gradient(logit)
loss = kl_divergence_with_logit(logit, forward(
x + r_vadv, update_batch_stats=False))
return tf.identity(loss, name=name)
def batch_normalization(x, dim, is_training=True, update_batch_stats=True, name="bn"):
mean = tf.reduce_mean(x, 0, keep_dims=True)
var = tf.reduce_mean(tf.pow(x - mean, 2.0), 0, keep_dims=True)
avg_mean = tf.get_variable(
name=name + "_mean",
shape=[1, dim],
initializer=tf.constant_initializer(0.0),
trainable=False
)
avg_var = tf.get_variable(
name=name + "_var",
shape=[1, dim],
initializer=tf.constant_initializer(1.0),
trainable=False
)
gamma = tf.get_variable(
name=name + "_gamma",
shape=[1, dim],
initializer=tf.constant_initializer(1.0),
)
beta = tf.get_variable(
name=name + "_beta",
shape=[1, dim],
initializer=tf.constant_initializer(0.0),
)
if is_training:
avg_mean_assign_op = tf.no_op()
avg_var_assign_op = tf.no_op()
if update_batch_stats:
avg_mean_assign_op = tf.assign(avg_mean, FLAGS.bn_stats_decay_factor * avg_mean
+ (1 - FLAGS.bn_stats_decay_factor) * mean)
avg_var_assign_op = tf.assign(avg_var,
FLAGS.bn_stats_decay_factor * avg_var
+ (FLAGS.batch_size / (FLAGS.batch_size - 1))
* (1 - FLAGS.bn_stats_decay_factor) * var)
with tf.control_dependencies([avg_mean_assign_op, avg_var_assign_op]):
ret = gamma * (x - mean) / tf.sqrt(1e-6 + var) + beta
else:
ret = gamma * (x - avg_mean) / tf.sqrt(1e-6 + avg_var) + beta
return ret
def forward(x, is_training=True, update_batch_stats=True):
layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32)
num_layers = len(layer_sizes) - 1
h = x
for l, (d1, d2) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
W = tf.get_variable(
name="W_l" + str(l),
shape=[d1, d2],
initializer=tf.random_normal_initializer(mean=0.0, stddev=numpy.sqrt(1.0 / d1))
)
b = tf.get_variable(
name="b_l" + str(l),
shape=[d2],
initializer=tf.constant_initializer(0.0)
)
lin = tf.matmul(h, W) + b
lin = batch_normalization(
lin, d2,
is_training=is_training,
update_batch_stats=update_batch_stats,
name="BN_l" + str(l))
if l == num_layers - 1:
h = lin
else:
h = tf.nn.relu(lin)
return h
def accuracy(logit, y):
pred = tf.argmax(logit, 1)
true = tf.argmax(y, 1)
return tf.reduce_mean(tf.to_float(tf.equal(pred, true)))
def build_training_graph():
global_step = tf.Variable(
name="global_step",
dtype=tf.int32,
initial_value=0,
trainable=False,
)
layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32)
x = tf.placeholder(tf.float32, [None, layer_sizes[0]])
y = tf.placeholder(tf.float32, [None, layer_sizes[-1]])
logit = forward(x)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logit, y))
acc = accuracy(logit, y)
scope = tf.get_variable_scope()
scope.reuse_variables()
vat_loss = virtual_adversarial_loss(x, logit)
loss += FLAGS.balance_factor * vat_loss
lr = tf.Variable(
name="learning_rate",
initial_value=FLAGS.learning_rate,
trainable=False
)
lr_update_op = tf.assign(lr, lr * FLAGS.learning_rate_decay_factor)
opt = tf.train.AdamOptimizer(learning_rate=lr)
tvars = tf.trainable_variables()
grads_and_vars = opt.compute_gradients(loss, tvars)
train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)
return x, y, loss, acc, train_op, lr_update_op, global_step
def build_eval_graph():
layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32)
x = tf.placeholder(tf.float32, [None, layer_sizes[0]])
y = tf.placeholder(tf.float32, [None, layer_sizes[-1]])
logit = forward(x, is_training=False)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logit, y))
acc = accuracy(logit, y)
return x, y, loss, acc
def to_one_hot(x, K):
N = x.shape[0]
x = numpy.asarray(x, numpy.int32)
nx = numpy.zeros((N, K))
for n in xrange(N):
nx[n, x[n]] = 1.0
return nx
def main(_):
rng = numpy.random.RandomState(FLAGS.seed)
with tf.Graph().as_default():
with tf.variable_scope("MNIST_NN") as scope:
tf.set_random_seed(seed=rng.randint(1234))
x, y, loss, acc, train_op, lr_update_op, global_step = build_training_graph()
scope.reuse_variables()
x_test, y_test, loss_test, acc_test = build_eval_graph()
init = tf.initialize_all_variables()
# sv = tf.train.Supervisor(logdir='/tmp/mydir')
# sess = sv.prepare_or_wait_for_session(0)
sess = tf.Session()
sess.run(init)
mnist = fetch_mldata('MNIST original')
data = mnist['data'] / 255.
target = to_one_hot(mnist['target'], 10)
x_tr = data[:60000]
x_ts = data[60000:]
y_tr = target[:60000]
y_ts = target[60000:]
def shuffle(x, y):
assert x.shape[0] == y.shape[0]
rand_ix = rng.permutation(x.shape[0])
return x[rand_ix], y[rand_ix]
print "Training..."
for ep in range(FLAGS.batch_size):
x_tr, y_tr = shuffle(x_tr, y_tr)
sum_loss = 0
sum_acc = 0
n_iter_per_epoch = 60000/FLAGS.batch_size
for i in range(n_iter_per_epoch):
batch_xs = x_tr[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
batch_ys = y_tr[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
sess.run([train_op, global_step], feed_dict={x: batch_xs, y: batch_ys})
batch_loss, batch_acc = sess.run([loss, acc], feed_dict={x: batch_xs, y: batch_ys})
sum_loss += batch_loss
sum_acc += batch_acc
print "Epoch:", ep, "CE_loss_train:", sum_loss / n_iter_per_epoch,\
"ACC_train:", sum_acc / n_iter_per_epoch
sum_loss = 0
sum_acc = 0
n_iter_per_epoch = 10000/FLAGS.batch_size
for i in range(n_iter_per_epoch):
batch_xs = x_ts[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
batch_ys = y_ts[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
batch_loss, batch_acc = sess.run([loss_test, acc_test], feed_dict={x_test: batch_xs, y_test: batch_ys})
sum_loss += batch_loss
sum_acc += batch_acc
print "Epoch:", ep, "CE_loss_test:", sum_loss / n_iter_per_epoch, \
"ACC_test:", sum_acc / n_iter_per_epoch
print sess.run(lr_update_op)
if __name__ == "__main__":
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment