Skip to content

Instantly share code, notes, and snippets.

@Efreeto
Created October 13, 2017 22:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Efreeto/9c6994694297da8203ed292e2107e068 to your computer and use it in GitHub Desktop.
Save Efreeto/9c6994694297da8203ed292e2107e068 to your computer and use it in GitHub Desktop.
# Modified from https://github.com/fxia22/stn.pytorch/blob/master/script/test.py
from __future__ import print_function
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
# from modules.stn import STN
# from modules.gridgen import AffineGridGen, CylinderGridGen, CylinderGridGenV2, DenseAffine3DGridGen, DenseAffine3DGridGen_rotate
import torch.nn.functional as F
import time
# nframes = 64
# height = 64
# width = 128
# channels = 64
nframes = 4
height = 6
width = 12
channels = 3
inputImages = torch.zeros(nframes, channels, height, width)
grids = torch.zeros(nframes, height, width, 2)
input1, input2 = Variable(inputImages, requires_grad=True), Variable(grids, requires_grad=True)
input1.data.uniform_()
input2.data.uniform_(-1,1)
# input = Variable(torch.from_numpy(np.array([[[0.8, 0.3, 1], [0.5, 0, 0]]], dtype=np.float32)), requires_grad = True)
theta = Variable(torch.Tensor([[1, 0, 0],[0, 1, 0]]).view(1, 2, 3).repeat(nframes, 1, 1), requires_grad=True)
print(theta)
# g = AffineGridGen(64, 128, aux_loss = True)
# out, aux = g(input)
grid = F.affine_grid(theta, torch.Size([nframes, channels, height, width]))
print((grid.size()))
grid.backward(grid.data)
print(theta.grad.size())
start = time.time()
# s = STN()
# out = s(input1, input2)
out = F.grid_sample(input1, input2)
print(out.size(), 'time:', time.time() - start)
start = time.time()
out.backward(input1.data)
print(input1.grad.size(), 'time:', time.time() - start)
with torch.cuda.device(0):
input1 = input1.cuda()
input2 = input2.cuda()
start = time.time()
# out = s(input1, input2)
out = F.grid_sample(input1, input2)
print(out.size(), 'time:', time.time() - start)
start = time.time()
out.backward(input1.data.cuda())
print('time:', time.time() - start)
# s2 = STN(layout = 'BCHW')
# input1, input2 = Variable(inputImages.transpose(2,3).transpose(1,2), requires_grad=True), Variable(grids.transpose(2,3).transpose(1,2), requires_grad=True)
# input1.data.uniform_()
# input2.data.uniform_(-1,1)
# start = time.time()
# out = s2(input1, input2)
# print(out.size(), 'time:', time.time() - start)
# start = time.time()
# out.backward(input1.data)
# print(input1.grad.size(), 'time:', time.time() - start)
# with torch.cuda.device(1):
# input1 = input1.cuda()
# input2 = input2.cuda()
# start = time.time()
# out = s2(input1, input2)
# print(out.size(), 'time:', time.time() - start)
# start = time.time()
# out.backward(input1.data.cuda())
# print('time:', time.time() - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment