Created
January 6, 2019 05:41
-
-
Save phuocphn/05cd9724123da7df2c6f77b1fcd00f69 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import tensorflow as tf | |
import tflearn | |
import matplotlib.pyplot as plt | |
from tensorflow.examples.tutorials.mnist import input_data | |
import tensorlayer as tl | |
import math | |
tf.reset_default_graph() | |
# TF-Slim is a lightweight library for defining, training and evaluating complex models in TensorFlow. | |
# Components of tf-slim can be freely mixed with native tensorflow, | |
# as well as other frameworks, such as tf.contrib.learn. | |
weight_decay_rate = 5e-4 | |
slim = tf.contrib.slim | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
# Hyperparameters | |
LAMBDA = 0.5 | |
CENTER_LOSS_ALPHA = 0.5 | |
NUM_CLASSES = 10 | |
NUM_EPOCH = 50 | |
with tf.name_scope('input'): | |
input_images = tf.placeholder(tf.float32, shape=(None,28,28,1), name='input_images') | |
labels = tf.placeholder(tf.int64, shape=(None), name='labels') | |
global_step = tf.Variable(0, trainable=False, name='global_step') | |
# def inference(x): | |
# w_init_method = tf.contrib.layers.xavier_initializer(uniform=True) | |
# # define the network | |
# network = tl.layers.InputLayer(x, name='input') | |
# network = tl.layers.Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv1_1') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn1') | |
# network = tl.layers.PReluLayer(network, name='prelu1') | |
# network = tl.layers.Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv1_2') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn2') | |
# network = tl.layers.PReluLayer(network, name='prelu2') | |
# network = tl.layers.Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(2, 2), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv1_3') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn3') | |
# network = tl.layers.PReluLayer(network, name='prelu3') | |
# network = tl.layers.Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv2_1') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn4') | |
# network = tl.layers.PReluLayer(network, name='prelu4') | |
# network = tl.layers.Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv2_2') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn5') | |
# network = tl.layers.PReluLayer(network, name='prelu5') | |
# network = tl.layers.Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(2, 2), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv2_3') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn6') | |
# network = tl.layers.PReluLayer(network, name='prelu6') | |
# network = tl.layers.Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv3_1') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn7') | |
# network = tl.layers.PReluLayer(network, name='prelu7') | |
# network = tl.layers.Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv3_2') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn8') | |
# network = tl.layers.PReluLayer(network, name='prelu8') | |
# network = tl.layers.Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(2, 2), padding='SAME', act=None, | |
# W_init=w_init_method, name='conv3_3') | |
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn9') | |
# network = tl.layers.PReluLayer(network, name='prelu9') | |
# network = tl.layers.FlattenLayer(network, name='flatten') | |
# feature = tl.layers.DenseLayer(network, 2, name='fc1') | |
# network = tl.layers.PReluLayer(feature, name='prelu10') | |
# feature = tl.layers.DenseLayer(network, 10, name='fc2') | |
# return network, feature | |
def inference(input_images): | |
with slim.arg_scope([slim.conv2d], kernel_size=3, padding='SAME'): | |
with slim.arg_scope([slim.max_pool2d], kernel_size=2): | |
x = slim.conv2d(input_images, num_outputs=32, scope='conv1_1') | |
x = slim.conv2d(x, num_outputs=32, scope='conv1_2') | |
x = slim.max_pool2d(x, scope='pool1') | |
x = slim.conv2d(x, num_outputs=64, scope='conv2_1') | |
x = slim.conv2d(x, num_outputs=64, scope='conv2_2') | |
x = slim.max_pool2d(x, scope='pool2') | |
x = slim.conv2d(x, num_outputs=128, scope='conv3_1') | |
x = slim.conv2d(x, num_outputs=128, scope='conv3_2') | |
x = slim.max_pool2d(x, scope='pool3') | |
x = slim.flatten(x, scope='flatten') | |
feature = slim.fully_connected(x, num_outputs=2, activation_fn=None, scope='fc1') | |
x = tflearn.prelu(feature) | |
x = slim.fully_connected(x, num_outputs=10, activation_fn=None, scope='fc2') | |
return x, feature | |
# def build_network(input_images, labels, ratio=0.5): | |
# net, features = inference(input_images) | |
# with tf.name_scope('loss'): | |
# ''' | |
# with tf.name_scope('center_loss'): | |
# center_loss, centers, centers_update_op = get_center_loss(features, labels, CENTER_LOSS_ALPHA, NUM_CLASSES) | |
# ''' | |
# with tf.name_scope('softmax_loss'): | |
# softmax_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)) | |
# with tf.name_scope('total_loss'): | |
# wd_loss = 0 | |
# for weights in tl.layers.get_variables_with_name('W_conv2d', True, True): | |
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(weights) | |
# for W in tl.layers.get_variables_with_name('resnet_v1_50/E_DenseLayer/W', True, True): | |
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(W) | |
# for weights in tl.layers.get_variables_with_name('embedding_weights', True, True): | |
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(weights) | |
# for gamma in tl.layers.get_variables_with_name('gamma', True, True): | |
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(gamma) | |
# # for beta in tl.layers.get_variables_with_name('beta', True, True): | |
# # wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(beta) | |
# for alphas in tl.layers.get_variables_with_name('alphas', True, True): | |
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(alphas) | |
# # for bias in tl.layers.get_variables_with_name('resnet_v1_50/E_DenseLayer/b', True, True): | |
# # wd_ | |
# total_loss = softmax_loss + wd_loss | |
# #total_loss = softmax_loss + ratio * center_loss | |
# with tf.name_scope('acc'): | |
# accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) | |
# with tf.name_scope('loss/'): | |
# tf.summary.scalar('CenterLoss', center_loss) | |
# tf.summary.scalar('SoftmaxLoss', softmax_loss) | |
# tf.summary.scalar('TotalLoss', total_loss) | |
# return net, features, total_loss, accuracy | |
def arcface_loss(embedding, labels, out_num, w_init=None, s=64., m=0.5): | |
''' | |
:param embedding: the input embedding vectors | |
:param labels: the input labels, the shape should be eg: (batch_size, 1) | |
:param s: scalar value default is 64 | |
:param out_num: output class num | |
:param m: the margin value, default is 0.5 | |
:return: the final cacualted output, this output is send into the tf.nn.softmax directly | |
''' | |
cos_m = math.cos(m) | |
sin_m = math.sin(m) | |
mm = sin_m * m # issue 1 | |
threshold = math.cos(math.pi - m) | |
with tf.variable_scope('arcface_loss'): | |
# inputs and weights norm | |
embedding_norm = tf.norm(embedding, axis=1, keep_dims=True) | |
embedding = tf.div(embedding, embedding_norm, name='norm_embedding') | |
weights = tf.get_variable(name='embedding_weights', shape=(embedding.get_shape().as_list()[-1], out_num), | |
initializer=w_init, dtype=tf.float32) | |
weights_norm = tf.norm(weights, axis=0, keep_dims=True) | |
weights = tf.div(weights, weights_norm, name='norm_weights') | |
# cos(theta+m) | |
cos_t = tf.matmul(embedding, weights, name='cos_t') | |
cos_t2 = tf.square(cos_t, name='cos_2') | |
sin_t2 = tf.subtract(1., cos_t2, name='sin_2') | |
sin_t = tf.sqrt(sin_t2, name='sin_t') | |
cos_mt = s * tf.subtract(tf.multiply(cos_t, cos_m), tf.multiply(sin_t, sin_m), name='cos_mt') | |
# this condition controls the theta+m should in range [0, pi] | |
# 0<=theta+m<=pi | |
# -m<=theta<=pi-m | |
cond_v = cos_t - threshold | |
cond = tf.cast(tf.nn.relu(cond_v, name='if_else'), dtype=tf.bool) | |
keep_val = s*(cos_t - mm) | |
cos_mt_temp = tf.where(cond, cos_mt, keep_val) | |
mask = tf.one_hot(labels, depth=out_num, name='one_hot_mask') | |
# mask = tf.squeeze(mask, 1) | |
inv_mask = tf.subtract(1., mask, name='inverse_mask') | |
s_cos_t = tf.multiply(s, cos_t, name='scalar_cos_t') | |
output = tf.add(tf.multiply(s_cos_t, inv_mask), tf.multiply(cos_mt_temp, mask), name='arcface_loss_output') | |
return output | |
net, feature = inference(input_images) | |
w_init_method = tf.contrib.layers.xavier_initializer(uniform=False) | |
logit = arcface_loss(embedding=net, labels=labels, w_init=w_init_method, out_num= 10) | |
inference_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit, labels=labels)) | |
wd_loss = 0 | |
for weights in tl.layers.get_variables_with_name('conv1_1', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('conv1_2', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('conv2_1', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('conv2_2', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('conv3_1', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('conv3_2', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('flatten', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('fc1', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
for weights in tl.layers.get_variables_with_name('fc2', True, True): | |
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights) | |
total_loss = inference_loss + wd_loss | |
# Prepare data | |
mnist = input_data.read_data_sets('/tmp/mnist', reshape=False) | |
# Optimizer | |
optimizer = tf.train.AdamOptimizer(0.001) | |
train_op = optimizer.minimize(total_loss, global_step=global_step) | |
# Session and Summary | |
summary_op = tf.summary.merge_all() | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
writer = tf.summary.FileWriter('/tmp/mnist_log', sess.graph) | |
# Training center loss | |
mean_data = np.mean(mnist.train.images, axis=0) | |
step = sess.run(global_step) | |
pred = tf.nn.softmax(logit) | |
acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), labels), dtype=tf.float32)) | |
f = plt.figure(figsize=(16,9)) | |
while step <= 15000: | |
batch_images, batch_labels = mnist.train.next_batch(128) | |
if batch_images is None or batch_labels is None: | |
exit() | |
_ , total_loss_res, in_loss, wd, train_acc = sess.run( | |
[inference_loss, total_loss, train_op, wd_loss, acc], | |
feed_dict={ | |
input_images: batch_images - mean_data, | |
labels: batch_labels | |
}) | |
step += 1 | |
print ("Training step: ", step) | |
#writer.add_summary(summary_str, global_step=step) | |
#if True: | |
if step % 200 == 0: | |
vali_image = mnist.validation.images - mean_data | |
vali_acc = sess.run( | |
accuracy, | |
feed_dict={ | |
input_images: vali_image, | |
labels: mnist.validation.labels | |
}) | |
print(("step: {}, train_acc:{:.4f}, vali_acc:{:.4f}". | |
format(step, train_acc, vali_acc))) | |
feed_dict = {input_images: mnist.train.images[:5000]-mean_data} | |
feat = sess.run(feature, feed_dict=feed_dict) | |
print (feat) | |
# Draw plot | |
test_labels = mnist.train.labels[:5000] | |
plt.clf() | |
c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', | |
'#ff00ff', '#990000', '#999900', '#009900', '#009999'] | |
ax=plt.gca() | |
for i in range(10): | |
plt.plot(feat[test_labels==i,0].flatten(), feat[test_labels==i,1].flatten(), '.', c=c[i]) | |
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc="lower right") | |
plt.title('Visualization of MNIST train_data at step: {}\n'.format(step), fontsize = 20) | |
plt.grid() | |
plt.savefig('snapshots/{}.png'.format(step)) | |
plt.draw() | |
plt.pause(0.01) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment