Skip to content

Instantly share code, notes, and snippets.

@mayankgrwl97
Last active February 17, 2021 16:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mayankgrwl97/8be59bb6347715014895c23a49bdd51f to your computer and use it in GitHub Desktop.
Save mayankgrwl97/8be59bb6347715014895c23a49bdd51f to your computer and use it in GitHub Desktop.
import torch
input = torch.arange(4*4).view(1, 1, 4, 4).float()
print(input)
'''
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]])
'''
d = torch.linspace(-1, 1, 4)
y, x = torch.meshgrid(d, d)
grid = torch.stack((x, y), dim=2).unsqueeze(0)
print(grid)
'''
tensor([[[[-1.0000, -1.0000],
[-0.3333, -1.0000],
[ 0.3333, -1.0000],
[ 1.0000, -1.0000]],
[[-1.0000, -0.3333],
[-0.3333, -0.3333],
[ 0.3333, -0.3333],
[ 1.0000, -0.3333]],
[[-1.0000, 0.3333],
[-0.3333, 0.3333],
[ 0.3333, 0.3333],
[ 1.0000, 0.3333]],
[[-1.0000, 1.0000],
[-0.3333, 1.0000],
[ 0.3333, 1.0000],
[ 1.0000, 1.0000]]]])
'''
output = torch.nn.functional.grid_sample(input, grid, padding_mode='reflection', mode='bilinear', align_corners=True)
print(output)
'''
tensor([[[[ 0.0000, 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000, 7.0000],
[ 8.0000, 9.0000, 10.0000, 11.0000],
[12.0000, 13.0000, 14.0000, 15.0000]]]])
'''
import torch
input = torch.arange(4*4).view(1, 1, 4, 4).float()
print(input)
'''
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]])
'''
d = torch.linspace(-1, 1, 7)
y, x = torch.meshgrid(d, d)
grid = torch.stack((x, y), dim=2).unsqueeze(0)
print(grid)
'''
tensor([[[[-1.0000, -1.0000],
[-0.6667, -1.0000],
[-0.3333, -1.0000],
[ 0.0000, -1.0000],
[ 0.3333, -1.0000],
[ 0.6667, -1.0000],
[ 1.0000, -1.0000]],
....
[[-1.0000, 1.0000],
[-0.6667, 1.0000],
[-0.3333, 1.0000],
[ 0.0000, 1.0000],
[ 0.3333, 1.0000],
[ 0.6667, 1.0000],
[ 1.0000, 1.0000]]]])
'''
output = torch.nn.functional.grid_sample(input, grid, padding_mode='reflection', mode='bilinear', align_corners=True)
print(output)
'''
tensor([[[[ 0.0000, 0.5000, 1.0000, 1.5000, 2.0000, 2.5000, 3.0000],
[ 2.0000, 2.5000, 3.0000, 3.5000, 4.0000, 4.5000, 5.0000],
[ 4.0000, 4.5000, 5.0000, 5.5000, 6.0000, 6.5000, 7.0000],
[ 6.0000, 6.5000, 7.0000, 7.5000, 8.0000, 8.5000, 9.0000],
[ 8.0000, 8.5000, 9.0000, 9.5000, 10.0000, 10.5000, 11.0000],
[10.0000, 10.5000, 11.0000, 11.5000, 12.0000, 12.5000, 13.0000],
[12.0000, 12.5000, 13.0000, 13.5000, 14.0000, 14.5000, 15.0000]]]])
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment