Skip to content

Instantly share code, notes, and snippets.

@etienne87
Created July 13, 2021 11:13
Show Gist options
  • Save etienne87/b842116800d8264f7ea53640b00f0b23 to your computer and use it in GitHub Desktop.
Save etienne87/b842116800d8264f7ea53640b00f0b23 to your computer and use it in GitHub Desktop.
"""
Local Convolutions in pytorch
"""
import torch
import torch.nn.functional as F
b,c,h,w = 2,3,32,32
d = 3
n = h*w
padding = 1
kernel_size = 3
stride = 1
win = kernel_size**2*c
x = torch.randn(b,c,h,w)
x = F.unfold(x, kernel_size, stride, padding=padding)
weights = torch.zeros((b,win,n,d), dtype=x.dtype)
y = torch.einsum('bcn,bcnd->bdn', x, weights)
y = y.view(b,d,h,w)
print(y.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment