Skip to content

Instantly share code, notes, and snippets.

@eqy
Created December 28, 2022 04:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eqy/eaed47f68debb0daf209baf2a0cafa03 to your computer and use it in GitHub Desktop.
Save eqy/eaed47f68debb0daf209baf2a0cafa03 to your computer and use it in GitHub Desktop.
cuDNN depthwise conv example
import torch
import time
torch.backends.cudnn.benchmark = True
iters = 10
conv = torch.nn.Conv2d(64, 64, 3, 3, groups=64, dtype=torch.half, device='cuda')
convb = torch.nn.Conv2d(64, 64, 3, 3, groups=64, dtype=torch.bfloat16, device='cuda')
data = torch.randn(16, 64, 1024, 1024, dtype=torch.half, device='cuda')
datab = torch.randn(16, 64, 1024, 1024, dtype=torch.bfloat16, device='cuda')
# half
# warmup
out = conv(data)
torch.cuda.synchronize()
t1 = time.time()
for _ in range(iters):
out = conv(data)
torch.cuda.synchronize()
t2 = time.time()
print(f"half took {(t2-t1)/iters} per iteration")
# bfloat16
# warmup
outb = convb(datab)
torch.cuda.synchronize()
t1 = time.time()
for _ in range(iters):
outb = convb(datab)
torch.cuda.synchronize()
t2 = time.time()
print(f"bfloat16 took {(t2-t1)/iters} per iteration")
@cchan
Copy link

cchan commented Dec 28, 2022

Changing 4 lines (the torch.nn.functional.conv2d's) produces slower times for both precisions and produces the bfloat16 <> float16 mismatch:

import torch
import time

torch.backends.cudnn.benchmark = True

iters = 10

conv = torch.nn.Conv2d(64, 64, 3, 3, groups=64, dtype=torch.half, device='cuda')
convb = torch.nn.Conv2d(64, 64, 3, 3, groups=64, dtype=torch.bfloat16, device='cuda')
data = torch.randn(16, 64, 1024, 1024, dtype=torch.half, device='cuda')
datab = torch.randn(16, 64, 1024, 1024, dtype=torch.bfloat16, device='cuda')

# half
# warmup
out = torch.nn.functional.conv2d(data, conv.weight, groups=64)
torch.cuda.synchronize()
t1 = time.time()
for _ in range(iters):
  out = torch.nn.functional.conv2d(data, conv.weight, groups=64)
torch.cuda.synchronize()
t2 = time.time()
print(f"half took {(t2-t1)/iters} per iteration")

# bfloat16
# warmup
outb = torch.nn.functional.conv2d(datab, convb.weight, groups=64)
torch.cuda.synchronize()
t1 = time.time()
for _ in range(iters):
  outb = torch.nn.functional.conv2d(datab, convb.weight, groups=64)
torch.cuda.synchronize()
t2 = time.time()
print(f"bfloat16 took {(t2-t1)/iters} per iteration")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment