Skip to content

Instantly share code, notes, and snippets.

@gngdb
Created September 17, 2019 16:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/52ada0a479bb74058eddcf480f191a45 to your computer and use it in GitHub Desktop.
Save gngdb/52ada0a479bb74058eddcf480f191a45 to your computer and use it in GitHub Desktop.
Pointwise convolution in PyTorch without using conv2d.
import torch
from torch.nn.functional import conv2d
def pointwise(X, W):
n,c_in,h,w = X.size() # (n examples, c_in channels, height, width)
c_out,c_in,_,_ = W.size() # (c_out channels, c_in channels, 1, 1)
W = W.view(c_out,c_in) # squeeze size 1 dims, shape=(c_out, c_in)
X = X.view(n,c_in,h*w) # flatten spatial dims
X = X.permute(0,2,1) # transpose, shape=(n,h*w, c_in)
K = X.reshape(n*h*w,c_in) # kernel matrix, shape=(n*h*w, c_in)
Y = torch.mm(K, W.T) # matrix multiplication
Y = Y.view(n,h*w,c_out).permute(0,2,1) # re-pack tensor from output Y
return Y.view(n,c_out,h,w) # (n examples, c_out channels, height, width)
if __name__ == '__main__':
X = torch.randn(4,3,4,4)
W = torch.randn(16,3,1,1)
error = torch.abs(conv2d(X, W)- pointwise(X, W))
print(error.max())
assert error.max() < 1e-6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment