Last active
April 24, 2019 12:50
-
-
Save pgmmpk/b64901e3bd77ed58c00be83bd170d982 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Demostrates what seem to be a performance bug (or puzzle) in PyTorch: | |
conv2d operation on trained weights is 20-50x slower than using random weights | |
''' | |
import contextlib | |
import torch | |
from torch import nn | |
import time | |
import contextlib | |
import urllib.request | |
import os | |
def download_as(url, filename): | |
if os.path.exists(filename): | |
return # do not download twice | |
print('Downloading', url, 'as', filename) | |
data = urllib.request.urlopen(url).read() | |
with open(filename, 'wb') as f: | |
f.write(data) | |
@contextlib.contextmanager | |
def timeit(message=''): | |
start = time.time() | |
try: | |
yield | |
finally: | |
print(message, 'Time elapsed: ', (time.time() - start)) | |
def printit(header, x): | |
print(header) | |
print('\tshape:', x.shape) | |
print('\tdtype:', x.dtype) | |
print('\tany nans?:', torch.isnan(x).any().item()) | |
print('\tmin:', x.min().item()) | |
print('\tmax:', x.max().item()) | |
print('\tmean:', x.mean().item()) | |
print('\tstd:', x.std().item()) | |
# 18Mb file with trained weights: | |
download_as( | |
'https://drive.google.com/uc?export=download&id=15Z6ACJKRep9Wjl_heyG1WS34Vsb_IUZn', | |
'trained_weights_numpy.pkl' | |
) | |
with open('trained_weights_numpy.pkl', 'rb') as f: | |
import pickle | |
z = pickle.load(f) | |
weight = torch.tensor(z) | |
printit('Trained weight', weight) | |
# random convolution weights | |
weight_random = torch.zeros(512, 1024, 3, 3) | |
weight_random.normal_(std=0.001) | |
printit('Good (random) weight', weight_random) | |
# random input vector | |
x = torch.zeros(1, 1024, 32, 32) | |
x.normal_() | |
printit('Input vector', x) | |
with timeit('Fast conv2d (random weights)'): | |
torch.nn.functional.conv2d(x, weight_random) | |
with timeit('Slow convd (trained weights)'): | |
torch.nn.functional.conv2d(x, weight) | |
Problem solved. See discussion on github
Let me summarize the problem and solution:
Problem is that floating-point operations on CPU can become very slow if numbers are "denormal" or "subnormal". This means that values are very very small (smaller than 1.e-32). This is very unusual.
Workaround is to force CPU to treat these numbers as zeroes. At the beginning of your code add torch.set_flush_denormal(True)
(this may not work on some older Intel CPUs though).
Alternative workaround is to "manually" remove denormals from your weights:
weight = ...
mask = (weight.abs() < 1.e-32).float()
weight = weight * mask
Better solution is to find out why your training ended up with such weird weights, and fix it!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here is the run log: