Created
January 9, 2018 12:47
-
-
Save db434/74fcfcf49d52aa08c2335ee916bac9f4 to your computer and use it in GitHub Desktop.
Debugging suspiciously slow grouped convolution in PyTorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import subprocess | |
import sys | |
import torch | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--batch", type=int, default=256) | |
parser.add_argument("--img-size", type=int, default=32) | |
parser.add_argument("--in-channels", type=int, default=128) | |
parser.add_argument("--out-channels", type=int, default=128) | |
parser.add_argument("--kernel-size", type=int, default=3) | |
parser.add_argument("--groups", type=bool, default=True) | |
parser.add_argument("--internal", action="store_true") | |
args = parser.parse_args() | |
if args.internal: | |
internal(args) | |
else: | |
wrapper(args) | |
def internal(args): | |
# Profiling is switched on: do convolution. | |
assert torch.cuda.is_available() | |
torch.backends.cudnn.benchmark = True | |
groups = args.in_channels if args.groups else 1 | |
conv = torch.nn.Conv2d(args.in_channels, args.out_channels, | |
args.kernel_size, groups = groups) | |
conv = conv.cuda() | |
data = torch.Tensor(args.batch, args.in_channels, args.img_size, | |
args.img_size) | |
data = torch.autograd.Variable(data.cuda()) | |
with torch.cuda.profiler.profile(): | |
conv(data) # Warmup CUDA memory allocator and profiler | |
with torch.autograd.profiler.emit_nvtx(): | |
conv(data) | |
def wrapper(args): | |
# Switch profiling on, run this script again, and interpret the results. | |
command = "nvprof --profile-from-start off -o trace.prof -- " | |
command += "python3 " + " ".join(sys.argv[:]) + " --internal" | |
print(command) | |
subprocess.run(command, shell=True) | |
# Now read the trace file. | |
profile = torch.autograd.profiler.load_nvprof("trace.prof") | |
print(profile) | |
os.remove("trace.prof") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment