Last active
September 30, 2020 12:30
-
-
Save dcslin/1ffeed1319c60381673e904848ce1e47 to your computer and use it in GitHub Desktop.
benchmark pytroch resnet18 cifar10 apex amp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Diff https://github.com/kuangliu/pytorch-cifar/blob/master/main.py | |
''' | |
from apex import amp | |
net, optimizer = amp.initialize(net, optimizer, opt_level=args.opt_level) | |
#if device == 'cuda': | |
# net = torch.nn.DataParallel(net) | |
# cudnn.benchmark = True | |
# loss.backward() | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() |
shape of 1 batch
inputs shape torch.Size([128, 3, 32, 32])
targets shape torch.Size([128])
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
O0 pure fp32 nvprof