Skip to content

Instantly share code, notes, and snippets.

@mcarilli
Last active July 31, 2020 00:00
Show Gist options
  • Save mcarilli/f6dc2b1ab5fc4af1990ea02b731ed692 to your computer and use it in GitHub Desktop.
Save mcarilli/f6dc2b1ab5fc4af1990ea02b731ed692 to your computer and use it in GitHub Desktop.
Example of batch replay with Amp opt_level=O1 + dynamic gradient scaling

This example is based on main_amp.py from the Apex imagenet amp examples and can be used with the same example commands. It demonstrates batch replay (instead of batch skipping) with the dynamic gradient scaling used by Amp.

Batch replay requires a bit of user-side control flow, but is fairly straightforward.

Ctrl+f "added for batch replay" in main_amp_replay.py below to see what was changed. There should only be 5 instances, found entirely in this section.

Vimdiffing main_amp_replay.py and main_amp.py from the Apex example directory is also instructive. Again, there should be few differences.

See the "Batch replay" example in the Automatic Mixed Precision RFC for a preview of how I plan this will work with the upstream integration of dynamic gradient scaling.

import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
def parse():
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
args = parser.parse_args()
return args
def main():
global best_prec1, args
args = parse()
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
if args.sync_bn:
import apex
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda()
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# Optionally resume from a checkpoint
if args.resume:
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
if(args.arch == "inception_v3"):
raise RuntimeError("Currently, inception_v3 is not supported by this example.")
# crop_size = 299
# val_size = 320 # I chose this value arbitrarily, we can adjust.
else:
crop_size = 224
val_size = 256
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
# transforms.ToTensor(), Too slow
# normalize,
]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
]))
train_sampler = None
val_sampler = None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate)
if args.evaluate:
validate(val_loader, model, criterion)
return
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
# evaluate on validation set
prec1 = validate(val_loader, model, criterion)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best)
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next()
i = 0
while input is not None:
i += 1
if args.prof >= 0 and i == args.prof:
print("Profiling begun at iteration {}".format(i))
torch.cuda.cudart().cudaProfilerStart()
if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
replay_batch = True # added for batch replay
while replay_batch: # added for batch replay
# compute output
if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
output = model(input)
if args.prof >= 0: torch.cuda.nvtx.range_pop()
loss = criterion(output, target)
# compute gradient and do SGD step
optimizer.zero_grad()
default_optimizer_step = optimizer.step # added for batch replay
if args.prof >= 0: torch.cuda.nvtx.range_push("backward")
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if args.prof >= 0: torch.cuda.nvtx.range_pop()
# added for batch replay
# If Amp detects an overflow, it patches optimizer.step during the "with amp.scale_loss"
# context manager's exit. In other words, to discover if there was an overflow, we can
# check if if optimizer.step was left unpatched. If so, we don't need to replay.
if optimizer.step is default_optimizer_step:
replay_batch = False
else:
print("Found overflow, replaying this batch")
if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")
# Comment added for batch replay
# If an overflow was detected, "optimizer.step" is the patched call, which does
# nothing but print the scale adjustment and restore optimizer.step to default_optimizer_step.
# That's why step() # is called here unconditionally.
# Note that calling step() here within the replay loop differs from the implementation of
# batch replay in the planned upstream API, in which scaler.step(optimizer) is called outside
# the replay loop.
optimizer.step()
if args.prof >= 0: torch.cuda.nvtx.range_pop()
if i%args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
# to_python_float incurs a host<->device sync
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
torch.cuda.synchronize()
batch_time.update((time.time() - end)/args.print_freq)
end = time.time()
if args.local_rank == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader),
args.world_size*args.batch_size/batch_time.val,
args.world_size*args.batch_size/batch_time.avg,
batch_time=batch_time,
loss=losses, top1=top1, top5=top5))
if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()")
input, target = prefetcher.next()
if args.prof >= 0: torch.cuda.nvtx.range_pop()
# Pop range "Body of iteration {}".format(i)
if args.prof >= 0: torch.cuda.nvtx.range_pop()
if args.prof >= 0 and i == args.prof + 10:
print("Profiling ended at iteration {}".format(i))
torch.cuda.cudart().cudaProfilerStop()
quit()
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
prefetcher = data_prefetcher(val_loader)
input, target = prefetcher.next()
i = 0
while input is not None:
i += 1
# compute output
with torch.no_grad():
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# TODO: Change timings to mirror train().
if args.local_rank == 0 and i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {2:.3f} ({3:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
input, target = prefetcher.next()
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr*(0.1**factor)
"""Warmup"""
if epoch < 5:
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
# if(args.local_rank == 0):
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= args.world_size
return rt
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment