Skip to content

Instantly share code, notes, and snippets.

@Ushk
Created October 14, 2020 16:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ushk/8831348c071ddea6a90c7dddad124d14 to your computer and use it in GitHub Desktop.
Save Ushk/8831348c071ddea6a90c7dddad124d14 to your computer and use it in GitHub Desktop.
Pytorch1.6_Cuda11.0.py
import os
import time
import torch
import torch.nn as nn
import numpy as np
import random
from apex import amp
from torch.cuda.amp import autocast
def set_seed(seed: int):
"""Set all seeds to make results reproducible (deterministic mode).
When seed is None, disables deterministic mode.
:param seed: an integer to your choosing
"""
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(5)
layer1 = nn.Conv2d(3,64, 3, padding=1, bias=False).cuda()
layer2 = nn.Conv2d(3,64, 3, padding=1, bias=False).cuda()
opt2 = torch.optim.Adam(layer2.parameters(), lr=1e-3)
layer2, opt2 = amp.initialize(layer2, opt2, opt_level="O2")
foo = torch.randn((4,3,16,16)).cuda()
bar = torch.randn((4,3,16,16)).cuda().half()
_ = layer1(foo)
nruns = 10000
fp32_times = []
for idx in range(nruns):
torch.cuda.synchronize()
start = time.time()
_ = layer1(foo)
loss1 = _.mean().backward()
fp32_times.append(time.time() - start)
auto_fp32_times = []
for idx in range(nruns):
torch.cuda.synchronize()
start = time.time()
with autocast():
_ = layer1(foo)
loss1 = _.mean().backward()
auto_fp32_times.append(time.time() - start)
apex_times = []
for idx in range(nruns):
torch.cuda.synchronize()
start = time.time()
_ = layer2(foo)
loss2 = _.mean()
with amp.scale_loss(loss2, opt2) as scaled_loss:
scaled_loss.backward()
apex_times.append(time.time() - start)
print(f"Pure fp32 time {1e6*sum(fp32_times)/len(fp32_times)}")
print(f"Autocast fp32 time {1e6*sum(auto_fp32_times)/len(auto_fp32_times)}")
print(f"Apex time {1e6*sum(apex_times)/len(apex_times)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment