Skip to content

Instantly share code, notes, and snippets.

@phuocphn
Created January 6, 2019 05:41
Show Gist options
  • Save phuocphn/05cd9724123da7df2c6f77b1fcd00f69 to your computer and use it in GitHub Desktop.
Save phuocphn/05cd9724123da7df2c6f77b1fcd00f69 to your computer and use it in GitHub Desktop.
import os
import numpy as np
import tensorflow as tf
import tflearn
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import tensorlayer as tl
import math
tf.reset_default_graph()
# TF-Slim is a lightweight library for defining, training and evaluating complex models in TensorFlow.
# Components of tf-slim can be freely mixed with native tensorflow,
# as well as other frameworks, such as tf.contrib.learn.
weight_decay_rate = 5e-4
slim = tf.contrib.slim
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# Hyperparameters
LAMBDA = 0.5
CENTER_LOSS_ALPHA = 0.5
NUM_CLASSES = 10
NUM_EPOCH = 50
with tf.name_scope('input'):
input_images = tf.placeholder(tf.float32, shape=(None,28,28,1), name='input_images')
labels = tf.placeholder(tf.int64, shape=(None), name='labels')
global_step = tf.Variable(0, trainable=False, name='global_step')
# def inference(x):
# w_init_method = tf.contrib.layers.xavier_initializer(uniform=True)
# # define the network
# network = tl.layers.InputLayer(x, name='input')
# network = tl.layers.Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None,
# W_init=w_init_method, name='conv1_1')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn1')
# network = tl.layers.PReluLayer(network, name='prelu1')
# network = tl.layers.Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None,
# W_init=w_init_method, name='conv1_2')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn2')
# network = tl.layers.PReluLayer(network, name='prelu2')
# network = tl.layers.Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(2, 2), padding='SAME', act=None,
# W_init=w_init_method, name='conv1_3')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn3')
# network = tl.layers.PReluLayer(network, name='prelu3')
# network = tl.layers.Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None,
# W_init=w_init_method, name='conv2_1')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn4')
# network = tl.layers.PReluLayer(network, name='prelu4')
# network = tl.layers.Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None,
# W_init=w_init_method, name='conv2_2')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn5')
# network = tl.layers.PReluLayer(network, name='prelu5')
# network = tl.layers.Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(2, 2), padding='SAME', act=None,
# W_init=w_init_method, name='conv2_3')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn6')
# network = tl.layers.PReluLayer(network, name='prelu6')
# network = tl.layers.Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None,
# W_init=w_init_method, name='conv3_1')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn7')
# network = tl.layers.PReluLayer(network, name='prelu7')
# network = tl.layers.Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), padding='SAME', act=None,
# W_init=w_init_method, name='conv3_2')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn8')
# network = tl.layers.PReluLayer(network, name='prelu8')
# network = tl.layers.Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(2, 2), padding='SAME', act=None,
# W_init=w_init_method, name='conv3_3')
# network = tl.layers.BatchNormLayer(network, act=tf.identity, is_train=True, name='bn9')
# network = tl.layers.PReluLayer(network, name='prelu9')
# network = tl.layers.FlattenLayer(network, name='flatten')
# feature = tl.layers.DenseLayer(network, 2, name='fc1')
# network = tl.layers.PReluLayer(feature, name='prelu10')
# feature = tl.layers.DenseLayer(network, 10, name='fc2')
# return network, feature
def inference(input_images):
with slim.arg_scope([slim.conv2d], kernel_size=3, padding='SAME'):
with slim.arg_scope([slim.max_pool2d], kernel_size=2):
x = slim.conv2d(input_images, num_outputs=32, scope='conv1_1')
x = slim.conv2d(x, num_outputs=32, scope='conv1_2')
x = slim.max_pool2d(x, scope='pool1')
x = slim.conv2d(x, num_outputs=64, scope='conv2_1')
x = slim.conv2d(x, num_outputs=64, scope='conv2_2')
x = slim.max_pool2d(x, scope='pool2')
x = slim.conv2d(x, num_outputs=128, scope='conv3_1')
x = slim.conv2d(x, num_outputs=128, scope='conv3_2')
x = slim.max_pool2d(x, scope='pool3')
x = slim.flatten(x, scope='flatten')
feature = slim.fully_connected(x, num_outputs=2, activation_fn=None, scope='fc1')
x = tflearn.prelu(feature)
x = slim.fully_connected(x, num_outputs=10, activation_fn=None, scope='fc2')
return x, feature
# def build_network(input_images, labels, ratio=0.5):
# net, features = inference(input_images)
# with tf.name_scope('loss'):
# '''
# with tf.name_scope('center_loss'):
# center_loss, centers, centers_update_op = get_center_loss(features, labels, CENTER_LOSS_ALPHA, NUM_CLASSES)
# '''
# with tf.name_scope('softmax_loss'):
# softmax_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))
# with tf.name_scope('total_loss'):
# wd_loss = 0
# for weights in tl.layers.get_variables_with_name('W_conv2d', True, True):
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(weights)
# for W in tl.layers.get_variables_with_name('resnet_v1_50/E_DenseLayer/W', True, True):
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(W)
# for weights in tl.layers.get_variables_with_name('embedding_weights', True, True):
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(weights)
# for gamma in tl.layers.get_variables_with_name('gamma', True, True):
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(gamma)
# # for beta in tl.layers.get_variables_with_name('beta', True, True):
# # wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(beta)
# for alphas in tl.layers.get_variables_with_name('alphas', True, True):
# wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(alphas)
# # for bias in tl.layers.get_variables_with_name('resnet_v1_50/E_DenseLayer/b', True, True):
# # wd_
# total_loss = softmax_loss + wd_loss
# #total_loss = softmax_loss + ratio * center_loss
# with tf.name_scope('acc'):
# accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))
# with tf.name_scope('loss/'):
# tf.summary.scalar('CenterLoss', center_loss)
# tf.summary.scalar('SoftmaxLoss', softmax_loss)
# tf.summary.scalar('TotalLoss', total_loss)
# return net, features, total_loss, accuracy
def arcface_loss(embedding, labels, out_num, w_init=None, s=64., m=0.5):
'''
:param embedding: the input embedding vectors
:param labels: the input labels, the shape should be eg: (batch_size, 1)
:param s: scalar value default is 64
:param out_num: output class num
:param m: the margin value, default is 0.5
:return: the final cacualted output, this output is send into the tf.nn.softmax directly
'''
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = sin_m * m # issue 1
threshold = math.cos(math.pi - m)
with tf.variable_scope('arcface_loss'):
# inputs and weights norm
embedding_norm = tf.norm(embedding, axis=1, keep_dims=True)
embedding = tf.div(embedding, embedding_norm, name='norm_embedding')
weights = tf.get_variable(name='embedding_weights', shape=(embedding.get_shape().as_list()[-1], out_num),
initializer=w_init, dtype=tf.float32)
weights_norm = tf.norm(weights, axis=0, keep_dims=True)
weights = tf.div(weights, weights_norm, name='norm_weights')
# cos(theta+m)
cos_t = tf.matmul(embedding, weights, name='cos_t')
cos_t2 = tf.square(cos_t, name='cos_2')
sin_t2 = tf.subtract(1., cos_t2, name='sin_2')
sin_t = tf.sqrt(sin_t2, name='sin_t')
cos_mt = s * tf.subtract(tf.multiply(cos_t, cos_m), tf.multiply(sin_t, sin_m), name='cos_mt')
# this condition controls the theta+m should in range [0, pi]
# 0<=theta+m<=pi
# -m<=theta<=pi-m
cond_v = cos_t - threshold
cond = tf.cast(tf.nn.relu(cond_v, name='if_else'), dtype=tf.bool)
keep_val = s*(cos_t - mm)
cos_mt_temp = tf.where(cond, cos_mt, keep_val)
mask = tf.one_hot(labels, depth=out_num, name='one_hot_mask')
# mask = tf.squeeze(mask, 1)
inv_mask = tf.subtract(1., mask, name='inverse_mask')
s_cos_t = tf.multiply(s, cos_t, name='scalar_cos_t')
output = tf.add(tf.multiply(s_cos_t, inv_mask), tf.multiply(cos_mt_temp, mask), name='arcface_loss_output')
return output
net, feature = inference(input_images)
w_init_method = tf.contrib.layers.xavier_initializer(uniform=False)
logit = arcface_loss(embedding=net, labels=labels, w_init=w_init_method, out_num= 10)
inference_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit, labels=labels))
wd_loss = 0
for weights in tl.layers.get_variables_with_name('conv1_1', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('conv1_2', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('conv2_1', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('conv2_2', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('conv3_1', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('conv3_2', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('flatten', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('fc1', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
for weights in tl.layers.get_variables_with_name('fc2', True, True):
wd_loss += tf.contrib.layers.l2_regularizer(weight_decay_rate)(weights)
total_loss = inference_loss + wd_loss
# Prepare data
mnist = input_data.read_data_sets('/tmp/mnist', reshape=False)
# Optimizer
optimizer = tf.train.AdamOptimizer(0.001)
train_op = optimizer.minimize(total_loss, global_step=global_step)
# Session and Summary
summary_op = tf.summary.merge_all()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('/tmp/mnist_log', sess.graph)
# Training center loss
mean_data = np.mean(mnist.train.images, axis=0)
step = sess.run(global_step)
pred = tf.nn.softmax(logit)
acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), labels), dtype=tf.float32))
f = plt.figure(figsize=(16,9))
while step <= 15000:
batch_images, batch_labels = mnist.train.next_batch(128)
if batch_images is None or batch_labels is None:
exit()
_ , total_loss_res, in_loss, wd, train_acc = sess.run(
[inference_loss, total_loss, train_op, wd_loss, acc],
feed_dict={
input_images: batch_images - mean_data,
labels: batch_labels
})
step += 1
print ("Training step: ", step)
#writer.add_summary(summary_str, global_step=step)
#if True:
if step % 200 == 0:
vali_image = mnist.validation.images - mean_data
vali_acc = sess.run(
accuracy,
feed_dict={
input_images: vali_image,
labels: mnist.validation.labels
})
print(("step: {}, train_acc:{:.4f}, vali_acc:{:.4f}".
format(step, train_acc, vali_acc)))
feed_dict = {input_images: mnist.train.images[:5000]-mean_data}
feat = sess.run(feature, feed_dict=feed_dict)
print (feat)
# Draw plot
test_labels = mnist.train.labels[:5000]
plt.clf()
c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
'#ff00ff', '#990000', '#999900', '#009900', '#009999']
ax=plt.gca()
for i in range(10):
plt.plot(feat[test_labels==i,0].flatten(), feat[test_labels==i,1].flatten(), '.', c=c[i])
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc="lower right")
plt.title('Visualization of MNIST train_data at step: {}\n'.format(step), fontsize = 20)
plt.grid()
plt.savefig('snapshots/{}.png'.format(step))
plt.draw()
plt.pause(0.01)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment