Skip to content

Instantly share code, notes, and snippets.

@dkurt
Last active April 2, 2020 14:19
Show Gist options
  • Save dkurt/d9dfa96c0f4e9a09d8018351a9f08ef4 to your computer and use it in GitHub Desktop.
Save dkurt/d9dfa96c0f4e9a09d8018351a9f08ef4 to your computer and use it in GitHub Desktop.
from math import ceil
d = 2
h = 3
w = 4
scale_d = 3
scale_h = 3
scale_w = 2
c = 3
input = Variable(torch.randn(1, c, d, h, w))
resize = nn.Upsample(size=[d*scale_d, h*scale_h, w*scale_w], mode='trilinear', align_corners=False)
resize.eval()
out_resize = resize(input)
kernel_d = (2 * scale_d - scale_d % 2)
kernel_h = (2 * scale_h - scale_h % 2)
kernel_w = (2 * scale_w - scale_w % 2)
print('kernel', kernel_d, kernel_h, kernel_w)
pad_d = ceil((scale_d - 1) / 2)
pad_h = ceil((scale_h - 1) / 2)
pad_w = ceil((scale_w - 1) / 2)
def bilinear_weighs():
weights = torch.randn(c, 1, kernel_d, kernel_h, kernel_w)
f_x = ceil(kernel_w / 2.0)
coeff_x = (kernel_w - 1) / (2.0 * f_x)
f_y = ceil(kernel_h / 2.0)
coeff_y = (kernel_h - 1) / (2.0 * f_y)
f_d = ceil(kernel_d / 2.0)
coeff_d = (kernel_d - 1) / (2.0 * f_d)
for z in range(kernel_d):
for y in range(kernel_h):
for x in range(kernel_w):
# print(x, f, coeff, abs(x / f - coeff), 1 - abs(x / f - coeff))
weights[:,:,z,y,x] = (1 - abs(x / f_x - coeff_x)) * (1 - abs(y / f_y - coeff_y)) * (1 - abs(z / f_d - coeff_d))
return weights
weights = bilinear_weighs()
deconv = torch.nn.functional.conv_transpose2d(input, weights, bias=None,
stride=(scale_d, scale_h, scale_w),
padding=(pad_d, pad_h, pad_w),
output_padding=(0, 0, 0),
groups=c, dilation=(1, 1, 1))
fix_scale_d = (kernel_d - 1) / (2.0 * ceil(kernel_d / 2.0))
fix_scale_h = (kernel_h - 1) / (2.0 * ceil(kernel_h / 2.0))
fix_scale_w = (kernel_w - 1) / (2.0 * ceil(kernel_w / 2.0))
deconv[:,:,0,:,:] /= fix_scale_d
deconv[:,:,-1,:,:] /= fix_scale_d
deconv[:,:,:,0,:] /= fix_scale_h
deconv[:,:,:,-1,:] /= fix_scale_h
deconv[:,:,:,:,0] /= fix_scale_w
deconv[:,:,:,:,-1] /= fix_scale_w
print(out_resize[0, 0, 0])
print(deconv[0, 0, 0])
print(out_resize.shape)
print(deconv.shape)
print(np.max(np.abs(np.array(out_resize) - np.array(deconv))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment