Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Data Parallelism in PyTorch for modules and losses
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu
## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co
## Copyright (c) 2017-2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel.scatter_gather import gather
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
torch_ver = torch.__version__[:3]
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
'patch_replication_callback']
def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
"""
return AllReduce.apply(*inputs)
class AllReduce(Function):
@staticmethod
def forward(ctx, num_inputs, *inputs):
ctx.num_inputs = num_inputs
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
inputs = [inputs[i:i + num_inputs]
for i in range(0, len(inputs), num_inputs)]
# sort before reduce sum
inputs = sorted(inputs, key=lambda i: i[0].get_device())
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return tuple([t for tensors in outputs for t in tensors])
@staticmethod
def backward(ctx, *inputs):
inputs = [i.data for i in inputs]
inputs = [inputs[i:i + ctx.num_inputs]
for i in range(0, len(inputs), ctx.num_inputs)]
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
inputs = sorted(inputs, key=lambda i: i.get_device())
return comm.reduce_add(inputs)
@staticmethod
def backward(ctx, gradOutput):
return Broadcast.apply(ctx.target_gpus, gradOutput)
class DistributedDataParallelModel(DistributedDataParallel):
"""Implements data parallelism at the module level for the DistributedDataParallel module.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the
batch dimension.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass,
gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
the same size (so that each GPU processes the same number of samples).
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def gather(self, outputs, output_device):
return outputs
class DataParallelModel(DataParallel):
"""Implements data parallelism at the module level.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the
batch dimension.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass,
gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
the same size (so that each GPU processes the same number of samples).
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def gather(self, outputs, output_device):
return outputs
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
class DataParallelCriterion(DataParallel):
"""
Calculate loss in multiple-GPUs, which balance the memory usage.
The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
>>> y = net(x)
>>> loss = criterion(y, target)
"""
def forward(self, inputs, *targets, **kwargs):
# input should be already scatterd
# scattering the targets instead
if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(inputs, *targets[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
#return Reduce.apply(*outputs) / len(outputs)
#return self.gather(outputs, self.output_device).mean()
return self.gather(outputs, self.output_device)
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
assert len(modules) == len(inputs)
assert len(targets) == len(inputs)
if kwargs_tup:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
if torch_ver != "0.3":
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, device=None):
if torch_ver != "0.3":
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
if not isinstance(target, (list, tuple)):
target = (target,)
output = module(*(input + target), **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, target,
kwargs, device),)
for i, (module, input, target, kwargs, device) in
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
#
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created
by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead
of calling the callback of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
@imirzadeh

This comment has been minimized.

Copy link

@imirzadeh imirzadeh commented Dec 2, 2018

Hi, I think we have to import DistributedDataParallel by "from torch.nn.parallel.distributed import DistributedDataParallel".
Because in line 66 the class has inherited it.

@mjc14

This comment has been minimized.

Copy link

@mjc14 mjc14 commented Dec 4, 2018

hi , i have a question about line 160-162, why did you change *Reduce.apply(outputs) / len(outputs) to self.gather(outputs, self.output_device) , it is different form the original code. dose it have the same result?

@DanZhao1027

This comment has been minimized.

Copy link

@DanZhao1027 DanZhao1027 commented May 7, 2019

Thanks for imirzadeh, your "from torch.nn.parallel.distributed import DistributedDataParallel" fixed my first problem. But there is another question I want to ask. I got an error "TypeError: Broadcast function not implemented for CPU tensors." Is there any method to solve this error, thanks again.

@ImMrMa

This comment has been minimized.

Copy link

@ImMrMa ImMrMa commented Jun 29, 2019

hi , i have a question about line 160-162, why did you change *Reduce.apply(outputs) / len(outputs) to self.gather(outputs, self.output_device) , it is different form the original code. dose it have the same result?

I find the problem too.I want to make each backward propagation start from the loss value of each device, rather than from the mean value of each loss. I don't know how to realize this function. Do you have any idea?

@sajidrahman

This comment has been minimized.

Copy link

@sajidrahman sajidrahman commented Jul 25, 2019

Hi @thomwolf,

I'm trying to use Load Balancing during multi-GPU environment. and following your tutorial published at medium. I'm fine-tuning GPT-2 small for a classification task. Here're the steps I've followed so far:

  1. Copy parallel.py in local directory
  2. Add from torch.nn.parallel.distributed import DistributedDataParallel in parallel.py file (otherwise getting an error 'DistributedDataParallel' not found)
  3. After loading GPT2Pretrained model, define the parallel model:
model = DataParallelModel(model, device_ids=[0, 1])  
parallel_loss = DataParallelCriterion(model, device_ids=[0,1])
  1. Now during training, got the following error. Complete stacktrace is as follows:

AssertionError Traceback (most recent call last)
in
19
20 # losses = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels)
---> 21 losses = parallel_loss(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels)
22
23 lm_loss, clf_loss = losses

~/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)

491             result = self._slow_forward(*input, **kwargs)
492         else:

--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/github_repos/pytorch-pretrained-BERT/examples/parallel.py in forward(self, inputs, *targets, **kwargs)

158             return self.module(inputs, *targets[0], **kwargs[0])
159         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])

--> 160 outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
161 #return Reduce.apply(*outputs) / len(outputs)
162 #return self.gather(outputs, self.output_device).mean()

~/github_repos/pytorch-pretrained-BERT/examples/parallel.py in _criterion_parallel_apply(modules, inputs, targets, kwargs_tup, devices)

165 
166 def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):

--> 167 assert len(modules) == len(inputs)
168 assert len(targets) == len(inputs)
169 if kwargs_tup:

AssertionError:

From the stacttrace, I'm not sure why module length needs to be equal of inputs length. Am I missing something here? I'm using Python 3.6 with PyTorch version 1.1.0. Any help/pointers will be highly appreciated. Thanks!

@Luonic

This comment has been minimized.

Copy link

@Luonic Luonic commented Jul 27, 2019

@sajidrahman i have not used this gist but for me it is strange that you are passing in parallel_loss = DataParallelCriterion(model, device_ids=[0,1]) model to parallel criterion. I think loss calculation class inherited from nn.Module should go there.

@crowegian

This comment has been minimized.

Copy link

@crowegian crowegian commented Aug 23, 2019

@sajidrahman were you able to fix your issue? I'm working with the huggingface code for BERT and getting the same error.

@Luonic, i think you are right. I wrapped the loss function in DataParallelCriterion and got the same error.

@alexanderhucheerful

This comment has been minimized.

Copy link

@alexanderhucheerful alexanderhucheerful commented Oct 20, 2019

i meet the question :
File "train.py", line 159, in
loss.backward()
File "/usr/hujh/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 118, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/hujh/anaconda3/lib/python3.7/site-packages/torch/autograd/init.py", line 87, in backward
grad_tensors = _make_grads(tensors, grad_tensors)
File "/usr/hujh/anaconda3/lib/python3.7/site-packages/torch/autograd/init.py", line 28, in _make_grads
raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs
i really dont kown how to slove it ,i think the issue is when loss.backward,it dose not distrubute the loss scalar into per gpu.and i dont konw how to solve it

@ChunshengLin

This comment has been minimized.

Copy link

@ChunshengLin ChunshengLin commented Oct 29, 2019

@alexhuooxx were you able to fix your issue? i got the same error.

@AlexHex7

This comment has been minimized.

Copy link

@AlexHex7 AlexHex7 commented Dec 9, 2019

Hi, there is a user warning. I use pytorch v1.2.0.

torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector. warnings.warn('Was asked to gather along dimension 0, but all '

Also If I use two GPUs, then the loss is a list [loss_1, loss_2]. Before loss.backward, I need to sum or mean them.

@ericaliu0610

This comment has been minimized.

Copy link

@ericaliu0610 ericaliu0610 commented Jul 11, 2020

Hi, there is a user warning. I use pytorch v1.2.0.

torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector. warnings.warn('Was asked to gather along dimension 0, but all '

Also If I use two GPUs, then the loss is a list [loss_1, loss_2]. Before loss.backward, I need to sum or mean them.

@AlexHex7 I feel like if the loss list is calculated as mean loss, this code will not balance load on multi-GPU machine 🤔
Not sure if there is anyone successfully run this code and solved the unbalancing memory usage problem.

@mikeyEcology

This comment has been minimized.

Copy link

@mikeyEcology mikeyEcology commented Sep 1, 2020

When I run the code here following the guidance from this post, I get this error when I try to train:

line 345, in _conv_forward
return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

This is how I'm using the code from here:

if torch.cuda.device_count() > 1:
    model = DataParallelModel(model, device_ids=[0, 1])
    loss_func = DataParallelCriterion(loss_func, device_ids=[0, 1])
@linksboy

This comment has been minimized.

Copy link

@linksboy linksboy commented Nov 11, 2021

@sajidrahman,were you able to fix your issue? i got the same error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment