Skip to content

Instantly share code, notes, and snippets.

@daviddao
Last active June 28, 2017 05:04
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daviddao/c05cc5913c815c1e2154 to your computer and use it in GitHub Desktop.
Save daviddao/c05cc5913c815c1e2154 to your computer and use it in GitHub Desktop.
Tensorflow Spatial Transformer
import tensorflow as tf
def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
"""Spatial Transformer Layer
Implements a spatial transformer layer as described in [1]_.
Based on [2]_ and edited by David Dao for Tensorflow.
Parameters
----------
U : float
The output of a convolutional net should have the
shape [num_batch, height, width, num_channels].
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
References
----------
.. [1] Spatial Transformer Networks
Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
Submitted on 5 Jun 2015
.. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
Notes
-----
To initialize the network to the identity transform init
``theta`` to :
identity = np.array([[1., 0., 0.],
[0., 1., 0.]])
identity = identity.flatten()
theta = tf.Variable(initial_value=identity)
"""
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
rep = tf.cast(rep, 'int32')
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
return tf.reshape(x,[-1])
def _interpolate(im, x, y, downsample_factor):
with tf.variable_scope('_interpolate'):
# constants
num_batch = tf.shape(im)[0]
height = tf.shape(im)[1]
width = tf.shape(im)[2]
channels = tf.shape(im)[3]
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
# scale indices from [-1, 1] to [0, width/height]
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0)*(height_f) / 2.0
# do sampling
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = width
dim1 = width*height
base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
base_y0 = base + y0*dim2
base_y1 = base + y1*dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# use indices to lookup pixels in the flat image and restore channels dim
im_flat = tf.reshape(im,tf.pack([-1, channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.gather(im_flat, idx_a)
Ib = tf.gather(im_flat, idx_b)
Ic = tf.gather(im_flat, idx_c)
Id = tf.gather(im_flat, idx_d)
# and finally calculate interpolated values
x0_f = tf.cast(x0, 'float32')
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return output
def _meshgrid(height, width):
with tf.variable_scope('_meshgrid'):
# This should be equivalent to:
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
# np.linspace(-1, 1, height))
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0]))
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
tf.ones(shape=tf.pack([1, width])))
x_t_flat = tf.reshape(x_t,(1, -1))
y_t_flat = tf.reshape(y_t,(1, -1))
ones = tf.ones_like(x_t_flat)
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
return grid
def _transform(theta, input_dim, downsample_factor):
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid,0)
grid = tf.reshape(grid,[-1])
grid = tf.tile(grid,tf.pack([num_batch]))
grid = tf.reshape(grid,tf.pack([num_batch, 3, -1]))
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g = tf.batch_matmul(theta, grid)
x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
x_s_flat = tf.reshape(x_s,[-1])
y_s_flat = tf.reshape(y_s,[-1])
input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
downsample_factor)
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(theta, U, downsample_factor)
return output
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
def conv2d(x, n_filters,
k_h=5, k_w=5,
stride_h=2, stride_w=2,
stddev=0.02,
activation=lambda x: x,
bias=True,
padding='SAME',
name="Conv2D"):
"""2D Convolution with options for kernel size, stride, and init deviation.
Parameters
----------
x : Tensor
Input tensor to convolve.
n_filters : int
Number of filters to apply.
k_h : int, optional
Kernel height.
k_w : int, optional
Kernel width.
stride_h : int, optional
Stride in rows.
stride_w : int, optional
Stride in cols.
stddev : float, optional
Initialization's standard deviation.
activation : arguments, optional
Function which applies a nonlinearity
padding : str, optional
'SAME' or 'VALID'
name : str, optional
Variable scope to use.
Returns
-------
x : Tensor
Convolved input.
"""
with tf.variable_scope(name):
w = tf.get_variable(
'w', [k_h, k_w, x.get_shape()[-1], n_filters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = tf.nn.conv2d(
x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
if bias:
b = tf.get_variable(
'b', [n_filters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = conv + b
return conv
def linear(x, n_units, scope=None, stddev=0.02,
activation=lambda x: x):
"""Fully-connected network.
Parameters
----------
x : Tensor
Input tensor to the network.
n_units : int
Number of units to connect to.
scope : str, optional
Variable scope to use.
stddev : float, optional
Initialization's standard deviation.
activation : arguments, optional
Function which applies a nonlinearity
Returns
-------
x : Tensor
Fully-connected output.
"""
shape = x.get_shape().as_list()
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
tf.random_normal_initializer(stddev=stddev))
return activation(tf.matmul(x, matrix))
# %%
def weight_variable(shape):
'''Helper function to create a weight variable initialized with
a normal distribution
Parameters
----------
shape : list
Size of weight variable
'''
#initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
initial = tf.zeros(shape)
return tf.Variable(initial)
# %%
def bias_variable(shape):
'''Helper function to create a bias variable initialized with
a constant value.
Parameters
----------
shape : list
Size of weight variable
'''
initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
return tf.Variable(initial)
# Preprocessing
# Create a batch of three images (1600 x 1200)
im = ndimage.imread('cat.jpg')
im = im / 255.
im = im.reshape(1, 1200, 1600, 3)
im = im.astype('float32')
# Simulate batch
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)
num_batch = 3
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')
num_batch = 3
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')
# Create localisation network and convolutional layer
with tf.variable_scope('spatial_transformer_0'):
# filter_size = 3
# n_filters_1 = 3
# W_conv1 = weight_variable([filter_size, filter_size, 3, n_filters_1])
# # %% Bias is [output_channels]
# b_conv1 = bias_variable([n_filters_1])
# # %% Now we can build a graph which does the first layer of convolution:
# # we define our stride as batch x height x width x channels
# # instead of pooling, we use strides of 2 and more layers
# # with smaller filters.
# h_conv1 = tf.nn.relu(
# tf.nn.conv2d(input=x,
# filter=W_conv1,
# strides=[1, 1, 1, 1],
# padding='SAME') +
# b_conv1)
# h_conv1_trans = tf.transpose(h_conv1, perm=[0, 3, 1, 2])
# # %% We'll now reshape so we can connect to a fully-connected layer:
# h_conv2_flat = tf.reshape(h_conv1, [-1, 1200 * 1600 * 3])
# %% Create a fully-connected layer:
n_fc = 6
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
initial = np.array([[0.5,0, 0],[0,0.5,0]])
initial = initial.flatten()
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
b_fc1 = tf.cast(b_fc1, 'float32') # cast it to float32
x_flatten = tf.reshape(x,[-1,1200 * 1600 * 3])
#h_fc1 = tf.nn.relu(tf.matmul(x_flatten, W_fc1) + b_fc1)
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, downsample_factor=2)
# Run session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
y = sess.run(h_trans, feed_dict={x: batch})
plt.imshow(y[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment