Skip to content

Instantly share code, notes, and snippets.

@daviddao
Created January 27, 2018 18:31
Show Gist options
  • Save daviddao/c456e3ba837c73865159aea4337b80b4 to your computer and use it in GitHub Desktop.
Save daviddao/c456e3ba837c73865159aea4337b80b4 to your computer and use it in GitHub Desktop.
Distributed model parallelism with PyTorch
def train(epoch):
model.train()
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
if dist.get_rank() == 0:
input_from_part2 = torch.FloatTensor(data.size()[0], 320)
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
data = Variable(data.cuda())
output = model(data)
dist.send(output.data.cpu(), dst=1)
dist.recv(tensor=input_from_part2, src= 1)
output.backward(input_from_part2.cuda())
optimizer.step()
else:
output_from_part1 = torch.FloatTensor(data.size()[0], 320)
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
target = Variable(target)
dist.recv(tensor=output_from_part1, src=0)
input = Variable(output_from_part1, requires_grad = True)
output = model(input)
loss = F.nll_loss(output, target)
loss.backward()
dist.send(input.grad.data, dst=0)
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment