Skip to content

Instantly share code, notes, and snippets.

@bwasti
Created September 30, 2022 23:06
Show Gist options
  • Save bwasti/1c2679f3495bea2c376045d7380b3dae to your computer and use it in GitHub Desktop.
Save bwasti/1c2679f3495bea2c376045d7380b3dae to your computer and use it in GitHub Desktop.
# examples of backward passes implemented with fwd functions
import torch
import torch.nn.functional as F
def simple():
print("simple")
x = torch.randn(1, 1, 4, 4)
x.requires_grad = True
w = torch.randn(1, 1, 3, 3)
w.requires_grad = True
y = torch.conv2d(x, w)
y.backward(torch.ones(1, 1, 2, 2))
# print(x.grad)
# print(w.grad)
grad = torch.ones(1, 1, 2, 2)
grad = F.pad(grad, (2, 2, 2, 2))
z = torch.conv2d(grad, w.flip([2, 3]))
grad = torch.ones(1, 1, 2, 2)
r = torch.conv2d(x, grad)
# print(z)
# print(r)
torch.testing.assert_close(x.grad, z)
torch.testing.assert_close(w.grad, r)
print("pass")
def padded():
print("padded")
x = torch.randn(1, 1, 4, 4)
x.requires_grad = True
w = torch.randn(1, 1, 3, 3)
w.requires_grad = True
y = torch.conv2d(x, w, padding=1)
y.backward(torch.ones(1, 1, 4, 4))
# print(x.grad)
# print(w.grad)
grad = torch.ones(1, 1, 4, 4)
grad = F.pad(grad, (1, 1, 1, 1))
z = torch.conv2d(grad, w.flip([2, 3]))
grad = torch.ones(1, 1, 4, 4)
r = torch.conv2d(x, grad, padding=1)
# print(z)
# print(r)
torch.testing.assert_close(x.grad, z)
torch.testing.assert_close(w.grad, r)
print("pass")
def strided():
print("strided")
x = torch.randn(1, 1, 5, 5)
x.requires_grad = True
w = torch.randn(1, 1, 3, 3)
w.requires_grad = True
y = torch.conv2d(x, w, stride=2)
y.backward(torch.ones(1, 1, 2, 2))
# print(x.grad)
# print(w.grad)
grad = torch.ones(1, 1, 2, 2)
grad = grad.reshape(1, 1, 2, 2, 1, 1)
grad = F.pad(grad, (1, 0, 1, 0)).transpose(3, 4)
grad = grad.reshape(1, 1, 4, 4)
grad = F.pad(grad, (1, 2, 1, 2))
z = torch.conv2d(grad, w.flip([2, 3]))
grad = torch.ones(1, 1, 2, 2)
r = torch.conv2d(x, grad, dilation=2)
# print(z)
# print(r)
torch.testing.assert_close(x.grad, z)
torch.testing.assert_close(w.grad, r)
print("pass")
simple()
padded()
strided()
# Strided explanation:
#
# make this
# x x
# x x
# into this
# x, x, x, x
# then pad
# 0 0 0 0, 0 0, 0 0
# 0 x, 0 x, 0 x, 0 x
# then reshape
# 0 0 0 0
# 0 x 0 x
# 0 0 0 0
# 0 x 0 x
# then pad 1, 3
# 0 0 0 0 0 0 0
# 0 0 0 0 0 0 0
# 0 0 x 0 x 0 0
# 0 0 0 0 0 0 0
# 0 0 x 0 x 0 0
# 0 0 0 0 0 0 0
# 0 0 0 0 0 0 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment