Skip to content

Instantly share code, notes, and snippets.

@hsfzxjy
Last active July 25, 2019 11:19
Show Gist options
  • Save hsfzxjy/0f6315cd93692e92eb0b9beb4682f1e9 to your computer and use it in GitHub Desktop.
Save hsfzxjy/0f6315cd93692e92eb0b9beb4682f1e9 to your computer and use it in GitHub Desktop.
Weighted Upsampling
import torch
import torch.nn.functional as F
import numpy as np
def _grid(in_w, in_h, out_w, out_h, x_coerce,y_coerce):
result = np.zeros((out_h, out_w, 2), dtype=np.float64)
for j in range(out_h):
for i in range(out_w):
if i * (in_w - 1) % (out_w - 1) == 0:
tx = i * 2.0 / (out_w - 1) - 1
else:
tx = x_coerce(i * (in_w - 1) / (out_w - 1)) / (in_w - 1) * 2.0 - 1
if j * (in_h - 1) % (out_h - 1) == 0:
ty = j * 2.0 / (out_h - 1) - 1
else:
ty = y_coerce(j * (in_h - 1) / (out_h - 1)) / (in_h - 1) * 2.0 - 1
result[j][i][0], result[j][i][1] = tx, ty
return result
def generate_grids(in_w, in_h, out_w, out_h):
result = []
for x_coerce in (np.floor, np.ceil):
for y_coerce in (np.floor, np.ceil):
result.append(_grid(in_w, in_h, out_w, out_h, x_coerce, y_coerce))
return result
def weighted_upsample(input_tensor, weights):
"""
input_tensor: N x C x in_H x in_W
weights: N x out_H x out_W x 4, i.e. (top-left, bottom-left, top-right, bottom-right)
"""
_, out_h, out_w, _ = weights.size()
n_batches, n_channels, in_h, in_w = input_tensor.size()
grids = generate_grids(in_w, in_h, out_w, out_h)
lst = []
for idx, grid in enumerate(grids):
grid = torch.tensor(grid).unsqueeze(0).repeat(n_batches, 1, 1, 1)
lst.append(
F.grid_sample(input_tensor, grid) * weights[:, :, :, idx].unsqueeze(1).repeat(1, n_channels, 1, 1)
)
return sum(lst)
if __name__ == '__main__':
input_tensor = torch.tensor([
[1, 1.5, 2],
[3, 3.5, 4],
# [5, 5.5, 6],
], dtype=torch.float64).view(1, 1, 2, 3)
weights = torch.stack([i * torch.ones((6, 9), dtype=torch.float64) for i in [.1, .2, .7, 1]]).permute(1, 2, 0).contiguous().view(1, 6, 9, 4) / 2
print(weights.permute(0, 3, 1, 2))
print(weighted_upsample(input_tensor, weights))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment