Skip to content

Instantly share code, notes, and snippets.

@scheidan
Last active May 22, 2017 08:12
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scheidan/aee745a953eb5ba40d5d to your computer and use it in GitHub Desktop.
Save scheidan/aee745a953eb5ba40d5d to your computer and use it in GitHub Desktop.
First draft of a spatial transformer network in Chainer
## -------------------------------------------------------
##
## Elements to implement a spatial transformer network
##
## See: Jaderberg, M., Simonyan, K., Zisserman, A., and Kavukcuoglu,
## K. (2015) Spatial Transformer Networks. arXiv:1506.02025
##
## February 9, 2016 -- Andreas Scheidegger
## andreas.scheidegger@eawag.ch
## -------------------------------------------------------
import numpy as np
from chainer import cuda, Function, Variable
import chainer.functions as F
# ---------------------------------
# Sampling grid generator
def expand_grid(dim):
"""
Returns a Grid representation of the coordiantes normalized to -1 ... 1
returns array of shape = (3, prod(dim)),
i.e. one row of ones is added for transformation
"""
x = np.linspace(-1, 1, dim[1])
y = np.linspace(-1, 1, dim[0])
xG, yG = np.meshgrid(x, y)
xG = xG.flatten()
yG = yG.flatten()
G = np.vstack((xG, yG, np.ones(np.prod(dim))))
return G.astype("float32")
def sampling_grid(A, target_grid):
"""
Generate a a matrix with coordinates of sampling points.
The affine transformation matrix 'A' is applied to the coordinate in 'target_grid'
'A' must be of shape (2,3).
'target_grid must be of shape (3, prod(target_dims))'
"""
G = F.matmul(A, target_grid)
return G
# ---------------------------------
# Interpolate at sampling grid
# input U at grid G_sample
# see Lasagne code here:
# https://github.com/Lasagne/Lasagne/blob/master/lasagne/layers/special.py#L232-L309
class Interpolate(Function):
def forward_cpu(self, inputs):
U, grid = inputs
height= U.shape[0]
width = U.shape[1]
x = grid[0,:]
y = grid[1,:]
# clip coordinates to [-1, 1], i.e. edge pixels are repeated.
x = x.clip(-1, 1)
y = y.clip(-1, 1)
# rescale coordiantes from [-1, 1] to [0, width/height - 1]
# (The factor 0.9999 ensures that the end points are always
# floored to the same side.)
x = (x*0.9999+1)/2 * (width-1)
y = (y*0.9999+1)/2 * (height-1)
# indices of the 2x2 pixel neighborhood surrounding the coordinates
x0 = np.floor(x)
x1 = x0+1
y0 = np.floor(y)
y1 = y0+1
# get weights
w1 = (x1-x) * (y1-y)
w2 = (x-x0) * (y1-y)
w3 = (x1-x) * (y-y0)
w4 = (x-x0) * (y-y0)
V = np.zeros(grid.shape[1]).astype("float32")
for i in range(grid.shape[1]):
V[i] = w1[i]*U[y0[i],x0[i]] + w2[i]*U[y0[i],x1[i]] + \
w3[i]*U[y1[i],x0[i]] + w4[i]*U[y1[i],x1[i]]
return V,
def backward_cpu(self, inputs, grad_outputs):
U, grid = inputs
gV, = grad_outputs # same dimension as V
height= U.shape[0]
width = U.shape[1]
x = grid[0,:]
y = grid[1,:]
# clip coordinates to [-1, 1], i.e. edge pixels are repeated.
x = x.clip(-1, 1)
y = y.clip(-1, 1)
# rescale coordiantes from [-1, 1] to [0, width/height - 1]
# (The factor 0.9999 ensures that the end points are always
# floored to the same side.)
x = (x*0.9999+1)/2 * (width-1)
y = (y*0.9999+1)/2 * (height-1)
# indices of the 2x2 pixel neighborhood surrounding the coordinates
x0 = np.floor(x)
x1 = x0+1
y0 = np.floor(y)
y1 = y0+1
# weights
wx0 = (x1-x)
wx1 = (x-x0)
wy0 = (y1-y)
wy1 = (y-y0)
# --- gx, gy
gx = np.zeros(grid.shape[1]).astype("float32")
gy = np.zeros(grid.shape[1]).astype("float32")
for i in range(grid.shape[1]):
gx[i] = - wy0[i] * U[y0[i],x0[i]] \
+ wy0[i] * U[y0[i],x1[i]] \
- wy1[i] * U[y1[i],x0[i]] \
+ wy1[i] * U[y1[i],x1[i]]
gy[i] = - wx0[i] * U[y0[i],x0[i]] \
- wx1[i] * U[y0[i],x1[i]] \
+ wx0[i] * U[y1[i],x0[i]] \
+ wx1[i] * U[y1[i],x1[i]]
gx = gx * gV
gy = gy * gV
ggrid = np.vstack((gx, gy))
# --- gU
gU = np.zeros((height, width)).astype("float32")
for cx in range(width):
for cy in range(height):
select_q1 = (x >= cx) & (x < cx+1) & (y <= cy) & (y > cy-1)
select_q2 = (x <= cx) & (x > cx-1) & (y <= cy) & (y > cy-1)
select_q3 = (x <= cx) & (x > cx-1) & (y >= cy) & (y < cy+1)
select_q4 = (x >= cx) & (x < cx+1) & (y >= cy) & (y < cy+1)
gU[cy,cx] = np.sum(wx0[select_q1]*wy1[select_q1]*gV[select_q1]) + \
np.sum(wx1[select_q2]*wy1[select_q2]*gV[select_q2]) + \
np.sum(wx1[select_q3]*wy0[select_q3]*gV[select_q3]) + \
np.sum(wx0[select_q4]*wy0[select_q4]*gV[select_q4])
return gU, ggrid
def forward_gpu(self, inputs):
U, grid = inputs
height= U.shape[0]
width = U.shape[1]
x = grid[0,:]
y = grid[1,:]
# clip coordinates to [-1, 1], i.e. edge pixels are repeated.
x = x.clip(-1, 1)
y = y.clip(-1, 1)
# rescale coordiantes from [-1, 1] to [0, width/height - 1]
# (The factor 0.9999 ensures that the end points are always
# floored to the same side.)
x = (x*0.9999+1)/2 * (width-1)
y = (y*0.9999+1)/2 * (height-1)
# indices of the 2x2 pixel neighborhood surrounding the coordinates
x0 = cuda.cupy.floor(x).astype("int32")
x1 = x0+1
y0 = cuda.cupy.floor(y).astype("int32")
y1 = y0+1
# get weights
w1 = (x1-x) * (y1-y)
w2 = (x-x0) * (y1-y)
w3 = (x1-x) * (y-y0)
w4 = (x-x0) * (y-y0)
kern = cuda.cupy.ElementwiseKernel(
'raw T U, T w1, T w2, T w3, T w4, int32 x0, int32 x1, int32 y0, int32 y1, int32 N',
'T V',
'V = w1*U[y0*N+x0] + w2*U[y0*N+x1] + w3*U[y1*N+x0] + w4*U[y1*N+x1]',
'compute_V'
)
V = kern(U,
w1.astype("float32"), w2.astype("float32"),
w3.astype("float32"), w4.astype("float32"),
x0, x1, y0, y1, U.shape[1])
return V,
def backward_gpu(self, inputs, grad_outputs):
U, grid = inputs
gV, = grad_outputs # same dimension as V
height= U.shape[0]
width = U.shape[1]
x = grid[0,:]
y = grid[1,:]
# clip coordinates to [-1, 1], i.e. edge pixels are repeated.
x = x.clip(-1, 1)
y = y.clip(-1, 1)
# rescale coordiantes from [-1, 1] to [0, width/height - 1]
# (The factor 0.9999 ensures that the end points are always
# floored to the same side.)
x = (x*0.9999+1)/2 * (width-1)
y = (y*0.9999+1)/2 * (height-1)
# indices of the 2x2 pixel neighborhood surrounding the coordinates
x0 = cuda.cupy.floor(x).astype("int32")
x1 = x0+1
y0 = cuda.cupy.floor(y).astype("int32")
y1 = y0+1
# weights
wx0 = (x1-x)
wx1 = (x-x0)
wy0 = (y1-y)
wy1 = (y-y0)
# --- gx, gy
gx_kern = cuda.cupy.ElementwiseKernel(
'raw T U, T wy0, T wy1, int32 x0, int32 x1, int32 y0, int32 y1, T gV, int32 N',
'T gx',
'gx = gV * (-wy0*U[y0*N+x0] + wy0*U[y0*N+x1] - wy1*U[y1*N+x0] + wy1*U[y1*N+x1])',
'compute_gx'
)
gx = gx_kern(U, wy0.astype("float32"), wy1.astype("float32"),
x0, x1, y0, y1, gV, U.shape[1])
gy_kern = cuda.cupy.ElementwiseKernel(
'raw T U, T wx0, T wx1, int32 x0, int32 x1, int32 y0, int32 y1, T gV, int32 N',
'T gy',
'gy = gV * (-wx0*U[y0*N+x0] - wx1*U[y0*N+x1] + wx0*U[y1*N+x0] + wx1*U[y1*N+x1])',
'compute_gx'
)
gy = gy_kern(U, wx0.astype("float32"), wx1.astype("float32"),
x0, x1, y0, y1, gV, U.shape[1])
ggrid = cuda.cupy.vstack((gx, gy))
# --- gU
gU = cuda.cupy.zeros((height, width)).astype("float32")
z = cuda.cupy.zeros_like(wx1)
for cx in range(width):
for cy in range(height):
select_q1 = (x >= cx) & (x < cx+1) & (y <= cy) & (y > cy-1)
select_q2 = (x <= cx) & (x > cx-1) & (y <= cy) & (y > cy-1)
select_q3 = (x <= cx) & (x > cx-1) & (y >= cy) & (y < cy+1)
select_q4 = (x >= cx) & (x < cx+1) & (y >= cy) & (y < cy+1)
gU[cy,cx] = cuda.cupy.sum(cuda.cupy.where(select_q1, wx0*wy1*gV, z)) + \
cuda.cupy.sum(cuda.cupy.where(select_q2, wx1*wy1*gV, z)) + \
cuda.cupy.sum(cuda.cupy.where(select_q3, wx1*wy0*gV, z)) + \
cuda.cupy.sum(cuda.cupy.where(select_q4, wx0*wy0*gV, z))
return gU, ggrid
# Wrapper
def interpolate(U, grid):
"""Sample from input feature map 'U' at coordinates of 'grid'
applying bilinear interpolation, see:
Jaderberg, M., Simonyan, K., Zisserman, A., and Kavukcuoglu,
K. (2015) Spatial Transformer Networks. arXiv:1506.02025
'U' has shape HxW, 'grid' is a two column matrix containing
(normalized) sampling coordinates. """
return Interpolate()(U, grid)
## -------------------------------------------------------
##
## Example spatial transformer
##
## See: Jaderberg, M., Simonyan, K., Zisserman, A., and Kavukcuoglu,
## K. (2015) Spatial Transformer Networks. arXiv:1506.02025
##
## February 8, 2016 -- Andreas Scheidegger
## andreas.scheidegger@eawag.ch
## -------------------------------------------------------
import numpy as np
from chainer import Variable, Chain, cuda, optimizers
import chainer.functions as F
import chainer.links as L
import chainer
import os
import sys
sys.path.append("{}/Dropbox/Projects/Nowcasting/Chainer/Modules".format(os.environ['HOME']))
import generators as g # data generators
import SpatialTransformer as st
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
gpuID = 0 # -1 = CPU, 0 = GPU
if gpuID>=0:
print(cuda.get_device(gpuID))
cuda.get_device(gpuID).use()
# -----------
# --- create data
data_in = np.random.random((30, 50)).astype("float32")
data_in[15:22, 20:33] = 2.0
data_in[8:10, 5:10] = 4.0
U = Variable(data_in.astype("float32"))
if gpuID>=0:
U.to_gpu()
# transform
A = Variable(np.asarray([[ 0.85, -0.1, 0.1], [-0.1, 0.85, 0.2]]).astype("float32"))
dimout = data_in.shape
G_target = Variable(st.expand_grid(dimout))
G = st.sampling_grid(A, G_target)
if gpuID>=0:
G.to_gpu()
dd = st.interpolate(U, G)
dd.to_cpu()
data_out = np.reshape(dd.data, dimout)
print(np.sum(data_out))
# -----------
# --- model to learn the transformation matrix
class AffineMat(chainer.Link):
"""
This link holds the transformation matrix as parameter.
However, typically the transformation matrix would be
provided by an localization network.
"""
def __init__(self):
super(AffineMat, self).__init__(
A=(2, 3),
)
self.A.data[...] = np.asarray([[1, 0, 0], [0, 1, 0]]).astype("float32")
def __call__(self):
return self.A
class TargetGrid(chainer.Link):
"""
grid...
"""
def __init__(self, dimout):
super(TargetGrid, self).__init__(
g_target=(3, np.prod(dimout)),
)
self.g_target.data[...] = st.expand_grid(dimout)
def __call__(self):
return self.g_target
class Testmodel(Chain):
def __init__(self, dimout):
self.dimout = dimout
super(Testmodel, self).__init__(
A = AffineMat(),
G_target = TargetGrid(dimout)
)
def transform(self, data_in, train=False):
# G_sampling = F.matmul(self.A(), self.G_target())
G_sampling = st.sampling_grid(self.A(), self.G_target())
xpred = st.interpolate(F.reshape(data_in, self.dimout), G_sampling)
return F.reshape(xpred, (1,)+self.dimout)
def loss(self, data_in, data_out):
data_trans = self.transform(data_in, train=False)
return F.mean_squared_error(data_trans, data_out)
model = Testmodel(dimout)
optimizer = optimizers.MomentumSGD(lr=0.2, momentum=0.9)
optimizer.setup(model)
data_in_var = Variable(data_in[np.newaxis,:])
data_out_var = Variable(data_out[np.newaxis])
if gpuID>=0:
model.to_gpu()
data_in_var.to_gpu()
data_out_var.to_gpu()
with PdfPages("test.pdf") as pdf:
for epoch in range(50):
print('epoch %d' % epoch)
# plot
if epoch % 2 == 0:
fig, axes = plt.subplots(nrows=1, ncols=3)
fig.tight_layout()
axes[0].imshow(data_in, interpolation="none", cmap="cubehelix_r", vmin=0, vmax=5)
axes[0].set_title("input")
axes[1].imshow(data_out, interpolation="none", cmap="cubehelix_r", vmin=0, vmax=5)
axes[1].set_title("label")
pred = model.transform(data_in_var).data
pred = cuda.to_cpu(pred)
axes[2].imshow(pred[0,:], interpolation="none", cmap="cubehelix_r", vmin=0, vmax=5)
axes[2].set_title("Epochs: {}".format(epoch))
pdf.savefig()
plt.close()
# update
model.zerograds()
loss = model.loss(data_in_var, data_out_var)
print("loss: {}".format(loss.data))
loss.backward(retain_grad=True)
optimizer.update()
print("Estimated transformation matrix:\n {}".format(model.A.A.data))
@scheidan
Copy link
Author

Gradient computation of GPU implementation is very inefficient! The problem are the lines 276 to 288.

@JadenTravnik
Copy link

The really inefficient part can be made 30% faster by replacing lines 276 - 288 with:

        q1 = wx0*wy1*gV
        q2 = wx1*wy1*gV
        q3 = wx1*wy0*gV
        q4 = wx0*wy0*gV

        Q = -1*cuda.cupy.ones((height+1, width+1)).astype("float32")
        for cx in range(width):
            for cy in range(height):
                Q[cy, cx+1] = Q[cy, cx+1] if Q[cy, cx+1]>-1 else (x >= cx) & (x < cx+1) & (y <= cy) & (y > cy-1) # q1
                Q[cy, cx] = Q[cy, cx] if Q[cy, cx]>-1 else (x <= cx) & (x > cx-1) & (y <= cy) & (y > cy-1) #q2
                Q[cy+1, cx] = Q[cy+1, cx] if Q[cy+1, cx]>-1 else (x <= cx) & (x > cx-1) & (y >= cy) & (y < cy+1) # q3
                Q[cy+1, cx+1] = Q[cy+1, cx+1] if Q[cy+1, cx+1]>-1 else (x >= cx) & (x < cx+1) & (y >= cy) & (y < cy+1) #q4

                gU[cy,cx] = cuda.cupy.sum(cuda.cupy.where(Q[cy, cx+1], q1, z)) + \
                            cuda.cupy.sum(cuda.cupy.where(Q[cy, cx], q2, z)) + \
                            cuda.cupy.sum(cuda.cupy.where(Q[cy+1, cx], q3, z)) + \
                            cuda.cupy.sum(cuda.cupy.where(Q[cy+1, cx+1], q4, z))



        return gU, ggrid

This uses dynamic programming to to store previously computed values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment