Skip to content

Instantly share code, notes, and snippets.

@dneprDroid
Created December 10, 2021 20:54
Show Gist options
  • Save dneprDroid/235d30d563e7b0b8688b0d1616085f91 to your computer and use it in GitHub Desktop.
Save dneprDroid/235d30d563e7b0b8688b0d1616085f91 to your computer and use it in GitHub Desktop.
IN_WH = 512
GRID_WH = 256
class TestModel(nn.Module):
def forward(self, x, grid):
grid_resized = self.resize_grid(grid)
return F.grid_sample(
x, grid_resized
)
def resize_grid(self, grid):
# [1, GRID_WH, GRID_WH, 2] => [1, 2, GRID_WH, GRID_WH]
grid_resized = grid.permute(0, 3, 1, 2)
# [1, 2, GRID_WH, GRID_WH] => [1, 2, IN_WH, IN_WH]
grid_resized = F.interpolate(
grid_resized,
size=(IN_WH, IN_WH),
mode='nearest'
)
# [1, 2, IN_WH, IN_WH] => [1, IN_WH, IN_WH, 2]
grid_resized = grid_resized.permute(0, 2, 3, 1)
return grid_resized
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment