Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created June 27, 2019 14:12
Show Gist options
  • Save ptrblck/4590cf20721d8f43296c9903abd4a774 to your computer and use it in GitHub Desktop.
Save ptrblck/4590cf20721d8f43296c9903abd4a774 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import pandas as pd
import time
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
print('PyTorch version {}'.format(torch.__version__))
print('CUDA version {}'.format(torch.version.cuda))
print('cuDNN version {}'.format(torch.backends.cudnn.version()))
print('cuDNN deterministic {}'.format(torch.backends.cudnn.deterministic))
print('cuDNN benchmark {}'.format(torch.backends.cudnn.benchmark))
# Define shape iterations
batch_sizes = [1, 8, 16, 32, 64, 128]
in_channels = [32*2**i for i in range(6)]
heights = [112//(2**i) for i in range(5)]
strides = [1, 2]
kernel_sizes = [1, 3]
dgrad = True
nb_warmup_iters = 50
nb_iters = 1000
# store results in DataFrame
results = pd.DataFrame()
columns = ['cudnn',
'batch_size',
'in_channels',
'out_channels',
'w_h',
'kW,_kH',
'stride',
'pad',
'groups',
'time_fwd',
'time_bwd',
'time_all']
cudnn_version = torch.backends.cudnn.version()
for batch_size in batch_sizes:
for c in in_channels:
for h in heights:
for s in strides:
for k in kernel_sizes:
if batch_size == 128 and c==1024 and h==112:
# OOM
continue
w = h
x = torch.randn(batch_size, c, h, w, device='cuda', dtype=torch.half, requires_grad=dgrad)
pad = k//2
conv = nn.Conv2d(
in_channels=c,
out_channels=c,
kernel_size=k,
stride=s,
padding=pad,
groups=c,
bias=False).half().to('cuda')
print('Testing [N, C, H, W]=[{}, {}, {}, {}], kH/kW={}, stride={}, pad={}'.format(
*x.size(), k, s, pad))
# Perform some dummy iterations to warmup cudnn.benchmark
for _ in range(nb_warmup_iters):
output = conv(x)
# Perform warumup for backwards
g0 = torch.rand_like(output)
for _ in range(nb_warmup_iters):
output = conv(x)
output.backward(g0)
# Profile forward pass
torch.cuda.synchronize()
start = time.time()
for _ in range(nb_iters):
output = conv(x)
torch.cuda.synchronize()
end = time.time()
fwd_time = (end - start) / nb_iters
# Profile backward pass
torch.cuda.synchronize() # Probably not necessary here
start = time.time()
for _ in range(nb_iters):
output = conv(x)
x.grad = None
conv.weight.grad = None
output.backward(g0)
torch.cuda.synchronize()
end = time.time()
all_time = (end - start) / nb_iters
bwd_time = all_time - fwd_time
tmp_df = pd.DataFrame(
[[cudnn_version, batch_size, c, c, h, k, s, pad, c, fwd_time, bwd_time, all_time]],
columns=columns)
results = results.append(tmp_df)
results.to_csv('cudnn_v100_1.4.csv', index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment