Skip to content

Instantly share code, notes, and snippets.

@madlag
Created September 30, 2020 17:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save madlag/b9d231911203201aab9ca659d8c6ee8a to your computer and use it in GitHub Desktop.
Save madlag/b9d231911203201aab9ca659d8c6ee8a to your computer and use it in GitHub Desktop.
Pytorch CUDA speed test for various data types, with and without AMP
#!/usr/bin/env python
# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/
# Written by Francois Fleuret <francois@fleuret.org>
# Modified by François Lagunas <francois.lagunas@m4x.org>
import time, torch
if torch.cuda.is_available():
device = torch.device('cuda')
sync = torch.cuda.synchronize
else:
device = torch.device('cpu')
sync = lambda: None
d1, d2, d3 = 2048 * 8 * 4, 2048 * 4, 2048 * 4
iterations = 100
def test(prefix):
for t in [ torch.float32, torch.float16, torch.bfloat16 ]:
try:
a = torch.rand(d1, d2, device = device, dtype=t)
b = torch.rand(d2, d3, device = device, dtype=t)
sync()
start_time = time.perf_counter()
for i in range(iterations):
c = torch.mm(a,b)
sync()
duration = time.perf_counter() - start_time
nb_flop = float(iterations * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
speed = nb_flop / duration
for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
if speed < 1e3: break
speed /= 1e3
print(f'{prefix} {speed:.02f} {u}flops with {t} on {device}')
except:
print(f'{prefix} {t} is not available on {device}')
test("AMP off")
with torch.cuda.amp.autocast():
test("AMP on")
# Results on a RTX 3090
#AMP off 34.66 Tflops with torch.float32 on cuda
#AMP off 77.28 Tflops with torch.float16 on cuda
#AMP off 78.06 Tflops with torch.bfloat16 on cuda
#AMP on 73.69 Tflops with torch.float32 on cuda
#AMP on 76.48 Tflops with torch.float16 on cuda
#AMP on 74.64 Tflops with torch.bfloat16 on cuda
@maxidl
Copy link

maxidl commented Nov 6, 2020

Thanks for this modified script. Here are some additional numbers (on cuda 11.0 though, not 11.1):

Results on V100 32GB PCIe:
AMP off 12.75 Tflops with torch.float32 on cuda
AMP off 85.36 Tflops with torch.float16 on cuda
AMP off torch.bfloat16 is not available on cuda
AMP on 78.11 Tflops with torch.float32 on cuda
AMP on 88.55 Tflops with torch.float16 on cuda
AMP on 86.93 Tflops with torch.bfloat16 on cuda

Results on A100 PCIe:
AMP off 72.51 Tflops with torch.float32 on cuda
AMP off 213.91 Tflops with torch.float16 on cuda
AMP off 207.26 Tflops with torch.bfloat16 on cuda
AMP on 197.10 Tflops with torch.float32 on cuda
AMP on 202.85 Tflops with torch.float16 on cuda
AMP on 192.95 Tflops with torch.bfloat16 on cuda

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment