Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Created May 28, 2017 20:44
Show Gist options
  • Save jcjohnson/876ca05163ad23ab06c2f98cf3bcd6bb to your computer and use it in GitHub Desktop.
Save jcjohnson/876ca05163ad23ab06c2f98cf3bcd6bb to your computer and use it in GitHub Desktop.
def bilinear_sample(feats, X, Y, idx):
"""
Perform bilinear sampling on the features in feats using the sampling grid
given by X and Y.
Inputs:
- feats: Tensor (or Variable) holding input feature map, of shape (N, C, H, W)
- X, Y: Tensors (or Variables) holding x and y coordinates of the sampling
grids; both have shape shape (B, HH, WW). Elements in X should be in the
range [0, W - 1] and elements in Y should be in the range [0, H - 1].
- idx: LongTensor (or Variable) of shape (B,) mapping elements of the sampling
grids to elements in feats. In particular idx[i] = j means that X[i], Y[i]
is a sampling grid for feats[j].
Returns:
- out: Tensor (or Variable) of shape (B, C, HH, WW) where out[i] is computed
by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]).
"""
N, C, H, W = feats.size()
assert X.size() == Y.size()
B, HH, WW = X.size()
outs, mask_idxs = [], []
for i in range(N):
# Figure out which elements of X and Y correspond to element i of feats.
# We need a bit of special-case logic for Tensors vs Variables.
mask = idx.eq(i)
if torch.is_tensor(idx):
BB = mask.sum()
else:
assert isinstance(mask, torch.autograd.Variable)
BB = mask.data.sum()
if BB == 0:
continue
if torch.is_tensor(idx):
mask_idx = mask.nonzero()[:, 0]
elif isinstance(mask, torch.autograd.Variable):
mask_idx = torch.autograd.Variable(mask.data.nonzero()[:, 0])
x = X.index_select(0, mask_idx)
y = Y.index_select(0, mask_idx)
mask_idxs.append(mask_idx)
# Get the x and y coordinates for the four samples
x0 = x.floor().clamp(min=0, max=W-1)
x1 = (x0 + 1).clamp(min=0, max=W-1)
y0 = y.floor().clamp(min=0, max=H-1)
y1 = (y0 + 1).clamp(min=0, max=H-1)
# In numpy we could do something like feats[i, :, y0, x0] to pull out
# the elements of feats at coordinates y0 and x0, but PyTorch doesn't
# yet support this style of indexing. Instead we have to use the gather
# method, which only allows us to index along one dimension at a time;
# therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W)
# and index along the last dimension. Below we generate linear indices into
# the collapsed last dimension for each of the four combinations we need.
y0x0_idx = (W * y0 + x0).view(BB, 1, HH * WW).expand(BB, C, HH * WW)
y1x0_idx = (W * y1 + x0).view(BB, 1, HH * WW).expand(BB, C, HH * WW)
y0x1_idx = (W * y0 + x1).view(BB, 1, HH * WW).expand(BB, C, HH * WW)
y1x1_idx = (W * y1 + x1).view(BB, 1, HH * WW).expand(BB, C, HH * WW)
# Actually use gather to pull out the values from feats corresponding
# to our four samples, then reshape them to (BB, C, HH, WW)
feats_i_flat = feats[i].view(1, C, H * W).expand(BB, C, H * W)
v1 = feats_i_flat.gather(2, y0x0_idx.long()).view(BB, C, HH, WW)
v2 = feats_i_flat.gather(2, y1x0_idx.long()).view(BB, C, HH, WW)
v3 = feats_i_flat.gather(2, y0x1_idx.long()).view(BB, C, HH, WW)
v4 = feats_i_flat.gather(2, y1x1_idx.long()).view(BB, C, HH, WW)
# Compute the weights for the four samples
w1 = ((x1 - x) * (y1 - y)).view(BB, 1, HH, WW).expand(BB, C, HH, WW)
w2 = ((x1 - x) * (y - y0)).view(BB, 1, HH, WW).expand(BB, C, HH, WW)
w3 = ((x - x0) * (y1 - y)).view(BB, 1, HH, WW).expand(BB, C, HH, WW)
w4 = ((x - x0) * (y - y0)).view(BB, 1, HH, WW).expand(BB, C, HH, WW)
# Multiply the samples by the weights to give our interpolated results.
cur_out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
outs.append(cur_out)
mask_idxs = torch.cat(mask_idxs, 0)
_, sidx = mask_idxs.sort()
return torch.cat(outs, 0).index_select(0, sidx)
@edgarriba
Copy link

edgarriba commented Jun 15, 2017

do you have any usage example for this? it's not clear to me what should be idx tensor

@hzxie
Copy link

hzxie commented Sep 29, 2018

@edgarriba Glad to see you here!
I have the same problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment