Skip to content

Instantly share code, notes, and snippets.

@db434
Created January 9, 2018 12:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save db434/74fcfcf49d52aa08c2335ee916bac9f4 to your computer and use it in GitHub Desktop.
Save db434/74fcfcf49d52aa08c2335ee916bac9f4 to your computer and use it in GitHub Desktop.
Debugging suspiciously slow grouped convolution in PyTorch.
import argparse
import os
import subprocess
import sys
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=256)
parser.add_argument("--img-size", type=int, default=32)
parser.add_argument("--in-channels", type=int, default=128)
parser.add_argument("--out-channels", type=int, default=128)
parser.add_argument("--kernel-size", type=int, default=3)
parser.add_argument("--groups", type=bool, default=True)
parser.add_argument("--internal", action="store_true")
args = parser.parse_args()
if args.internal:
internal(args)
else:
wrapper(args)
def internal(args):
# Profiling is switched on: do convolution.
assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
groups = args.in_channels if args.groups else 1
conv = torch.nn.Conv2d(args.in_channels, args.out_channels,
args.kernel_size, groups = groups)
conv = conv.cuda()
data = torch.Tensor(args.batch, args.in_channels, args.img_size,
args.img_size)
data = torch.autograd.Variable(data.cuda())
with torch.cuda.profiler.profile():
conv(data) # Warmup CUDA memory allocator and profiler
with torch.autograd.profiler.emit_nvtx():
conv(data)
def wrapper(args):
# Switch profiling on, run this script again, and interpret the results.
command = "nvprof --profile-from-start off -o trace.prof -- "
command += "python3 " + " ".join(sys.argv[:]) + " --internal"
print(command)
subprocess.run(command, shell=True)
# Now read the trace file.
profile = torch.autograd.profiler.load_nvprof("trace.prof")
print(profile)
os.remove("trace.prof")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment