Skip to content

Instantly share code, notes, and snippets.

@jgc128
Created August 13, 2017 15:43
Show Gist options
  • Save jgc128/760af23c4558deb83e1ec80b6f22fa49 to your computer and use it in GitHub Desktop.
Save jgc128/760af23c4558deb83e1ec80b6f22fa49 to your computer and use it in GitHub Desktop.
PyTroch DataParallel Example
import numpy as np
import torch
import torch.nn
import torch.cuda
from torch.autograd import Variable
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
modules = [
torch.nn.Linear(10, 3),
torch.nn.Linear(3, 4),
torch.nn.Linear(4, 5),
]
self.net = torch.nn.ModuleList(modules)
def forward(self, inputs):
for i, n in enumerate(self.net):
inputs = n(inputs)
return inputs
def main():
X = np.random.uniform(-1, 1, (15, 10)).astype(np.float32)
y = np.random.randint(0, 5, (15,))
print(X.shape)
print(y.shape)
model = Net()
loss = torch.nn.CrossEntropyLoss()
print('Model:', type(model))
print('Loss:', type(loss))
X = torch.from_numpy(X)
y = torch.from_numpy(y)
print('X', X.size(), 'y', y.size())
if torch.cuda.is_available():
model = torch.nn.DataParallel(model)
print('Model:', type(model))
print('Devices:', model.device_ids)
model = model.cuda()
loss = loss.cuda()
X = X.cuda()
y = y.cuda()
else:
print('No devices available')
X = Variable(X)
y = Variable(y)
outputs = model(X)
l = loss(outputs, y)
print('Loss:', l.data[0])
if __name__ == '__main__':
main()
(15, 10)
(15,)
Model: <class '__main__.Net'>
Loss: <class 'torch.nn.modules.loss.CrossEntropyLoss'>
X torch.Size([15, 10]) y torch.Size([15])
Model: <class 'torch.nn.parallel.data_parallel.DataParallel'>
Devices: [0, 1, 2]
Loss: 1.6945154666900635
(15, 10)
(15,)
Model: <class '__main__.Net'>
Loss: <class 'torch.nn.modules.loss.CrossEntropyLoss'>
X torch.Size([15, 10]) y torch.Size([15])
No devices available
Loss: 1.71356201171875
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment