Skip to content

Instantly share code, notes, and snippets.

@sjchoi86
Last active January 25, 2017 07:49
Show Gist options
  • Save sjchoi86/0b2283fba3ebeb80dd1ebe2bbd703ed0 to your computer and use it in GitHub Desktop.
Save sjchoi86/0b2283fba3ebeb80dd1ebe2bbd703ed0 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.framework import ops
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import cPickle as pkl
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import urllib
import os
import tarfile
import skimage
import skimage.io
import skimage.transform
%matplotlib inline
print ("PACKAGES LOADED")
filelink = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz'
filename = 'data/BSR_bsds500.tgz'
if os.path.isfile(filename):
print ("[%s] ALREADY EXISTS." % (filename))
else:
print ("DOWNLOADING %s ..." % (filename))
urllib.urlretrieve(filelink, filename)
print ("DONE")
def compose_image(digit, background):
"""Difference-blend a digit and a random patch from a background image."""
w, h, _ = background.shape
dw, dh, _ = digit.shape
x = np.random.randint(0, w - dw)
y = np.random.randint(0, h - dh)
bg = background[x:x+dw, y:y+dh]
return np.abs(bg - digit).astype(np.uint8)
def mnist_to_img(x):
"""Binarize MNIST digit and convert to RGB."""
x = (x > 0).astype(np.float32)
d = x.reshape([28, 28, 1]) * 255
return np.concatenate([d, d, d], 2)
def create_mnistm(X):
"""
Give an array of MNIST digits, blend random background patches to
build the MNIST-M dataset as described in
http://jmlr.org/papers/volume17/15-239/15-239.pdf
"""
X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
for i in range(X.shape[0]):
bg_img = rand.choice(background_data)
d = mnist_to_img(X[i])
d = compose_image(d, bg_img)
X_[i] = d
return X_
print ("FUNCTIONS READY")
mnistm_name = 'data/mnistm.pkl'
if os.path.isfile(mnistm_name):
print ("[%s] ALREADY EXISTS. " % (mnistm_name))
else:
mnist = input_data.read_data_sets('data')
# OPEN BSDS500
f = tarfile.open(filename)
train_files = []
for name in f.getnames():
if name.startswith('BSR/BSDS500/data/images/train/'):
train_files.append(name)
print ("WE HAVE [%d] TRAIN FILES" % (len(train_files)))
# GET BACKGROUND
print ("GET BACKGROUND FOR MNIST-M")
background_data = []
for name in train_files:
try:
fp = f.extractfile(name)
bg_img = skimage.io.imread(fp)
background_data.append(bg_img)
except:
continue
print ("WE HAVE [%d] BACKGROUND DATA" % (len(background_data)))
rand = np.random.RandomState(42)
print ("BUILDING TRAIN SET...")
train = create_mnistm(mnist.train.images)
print ("BUILDING TEST SET...")
test = create_mnistm(mnist.test.images)
print ("BUILDING VALIDATION SET...")
valid = create_mnistm(mnist.validation.images)
# SAVE
print ("SAVE MNISTM DATA TO %s" % (mnistm_name))
with open(mnistm_name, 'w') as f:
pkl.dump({ 'train': train, 'test': test, 'valid': valid }, f, -1)
print ("DONE")
print ("LOADING MNIST")
mnist = input_data.read_data_sets('data', one_hot=True)
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
mnist_train_label = mnist.train.labels
mnist_test_label = mnist.test.labels
print ("LOADING MNIST-M")
mnistm_name = 'data/mnistm.pkl'
mnistm = pkl.load(open(mnistm_name))
mnistm_train = mnistm['train']
mnistm_test = mnistm['test']
mnistm_valid = mnistm['valid']
mnistm_train_label = mnist_train_label
mnistm_test_label = mnist_test_label
print ("GENERATING DOMAIN DATA")
total_train = np.vstack([mnist_train, mnistm_train])
total_test = np.vstack([mnist_test, mnistm_test])
ntrain = mnist_train.shape[0]
ntest = mnist_test.shape[0]
total_train_domain = np.vstack([np.tile([1., 0.], [ntrain, 1]), np.tile([0., 1.], [ntrain, 1])])
total_test_domain = np.vstack([np.tile([1., 0.], [ntest, 1]), np.tile([0., 1.], [ntest, 1])])
n_total_train = total_train.shape[0]
n_total_test = total_test.shape[0]
# GET PIXEL MEAN
pixel_mean = np.vstack([mnist_train, mnistm_train]).mean((0, 1, 2))
# PLOT IMAGES
def imshow_grid(images, shape=[2, 8]):
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure()
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in range(size):
grid[i].axis('off')
grid[i].imshow(images[i])
plt.show()
imshow_grid(mnist_train, shape=[5, 10])
imshow_grid(mnistm_train, shape=[5, 10])
def print_npshape(x, name):
print ("SHAPE OF %s IS %s" % (name, x.shape,))
print_npshape(total_train, "total_train")
print_npshape(total_test, "total_test")
print_npshape(total_train_domain, "total_train_domain")
print_npshape(total_test_domain, "total_test_domain")
class FlipGradientBuilder(object):
def __init__(self):
self.num_calls = 0
def __call__(self, x, l=1.0):
grad_name = "FlipGradient%d" % self.num_calls
@ops.RegisterGradient(grad_name)
def _flip_gradients(op, grad):
return [tf.neg(grad) * l]
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": grad_name}):
y = tf.identity(x)
self.num_calls += 1
return y
flip_gradient = FlipGradientBuilder()
x = tf.placeholder(tf.uint8, [None, 28, 28, 3])
y = tf.placeholder(tf.float32, [None, 10])
d = tf.placeholder(tf.float32, [None, 2]) # DOMAIN LABEL
lr = tf.placeholder(tf.float32, [])
dw = tf.placeholder(tf.float32, [])
# FEATURE EXTRACTOR
def feat_ext_net(x, reuse=False):
with tf.variable_scope('feat_ext') as scope:
if reuse:
scope.reuse_variables()
x = (tf.cast(x, tf.float32) - pixel_mean) / 255.
net = slim.conv2d(x, 32, [5, 5], scope = 'conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.conv2d(net, 48, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
feat = slim.flatten(net, scope='flatten3')
return feat
# CLASS PREDICTION
def class_pred_net(feat, reuse=False):
with tf.variable_scope('class_pred') as scope:
if reuse:
scope.reuse_variables()
net = slim.fully_connected(feat, 100, scope='fc1')
net = slim.fully_connected(net, 100, scope='fc2')
net = slim.fully_connected(net, 10, activation_fn = None, scope='out')
return net
# DOMAIN PREDICTION
def domain_pred_net(feat, reuse=False):
with tf.variable_scope('domain_pred') as scope:
if reuse:
scope.reuse_variables()
feat = flip_gradient(feat, dw) # GRADIENT REVERSAL
net = slim.fully_connected(feat, 100, scope='fc1')
net = slim.fully_connected(net, 2, activation_fn = None, scope='out')
return net
feat_ext = feat_ext_net(x)
class_pred = class_pred_net(feat_ext)
domain_pred = domain_pred_net(feat_ext)
print ("MODEL READY")
class_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(class_pred, y))
domain_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(domain_pred, d))
optm_class = tf.train.MomentumOptimizer(lr, 0.9).minimize(class_loss)
optm_domain = tf.train.MomentumOptimizer(lr, 0.9).minimize(domain_loss)
accr_class = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(class_pred, 1), tf.arg_max(y, 1)), tf.float32))
accr_domain = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(class_pred, 1), tf.arg_max(d, 1)), tf.float32))
init = tf.global_variables_initializer()
print ("FUNCTIONS READY")
sess = tf.Session()
sess.run(init)
print ("SESSION OPENED")
# PARAMETERS
batch_size = 128
display_step = 50
training_epochs = 10
num_batch = int(ntrain/batch_size)+1
total_iter = training_epochs*num_batch
for epoch in range(training_epochs):
randpermlist = np.random.permutation(ntrain)
for i in range(num_batch):
curriter = epoch*num_batch + i
p = float(i) / total_iter
dw_val = 2. / (1. + np.exp(-10. * p)) - 1
lr_val = 0.01 / (1. + 10 * p)**0.75
# CLASS
randidx_class = randpermlist[i*batch_size:min((i+1)*batch_size, ntrain-1)]
batch_x_class = mnist_train[randidx_class]
batch_y_class = mnist_train_label[randidx_class, :]
feeds_class = {x:batch_x_class, y:batch_y_class, lr:lr_val, dw:dw_val}
_, lossclass_val = sess.run([optm_class, class_loss], feed_dict=feeds_class)
# DOMAIN
randidx_domain = np.random.permutation(n_total_train)[:batch_size]
batch_x_domain = total_train[randidx_domain]
batch_d_domain = total_train_domain[randidx_domain, :]
feeds_domain = {x:batch_x_domain, d:batch_d_domain, lr:lr_val, dw:dw_val}
_, lossdomain_val = sess.run([optm_domain, domain_loss], feed_dict=feeds_domain)
print ("[%d/%d] p: %.3e lossclass_val: %.3e, lossdomain_val: %.3e"
% (curriter, total_iter, p, lossclass_val, lossdomain_val))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment