Skip to content

Instantly share code, notes, and snippets.

@nishi-t
Created May 17, 2018 14:57
Show Gist options
  • Save nishi-t/f1c283f176b073b038338dc2ff044ecf to your computer and use it in GitHub Desktop.
Save nishi-t/f1c283f176b073b038338dc2ff044ecf to your computer and use it in GitHub Desktop.
Testing AvgPooling with count_include_pad
import numpy as np
import tvm
import topi
import torch
from topi.util import get_const_tuple
def verify_pool(n, ic, ih, kh, sh, padding, ceil_mode, count_include_pad=True):
iw = ih
kw = kh
sw = sh
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type='avg', ceil_mode=ceil_mode, count_include_pad=count_include_pad)
B = topi.nn.relu(B)
dtype = A.dtype
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
avg_pool = torch.nn.AvgPool2d(kernel_size=(kh, kh), stride=(sh, sh), padding=padding, ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
b_torch_np = avg_pool(torch.Tensor(a_np)).numpy()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_pool(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_torch_np, rtol=1e-5)
for device in ['cuda', 'llvm']:
check_device(device)
def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], False, True)
verify_pool(1, 256, 31, 4, 4, [1, 2], False, True)
verify_pool(1, 256, 32, 4, 4, [1, 2], False, False)
verify_pool(1, 256, 31, 6, 6, [3, 3], False, False)
verify_pool(1, 256, 31, 6, 6, [0, 0], False, False)
if __name__ == "__main__":
test_pool()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment