Created
May 8, 2020 18:18
-
-
Save MauroPfister/79e9b500a93c028ff16fc12527763c6e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from ops.dcn import DeformConvPack # Needs dcn module from https://github.com/open-mmlab/mmdetection | |
torch.manual_seed(0) | |
def check_grad(): | |
device = torch.device('cuda') | |
conv_full = DeformConvPack(in_channels=4, | |
out_channels=4, | |
kernel_size=(3, 3), | |
padding=1, | |
groups=1, | |
stride=1).to(device) | |
conv_group = DeformConvPack(in_channels=4, | |
out_channels=4, | |
kernel_size=(3, 3), | |
padding=1, | |
groups=4, | |
stride=1).to(device) | |
# Initialization with zero-padded kernels | |
with torch.no_grad(): | |
new_weight = torch.zeros_like(conv_full.weight) | |
idx = [i for i in range(conv_full.weight.shape[1])] | |
new_weight[idx, idx, :, :] = conv_group.weight[idx, 0, :, :] | |
conv_full.weight = torch.nn.Parameter(new_weight) | |
x = torch.ones(1, 4, 5, 5).to(device) | |
x[:, :, 3, 2] = 2 | |
y = torch.ones(1, 4, 5, 5).to(device) * 0.1 | |
loss_fn = torch.nn.MSELoss(reduction='sum').to(device) | |
learning_rate = 1e-4 | |
optim_full = torch.optim.Adam(conv_full.parameters(), lr=learning_rate) | |
optim_group = torch.optim.Adam(conv_group.parameters(), lr=learning_rate) | |
epochs = 100 | |
for i in range(epochs): | |
print(f"\nEpoch {i:02d}") | |
# Forward pass | |
y_pred_full = conv_full(x) | |
y_pred_group = conv_group(x) | |
print(f"Predictions are same: {torch.allclose(y_pred_full, y_pred_group)}") | |
# Compute loss | |
loss_full = loss_fn(y_pred_full, y) | |
loss_group = loss_fn(y_pred_group, y) | |
optim_full.zero_grad() | |
optim_group.zero_grad() | |
# Backward pass | |
loss_full.backward() | |
loss_group.backward() | |
# Zero out conv full gradients | |
with torch.no_grad(): | |
grad = conv_full.weight.grad.clone() | |
conv_full.weight.grad.zero_() | |
idx = [i for i in range(conv_full.weight.shape[1])] | |
conv_full.weight.grad[idx, idx, :, :] = grad[idx, idx, :, :] | |
# Check if gradients of conv_full and conv_group are the same | |
for i in range(conv_full.weight.shape[1]): | |
conv_full_grad_i = conv_full.weight.grad[i, i, :, :] | |
conv_group_grad_i = conv_group.weight.grad[i, 0, :, :] | |
is_equal = torch.allclose(conv_full_grad_i, conv_group_grad_i) | |
print(f"Grads in {i}-th channels are same: {is_equal}") | |
# Update parameters | |
optim_full.step() | |
optim_group.step() | |
if __name__ == "__main__": | |
check_grad() |
Hi @palver7, as far as I see Pytorch does not have an equivalent for DeformConvPack
which includes the calculation of the offset tensor. I assume you are using DeformConv2d
from torchvision.ops
and create another Conv2d
for the calculation of the offset? If so, are you sure your offset calculation matches the one in DeformConvPack
?
Also, did you check if the two predictions are completely different or if they are simply different enough to produce False
with the default values of torch.allclose
?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @MauroPfister, I downloaded this script and changed the mmlab dcn to pytorch dcn but I got all results false as you can see in this screenshot :
Can you help me figure out why I get this result ? You said in a pytorch discussion forum that your hack produced almost the same values as the normal method but faster, I would like to reproduce it using PyTorch's own Deformable Convolution.