Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
1-late SGD for PyTorch ImageNet example with Horovod
from __future__ import print_function
import argparse
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import horovod.torch as hvd
import tensorboardX
import os
from tqdm import tqdm
# Training settings
parser = argparse.ArgumentParser(description='PyTorch ImageNet Example',
parser.add_argument('--train-dir', default=os.path.expanduser('~/imagenet/train'),
help='path to training data')
parser.add_argument('--val-dir', default=os.path.expanduser('~/imagenet/validation'),
help='path to validation data')
parser.add_argument('--log-dir', default='./logs',
help='tensorboard log directory')
parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.pth.tar',
help='checkpoint file format')
parser.add_argument('--batch-size', type=int, default=32,
help='input batch size for training')
parser.add_argument('--val-batch-size', type=int, default=32,
help='input batch size for validation')
parser.add_argument('--epochs', type=int, default=90,
help='number of epochs to train')
parser.add_argument('--base-lr', type=float, default=0.0125,
help='learning rate for a single GPU')
parser.add_argument('--warmup-epochs', type=float, default=5,
help='number of warmup epochs')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--wd', type=float, default=0.00005,
help='weight decay')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42,
help='random seed')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
# Horovod: pin GPU to local rank.
cudnn.benchmark = True
# If set > 0, will resume training from a given checkpoint.
resume_from_epoch = 0
for try_epoch in range(args.epochs, 0, -1):
if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
resume_from_epoch = try_epoch
# Horovod: broadcast resume_from_epoch from rank 0 (which will have
# checkpoints) to other ranks.
resume_from_epoch = hvd.broadcast(torch.tensor(resume_from_epoch), root_rank=0,
# Horovod: print logs on the first worker.
verbose = 1 if hvd.rank() == 0 else 0
# Horovod: write TensorBoard logs on first worker.
log_writer = tensorboardX.SummaryWriter(args.log_dir) if hvd.rank() == 0 else None
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
train_dataset = \
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# Horovod: use DistributedSampler to partition data among workers. Manually specify
# `num_replicas=hvd.size()` and `rank=hvd.rank()`.
train_sampler =
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
train_loader =
train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs)
val_dataset = \
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_sampler =
val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_loader =, batch_size=args.val_batch_size,
sampler=val_sampler, **kwargs)
# Set up standard ResNet-50 model.
model = models.resnet50()
if args.cuda:
# Move model to GPU.
# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(), lr=args.base_lr * hvd.size(),
momentum=args.momentum, weight_decay=args.wd)
# Custom Distributed Optimizer that does 1-late gradient.
class _DistributedOptimizer(torch.optim.Optimizer):
def __init__(self, params, named_parameters=None):
super(self.__class__, self).__init__(params)
if named_parameters is not None:
named_parameters = list(named_parameters)
named_parameters = []
# make sure that named_parameters are tuples
if any([not isinstance(p, tuple) for p in named_parameters]):
raise ValueError('named_parameters should be a sequence of '
'tuples (name, parameter), usually produced by '
self._parameter_names = {v: k for k, v
in sorted(named_parameters)}
self._handles = {}
self._grad_accs = []
self._last_grad = {}
if hvd.size() > 1:
def _register_hooks(self):
for param_group in self.param_groups:
for p in param_group['params']:
if p.requires_grad:
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
def _make_hook(self, p):
def hook(*ignore):
assert not p.grad.requires_grad
name = self._parameter_names.get(p)
new_grad =
if p in self._handles:
last_grad = hvd.synchronize(self._handles[p])
last_grad = torch.zeros_like(new_grad)
self._last_grad[p] = last_grad
handle = hvd.allreduce_async_(new_grad, average=True, name=name)
self._handles[p] = handle
return hook
def step(self, closure=None):
for p, lg in self._last_grad.items():
return super(self.__class__, self).step(closure)
def DistributedOptimizer(optimizer, named_parameters=None):
An optimizer that wraps another torch.optim.Optimizer, using an allreduce to
average gradient values before applying gradients to model weights.
Allreduce operations are executed after each gradient is computed by `loss.backward()`
in parallel with each other. The `step()` method ensures that all allreduce operations are
finished before applying gradients to the model.
DistributedOptimizer exposes the `synchronize()` method, which forces allreduce operations
to finish before continuing the execution. It's useful in conjunction with gradient
clipping, or other operations that modify gradients in place before `step()` is executed.
Example of gradient clipping:
output = model(data)
loss = F.nll_loss(output, target)
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
optimizer: Optimizer to use for computing gradients and applying updates.
named_parameters: A mapping between parameter names and values. Used for naming of
allreduce operations. Typically just `model.named_parameters()`.
# We dynamically create a new class that inherits from the optimizer that was passed in.
# The goal is to override the `step()` method with an allreduce implementation.
cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
return cls(optimizer.param_groups, named_parameters)
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = DistributedOptimizer(
optimizer, named_parameters=model.named_parameters())
# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
if resume_from_epoch > 0 and hvd.rank() == 0:
filepath = args.checkpoint_format.format(epoch=resume_from_epoch)
checkpoint = torch.load(filepath)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
def train(epoch):
train_loss = Metric('train_loss')
train_accuracy = Metric('train_accuracy')
with tqdm(total=len(train_loader),
desc='Train Epoch #{}'.format(epoch + 1),
disable=not verbose) as t:
for batch_idx, (data, target) in enumerate(train_loader):
adjust_learning_rate(epoch, batch_idx)
if args.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
loss = F.cross_entropy(output, target)
train_accuracy.update(accuracy(output, target))
t.set_postfix({'loss': train_loss.avg.item(),
'accuracy': 100. * train_accuracy.avg.item()})
if log_writer:
log_writer.add_scalar('train/loss', train_loss.avg, epoch)
log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
def validate(epoch):
val_loss = Metric('val_loss')
val_accuracy = Metric('val_accuracy')
with tqdm(total=len(val_loader),
desc='Validate Epoch #{}'.format(epoch + 1),
disable=not verbose) as t:
with torch.no_grad():
for data, target in val_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
val_loss.update(F.cross_entropy(output, target))
val_accuracy.update(accuracy(output, target))
t.set_postfix({'loss': val_loss.avg.item(),
'accuracy': 100. * val_accuracy.avg.item()})
if log_writer:
log_writer.add_scalar('val/loss', val_loss.avg, epoch)
log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch)
# Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final
# accuracy. Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` during
# the first five epochs. See for details.
# After the warmup reduce learning rate by 10 on the 30th, 60th and 80th epochs.
def adjust_learning_rate(epoch, batch_idx):
if epoch < args.warmup_epochs:
epoch += float(batch_idx + 1) / len(train_loader)
lr_adj = 1. / hvd.size() * (epoch * (hvd.size() - 1) / args.warmup_epochs + 1)
elif epoch < 30:
lr_adj = 1.
elif epoch < 60:
lr_adj = 1e-1
elif epoch < 80:
lr_adj = 1e-2
lr_adj = 1e-3
for param_group in optimizer.param_groups:
param_group['lr'] = args.base_lr * hvd.size() * lr_adj
def accuracy(output, target):
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
return pred.eq(target.view_as(pred)).cpu().float().mean()
def save_checkpoint(epoch):
if hvd.rank() == 0:
filepath = args.checkpoint_format.format(epoch=epoch + 1)
state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, filepath)
# Horovod: average metrics from distributed training.
class Metric(object):
def __init__(self, name): = name
self.sum = torch.tensor(0.)
self.n = torch.tensor(0.)
def update(self, val):
self.sum += hvd.allreduce(val.cpu(),
self.n += 1
def avg(self):
return self.sum / self.n
for epoch in range(resume_from_epoch, args.epochs):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment