Skip to content

Instantly share code, notes, and snippets.

@pannous
Last active July 30, 2017 16:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pannous/09665dfb41abe18382fc34e4db2c3533 to your computer and use it in GitHub Desktop.
Save pannous/09665dfb41abe18382fc34e4db2c3533 to your computer and use it in GitHub Desktop.
Implementation of DenseNet: Densely Connected Convolutional Networks https://arxiv.org/abs/1608.06993 in tensorflow
#!/usr/bin/python
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
mnist = read_data_sets("/tmp/data/", one_hot=True)
force_gpu = False
debug = False # histogram_summary ...
# _cpu='/cpu:0'
default_learning_rate=0.01
decay_step = 3000
decay_size = 0.95
# dropout = 0.6
dropout = None #None to disable dropout, non - zero number to enable dropout and set keep rate
batch_size=64
_cpu='/cpu:0'
tensorboard_logs = '/tmp/tensorboard-logs/'
# $(sleep 5; open http://0.0.0.0:6006) & tensorboard --debug --logdir=/tmp/tensorboard-logs/
class net():
def input_width(self,data):
return 28*28
def __init__(self,model,data,name=0,learning_rate=default_learning_rate,batch_size=batch_size):
# device = '/GPU:0' if gpu else '/cpu:0'
# device =None # auto
# print("Using device ",device)
# with tf.device(device):
if True:
self.session=sess=session=tf.Session()
# self.session=sess=session=tf.Session(config=tf.ConfigProto(log_device_placement=True))
self.model=model
self.data=data # assigned to self.x=net.input via train
self.batch_size=batch_size
self.layers=[]
self.last_width=self.input_width(data)
self.learning_rate=learning_rate
# if not name: name=model.__name__
# if name and os.path.exists(name):
# return self.load_model(name)
self.generate_model(model)
def generate_model(self,model, name=''):
if not model: return self
with tf.name_scope('state'):
self.keep_prob = tf.placeholder(tf.float32) # 1 for testing! else 1 - dropout
self.train_phase = tf.placeholder(tf.bool, name='train_phase')
self.global_step = tf.Variable(0) # dont set, feed or increment global_step, tensorflow will do it automatically
with tf.name_scope('data'):
n_input=28*28
n_classes=10
self.x = x = self.input = tf.placeholder(tf.float32, [None, n_input])
self.last_layer=x
self.y = y = self.target = tf.placeholder(tf.float32, [None, n_classes])
if not force_gpu: tf.image_summary("mnist", tf.reshape(self.x, [-1, 28, 28, 1], "mnist_images"))
with tf.name_scope('model'):
model(self)
if(self.last_width!=n_classes): self.classifier() # 10 classes auto
def add(self, layer):
self.layers.append(layer)
self.last_layer = layer
self.last_shape = layer.get_shape()
def reshape(self,shape):
self.last_layer = tf.reshape(self.last_layer,shape)
self.last_shape = shape
self.last_width = shape[-1]
def batchnorm(self):
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
with tf.name_scope('batchnorm') as scope:
input = self.last_layer
# mean, var = tf.nn.moments(input, axes=[0, 1, 2])
# self.batch_norm = tf.nn.batch_normalization(input, mean, var, offset=1, scale=1, variance_epsilon=1e-6)
# self.last_layer=self.batch_norm
train_op=batch_norm(input, is_training=True, center=False, updates_collections=None, scope=scope)
test_op=batch_norm(input, is_training=False, updates_collections=None, center=False,scope=scope, reuse=True)
self.add(tf.cond(self.train_phase,lambda:train_op,lambda:test_op))
# Fully connected layer
def dense(self, hidden=1024, depth=1, act=tf.nn.tanh, dropout=False, parent=-1): #
if parent==-1: parent=self.last_layer
shape = self.last_layer.get_shape()
if shape and len(shape)>2:
self.last_width= int(shape[1]*shape[2]*shape[3])
print("reshapeing ",shape,"to",self.last_width)
parent = tf.reshape(parent, [-1, self.last_width])
width = hidden
while depth>0:
with tf.name_scope('Dense_{:d}'.format(hidden)) as scope:
print("Dense ", self.last_width, width)
nr = len(self.layers)
if self.last_width == width:
U = closest_unitary(np.random.rand(self.last_width, width) / (self.last_width + width))
weights = tf.Variable(U, name="weights_dense_" + str(nr))
else:
weights = tf.Variable(tf.random_uniform([self.last_width, width], minval=-1. / width, maxval=1. / width), name="weights_dense")
bias = tf.Variable(tf.random_uniform([width],minval=-1./width,maxval=1./width), name="bias_dense")
dense1 = tf.matmul(parent, weights, name='dense_'+str(nr))+ bias
tf.histogram_summary('dense_'+str(nr),dense1)
tf.histogram_summary('weights_'+str(nr),weights)
tf.histogram_summary('bias_'+str(nr),bias)
tf.histogram_summary('dense_'+str(nr)+'/sparsity', tf.nn.zero_fraction(dense1))
tf.histogram_summary('weights_'+str(nr)+'/sparsity', tf.nn.zero_fraction(weights))
if act: dense1 = act(dense1)
# if norm: dense1 = self.norm(dense1,lsize=1) # SHAPE!
if dropout: dense1 = tf.nn.dropout(dense1, self.keep_prob)
self.layers.append(dense1)
self.last_layer = parent = dense1
self.last_width = width
depth=depth-1
self.last_shape=[-1,width] # dense
# Convolution Layer
def conv(self,shape,act=tf.nn.relu,pool=True,dropout=False,norm=True,name=None): # True why dropout bad in tensorflow??
with tf.name_scope('conv'):
print("input shape ",self.last_shape)
print("conv shape ",shape)
width=shape[-1]
# filters = tf.Variable(tf.random_uniform(shape, minval=-1. / width, maxval=1. / width), name="filters")
filters=tf.Variable(tf.random_normal(shape)) # positive weights help with image classification
_bias=tf.Variable(tf.random_normal([shape[-1]]))
conv1=tf.nn.bias_add(tf.nn.conv2d(self.last_layer,filter=filters, strides=[1, 1, 1, 1], padding='SAME'), _bias)
if debug: tf.histogram_summary('conv_' + str(len(self.layers)), conv1)
if act: conv1=act(conv1)
if pool: conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
if norm: conv1 = tf.nn.lrn(conv1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
if debug: tf.histogram_summary('norm_' + str(len(self.layers)), conv1)
if dropout: conv1 = tf.nn.dropout(conv1,self.keep_prob)
print("output shape ",conv1.get_shape())
self.add(conv1)
def classifier(self,classes=10):
""" Define loss and optimizer """
with tf.name_scope('prediction'):# prediction
if self.last_width!=classes:
# print("Automatically adding dense prediction")
self.dense(hidden=classes, act= False, dropout = False)
with tf.name_scope('classifier'):
y_=self.target
manual_cost_formula=False # True
if manual_cost_formula:
# prediction = y =self.last_layer=tf.nn.softmax(self.last_layer)
# self.cost = cross_entropy = -tf.reduce_sum(y_ * tf.log(y+ 1e-10)) # against NaN!
prediction = y = tf.nn.log_softmax(self.last_layer)
self.cost = cross_entropy = -tf.reduce_sum(y_ * y)
elif classes>100:
print("using sampled_softmax_loss")
y=prediction=self.last_layer
self.cost = tf.reduce_mean(tf.nn.sampled_softmax_loss(y, y_)) # for big vocab
else:
y = prediction = self.last_layer
self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_)) # prediction, target
with tf.device(_cpu):tf.scalar_summary('cost', self.cost)
# self.cost = tf.Print(self.cost , [self.cost ], "debug cost : ")
# learning_scheme=self.learning_rate
learning_scheme=tf.train.exponential_decay(self.learning_rate, self.global_step, decay_step, decay_size)
self.optimizer = tf.train.AdamOptimizer(learning_scheme).minimize(self.cost)
# Evaluate model
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(self.target, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
if not force_gpu: tf.scalar_summary('accuracy', self.accuracy)
def addLayer(self, nChannels, nOutChannels, dropout):
ident=self.last_layer
self.batchnorm()
# self.add(tf.nn.relu(ident)) # nChannels ?
self.conv([3,3,nChannels,nOutChannels], pool=False, dropout=dropout, norm=tf.nn.relu)#None
concat = tf.concat(3, [ident, self.last_layer])
print("concat ",concat.get_shape())
self.add(concat)
def addTransition(self, nChannels, nOutChannels, dropout):
self.batchnorm()
self.add(tf.nn.relu(self.last_layer))
self.conv([1,1, nChannels, nOutChannels], pool=True, dropout=dropout, norm=None) # pool (2, 2)
def buildDenseConv(self):
depth = 3 * 1 + 4
if (depth - 4) % 3 : raise Exception("Depth must be 3N + 4! (4,7,10,...) ") # # layers in each denseblock
N = (depth - 4) / 3
# channels before entering the first denseblock
# set it to be comparable with growth rate ?
nChannels = 16
growthRate = 12
self.conv([3,3,1,nChannels]) # prepare 16 filters with 3x3 view -> 28x28 just as input
for i in range(N): # 1st block
self.addLayer(nChannels, growthRate, dropout)
nChannels = nChannels + growthRate
self.addTransition(nChannels, nChannels, dropout)
for i in range(N): # 2nd block
self.addLayer(nChannels, growthRate, dropout)
nChannels = nChannels + growthRate
self.addTransition(nChannels, nChannels, dropout)
for i in range(N): # 3rd block
self.addLayer(nChannels, growthRate, dropout)
nChannels = nChannels + growthRate
# no transition, but densely connected layers.
self.batchnorm()
self.add(tf.nn.relu(self.last_layer))
# self.add(tf.nn.max_pool(self.last_layer, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='SAME'))
# self.reshape([-1,nChannels*4*4])
self.add(tf.nn.max_pool(self.last_layer, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME'))
self.reshape([-1, nChannels * 4 ])
def next_batch(self,batch_size=10):
return self.data.train.next_batch(batch_size)
def train(self,steps=-1,dropout=None,display_step=10,test_step=200): #epochs=-1,
steps = 9999999 if steps==-1 else steps
session=self.session
# with tf.device(_cpu):
# import tensorflow.contrib.layers as layers
# t = tf.verify_tensor_all_finite(t, msg)
tf.add_check_numerics_ops()
self.summaries = tf.merge_all_summaries()
self.summary_writer = tf.train.SummaryWriter(tensorboard_logs, session.graph) #
if not dropout:dropout=1. # keep all
x=self.x
y=self.y
keep_prob=self.keep_prob
session.run([tf.initialize_all_variables()])
step = 1 # show first
while step < steps:
# print("step %d \r" % step)# end=' ')
batch_xs, batch_ys = self.next_batch(self.batch_size)
# tf.train.shuffle_batch_join(example_list, batch_size, capacity=min_queue_size + batch_size * 16, min_queue_size)
# Fit training using batch data
feed_dict = {x: batch_xs, y: batch_ys, keep_prob: dropout, self.train_phase: True}
loss,_= session.run([self.cost,self.optimizer], feed_dict=feed_dict)
if step % test_step == 0: self.test(step)
if step % display_step == 0:
# Calculate batch accuracy, loss
feed = {x: batch_xs, y: batch_ys, keep_prob: 1., self.train_phase: False}
acc = session.run(self.accuracy, feed_dict=feed)
# acc , summary = session.run([self.accuracy,self.summaries], feed_dict=feed)
# self.summary_writer.add_summary(summary, step) # only test summaries for smoother curve
print("\rStep {:d} Loss= {:.6f} Accuracy= {:.3f}".format(step,loss,acc),end=' ')
if str(loss)=="nan": return print("\nLoss gradiant explosion, exiting!!!") #restore!
step += 1
print("\nOptimization Finished!")
self.test(step,number=10000) # final test
def inputs(self,data):
self.inputs, self.labels = load_data()#...)
def test(self,step,number=400):#256
session=sess=self.session
run_metadata = tf.RunMetadata()
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
# Calculate accuracy for 256 mnist test images
test_labels = self.data.test.labels[:number]
test_images = self.data.test.images[:number]
feed_dict = {self.x: test_images, self.y: test_labels, self.keep_prob: 1., self.train_phase:False}
accuracy,summary= self.session.run([self.accuracy, self.summaries], feed_dict=feed_dict)
# accuracy,summary = session.run([self.accuracy, self.summaries], feed_dict, run_options, run_metadata)
print('\t'*3+"Test Accuracy:",accuracy)
# self.summary_writer.add_run_metadata(run_metadata, 'step #%03d' % step)
self.summary_writer.add_summary(summary,global_step=step)
def dense(net): # best with lr ~0.001
# type: (layer.net) -> None
# net.batchnorm() # start lower, else no effect
# net.dense(400,act=None)# # ~95% we can do better:
net.dense(400, act=tf.nn.tanh)# 0.996 YAY only 0.985 on full set, Step 5000 flat
return # 0.957% without any model!!
def alex(net):
# type: (layer.net) -> None
print("Building Alex-net")
net.reshape(shape=[-1, 28, 28, 1]) # Reshape input pictures
# net.batchnorm()
net.conv([3, 3, 1, 64])
net.conv([3, 3, 64, 128])
net.conv([3, 3, 128, 256])
net.dense(1024,act=tf.nn.relu)
net.dense(1024,act=tf.nn.relu)
# OH, it does converge!!
def denseConv(net):
# type: (layer.net) -> None
print("Building dense-net")
net.reshape(shape=[-1, 28, 28, 1]) # Reshape input picture
# net.batchnorm()
# net.conv([3, 3, 1, 64])
net.buildDenseConv()
net.classifier() # 10 classes auto
# net=net(dense,data=mnist, learning_rate=0.01 )#,'mnist' baseline
# _net=net(alex,data=mnist, learning_rate=0.001)#,'mnist'
_net=net(model=denseConv,data=mnist, learning_rate=0.001)
# _net.train(50000,dropout=keep_rate ,display_step=1,test_step=1) # debug
_net.train(50000, dropout=dropout, display_step=1, test_step=20) # gpu
@pannous
Copy link
Author

pannous commented Sep 4, 2016

Implementation of DenseNet:
Densely Connected Convolutional Networks https://arxiv.org/abs/1608.06993
in tensorflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment