Created
June 6, 2019 18:27
-
-
Save lopuhin/b052160947cd37aac73350504a7f89c3 to your computer and use it in GitHub Desktop.
Faster ResNet CPU inference with MKLDNN + PyTorch 1.1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ pip install torch==1.1.0 torchvision==0.3.0 | |
$ OMP_NUM_THREADS=1 ipython | |
Python 3.6.7 (default, Oct 22 2018, 11:32:17) | |
In [1]: import torch | |
...: from torchvision.models import resnet50 | |
In [2]: def forward(m, x): | |
...: """ resnet without average pooling """ | |
...: x = m.conv1(x) | |
...: x = m.bn1(x) | |
...: x = m.relu(x) | |
...: x = m.maxpool(x) | |
...: x = m.layer1(x) | |
...: x = m.layer2(x) | |
...: x = m.layer3(x) | |
...: x = m.layer4(x) | |
...: return x | |
...: | |
In [3]: x = torch.randn(1, 3, 320, 960) | |
In [4]: model = resnet50() | |
In [5]: model.eval(); | |
In [6]: model_mkldnn = resnet50() | |
In [7]: model_mkldnn._apply(lambda x: x.to_mkldnn() if x.dtype == torch.float32 else x); | |
In [8]: model_mkldnn.eval(); | |
In [9]: %timeit with torch.no_grad(): _ = forward(model, x) | |
616 ms ± 683 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
In [10]: %timeit with torch.no_grad(): _ = forward(model_mkldnn, x.to_mkldnn()).to_dense() | |
485 ms ± 7.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment