Skip to content

Instantly share code, notes, and snippets.

@sumanmichael
Created June 3, 2021 13:35
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sumanmichael/4de9dee93f972d47c80c4ade8e149ea6 to your computer and use it in GitHub Desktop.
Save sumanmichael/4de9dee93f972d47c80c4ade8e149ea6 to your computer and use it in GitHub Desktop.
PyTorch Conv2d equivalent of Tensorflow tf.nn.conv2d(....,padding='SAME')
import tensorflow as tf
import torch
from torch import nn
import numpy as np
from functools import reduce
from operator import __add__
class Conv2dSamePadding(nn.Conv2d):
def __init__(self,*args,**kwargs):
super(Conv2dSamePadding, self).__init__(*args, **kwargs)
self.zero_pad_2d = nn.ZeroPad2d(reduce(__add__,
[(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]]))
def forward(self, input):
return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)
#let's test it
val = np.random.rand(1,4,4,128).astype("float32")
weights = np.random.rand(2,2,128,512).astype("float32")
tf_in = tf.constant(val)
v = tf.Variable(weights)
tf_out = tf.nn.conv2d(tf_in, v, strides=(1, 1, 1, 1), padding='SAME').numpy()
pt_in = torch.tensor(val).permute(0,3,1,2)
pt_l = Conv2dSamePadding(128, 512, 2, 1, 0, bias=False)
pt_l.weight = torch.nn.Parameter(torch.tensor(weights).permute(3,2,0,1))
pt_out = pt_l(pt_in)
pt_out = pt_out.permute(0,2,3,1).detach().numpy()
assert np.allclose(tf_out, pt_out,atol=1e-7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment