Skip to content

Instantly share code, notes, and snippets.

@dhruvbird
Created June 20, 2023 14:44
Show Gist options
  • Save dhruvbird/524c99f5adb44c2f720d12fea15e069a to your computer and use it in GitHub Desktop.
Save dhruvbird/524c99f5adb44c2f720d12fea15e069a to your computer and use it in GitHub Desktop.
Showing equivalence of nn.Conv2d and nn.Linear
import torch
from torch import nn
from torch.nn import functional as F
from matplotlib import pyplot as plt
n_out_channels = 1
mat = torch.arange(0, 36).reshape((6, 6)).float()
print(f"mat:\n{mat}")
patches = []
for i in range(0, 6-2):
for j in range(0, 6-2):
patches.append(mat[i:i+3,j:j+3].reshape(-1))
# end for
# end for
patches = torch.stack(patches)
print(f"patches.shape: {patches.shape}")
torch.manual_seed(21)
conv = nn.Conv2d(1, n_out_channels, kernel_size=3, padding=0, stride=1, bias=False)
torch.manual_seed(21)
lin = nn.Linear(9, n_out_channels, bias=False)
def count_params(m):
return sum(p.numel() for p in m.parameters())
print(f"Conv parameters: {count_params(conv)}")
print(f"Linear parameters: {count_params(lin)}")
# conv.weight = nn.Parameter(torch.arange(1, 10).reshape(1, 1, 3, 3).float())
# lin.weight = nn.Parameter(torch.arange(1, 10).reshape(1, 9).float())
print(conv.weight)
print(lin.weight)
conv_mat = conv(mat[None,None,...])
lin_mat = lin(patches)
print(f"conv_mat.shape: {conv_mat.shape}, lin_mat.shape: {lin_mat.shape}")
lin_mat = lin_mat.permute(1, 0).reshape(1, n_out_channels, 4, 4)
print(conv_mat)
print(lin_mat)
mat:
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.],
[12., 13., 14., 15., 16., 17.],
[18., 19., 20., 21., 22., 23.],
[24., 25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34., 35.]])
patches.shape: torch.Size([16, 9])
Conv parameters: 9
Linear parameters: 9
Parameter containing:
tensor([[[[-0.0176, 0.0757, -0.3253],
[-0.3069, 0.0449, -0.1641],
[ 0.0225, 0.0729, 0.1442]]]], requires_grad=True)
Parameter containing:
tensor([[-0.0176, 0.0757, -0.3253, -0.3069, 0.0449, -0.1641, 0.0225, 0.0729,
0.1442]], requires_grad=True)
conv_mat.shape: torch.Size([1, 1, 4, 4]), lin_mat.shape: torch.Size([16, 1])
tensor([[[[-0.1792, -0.6330, -1.0867, -1.5405],
[-2.9018, -3.3555, -3.8093, -4.2631],
[-5.6244, -6.0781, -6.5319, -6.9856],
[-8.3469, -8.8007, -9.2545, -9.7082]]]],
grad_fn=<ConvolutionBackward0>)
tensor([[[[-0.1792, -0.6330, -1.0867, -1.5405],
[-2.9018, -3.3555, -3.8093, -4.2631],
[-5.6244, -6.0781, -6.5319, -6.9857],
[-8.3469, -8.8007, -9.2545, -9.7082]]]],
grad_fn=<ReshapeAliasBackward0>)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment