Last active
September 30, 2016 23:07
-
-
Save LaurentMazare/0cc6318acc8eba41752a4c2ae392071f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
def unpickle(file): | |
import cPickle | |
fo = open(file, 'rb') | |
dict = cPickle.load(fo) | |
fo.close() | |
if 'data' in dict: | |
dict['data'] = dict['data'].reshape((-1, 3, 32, 32)).swapaxes(1, 3).swapaxes(1, 2).reshape(-1, 32*32*3) / 256. | |
return dict | |
def load_data_one(f): | |
batch = unpickle(f) | |
data = batch['data'] | |
labels = batch['labels'] | |
print "Loading %s: %d" % (f, len(data)) | |
return data, labels | |
def load_data(files, data_dir, label_count): | |
data, labels = load_data_one(data_dir + '/' + files[0]) | |
for f in files[1:]: | |
data_n, labels_n = load_data_one(data_dir + '/' + f) | |
data = np.append(data, data_n, axis=0) | |
labels = np.append(labels, labels_n, axis=0) | |
labels = np.array([ [ float(i == label) for i in xrange(label_count) ] for label in labels ]) | |
return data, labels | |
def weight_variable(shape): | |
initial = tf.truncated_normal(shape, stddev=0.01) | |
return tf.Variable(initial) | |
def bias_variable(shape): | |
initial = tf.constant(0.01, shape=shape) | |
return tf.Variable(initial) | |
def conv2d(input, in_features, out_features, kernel_size, stride): | |
W = weight_variable([ kernel_size, kernel_size, in_features, out_features ]) | |
b = bias_variable([ out_features ]) | |
return tf.nn.conv2d(input, W, [ 1, stride, stride, 1 ], padding='SAME') + b | |
def basic_block(input, in_features, out_features, stride, is_training): | |
if stride == 1: | |
shortcut = input | |
else: | |
shortcut = tf.nn.avg_pool(input, [ 1, stride, stride, 1 ], [1, stride, stride, 1 ], 'VALID') | |
shortcut = tf.pad(shortcut, [[0, 0], [0, 0], [0, 0], | |
[(out_features-in_features)//2, (out_features-in_features)//2]]) | |
current = conv2d(input, in_features, out_features, 3, stride) | |
current = tf.contrib.layers.batch_norm(current, is_training=is_training) | |
current = tf.nn.relu(current) | |
current = conv2d(current, out_features, out_features, 3, 1) | |
current = tf.contrib.layers.batch_norm(current, is_training=is_training) | |
# No final relu as per http://torch.ch/blog/2016/02/04/resnets.html | |
return current + shortcut | |
def block_stack(input, in_features, out_features, stride, depth, is_training): | |
current = basic_block(input, in_features, out_features, stride, is_training) | |
for _d in xrange(depth - 1): | |
current = basic_block(current, out_features, out_features, 1, is_training) | |
return current | |
def run_model(data, image_dim, label_count): | |
graph = tf.Graph() | |
with graph.as_default(): | |
xs = tf.placeholder("float", shape=[None, image_dim]) | |
ys = tf.placeholder("float", shape=[None, label_count]) | |
lr = tf.placeholder("float", shape=[]) | |
is_training = tf.placeholder("bool", shape=[]) | |
current = tf.reshape(xs, [ -1, 32, 32, 3 ]) | |
current = conv2d(current, 3, 16, 3, 1) | |
current = tf.nn.relu(current) | |
# dimension is 32x32x16 | |
current = block_stack(current, 16, 16, 1, 3, is_training) | |
current = block_stack(current, 16, 32, 2, 3, is_training) | |
# dimension is 16x16x32 | |
current = block_stack(current, 32, 64, 2, 3, is_training) | |
# dimension is 8x8x64 | |
current = tf.reduce_mean(current, reduction_indices=[1, 2], name="avg_pool") | |
# current = tf.nn.avg_pool(current, [ 1, 8, 8, 1 ], [ 1, 1, 1, 1 ], padding='SAME') | |
final_dim = 64 | |
current = tf.reshape(current, [ -1, final_dim ]) | |
Wfc = weight_variable([ final_dim, label_count ]) | |
bfc = bias_variable([ label_count ]) | |
ys_ = tf.nn.softmax( tf.matmul(current, Wfc) + bfc ) | |
cross_entropy = -tf.reduce_mean(ys * tf.log(ys_ + 1e-12)) | |
train_step = tf.train.MomentumOptimizer(lr, 0.9).minimize(cross_entropy) | |
correct_prediction = tf.equal(tf.argmax(ys_, 1), tf.argmax(ys, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) | |
with tf.Session(graph=graph) as session: | |
start_batch = 0 | |
batch_size = 128 | |
learning_rate = 0.1 | |
session.run(tf.initialize_all_variables()) | |
for epoch in xrange(1, 1+80000): | |
if epoch == 3000: learning_rate = 0.01 | |
if epoch == 4000: learning_rate = 0.001 | |
batch_data = data['train_data'][start_batch:start_batch+batch_size] | |
batch_labels = data['train_labels'][start_batch:start_batch+batch_size] | |
start_batch += batch_size | |
if start_batch + batch_size >= len(data['train_data']): start_batch = 0 | |
batch_res = session.run([ train_step, cross_entropy, accuracy ], | |
feed_dict = { xs: batch_data, ys: batch_labels, lr: learning_rate, is_training: True }) | |
if epoch % 50 == 0: print epoch, batch_res[1:] | |
if epoch % 250 == 0: | |
total_acc, total_ce = 0, 0 | |
for i in xrange(5): | |
ce, acc = session.run([ cross_entropy, accuracy ], | |
feed_dict = { xs: data['test_data'][i*1000:(i+1)*1000], ys: data['test_labels'][i*1000:(i+1)*1000], is_training: True }) | |
total_acc, total_ce = total_acc + acc, total_ce + ce | |
print epoch, batch_res[1:], total_acc / 5, total_ce / 5 | |
def run(): | |
data_dir = 'data' | |
image_size = 32 | |
image_dim = image_size * image_size * 3 | |
meta = unpickle(data_dir + '/batches.meta') | |
label_names = meta['label_names'] | |
label_count = len(label_names) | |
train_files = [ 'data_batch_%d' % d for d in xrange(1, 6) ] | |
data, labels = load_data(train_files, data_dir, label_count) | |
pi = np.random.permutation(len(data)) | |
data, labels = data[pi], labels[pi] | |
train_data, train_labels = data[5000:], labels[5000:] | |
test_data, test_labels = data[:5000], labels[:5000] | |
print "Train:", np.shape(train_data), np.shape(train_labels) | |
print "Test:", np.shape(test_data), np.shape(test_labels) | |
data = { 'train_data': train_data, | |
'train_labels': train_labels, | |
'test_data': test_data, | |
'test_labels': test_labels } | |
run_model(data, image_dim, label_count) | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment