Created
May 17, 2018 14:57
-
-
Save nishi-t/f1c283f176b073b038338dc2ff044ecf to your computer and use it in GitHub Desktop.
Testing AvgPooling with count_include_pad
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tvm | |
import topi | |
import torch | |
from topi.util import get_const_tuple | |
def verify_pool(n, ic, ih, kh, sh, padding, ceil_mode, count_include_pad=True): | |
iw = ih | |
kw = kh | |
sw = sh | |
A = tvm.placeholder((n, ic, ih, iw), name='A') | |
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, | |
pool_type='avg', ceil_mode=ceil_mode, count_include_pad=count_include_pad) | |
B = topi.nn.relu(B) | |
dtype = A.dtype | |
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) | |
avg_pool = torch.nn.AvgPool2d(kernel_size=(kh, kh), stride=(sh, sh), padding=padding, ceil_mode=ceil_mode, | |
count_include_pad=count_include_pad) | |
b_torch_np = avg_pool(torch.Tensor(a_np)).numpy() | |
def check_device(device): | |
ctx = tvm.context(device, 0) | |
if not ctx.exist: | |
print("Skip because %s is not enabled" % device) | |
return | |
print("Running on target: %s" % device) | |
with tvm.target.create(device): | |
s = topi.generic.schedule_pool(B) | |
a = tvm.nd.array(a_np, ctx) | |
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) | |
f = tvm.build(s, [A, B], device) | |
f(a, b) | |
np.testing.assert_allclose(b.asnumpy(), b_torch_np, rtol=1e-5) | |
for device in ['cuda', 'llvm']: | |
check_device(device) | |
def test_pool(): | |
verify_pool(1, 256, 32, 2, 2, [0, 0], False, True) | |
verify_pool(1, 256, 31, 4, 4, [1, 2], False, True) | |
verify_pool(1, 256, 32, 4, 4, [1, 2], False, False) | |
verify_pool(1, 256, 31, 6, 6, [3, 3], False, False) | |
verify_pool(1, 256, 31, 6, 6, [0, 0], False, False) | |
if __name__ == "__main__": | |
test_pool() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment