Created
May 8, 2020 18:18
-
-
Save MauroPfister/79e9b500a93c028ff16fc12527763c6e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from ops.dcn import DeformConvPack # Needs dcn module from https://github.com/open-mmlab/mmdetection | |
torch.manual_seed(0) | |
def check_grad(): | |
device = torch.device('cuda') | |
conv_full = DeformConvPack(in_channels=4, | |
out_channels=4, | |
kernel_size=(3, 3), | |
padding=1, | |
groups=1, | |
stride=1).to(device) | |
conv_group = DeformConvPack(in_channels=4, | |
out_channels=4, | |
kernel_size=(3, 3), | |
padding=1, | |
groups=4, | |
stride=1).to(device) | |
# Initialization with zero-padded kernels | |
with torch.no_grad(): | |
new_weight = torch.zeros_like(conv_full.weight) | |
idx = [i for i in range(conv_full.weight.shape[1])] | |
new_weight[idx, idx, :, :] = conv_group.weight[idx, 0, :, :] | |
conv_full.weight = torch.nn.Parameter(new_weight) | |
x = torch.ones(1, 4, 5, 5).to(device) | |
x[:, :, 3, 2] = 2 | |
y = torch.ones(1, 4, 5, 5).to(device) * 0.1 | |
loss_fn = torch.nn.MSELoss(reduction='sum').to(device) | |
learning_rate = 1e-4 | |
optim_full = torch.optim.Adam(conv_full.parameters(), lr=learning_rate) | |
optim_group = torch.optim.Adam(conv_group.parameters(), lr=learning_rate) | |
epochs = 100 | |
for i in range(epochs): | |
print(f"\nEpoch {i:02d}") | |
# Forward pass | |
y_pred_full = conv_full(x) | |
y_pred_group = conv_group(x) | |
print(f"Predictions are same: {torch.allclose(y_pred_full, y_pred_group)}") | |
# Compute loss | |
loss_full = loss_fn(y_pred_full, y) | |
loss_group = loss_fn(y_pred_group, y) | |
optim_full.zero_grad() | |
optim_group.zero_grad() | |
# Backward pass | |
loss_full.backward() | |
loss_group.backward() | |
# Zero out conv full gradients | |
with torch.no_grad(): | |
grad = conv_full.weight.grad.clone() | |
conv_full.weight.grad.zero_() | |
idx = [i for i in range(conv_full.weight.shape[1])] | |
conv_full.weight.grad[idx, idx, :, :] = grad[idx, idx, :, :] | |
# Check if gradients of conv_full and conv_group are the same | |
for i in range(conv_full.weight.shape[1]): | |
conv_full_grad_i = conv_full.weight.grad[i, i, :, :] | |
conv_group_grad_i = conv_group.weight.grad[i, 0, :, :] | |
is_equal = torch.allclose(conv_full_grad_i, conv_group_grad_i) | |
print(f"Grads in {i}-th channels are same: {is_equal}") | |
# Update parameters | |
optim_full.step() | |
optim_group.step() | |
if __name__ == "__main__": | |
check_grad() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @palver7, as far as I see Pytorch does not have an equivalent for
DeformConvPack
which includes the calculation of the offset tensor. I assume you are usingDeformConv2d
fromtorchvision.ops
and create anotherConv2d
for the calculation of the offset? If so, are you sure your offset calculation matches the one inDeformConvPack
?Also, did you check if the two predictions are completely different or if they are simply different enough to produce
False
with the default values oftorch.allclose
?