Skip to content

Instantly share code, notes, and snippets.

@bikcrum
Created February 21, 2024 21:44
Show Gist options
  • Save bikcrum/9a7c34fb65df602dcbf013ac834dbdc8 to your computer and use it in GitHub Desktop.
Save bikcrum/9a7c34fb65df602dcbf013ac834dbdc8 to your computer and use it in GitHub Desktop.
Implementation of ND convolution using numpy
import numpy as np
import torch
from torch import nn
import tqdm
def convNd(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=1, weight=None, bias=None):
# These are filters or kernels
if weight is None:
weight = np.ones((out_channels, in_channels, *kernel_size))
# O,C,*kernel_size
if bias is None:
bias = np.ones(out_channels)
# O
def func(x):
# B,C,*dims
out = x.copy()
n = x.ndim - 2
for i in range(n):
N = x.shape[i + 2]
idx = np.arange(-padding, N - kernel_size[i] + padding + 1, stride).reshape(-1, 1) + np.arange(
kernel_size[i])
mask = (idx >= 0) & (idx < N)
out = np.take(out, idx * mask, axis=(i + 1) * 2) * np.expand_dims(mask, axis=tuple(range(-1, -n + i, -1)))
axes = tuple(range(2, (n + 1) * 2 + 1, 2))
out = (np.expand_dims(out, axis=1) * np.expand_dims(weight, axis=axes[:-1])).sum(axis=axes)
return out + bias.reshape(-1, *[1] * (len(axes) - 1))
return func
def test():
tq = tqdm.tqdm(range(100))
for _ in tq:
CONV_DIM = np.random.randint(1, 4)
B = np.random.randint(1, 3)
H = np.random.randint(1, 20)
W = np.random.randint(1, 10)
D = np.random.randint(1, 20)
padding = np.random.randint(1, 5)
stride = np.random.randint(1, 5)
in_channels = np.random.randint(1, 5)
out_channels = np.random.randint(1, 5)
# Target conv function from tensor for testing
match CONV_DIM:
case 1:
a = np.random.rand(B, in_channels, H).astype(np.float32)
kernel_size = (np.random.randint(1, H + 1),)
conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride,
padding=padding)
case 2:
a = np.random.rand(B, in_channels, H, W).astype(np.float32)
kernel_size = (np.random.randint(1, H + 1), np.random.randint(1, W + 1))
conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride,
padding=padding)
case 3:
a = np.random.rand(B, in_channels, H, W, D).astype(np.float32)
kernel_size = (np.random.randint(1, H + 1), np.random.randint(1, W + 1), np.random.randint(1, D + 1))
conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride,
padding=padding)
case _:
raise NotImplemented
conv_ = convNd(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, weight=conv.weight.data.numpy(), bias=conv.bias.data.numpy())
delta = np.max(np.abs(conv(torch.tensor(a)).detach().numpy() - conv_(a)))
tq.set_description(
f'CONV_DIM:{CONV_DIM}, B:{B}, H:{H}, W:{W}, D:{D}, padding:{padding}, kernel:{kernel_size}, stride:{stride}, in_channels:{in_channels}, out_channels:{out_channels}, delta:{0}')
assert delta < 1e-5, delta
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment