Skip to content

Instantly share code, notes, and snippets.

@pinzhenx
Last active December 7, 2020 06:41
Show Gist options
  • Save pinzhenx/8f62d5076bb04f0fd2108380b22dfbaa to your computer and use it in GitHub Desktop.
Save pinzhenx/8f62d5076bb04f0fd2108380b22dfbaa to your computer and use it in GitHub Desktop.
import torch
import pandas as pd
def profile(m, x, nwarm=10, nrun=300):
for _ in range(nwarm):
m(x)
with torch.autograd.profiler.profile(True) as prof:
for _ in range(nrun):
m(x)
return getattr(prof.key_averages()[0], 'cpu_time') / 1000
S = []
# PR #40610 (Disable special cases)
S += [
[1, 1024, 14, 14, 2048, 1, 1, 2, 2, 0, 0, 1],
[1, 512, 28, 28, 512, 3, 3, 2, 2, 1, 1, 32],
[1, 256, 56, 56, 256, 3, 3, 2, 2, 1, 1, 32],
[1, 256, 56, 56, 256, 1, 1, 1, 1, 0, 0, 1],
[1, 128, 56, 56, 256, 1, 1, 1, 1, 0, 0, 1],
[1, 256, 56, 56, 512, 1, 1, 2, 2, 0, 0, 1],
[1, 256, 56, 56, 128, 1, 1, 1, 1, 0, 0, 1],
[1, 1024, 7, 7, 2048, 1, 1, 1, 1, 0, 0, 1],
[1, 2048, 7, 7, 1024, 1, 1, 1, 1, 0, 0, 1],
[1, 1024, 14, 14, 1024, 3, 3, 2, 2, 1, 1, 32],
[1, 1024, 14, 14, 512, 1, 1, 1, 1, 0, 0, 1],
[1, 256, 28, 28, 256, 3, 3, 1, 1, 1, 1, 32],
[1, 3, 224, 224, 64, 7, 7, 2, 2, 3, 3, 1],
[1, 128, 56, 56, 128, 3, 3, 1, 1, 1, 1, 32],
[1, 1024, 7, 7, 1024, 3, 3, 1, 1, 1, 1, 32],
[1, 512, 28, 28, 512, 1, 1, 1, 1, 0, 0, 1],
[1, 512, 28, 28, 256, 1, 1, 1, 1, 0, 0, 1],
[1, 256, 28, 28, 512, 1, 1, 1, 1, 0, 0, 1],
[1, 512, 28, 28, 1024, 1, 1, 2, 2, 0, 0, 1],
[1, 64, 56, 56, 128, 1, 1, 1, 1, 0, 0, 1],
[1, 64, 56, 56, 256, 1, 1, 1, 1, 0, 0, 1],
[1, 512, 14, 14, 512, 3, 3, 1, 1, 1, 1, 32],
[1, 512, 14, 14, 1024, 1, 1, 1, 1, 0, 0, 1],
[1, 1024, 14, 14, 1024, 1, 1, 1, 1, 0, 0, 1]
]
# Issue #35937 (2x slower)
S += [
[1, 512, 4, 4, 512, 3, 3, 1, 1, 1, 1, 1]
]
# PR #46675 (fix heuristics)
S += [
[25, 3, 48, 320, 64, 7, 7, 1, 1, 0, 0, 1],
[1, 3, 384, 288, 64, 7, 7, 1, 1, 0, 0, 1],
[1, 3, 16, 224, 224, 32, 1, 7, 7, 1, 1, 1, 0, 0, 0, 1],
[1, 3, 4, 112, 112, 64, 3, 7, 7, 1, 1, 1, 0, 0, 0, 1],
[1, 256, 8, 14, 14, 256, 3, 3, 3, 1, 1, 1, 0, 0, 0, 1]
]
df = pd.DataFrame(columns=['src', 'wei', 'str', 'pad', 'g', 'mkldnn', 'thnn', 'result'])
for P in S:
print(P)
if len(P) == 12:
N, C, H, W, M, kh, kw, str_h, str_w, pad_h, pad_w, g = P
xsize = [N, C, H, W]
ksize = [kh, kw]
strides = [str_h, str_w]
pads = [pad_h, pad_w]
conv = torch.nn.Conv2d
else:
N, C, D, H, W, M, kd, kh, kw, str_d, str_h, str_w, pad_d, pad_h, pad_w, g = P
xsize = [N, C, D, H, W]
ksize = [kd, kh, kw]
strides = [str_d, str_h, str_w]
pads = [pad_d, pad_h, pad_w]
conv = torch.nn.Conv3d
x = torch.rand(xsize)
m = conv(C, M, ksize, stride=strides, padding=pads, groups=g, bias=True)
print('src', xsize)
print('wei', [M, C] + ksize)
print('str', strides)
print('pad', pads)
print('g', g)
torch._C._set_mkldnn_enabled(True)
mkldnn_t = profile(m, x)
mkldnn_res = m(x)
print('MKLDNN time =', mkldnn_t)
torch._C._set_mkldnn_enabled(False)
thnn_t = profile(m, x)
thnn_res = m(x)
print('THNN time =', thnn_t)
if mkldnn_t > thnn_t:
print('\033[31mSLOW\033[0m\n')
result = 'slow'
else:
print('\033[32mFAST\033[0m\n')
result = 'fast'
assert torch.allclose(mkldnn_res, thnn_res, rtol=1e-5, atol=1e-5)
df.loc[len(df)] = [xsize, [M, C] + ksize, strides, pads, g, mkldnn_t, thnn_t, result]
print(df.to_string(index=False))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment