Skip to content

Instantly share code, notes, and snippets.

@bwasti
Created October 2, 2022 21:32
Show Gist options
  • Save bwasti/519f71e05684cb963ace913dd4c9ec07 to your computer and use it in GitHub Desktop.
Save bwasti/519f71e05684cb963ace913dd4c9ec07 to your computer and use it in GitHub Desktop.
# conv bwd implemented with fwd functions
import torch
import torch.nn.functional as F
def dconv2d(grad, x, w, stride, padding, groups):
batch = grad.shape[0]
channel_out = grad.shape[1]
channel_in = x.shape[1]
k = w.shape[-1]
# differentiating w.r.t x
gpad = (k - 1) - (stride - 1) - padding
dxgrad = grad
if stride > 1: # manually dilate the incoming gradient
dxgrad = dxgrad.reshape(*dxgrad.shape, 1, 1)
dxgrad = F.pad(dxgrad, (stride - 1, 0, stride - 1, 0)).transpose(3, 4)
dxgrad = dxgrad.reshape(*grad.shape[:-2], *[2 * d for d in grad.shape[-2:]])
dxgrad = F.pad(dxgrad, (0, stride - 1, 0, stride - 1))
dxw = w.flip([2, 3])
if groups > 1: # transpose within the groups
dxw = dxw.reshape(groups, dxw.shape[0] // groups, *dxw.shape[1:])
dxw = dxw.transpose(1, 2)
dxw = dxw.reshape(-1, *dxw.shape[2:])
else:
dxw = dxw.transpose(0, 1)
dx = torch.conv2d(dxgrad, dxw, padding=gpad, groups=groups)
# differentiating w.r.t w
dwgrad = grad.transpose(0, 1)
if groups > 1:
dwx = x.reshape(x.shape[0], groups, x.shape[1] // groups, *x.shape[2:])
dwx = dwx.transpose(0, 2)
dwx = dwx.reshape(dwx.shape[0], -1, *dwx.shape[3:])
else:
dwx = x.transpose(0, 1)
dw = torch.conv2d(dwx, dwgrad, padding=padding, dilation=stride, groups=groups)
dw = dw.transpose(0, 1)
return dx, dw
def simple():
print("simple")
x = torch.randn(1, 1, 4, 4)
x.requires_grad = True
w = torch.randn(1, 1, 3, 3)
w.requires_grad = True
grad = torch.randn(1, 1, 2, 2)
y = torch.conv2d(x, w)
y.backward(grad)
dx, dw = dconv2d(grad, x, w, 1, 0, 1)
torch.testing.assert_close(x.grad, dx)
torch.testing.assert_close(w.grad, dw)
print("pass")
def padded():
print("padded")
x = torch.randn(1, 1, 4, 4)
x.requires_grad = True
w = torch.randn(1, 1, 3, 3)
w.requires_grad = True
y = torch.conv2d(x, w, padding=1)
grad = torch.randn(1, 1, 4, 4)
y.backward(grad)
dx, dw = dconv2d(grad, x, w, 1, 1, 1)
torch.testing.assert_close(x.grad, dx)
torch.testing.assert_close(w.grad, dw)
print("pass")
def strided():
print("strided")
x = torch.randn(1, 1, 5, 5)
x.requires_grad = True
w = torch.randn(1, 1, 3, 3)
w.requires_grad = True
y = torch.conv2d(x, w, stride=2)
grad = torch.randn(1, 1, 2, 2)
y.backward(grad)
dx, dw = dconv2d(grad, x, w, 2, 0, 1)
torch.testing.assert_close(x.grad, dx)
torch.testing.assert_close(w.grad, dw)
print("pass")
def strided_padded():
print("strided/padded")
x = torch.randn(8, 2, 5, 5)
x.requires_grad = True
w = torch.randn(4, 2, 3, 3)
w.requires_grad = True
y = torch.conv2d(x, w, stride=2, padding=1)
grad = torch.randn(8, 4, 3, 3)
y.backward(grad)
dx, dw = dconv2d(grad, x, w, 2, 1, 1)
torch.testing.assert_close(x.grad, dx)
torch.testing.assert_close(w.grad, dw)
print("pass")
def strided_padded_grouped():
print("strided/padded/grouped")
x = torch.randn(7, 4, 5, 5)
x.requires_grad = True
w = torch.randn(6, 2, 3, 3)
w.requires_grad = True
y = torch.conv2d(x, w, stride=2, padding=1, groups=2)
grad = torch.randn(7, 6, 3, 3)
y.backward(grad)
dx, dw = dconv2d(grad, x, w, 2, 1, 2)
torch.testing.assert_close(x.grad, dx)
torch.testing.assert_close(w.grad, dw)
print("pass")
simple()
padded()
strided()
strided_padded()
strided_padded_grouped()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment