Skip to content

Instantly share code, notes, and snippets.

@bhpfelix
Created June 25, 2019 10:52
Show Gist options
  • Save bhpfelix/b60e218a645964dcdd8852b5b46cdf89 to your computer and use it in GitHub Desktop.
Save bhpfelix/b60e218a645964dcdd8852b5b46cdf89 to your computer and use it in GitHub Desktop.
Comparing TensorFlow and PyTorch Operation (AvgPool, Conv2d)
import numpy as np
np.random.seed(0)
import torch
import torch.nn as nn
import tensorflow as tf
import matplotlib.pyplot as plt
slim = tf.contrib.slim
x = np.random.randn(1, 41, 41, 1)
tf_input = tf.convert_to_tensor(x, dtype=tf.float32)
y = slim.avg_pool2d(tf_input, [3, 3], stride=1, padding='SAME')
# y = slim.conv2d(y, 1, [3, 3], rate=12, weights_initializer=tf.ones_initializer,
# padding='SAME', activation_fn=None, normalizer_fn=None)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
y_value = sess.run([y, tf_input])[0]
plt.matshow(y_value.squeeze())
plt.colorbar()
plt.title('tf')
pt_input = torch.from_numpy(x.transpose(0, 3, 1, 2)).float()
pt_avg = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
pt_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=12, dilation=12)
pt_conv.weight.data = torch.ones(1, 1, 3, 3)
pt_conv.bias.data = torch.zeros(1)
pt_out = pt_avg(pt_input)
# pt_out = pt_conv(pt_out)
pt_out = pt_out.detach().numpy().transpose((0, 2, 3, 1))
print(pt_out.shape)
plt.matshow(pt_out.squeeze())
plt.colorbar()
plt.title('pt')
diff = np.abs(y_value.squeeze() - pt_out.squeeze())
plt.matshow(diff)
plt.colorbar()
plt.title('diff')
plt.show()
@bhpfelix
Copy link
Author

Must set count_include_pad=False for PyTorch AvgPool2d in order to match the default behavior of TensorFlow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment