Skip to content

Instantly share code, notes, and snippets.

@alsrgv
Last active January 12, 2022 12:04
Show Gist options
  • Save alsrgv/0713add50fe49a409316832a31612dde to your computer and use it in GitHub Desktop.
Save alsrgv/0713add50fe49a409316832a31612dde to your computer and use it in GitHub Desktop.
Horovod-PyTorch with Apex (look for "# Apex")
from __future__ import print_function
import argparse
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed
from torchvision import models
import horovod.torch as hvd
import timeit
import numpy as np
# Apex
from apex import amp
# Benchmark settings
parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
parser.add_argument('--model', type=str, default='resnet50',
help='model to benchmark')
parser.add_argument('--batch-size', type=int, default=32,
help='input batch size')
parser.add_argument('--num-warmup-batches', type=int, default=10,
help='number of warm-up batches that don\'t count towards benchmark')
parser.add_argument('--num-batches-per-iter', type=int, default=10,
help='number of batches per benchmark iteration')
parser.add_argument('--num-iters', type=int, default=10,
help='number of benchmark iterations')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
hvd.init()
if args.cuda:
# Horovod: pin GPU to local rank.
torch.cuda.set_device(hvd.local_rank())
cudnn.benchmark = True
# Set up standard model.
model = getattr(models, args.model)()
if args.cuda:
# Move model to GPU.
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
# Apex
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
# Set up fixed fake data
data = torch.randn(args.batch_size, 3, 224, 224)
target = torch.LongTensor(args.batch_size).random_() % 1000
if args.cuda:
data, target = data.cuda(), target.cuda()
def benchmark_step():
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
# Apex
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
def log(s, nl=True):
if hvd.rank() != 0:
return
print(s, end='\n' if nl else '')
log('Model: %s' % args.model)
log('Batch size: %d' % args.batch_size)
device = 'GPU' if args.cuda else 'CPU'
log('Number of %ss: %d' % (device, hvd.size()))
# Warm-up
log('Running warmup...')
timeit.timeit(benchmark_step, number=args.num_warmup_batches)
# Benchmark
log('Running benchmark...')
img_secs = []
for x in range(args.num_iters):
time = timeit.timeit(benchmark_step, number=args.num_batches_per_iter)
img_sec = args.batch_size * args.num_batches_per_iter / time
log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device))
img_secs.append(img_sec)
# Results
img_sec_mean = np.mean(img_secs)
img_sec_conf = 1.96 * np.std(img_secs)
log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf))
log('Total img/sec on %d %s(s): %.1f +-%.1f' %
(hvd.size(), device, hvd.size() * img_sec_mean, hvd.size() * img_sec_conf))
@un-knight
Copy link

What does img_sec_conf = 1.96 * np.std(img_secs) in line 115 means? Why multiply std with 1.96?

@alsrgv
Copy link
Author

alsrgv commented Jun 4, 2019

@un-knight, I use it to get 95% confidence interval - http://onlinestatbook.com/2/estimation/mean.html

@un-knight
Copy link

@alsrgv Then why don't you divide image_sec_conf by sqrt(num_iters) ? According to your reference, we should use the standard error of the mean, which equals to 1.96 * np.std(img_secs) / math.sqrt(args.num_iters).

@alsrgv
Copy link
Author

alsrgv commented Jun 6, 2019

@un-knight, I don't think that's necessary since np.std does compute standard deviation according to https://docs.scipy.org/doc/numpy/reference/generated/numpy.std.html. That said, I'm not a statistician and I can be totally wrong :-)

@un-knight
Copy link

@alsrgv I think you are right, I made some misunderstanding.

@qingyu-wang
Copy link

there are something wrong if i set --fp16-allreduce, the error are show in the blow:

Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/root/tsukiko2/test/fp16_hvd.py", line 292, in <module>
    timeit.timeit(benchmark_step, number=args.num_warmup_batches)
  File "/usr/lib/python3.6/timeit.py", line 233, in timeit
    return Timer(stmt, setup, timer, globals).timeit(number)
  File "/usr/lib/python3.6/timeit.py", line 178, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/root/tsukiko2/test/fp16_hvd.py", line 274, in benchmark_step
    optimizer.synchronize()
  File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 157, in synchronize
    p.grad.set_(self._compression.decompress(output, ctx))
RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()
Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/root/tsukiko2/test/fp16_hvd.py", line 292, in <module>
    timeit.timeit(benchmark_step, number=args.num_warmup_batches)
  File "/usr/lib/python3.6/timeit.py", line 233, in timeit
    return Timer(stmt, setup, timer, globals).timeit(number)
  File "/usr/lib/python3.6/timeit.py", line 178, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/root/tsukiko2/test/fp16_hvd.py", line 273, in benchmark_step
    scaled_loss.backward()
  File "/usr/local/lib/python3.6/dist-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 139, in hook
    handle, ctx = self._allreduce_grad_async(p)
  File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 122, in _allreduce_grad_async
    handle = allreduce_async_(tensor_compressed, average=True, name=name)
  File "/usr/local/lib/python3.6/dist-packages/horovod/torch/mpi_ops.py", line 176, in allreduce_async_
    return _allreduce_async(tensor, tensor, average, name)
  File "/usr/local/lib/python3.6/dist-packages/horovod/torch/mpi_ops.py", line 81, in _allreduce_async
    name.encode() if name is not None else _NULL)
RuntimeError: Horovod has been shut down. This was caused by an exception on one of the ranks or an attempt to allreduce, allgather or broadcast a tensor after one of the ranks finished execution. If the shutdown was caused by an exceptio
n, you should see the exception in the log before the first shutdown message.

and when i run my own demo i got this error:

from __future__ import division
from __future__ import print_function

import argparse
import time
import os
import sys

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
# Horovod
import horovod.torch as hvd

parser = argparse.ArgumentParser()
parser.add_argument("--apex", action="store_true")
parser.add_argument("--opt_level", type=str, default="O1")
parser.add_argument("--fp16_allreduce", action="store_true")

args = parser.parse_args()

hvd.init()

world_size = hvd.size()
world_rank = hvd.rank()
local_rank = hvd.local_rank()

APEX = args.apex
if APEX:
    import apex
    if world_rank == 0:
        print("use apex")

DETERMINISTIC = True
if DETERMINISTIC:
    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.manual_seed(4)
    torch.cuda.manual_seed_all(4)
    torch.set_printoptions(precision=10)
else:
    cudnn.benchmark = True

time.sleep(0.1 * world_rank)
print("init [%2s/%2s]" % (world_rank, world_size))
time.sleep(0.1 * (world_size - world_rank))

torch.cuda.set_device(local_rank)
if world_rank == 0:
    print("set_device")

assert torch.backends.cudnn.enabled

model = torchvision.models.resnet50()
if world_rank == 0:
    print("model")

if APEX:
    SYNC_BN = True
    if SYNC_BN:
        model = apex.parallel.convert_syncbn_model(model)
        if world_rank == 0:
            print("convert_syncbn_model")

model = model.cuda()
if world_rank == 0:
    print("model.cuda")

optimizer = torch.optim.SGD(
    model.parameters(), 0.01,
    momentum=0.9,
    weight_decay=1e-4
)
if world_rank == 0:
    print("optimizer")

FP16_ALLREDUCE = args.fp16_allreduce
optimizer = hvd.DistributedOptimizer(
    optimizer,
    named_parameters=model.named_parameters(),
    compression=hvd.Compression.fp16 if FP16_ALLREDUCE else hvd.Compression.none
)
if world_rank == 0:
    print("hvd.DistributedOptimizer")

hvd.broadcast_parameters(model.state_dict(), root_rank=0)
if world_rank == 0:
    print("hvd.broadcast_parameters")

hvd.broadcast_optimizer_state(optimizer, root_rank=0)
if world_rank == 0:
    print("hvd.broadcast_optimizer_state")

if APEX:
    OPT_LEVEL = args.opt_level
    model, optimizer = apex.amp.initialize(model, optimizer,
        opt_level=OPT_LEVEL
    )
    if world_rank == 0:
        print("apex.amp.initialize: %s" % OPT_LEVEL)

criterion = nn.CrossEntropyLoss().cuda()

torch.manual_seed(4)
torch.cuda.manual_seed_all(4)

_inputs = (torch.LongTensor(64, 3, 224, 224).random_().cuda() % 255).float().add_(-127.5).mul_(1/255)
targets = (torch.LongTensor(64).random_().cuda() % 1000)

_time = time.time()
epochs = 12
batches = 5
for epoch_idx in range(epochs):
    for batch_idx in range(batches):

        if world_rank == 0:
            print("\n\nepoch: %s, batch: %s" % (epoch_idx, batch_idx))

        seed = epoch_idx * batches + batch_idx
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        inputs = _inputs + (torch.LongTensor(64, 3, 224, 224).random_().cuda() % 255).float().add_(-127.5).mul_(1/25500)
        if world_rank == 0:
            print(inputs[0, :, 0, 0])

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()

        if APEX:
            with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
                optimizer.synchronize()
            with optimizer.skip_synchronize():
                optimizer.step()
        else:
            loss.backward()
            optimizer.step()

        # If tensor requires gradient, then
        # tensor.cpu().detach() constructs the .cpu autograd edge, which soon gets destructed since the result is not stored.
        # tensor.detach().cpu() does not do this.
        # However, this is very fast so virtually they are the same.

        if world_rank == 0:
            print(targets.detach().cpu().numpy())
            print(outputs.detach().cpu().numpy().argmax(axis=1))

if world_rank == 0:
    print("time: %s" % (time.time() - _time))
/usr/local/bin/mpirun \
--allow-run-as-root \
-np 2 \
-H localhost:2 \
-bind-to none -map-by slot \
-x LD_LIBRARY_PATH -x PATH -x PYTHONPATH \
-mca pml ob1 -mca btl ^openib \
python -u -m test.fp16_hvd --apex --opt_level O2 --fp16_allreduce
Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/root/tsukiko2/test/fp16_hvd.py", line 169, in <module>
    optimizer.synchronize()
  File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 157, in synchronize
    p.grad.set_(self._compression.decompress(output, ctx))
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'source'
Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/root/tsukiko2/test/fp16_hvd.py", line 169, in <module>
    optimizer.synchronize()
  File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 157, in synchronize
    p.grad.set_(self._compression.decompress(output, ctx))
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'source'

@Richie-yan
Copy link

Richie-yan commented Aug 3, 2020

Hi, @alsrgv @qingyu-wang
When I set fp16-allreduce to True, the package error is as follows:

<stderr>: optimizer.synchronize()
<stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize
<stderr>: optimizer.synchronize()
<stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize
<stderr>: p.grad.set_(self._compression.decompress(output, ctx))
<stderr>:RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach()

How can I solve this problem?

@gongjingcs
Copy link

Hi, @alsrgv @qingyu-wang
When I set fp16-allreduce to True, the package error is as follows:

<stderr>: optimizer.synchronize()
<stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize
<stderr>: optimizer.synchronize()
<stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize
<stderr>: p.grad.set_(self._compression.decompress(output, ctx))
<stderr>:RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach()

How can I solve this problem?

I also had this problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment