Skip to content

Instantly share code, notes, and snippets.

@arunmallya
Created June 20, 2017 20:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arunmallya/5e569e4c23ad0567a64764ad70a393b1 to your computer and use it in GitHub Desktop.
Save arunmallya/5e569e4c23ad0567a64764ad70a393b1 to your computer and use it in GitHub Desktop.
Exposes bug with DataParallel when using dicts as input
import torch
import torch.nn as nn
from torch.autograd import Variable
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.net = nn.Linear(10, 2)
def forward(self, inputs):
return self.net(inputs['data'])
net = nn.DataParallel(SimpleModel()).cuda()
inputs = {'data': Variable(torch.rand(10, 10).cuda())}
outputs = net(inputs)
print(outputs)
"""
# Works fine on single device as DataParallel defaults to simple execution if one device only.
$ CUDA_VISIBLE_DEVICES=0 python bug.py
Variable containing:
-0.1836 0.2654
-0.3584 0.0049
-0.2587 -0.0808
-0.2482 -0.2587
-0.4238 -0.2014
-0.1964 -0.1709
-0.6334 -0.0843
-0.4466 0.1243
-0.5991 -0.2169
-0.3005 -0.0565
[torch.cuda.FloatTensor of size 10x2 (GPU 0)]
# Fails on multiple devices.
$ CUDA_VISIBLE_DEVICES=0,1 python bug.py
Exception in thread Thread-2:
Traceback (most recent call last):
File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
self.run()
File "/usr/lib/python3.5/threading.py", line 862, in run
self._target(*self._args, **self._kwargs)
File "venv/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 22, in _worker
var_input = var_input[0]
KeyError: 0
Exception in thread Thread-1:
Traceback (most recent call last):
File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
self.run()
File "/usr/lib/python3.5/threading.py", line 862, in run
self._target(*self._args, **self._kwargs)
File "venv/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 22, in _worker
var_input = var_input[0]
KeyError: 0
Traceback (most recent call last):
File "bug.py", line 16, in <module>
outputs = net(inputs)
File "venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in __call__
result = self.forward(*input, **kwargs)
File "venv/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 61, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "venv/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 71, in parallel_apply
return parallel_apply(replicas, inputs, kwargs)
File "venv/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 44, in parallel_apply
output = results[i]
KeyError: 0
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment