Skip to content

Instantly share code, notes, and snippets.

@MauroPfister
Created May 8, 2020 18:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MauroPfister/79e9b500a93c028ff16fc12527763c6e to your computer and use it in GitHub Desktop.
Save MauroPfister/79e9b500a93c028ff16fc12527763c6e to your computer and use it in GitHub Desktop.
import torch
from ops.dcn import DeformConvPack # Needs dcn module from https://github.com/open-mmlab/mmdetection
torch.manual_seed(0)
def check_grad():
device = torch.device('cuda')
conv_full = DeformConvPack(in_channels=4,
out_channels=4,
kernel_size=(3, 3),
padding=1,
groups=1,
stride=1).to(device)
conv_group = DeformConvPack(in_channels=4,
out_channels=4,
kernel_size=(3, 3),
padding=1,
groups=4,
stride=1).to(device)
# Initialization with zero-padded kernels
with torch.no_grad():
new_weight = torch.zeros_like(conv_full.weight)
idx = [i for i in range(conv_full.weight.shape[1])]
new_weight[idx, idx, :, :] = conv_group.weight[idx, 0, :, :]
conv_full.weight = torch.nn.Parameter(new_weight)
x = torch.ones(1, 4, 5, 5).to(device)
x[:, :, 3, 2] = 2
y = torch.ones(1, 4, 5, 5).to(device) * 0.1
loss_fn = torch.nn.MSELoss(reduction='sum').to(device)
learning_rate = 1e-4
optim_full = torch.optim.Adam(conv_full.parameters(), lr=learning_rate)
optim_group = torch.optim.Adam(conv_group.parameters(), lr=learning_rate)
epochs = 100
for i in range(epochs):
print(f"\nEpoch {i:02d}")
# Forward pass
y_pred_full = conv_full(x)
y_pred_group = conv_group(x)
print(f"Predictions are same: {torch.allclose(y_pred_full, y_pred_group)}")
# Compute loss
loss_full = loss_fn(y_pred_full, y)
loss_group = loss_fn(y_pred_group, y)
optim_full.zero_grad()
optim_group.zero_grad()
# Backward pass
loss_full.backward()
loss_group.backward()
# Zero out conv full gradients
with torch.no_grad():
grad = conv_full.weight.grad.clone()
conv_full.weight.grad.zero_()
idx = [i for i in range(conv_full.weight.shape[1])]
conv_full.weight.grad[idx, idx, :, :] = grad[idx, idx, :, :]
# Check if gradients of conv_full and conv_group are the same
for i in range(conv_full.weight.shape[1]):
conv_full_grad_i = conv_full.weight.grad[i, i, :, :]
conv_group_grad_i = conv_group.weight.grad[i, 0, :, :]
is_equal = torch.allclose(conv_full_grad_i, conv_group_grad_i)
print(f"Grads in {i}-th channels are same: {is_equal}")
# Update parameters
optim_full.step()
optim_group.step()
if __name__ == "__main__":
check_grad()
@MauroPfister
Copy link
Author

Hi @palver7, as far as I see Pytorch does not have an equivalent for DeformConvPack which includes the calculation of the offset tensor. I assume you are using DeformConv2d from torchvision.ops and create another Conv2d for the calculation of the offset? If so, are you sure your offset calculation matches the one in DeformConvPack?

Also, did you check if the two predictions are completely different or if they are simply different enough to produce False with the default values of torch.allclose?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment