Skip to content

Instantly share code, notes, and snippets.

@metamath1
Last active May 2, 2019 07:56
Show Gist options
  • Save metamath1/8782e0353c195bbc58a73815a29901e4 to your computer and use it in GitHub Desktop.
Save metamath1/8782e0353c195bbc58a73815a29901e4 to your computer and use it in GitHub Desktop.
Experiment on whether 'transpose convolution' and 'backward pass of convolution' are the same
import torch
import torch.nn.functional as F
x = torch.randn(1 ,1, 4, 4, requires_grad=True)
w = torch.randn(1, 1, 3, 3)
o = F.conv2d(x, w, stride=2, padding=1)
print("x")
print(x)
print("w")
print(w)
print("o")
print(o)
print("\n")
# 1. pytorch를 이용한 직접 미분,
# 2. 아웃풋을 적당히 변형시켜 웨이트w에 컨벌루션
# 3. transpose conv.
# 위 세가지 결과를 비교한다.
# 세 결과는 모두 같아야 한다.
# pytorch autograd를 이용한 그래디언트
# 미분을 위한 상위 그래디언트는 그냥 o와 같다고 가정한다.
grad = torch.autograd.grad(o, x, o, retain_graph=True)
print('grad by autograd')
print(grad[0])
# 일반적인 컨벌루션의 백워드 연산
# forward pass에 stride가 있으므로 인해 출력 o 사이사이에 0을 넣고 난다음
# 180도 돌린다.
# o가 (1,1,4,4)가 되었으므로 full conv위해서 padding=3으로 줘야하는데
# forward pass에 padding이 있으므로 여기서는 padding=2만 준다.
o_stride = torch.zeros((1, 1, 4, 4))
o_stride[0,0,0,0] = o[0,0,0,0]
o_stride[0,0,0,2] = o[0,0,0,1]
o_stride[0,0,2,0] = o[0,0,1,0]
o_stride[0,0,2,2] = o[0,0,1,1]
o_flip = torch.flip(o_stride, [2, 3])
backward_conv = F.conv2d(w, o_flip, padding=2)
print("backward conv1")
print(backward_conv)
# o를 기준으로 잡고 w를 180도 돌려서 conv해도 결과는 같다.
# 이 때는 w가 o_stride에 full conv되어야 하므로 padding=2가 되야 하는데
# forward pass에 padding이 있으므로 여기서는 padding=1만 준다.
w_flip = torch.flip(w, [2, 3])
backward_conv2 = F.conv2d(o_stride, w_flip, padding=1)
print("backward conv2")
print(backward_conv2)
# transpose conv 연산
trans_conv = F.conv_transpose2d(o, w, padding=1, stride=2, output_padding=1)
print("transpose conv")
print(trans_conv)
# 세 결과를 비교하면 grad, backward_conv와 trans_conv는 완전히 같다.
# 이것으로 conv의 backward와 적당히 변형된 conv연산, transpose conv.가 모두
# 같다는 것을 실험으로 확인했다.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment