Created
December 26, 2016 13:20
-
-
Save yenchenlin/22c2963412d3340ad053ad24750dd7b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy | |
from sklearn.datasets import fetch_mldata | |
FLAGS = tf.app.flags.FLAGS | |
tf.app.flags.DEFINE_integer('seed', 1, "initial random seed") | |
tf.app.flags.DEFINE_string('layer_sizes', '784-1200-600-300-150-10', "layer sizes") | |
tf.app.flags.DEFINE_integer('batch_size', 100, "the number of examples in a batch") | |
tf.app.flags.DEFINE_integer('num_epochs', 100, "the number of epochs for training") | |
tf.app.flags.DEFINE_float('learning_rate', 0.002, "initial leanring rate") | |
tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.9, "learning rate decay factor") | |
tf.app.flags.DEFINE_float('bn_stats_decay_factor', 0.99, | |
"moving average decay factor for stats on batch normalization") | |
tf.app.flags.DEFINE_float('epsilon', 2.0, "norm length for (virtual) adversarial training ") | |
tf.app.flags.DEFINE_float('balance_factor', 1.0, | |
"balance factor between neg. log-likelihood and (virtual) adversarial loss") | |
tf.app.flags.DEFINE_integer('num_power_iterations', 1, "the number of power iterations") | |
tf.app.flags.DEFINE_float('xi', 1e-6, "small constant for finite difference") | |
tf.app.flags.DEFINE_bool('stop_grad_of_1st_kl_arg', True, "stop gradient of 1st kl arg") | |
def get_normalized_vector(d): | |
d /= (1e-12 + tf.reduce_max(tf.abs(d), 1, keep_dims=True)) | |
d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), 1, keep_dims=True)) | |
return d | |
def kl_divergence_with_logit(q_logit, p_logit, name="kl_divergence"): | |
q = tf.nn.softmax(q_logit) | |
p = tf.nn.softmax(p_logit) | |
kl = tf.reduce_mean(tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1)) | |
return tf.identity(kl, name=name) | |
def generate_virtual_adversarial_perturbation(x, logit): | |
d = tf.random_normal(shape=tf.shape(x)) | |
for _ in range(FLAGS.num_power_iterations): | |
d = FLAGS.xi * get_normalized_vector(d) | |
logit_d = forward(x+d, update_batch_stats=False) | |
kl = kl_divergence_with_logit(logit, logit_d) | |
grad = tf.gradients(kl, [d])[0] | |
d = tf.stop_gradient(grad) | |
return FLAGS.epsilon * get_normalized_vector(d) | |
def virtual_adversarial_loss(x, logit, name="vat_loss"): | |
r_vadv = generate_virtual_adversarial_perturbation(x, logit) | |
if FLAGS.stop_grad_of_1st_kl_arg: | |
logit = tf.stop_gradient(logit) | |
loss = kl_divergence_with_logit(logit, forward( | |
x + r_vadv, update_batch_stats=False)) | |
return tf.identity(loss, name=name) | |
def batch_normalization(x, dim, is_training=True, update_batch_stats=True, name="bn"): | |
mean = tf.reduce_mean(x, 0, keep_dims=True) | |
var = tf.reduce_mean(tf.pow(x - mean, 2.0), 0, keep_dims=True) | |
avg_mean = tf.get_variable( | |
name=name + "_mean", | |
shape=[1, dim], | |
initializer=tf.constant_initializer(0.0), | |
trainable=False | |
) | |
avg_var = tf.get_variable( | |
name=name + "_var", | |
shape=[1, dim], | |
initializer=tf.constant_initializer(1.0), | |
trainable=False | |
) | |
gamma = tf.get_variable( | |
name=name + "_gamma", | |
shape=[1, dim], | |
initializer=tf.constant_initializer(1.0), | |
) | |
beta = tf.get_variable( | |
name=name + "_beta", | |
shape=[1, dim], | |
initializer=tf.constant_initializer(0.0), | |
) | |
if is_training: | |
avg_mean_assign_op = tf.no_op() | |
avg_var_assign_op = tf.no_op() | |
if update_batch_stats: | |
avg_mean_assign_op = tf.assign(avg_mean, FLAGS.bn_stats_decay_factor * avg_mean | |
+ (1 - FLAGS.bn_stats_decay_factor) * mean) | |
avg_var_assign_op = tf.assign(avg_var, | |
FLAGS.bn_stats_decay_factor * avg_var | |
+ (FLAGS.batch_size / (FLAGS.batch_size - 1)) | |
* (1 - FLAGS.bn_stats_decay_factor) * var) | |
with tf.control_dependencies([avg_mean_assign_op, avg_var_assign_op]): | |
ret = gamma * (x - mean) / tf.sqrt(1e-6 + var) + beta | |
else: | |
ret = gamma * (x - avg_mean) / tf.sqrt(1e-6 + avg_var) + beta | |
return ret | |
def forward(x, is_training=True, update_batch_stats=True): | |
layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32) | |
num_layers = len(layer_sizes) - 1 | |
h = x | |
for l, (d1, d2) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): | |
W = tf.get_variable( | |
name="W_l" + str(l), | |
shape=[d1, d2], | |
initializer=tf.random_normal_initializer(mean=0.0, stddev=numpy.sqrt(1.0 / d1)) | |
) | |
b = tf.get_variable( | |
name="b_l" + str(l), | |
shape=[d2], | |
initializer=tf.constant_initializer(0.0) | |
) | |
lin = tf.matmul(h, W) + b | |
lin = batch_normalization( | |
lin, d2, | |
is_training=is_training, | |
update_batch_stats=update_batch_stats, | |
name="BN_l" + str(l)) | |
if l == num_layers - 1: | |
h = lin | |
else: | |
h = tf.nn.relu(lin) | |
return h | |
def accuracy(logit, y): | |
pred = tf.argmax(logit, 1) | |
true = tf.argmax(y, 1) | |
return tf.reduce_mean(tf.to_float(tf.equal(pred, true))) | |
def build_training_graph(): | |
global_step = tf.Variable( | |
name="global_step", | |
dtype=tf.int32, | |
initial_value=0, | |
trainable=False, | |
) | |
layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32) | |
x = tf.placeholder(tf.float32, [None, layer_sizes[0]]) | |
y = tf.placeholder(tf.float32, [None, layer_sizes[-1]]) | |
logit = forward(x) | |
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logit, y)) | |
acc = accuracy(logit, y) | |
scope = tf.get_variable_scope() | |
scope.reuse_variables() | |
vat_loss = virtual_adversarial_loss(x, logit) | |
loss += FLAGS.balance_factor * vat_loss | |
lr = tf.Variable( | |
name="learning_rate", | |
initial_value=FLAGS.learning_rate, | |
trainable=False | |
) | |
lr_update_op = tf.assign(lr, lr * FLAGS.learning_rate_decay_factor) | |
opt = tf.train.AdamOptimizer(learning_rate=lr) | |
tvars = tf.trainable_variables() | |
grads_and_vars = opt.compute_gradients(loss, tvars) | |
train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) | |
return x, y, loss, acc, train_op, lr_update_op, global_step | |
def build_eval_graph(): | |
layer_sizes = numpy.asarray(FLAGS.layer_sizes.split('-'), numpy.int32) | |
x = tf.placeholder(tf.float32, [None, layer_sizes[0]]) | |
y = tf.placeholder(tf.float32, [None, layer_sizes[-1]]) | |
logit = forward(x, is_training=False) | |
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logit, y)) | |
acc = accuracy(logit, y) | |
return x, y, loss, acc | |
def to_one_hot(x, K): | |
N = x.shape[0] | |
x = numpy.asarray(x, numpy.int32) | |
nx = numpy.zeros((N, K)) | |
for n in xrange(N): | |
nx[n, x[n]] = 1.0 | |
return nx | |
def main(_): | |
rng = numpy.random.RandomState(FLAGS.seed) | |
with tf.Graph().as_default(): | |
with tf.variable_scope("MNIST_NN") as scope: | |
tf.set_random_seed(seed=rng.randint(1234)) | |
x, y, loss, acc, train_op, lr_update_op, global_step = build_training_graph() | |
scope.reuse_variables() | |
x_test, y_test, loss_test, acc_test = build_eval_graph() | |
init = tf.initialize_all_variables() | |
# sv = tf.train.Supervisor(logdir='/tmp/mydir') | |
# sess = sv.prepare_or_wait_for_session(0) | |
sess = tf.Session() | |
sess.run(init) | |
mnist = fetch_mldata('MNIST original') | |
data = mnist['data'] / 255. | |
target = to_one_hot(mnist['target'], 10) | |
x_tr = data[:60000] | |
x_ts = data[60000:] | |
y_tr = target[:60000] | |
y_ts = target[60000:] | |
def shuffle(x, y): | |
assert x.shape[0] == y.shape[0] | |
rand_ix = rng.permutation(x.shape[0]) | |
return x[rand_ix], y[rand_ix] | |
print "Training..." | |
for ep in range(FLAGS.batch_size): | |
x_tr, y_tr = shuffle(x_tr, y_tr) | |
sum_loss = 0 | |
sum_acc = 0 | |
n_iter_per_epoch = 60000/FLAGS.batch_size | |
for i in range(n_iter_per_epoch): | |
batch_xs = x_tr[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] | |
batch_ys = y_tr[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] | |
sess.run([train_op, global_step], feed_dict={x: batch_xs, y: batch_ys}) | |
batch_loss, batch_acc = sess.run([loss, acc], feed_dict={x: batch_xs, y: batch_ys}) | |
sum_loss += batch_loss | |
sum_acc += batch_acc | |
print "Epoch:", ep, "CE_loss_train:", sum_loss / n_iter_per_epoch,\ | |
"ACC_train:", sum_acc / n_iter_per_epoch | |
sum_loss = 0 | |
sum_acc = 0 | |
n_iter_per_epoch = 10000/FLAGS.batch_size | |
for i in range(n_iter_per_epoch): | |
batch_xs = x_ts[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] | |
batch_ys = y_ts[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] | |
batch_loss, batch_acc = sess.run([loss_test, acc_test], feed_dict={x_test: batch_xs, y_test: batch_ys}) | |
sum_loss += batch_loss | |
sum_acc += batch_acc | |
print "Epoch:", ep, "CE_loss_test:", sum_loss / n_iter_per_epoch, \ | |
"ACC_test:", sum_acc / n_iter_per_epoch | |
print sess.run(lr_update_op) | |
if __name__ == "__main__": | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment