Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active May 6, 2023 07:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scturtle/d747f7fcc1fd236fa9bd675cf86e686a to your computer and use it in GitHub Desktop.
Save scturtle/d747f7fcc1fd236fa9bd675cf86e686a to your computer and use it in GitHub Desktop.
conv2d forward and backward implementation
import torch
from torch import nn
import numpy as np
# https://github.com/arasdar/DL/blob/master/uri-dl/uri-dl-hw-2/assignment2/cs231n/layers.py
# https://github.com/brandontrabucco/conv-python/blob/master/main.py
def get_output_shape(x, kernel, stride):
(_, _, ih, iw), (kh, kw) = x.shape, kernel
oh = (ih - kh) // stride + 1
ow = (iw - kw) // stride + 1
return oh, ow
def get_im2col_indices(x, kernel, stride):
(_, ic, _, _), (kh, kw) = x.shape, kernel
oh, ow = get_output_shape(x, kernel, stride)
i = np.tile(np.repeat(np.arange(kh), kw), ic).reshape(-1, 1) +\
(stride * np.repeat(np.arange(oh), ow)).reshape(1, -1)
j = np.tile(np.arange(kw), kh * ic).reshape(-1, 1) +\
(stride * np.tile(np.arange(ow), oh)).reshape(1, -1)
k = np.repeat(np.arange(ic), kh * kw).reshape(-1, 1)
return k, i, j
torch.manual_seed(42)
x = torch.randint(0, 9, (1, 2, 5, 5), dtype=torch.float32, requires_grad=True)
w = torch.randint(0, 9, (7, 2, 3, 3), dtype=torch.float32, requires_grad=True)
r = nn.functional.conv2d(x, w, stride=1)
grad = torch.ones_like(r) # (n, oc, oh, ow)
r.backward(gradient=grad)
n = x.shape[0]
oc = w.shape[0]
kernel = w.shape[2:4]
stride = 1
oh, ow = get_output_shape(x, kernel, stride)
k, i, j = get_im2col_indices(x, kernel, stride)
x_col = x[:, k, i, j]
# (n, ic * kh * kw, oh * ow) -> (ic * kh * kw, oh * ow * n)
x_col = x_col.permute(1, 2, 0).reshape(x_col.shape[1], -1)
r2 = w.reshape(oc, -1) @ x_col # (oc, oh * ow * n)
r2 = r2.reshape(oc, oh, ow, x.shape[0]).permute(3, 0, 1, 2)
print("r?", torch.allclose(r, r2))
grad = grad.permute(1, 2, 3, 0).reshape(oc, -1) # (oc, oh * ow * n)
dw = (grad @ x_col.T).reshape(w.shape)
print("dw?", torch.allclose(w.grad, dw))
dx_col = w.reshape(oc, -1).T @ grad # (ic * kh * kw, oh * ow * n)
dx_col = dx_col.detach().numpy()
# (ic * kh * kw, oh * ow * n) -> (n, k, oh * ow)
dx_col = dx_col.reshape(dx_col.shape[0], -1, n).transpose(2, 0, 1)
dx = np.zeros_like(x.detach().numpy())
np.add.at(dx, (slice(None), k, i, j), dx_col)
dx = torch.tensor(dx)
print("dx?", torch.allclose(x.grad, dx))
import torch
from torch import nn
import numpy as np
# https://github.com/aureliancnx/tinygrad/blob/master/tinygrad/ops.py#L177
def get_output_shape(x, kernel, stride):
(_, _, ih, iw), (kh, kw) = x.shape, kernel
oh = (ih - kh) // stride + 1
ow = (iw - kw) // stride + 1
return oh, ow
torch.manual_seed(42)
x = torch.randint(0, 9, (1, 2, 5, 5), dtype=torch.float32, requires_grad=True)
w = torch.randint(0, 9, (7, 2, 3, 3), dtype=torch.float32, requires_grad=True)
stride = 2
r = nn.functional.conv2d(x, w, stride=stride)
grad = torch.ones_like(r)
r.backward(gradient=grad)
n, ic, ih, iw = x.shape
x_ = x.detach().numpy()
w_ = w.detach().numpy()
ns, ics, ihs, iws = x_.strides
oc, _, kh, kw = w.shape
oh, ow = get_output_shape(x, (kh, kw), stride)
x_col = np.lib.stride_tricks.as_strided(
x_,
shape=(ic, kh, kw, n, oh, ow),
strides=(ics, ihs, iws, ns, stride * ihs, stride * iws),
writeable=False
)
r2 = np.tensordot(w_, x_col,((1, 2, 3), (0, 1, 2)))
r2 = r2.transpose(1, 0, 2, 3)
print("r?", torch.allclose(r, torch.tensor(r2)))
grad_ = grad.numpy()
dw = np.tensordot(grad, x_col, ((0, 2, 3), (3, 4, 5)))
print("dw?", torch.allclose(w.grad, torch.tensor(dw)))
dx = np.zeros_like(x_)
for i in range(oh):
for j in range(ow):
ii, jj = i * stride, j * stride
dx[:, :, ii: ii + kh, jj: jj + kw] +=\
np.tensordot(grad_[:, :, i, j], w_, ((1,), (0,)))
print("dx?", torch.allclose(x.grad, torch.tensor(dx)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment