Last active
May 2, 2019 07:56
-
-
Save metamath1/8782e0353c195bbc58a73815a29901e4 to your computer and use it in GitHub Desktop.
Experiment on whether 'transpose convolution' and 'backward pass of convolution' are the same
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
x = torch.randn(1 ,1, 4, 4, requires_grad=True) | |
w = torch.randn(1, 1, 3, 3) | |
o = F.conv2d(x, w, stride=2, padding=1) | |
print("x") | |
print(x) | |
print("w") | |
print(w) | |
print("o") | |
print(o) | |
print("\n") | |
# 1. pytorch를 이용한 직접 미분, | |
# 2. 아웃풋을 적당히 변형시켜 웨이트w에 컨벌루션 | |
# 3. transpose conv. | |
# 위 세가지 결과를 비교한다. | |
# 세 결과는 모두 같아야 한다. | |
# pytorch autograd를 이용한 그래디언트 | |
# 미분을 위한 상위 그래디언트는 그냥 o와 같다고 가정한다. | |
grad = torch.autograd.grad(o, x, o, retain_graph=True) | |
print('grad by autograd') | |
print(grad[0]) | |
# 일반적인 컨벌루션의 백워드 연산 | |
# forward pass에 stride가 있으므로 인해 출력 o 사이사이에 0을 넣고 난다음 | |
# 180도 돌린다. | |
# o가 (1,1,4,4)가 되었으므로 full conv위해서 padding=3으로 줘야하는데 | |
# forward pass에 padding이 있으므로 여기서는 padding=2만 준다. | |
o_stride = torch.zeros((1, 1, 4, 4)) | |
o_stride[0,0,0,0] = o[0,0,0,0] | |
o_stride[0,0,0,2] = o[0,0,0,1] | |
o_stride[0,0,2,0] = o[0,0,1,0] | |
o_stride[0,0,2,2] = o[0,0,1,1] | |
o_flip = torch.flip(o_stride, [2, 3]) | |
backward_conv = F.conv2d(w, o_flip, padding=2) | |
print("backward conv1") | |
print(backward_conv) | |
# o를 기준으로 잡고 w를 180도 돌려서 conv해도 결과는 같다. | |
# 이 때는 w가 o_stride에 full conv되어야 하므로 padding=2가 되야 하는데 | |
# forward pass에 padding이 있으므로 여기서는 padding=1만 준다. | |
w_flip = torch.flip(w, [2, 3]) | |
backward_conv2 = F.conv2d(o_stride, w_flip, padding=1) | |
print("backward conv2") | |
print(backward_conv2) | |
# transpose conv 연산 | |
trans_conv = F.conv_transpose2d(o, w, padding=1, stride=2, output_padding=1) | |
print("transpose conv") | |
print(trans_conv) | |
# 세 결과를 비교하면 grad, backward_conv와 trans_conv는 완전히 같다. | |
# 이것으로 conv의 backward와 적당히 변형된 conv연산, transpose conv.가 모두 | |
# 같다는 것을 실험으로 확인했다. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment