Last active
December 3, 2019 17:48
-
-
Save nizhib/9f6c43ef1faab23f1b1ff606685813c2 to your computer and use it in GitHub Desktop.
Parallel model+criterion for pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import torch | |
from torch.nn.modules import Module | |
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |
from torch.nn.parallel.replicate import replicate | |
from torch.nn.parallel.parallel_apply import parallel_apply | |
__all__ = ['DataParallelModel', 'DataParallelCriterion'] | |
class DataParallelModel(Module): | |
def __init__(self, module, device_ids=None, output_device=None, dim=0): | |
super(DataParallelModel, self).__init__() | |
if not torch.cuda.is_available(): | |
self.module = module | |
self.device_ids = [] | |
return | |
if device_ids is None: | |
device_ids = list(range(torch.cuda.device_count())) | |
if output_device is None: | |
output_device = device_ids[0] | |
self.dim = dim | |
self.module = module | |
self.device_ids = device_ids | |
self.output_device = output_device | |
if len(self.device_ids) == 1: | |
self.module.cuda(device_ids[0]) | |
def forward(self, *inputs, **kwargs): | |
if not self.device_ids: | |
return self.module(*inputs, **kwargs) | |
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | |
if len(self.device_ids) == 1: | |
return self.module(*inputs[0], **kwargs[0]) | |
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) | |
outputs = self.parallel_apply(replicas, inputs, kwargs) | |
# return self.gather(outputs, self.output_device) | |
return outputs | |
def replicate(self, module, device_ids): | |
return replicate(module, device_ids) | |
def scatter(self, inputs, kwargs, device_ids): | |
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) | |
def parallel_apply(self, replicas, inputs, kwargs): | |
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) | |
# def gather(self, outputs, output_device): | |
# return gather(outputs, output_device, dim=self.dim) | |
class DataParallelCriterion(Module): | |
def __init__(self, module, device_ids=None, output_device=None, dim=0): | |
super(DataParallelCriterion, self).__init__() | |
if not torch.cuda.is_available(): | |
self.module = module | |
self.device_ids = [] | |
return | |
if device_ids is None: | |
device_ids = list(range(torch.cuda.device_count())) | |
if output_device is None: | |
output_device = device_ids[0] | |
self.dim = dim | |
self.module = module | |
self.device_ids = device_ids | |
self.output_device = output_device | |
if len(self.device_ids) == 1: | |
self.module.cuda(device_ids[0]) | |
def forward(self, inputs, *targets, **kwargs): | |
# input should be already scatterd | |
# scattering the targets instead | |
targets, kwargs = self.scatter(targets, kwargs, self.device_ids) | |
inputs = [(input,) for input in inputs] | |
if len(self.device_ids) == 1: | |
return self.module(*inputs[0], *targets[0], **kwargs[0]) | |
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) | |
outputs = self.parallel_apply(replicas, inputs, targets, kwargs) | |
# applied = ReduceAddCoalesced.apply(self.output_device, 1, *outputs) | |
# if isinstance(applied, tuple) and len(applied) == 1: | |
# return applied[0] / len(outputs) | |
# else: | |
# return applied / len(outputs) | |
return self.gather(outputs, self.output_device) | |
def replicate(self, module, device_ids): | |
return replicate(module, device_ids) | |
def scatter(self, inputs, kwargs, device_ids): | |
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) | |
def parallel_apply(self, replicas, inputs, targets, kwargs): | |
return criterion_parallel_apply( | |
replicas, inputs, targets, kwargs, self.device_ids[:len(replicas)]) | |
def gather(self, outputs, output_device): | |
# FIXME: Dirty hack here as 0.4.0 fails to gather scalars | |
if not len(outputs[0].shape): | |
outputs = [output.unsqueeze(0) for output in outputs] | |
return gather(outputs, output_device, dim=self.dim) | |
def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): | |
assert len(modules) == len(inputs) | |
assert len(targets) == len(inputs) | |
if kwargs_tup is not None: | |
assert len(modules) == len(kwargs_tup) | |
else: | |
kwargs_tup = ({},) * len(modules) | |
if devices is not None: | |
assert len(modules) == len(devices) | |
else: | |
devices = [None] * len(modules) | |
lock = threading.Lock() | |
results = {} | |
grad_enabled = torch.is_grad_enabled() | |
def _worker(i, module, input, target, kwargs, device=None): | |
torch.set_grad_enabled(grad_enabled) | |
if device is None: | |
device = get_a_var(input).get_device() | |
try: | |
with torch.cuda.device(device): | |
output = module(*input, *target, **kwargs) | |
with lock: | |
results[i] = output | |
except Exception as e: | |
with lock: | |
results[i] = e | |
if len(modules) > 1: | |
threads = [threading.Thread(target=_worker, | |
args=(i, module, input, target, kwargs, device), | |
) | |
for i, (module, input, target, kwargs, device) in | |
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] | |
for thread in threads: | |
thread.start() | |
for thread in threads: | |
thread.join() | |
else: | |
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |
outputs = [] | |
for i in range(len(inputs)): | |
output = results[i] | |
if isinstance(output, Exception): | |
raise output | |
outputs.append(output) | |
return outputs | |
def get_a_var(obj): | |
if isinstance(obj, torch.Tensor): | |
return obj | |
if isinstance(obj, list) or isinstance(obj, tuple): | |
for result in map(get_a_var, obj): | |
if isinstance(result, torch.Tensor): | |
return result | |
if isinstance(obj, dict): | |
for result in map(get_a_var, obj.items()): | |
if isinstance(result, torch.Tensor): | |
return result | |
return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment