Last active
January 25, 2017 07:49
-
-
Save sjchoi86/0b2283fba3ebeb80dd1ebe2bbd703ed0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow.contrib.slim as slim | |
from tensorflow.python.framework import ops | |
from tensorflow.examples.tutorials.mnist import input_data | |
import numpy as np | |
import cPickle as pkl | |
from sklearn.manifold import TSNE | |
import matplotlib.pyplot as plt | |
import urllib | |
import os | |
import tarfile | |
import skimage | |
import skimage.io | |
import skimage.transform | |
%matplotlib inline | |
print ("PACKAGES LOADED") | |
filelink = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz' | |
filename = 'data/BSR_bsds500.tgz' | |
if os.path.isfile(filename): | |
print ("[%s] ALREADY EXISTS." % (filename)) | |
else: | |
print ("DOWNLOADING %s ..." % (filename)) | |
urllib.urlretrieve(filelink, filename) | |
print ("DONE") | |
def compose_image(digit, background): | |
"""Difference-blend a digit and a random patch from a background image.""" | |
w, h, _ = background.shape | |
dw, dh, _ = digit.shape | |
x = np.random.randint(0, w - dw) | |
y = np.random.randint(0, h - dh) | |
bg = background[x:x+dw, y:y+dh] | |
return np.abs(bg - digit).astype(np.uint8) | |
def mnist_to_img(x): | |
"""Binarize MNIST digit and convert to RGB.""" | |
x = (x > 0).astype(np.float32) | |
d = x.reshape([28, 28, 1]) * 255 | |
return np.concatenate([d, d, d], 2) | |
def create_mnistm(X): | |
""" | |
Give an array of MNIST digits, blend random background patches to | |
build the MNIST-M dataset as described in | |
http://jmlr.org/papers/volume17/15-239/15-239.pdf | |
""" | |
X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8) | |
for i in range(X.shape[0]): | |
bg_img = rand.choice(background_data) | |
d = mnist_to_img(X[i]) | |
d = compose_image(d, bg_img) | |
X_[i] = d | |
return X_ | |
print ("FUNCTIONS READY") | |
mnistm_name = 'data/mnistm.pkl' | |
if os.path.isfile(mnistm_name): | |
print ("[%s] ALREADY EXISTS. " % (mnistm_name)) | |
else: | |
mnist = input_data.read_data_sets('data') | |
# OPEN BSDS500 | |
f = tarfile.open(filename) | |
train_files = [] | |
for name in f.getnames(): | |
if name.startswith('BSR/BSDS500/data/images/train/'): | |
train_files.append(name) | |
print ("WE HAVE [%d] TRAIN FILES" % (len(train_files))) | |
# GET BACKGROUND | |
print ("GET BACKGROUND FOR MNIST-M") | |
background_data = [] | |
for name in train_files: | |
try: | |
fp = f.extractfile(name) | |
bg_img = skimage.io.imread(fp) | |
background_data.append(bg_img) | |
except: | |
continue | |
print ("WE HAVE [%d] BACKGROUND DATA" % (len(background_data))) | |
rand = np.random.RandomState(42) | |
print ("BUILDING TRAIN SET...") | |
train = create_mnistm(mnist.train.images) | |
print ("BUILDING TEST SET...") | |
test = create_mnistm(mnist.test.images) | |
print ("BUILDING VALIDATION SET...") | |
valid = create_mnistm(mnist.validation.images) | |
# SAVE | |
print ("SAVE MNISTM DATA TO %s" % (mnistm_name)) | |
with open(mnistm_name, 'w') as f: | |
pkl.dump({ 'train': train, 'test': test, 'valid': valid }, f, -1) | |
print ("DONE") | |
print ("LOADING MNIST") | |
mnist = input_data.read_data_sets('data', one_hot=True) | |
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255 | |
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3) | |
mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255 | |
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3) | |
mnist_train_label = mnist.train.labels | |
mnist_test_label = mnist.test.labels | |
print ("LOADING MNIST-M") | |
mnistm_name = 'data/mnistm.pkl' | |
mnistm = pkl.load(open(mnistm_name)) | |
mnistm_train = mnistm['train'] | |
mnistm_test = mnistm['test'] | |
mnistm_valid = mnistm['valid'] | |
mnistm_train_label = mnist_train_label | |
mnistm_test_label = mnist_test_label | |
print ("GENERATING DOMAIN DATA") | |
total_train = np.vstack([mnist_train, mnistm_train]) | |
total_test = np.vstack([mnist_test, mnistm_test]) | |
ntrain = mnist_train.shape[0] | |
ntest = mnist_test.shape[0] | |
total_train_domain = np.vstack([np.tile([1., 0.], [ntrain, 1]), np.tile([0., 1.], [ntrain, 1])]) | |
total_test_domain = np.vstack([np.tile([1., 0.], [ntest, 1]), np.tile([0., 1.], [ntest, 1])]) | |
n_total_train = total_train.shape[0] | |
n_total_test = total_test.shape[0] | |
# GET PIXEL MEAN | |
pixel_mean = np.vstack([mnist_train, mnistm_train]).mean((0, 1, 2)) | |
# PLOT IMAGES | |
def imshow_grid(images, shape=[2, 8]): | |
from mpl_toolkits.axes_grid1 import ImageGrid | |
fig = plt.figure() | |
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05) | |
size = shape[0] * shape[1] | |
for i in range(size): | |
grid[i].axis('off') | |
grid[i].imshow(images[i]) | |
plt.show() | |
imshow_grid(mnist_train, shape=[5, 10]) | |
imshow_grid(mnistm_train, shape=[5, 10]) | |
def print_npshape(x, name): | |
print ("SHAPE OF %s IS %s" % (name, x.shape,)) | |
print_npshape(total_train, "total_train") | |
print_npshape(total_test, "total_test") | |
print_npshape(total_train_domain, "total_train_domain") | |
print_npshape(total_test_domain, "total_test_domain") | |
class FlipGradientBuilder(object): | |
def __init__(self): | |
self.num_calls = 0 | |
def __call__(self, x, l=1.0): | |
grad_name = "FlipGradient%d" % self.num_calls | |
@ops.RegisterGradient(grad_name) | |
def _flip_gradients(op, grad): | |
return [tf.neg(grad) * l] | |
g = tf.get_default_graph() | |
with g.gradient_override_map({"Identity": grad_name}): | |
y = tf.identity(x) | |
self.num_calls += 1 | |
return y | |
flip_gradient = FlipGradientBuilder() | |
x = tf.placeholder(tf.uint8, [None, 28, 28, 3]) | |
y = tf.placeholder(tf.float32, [None, 10]) | |
d = tf.placeholder(tf.float32, [None, 2]) # DOMAIN LABEL | |
lr = tf.placeholder(tf.float32, []) | |
dw = tf.placeholder(tf.float32, []) | |
# FEATURE EXTRACTOR | |
def feat_ext_net(x, reuse=False): | |
with tf.variable_scope('feat_ext') as scope: | |
if reuse: | |
scope.reuse_variables() | |
x = (tf.cast(x, tf.float32) - pixel_mean) / 255. | |
net = slim.conv2d(x, 32, [5, 5], scope = 'conv1') | |
net = slim.max_pool2d(net, [2, 2], scope='pool1') | |
net = slim.conv2d(net, 48, [5, 5], scope='conv2') | |
net = slim.max_pool2d(net, [2, 2], scope='pool2') | |
feat = slim.flatten(net, scope='flatten3') | |
return feat | |
# CLASS PREDICTION | |
def class_pred_net(feat, reuse=False): | |
with tf.variable_scope('class_pred') as scope: | |
if reuse: | |
scope.reuse_variables() | |
net = slim.fully_connected(feat, 100, scope='fc1') | |
net = slim.fully_connected(net, 100, scope='fc2') | |
net = slim.fully_connected(net, 10, activation_fn = None, scope='out') | |
return net | |
# DOMAIN PREDICTION | |
def domain_pred_net(feat, reuse=False): | |
with tf.variable_scope('domain_pred') as scope: | |
if reuse: | |
scope.reuse_variables() | |
feat = flip_gradient(feat, dw) # GRADIENT REVERSAL | |
net = slim.fully_connected(feat, 100, scope='fc1') | |
net = slim.fully_connected(net, 2, activation_fn = None, scope='out') | |
return net | |
feat_ext = feat_ext_net(x) | |
class_pred = class_pred_net(feat_ext) | |
domain_pred = domain_pred_net(feat_ext) | |
print ("MODEL READY") | |
class_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(class_pred, y)) | |
domain_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(domain_pred, d)) | |
optm_class = tf.train.MomentumOptimizer(lr, 0.9).minimize(class_loss) | |
optm_domain = tf.train.MomentumOptimizer(lr, 0.9).minimize(domain_loss) | |
accr_class = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(class_pred, 1), tf.arg_max(y, 1)), tf.float32)) | |
accr_domain = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(class_pred, 1), tf.arg_max(d, 1)), tf.float32)) | |
init = tf.global_variables_initializer() | |
print ("FUNCTIONS READY") | |
sess = tf.Session() | |
sess.run(init) | |
print ("SESSION OPENED") | |
# PARAMETERS | |
batch_size = 128 | |
display_step = 50 | |
training_epochs = 10 | |
num_batch = int(ntrain/batch_size)+1 | |
total_iter = training_epochs*num_batch | |
for epoch in range(training_epochs): | |
randpermlist = np.random.permutation(ntrain) | |
for i in range(num_batch): | |
curriter = epoch*num_batch + i | |
p = float(i) / total_iter | |
dw_val = 2. / (1. + np.exp(-10. * p)) - 1 | |
lr_val = 0.01 / (1. + 10 * p)**0.75 | |
# CLASS | |
randidx_class = randpermlist[i*batch_size:min((i+1)*batch_size, ntrain-1)] | |
batch_x_class = mnist_train[randidx_class] | |
batch_y_class = mnist_train_label[randidx_class, :] | |
feeds_class = {x:batch_x_class, y:batch_y_class, lr:lr_val, dw:dw_val} | |
_, lossclass_val = sess.run([optm_class, class_loss], feed_dict=feeds_class) | |
# DOMAIN | |
randidx_domain = np.random.permutation(n_total_train)[:batch_size] | |
batch_x_domain = total_train[randidx_domain] | |
batch_d_domain = total_train_domain[randidx_domain, :] | |
feeds_domain = {x:batch_x_domain, d:batch_d_domain, lr:lr_val, dw:dw_val} | |
_, lossdomain_val = sess.run([optm_domain, domain_loss], feed_dict=feeds_domain) | |
print ("[%d/%d] p: %.3e lossclass_val: %.3e, lossdomain_val: %.3e" | |
% (curriter, total_iter, p, lossclass_val, lossdomain_val)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment