Skip to content

Instantly share code, notes, and snippets.

Created February 12, 2018 00:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/a5b4282a46be7e2971baf690e8cde054 to your computer and use it in GitHub Desktop.
Save anonymous/a5b4282a46be7e2971baf690e8cde054 to your computer and use it in GitHub Desktop.
import os
import torch
import torch.distributed as dist
from torch import multiprocessing, nn, randperm
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.autograd import Variable
from torch.optim import SGD
from torchvision import datasets, transforms
from functools import partial
from contextlib import ExitStack, contextmanager
class Data(object):
def __init__(self, batch_size=128, num_workers=4, val_size=8800,
data_path='tmp/'):
"""
Needs either val-dataset or combination of train and val
indexes, in which case the train and val instances are both
sampled from the train-dataset.
The latter is useful for torchvision datasets that don't
have a splitting mechanism otherwise.
"""
self.test_dataset = None
dataset_opts = dict(
root=data_path,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.1307,),
(0.3081,),
)
])
)
self.train_dataset = datasets.MNIST(
train=True,
download=True,
**dataset_opts
)
self.test_dataset = datasets.MNIST(
train=False,
**dataset_opts
)
indexes = randperm(len(self.train_dataset))
self.train_indexes = indexes[val_size:]
self.batch_size = batch_size
self.num_workers = num_workers
self.splits = None
# FIXME: fails when l % num_splits != 0
self.splits = self.train_indexes.view(self.num_workers, -1)
def get_train(self, split_idx):
worker_batch_size = self.batch_size / self.num_workers
opts = dict(
dataset=self.train_dataset,
batch_size=worker_batch_size,
sampler=SubsetRandomSampler(self.splits[split_idx]),
)
return DataLoader(**opts)
class MnistNet(nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
self.l1 = nn.Linear(28 * 28, 1024)
self.l2 = nn.Linear(1024, 1024)
self.l3 = nn.Linear(1024, 1024)
self.out = nn.Linear(1024, 10)
for name, param in self.named_parameters():
if 'weight' in name:
nn.init.kaiming_normal(param)
def forward(self, x):
# input
x = x.view(-1, 28 * 28)
# reshaped
x = F.dropout(x, p=0.2, training=self.training)
x = self.l1(x)
x = F.relu(x)
# l1
x = F.dropout(x, p=0.5, training=self.training)
x = self.l2(x)
x = F.relu(x)
# l2
x = F.dropout(x, p=0.5, training=self.training)
x = self.l3(x)
x = F.relu(x)
# l3
x = F.dropout(x, p=0.5, training=self.training)
x = self.out(x)
x = F.log_softmax(x, dim=1)
# output
return x
class Trainer(object):
def __init__(self, rank, lr=0.01, momentum=0.99, nesterov=True,
epochs=300, gpu=True, agg_period=2):
self.rank = rank
self.gpu = gpu
self.agg_period = agg_period
self.train_loader = Data().get_train(self.rank)
self.model = MnistNet()
self.loss_fn = F.nll_loss
self.on_update_fn = None
self.start_epoch = 0
self.epoch = 0
self.update = 0
self.num_epochs = epochs
if self.gpu:
# This splits models evenly across available GPUs
self.device_id = self.rank % torch.cuda.device_count()
with torch.cuda.device(self.device_id):
self.model.cuda()
self.optimizer = SGD(
self.model.parameters(),
lr=lr,
momentum=momentum,
nesterov=nesterov
)
def train(self):
for epoch in range(self.start_epoch + 1, self.num_epochs + 1):
self.epoch = epoch
for input_data, target in self.train_loader:
self.update += 1
self.train_step(input_data, target)
return
def train_step(self, input_data, target):
# Set mode to train:
self.model.train()
# Transfer data to GPU
if self.gpu:
with torch.cuda.device(self.device_id):
input_data = input_data.pin_memory() \
.cuda(self.device_id, async=True)
target = target.pin_memory() \
.cuda(self.device_id, async=True)
# Compute loss and gradients for the batch
x = Variable(input_data)
y = Variable(target)
self.optimizer.zero_grad()
out = self.model(x)
loss = self.loss_fn(out, y)
loss.backward()
# Gradient descent step
self.optimizer.step()
# This is where we potentially aggregate parameters across
# workers in DistTrainer
if self.on_update_fn:
self.on_update_fn()
print('Epoch: {}, Update: {}, Train loss: {}'
.format(self.epoch, self.update, loss.data[0]))
return loss
class DistTrainer(Trainer):
def __init__(self, *args, **kwargs):
super(DistTrainer, self).__init__(*args, **kwargs)
self.on_update_fn = self.gossip_push_pull_reduce
# Ensure all workers start training with the same parameters.
self.init_parameters()
@contextmanager
def _on_cpu_for_comm(self):
with ExitStack() as stack:
if self.gpu:
stack.enter_context(torch.cuda.device(self.device_id))
self.model.cpu()
yield
if self.gpu:
self.model.cuda()
def init_parameters(self):
with self._on_cpu_for_comm():
for param in self.model.parameters():
dist.broadcast(param.data, src=0)
return
@staticmethod
def _select_random(size, rank):
""" Select a random peer in the cluster."""
random_peer = torch.ceil(torch.rand(1) * (size-1))
random_peer = (random_peer + rank) % size
return random_peer
def _init_gossip(self):
"""
Select a peer with uniform probability, and receive every
other node's peer selection. Once done, determine all peers
we'd be gossiping with.
"""
size = dist.get_world_size()
rank = dist.get_rank()
own_peer = self._select_random(size, rank)
# Then, dist.all_gather these selections so all workers know
# every other worker's selection
requesting_peers = [
torch.zeros(1)
for _ in range(dist.get_world_size())
]
dist.all_gather(requesting_peers, own_peer)
# Convert list of 1x1 tensors into 1-D tensor
requesting_peers = torch.Tensor(list(map(lambda x: x[0],
requesting_peers)))
# Then, collect a list of other workers that have selected self
# to gossip with. This is accomplished by looking at the index
# where self occurs in the all_gather'd list
requesting_peers = torch \
.nonzero(requesting_peers == dist.get_rank()) \
.view(-1) \
.tolist()
own_peer = own_peer.int()[0]
peers = set(requesting_peers + [own_peer])
return peers
def gossip_push_pull_reduce(self):
if self.update % self.agg_period != 0:
return
peers = self._init_gossip()
with self._on_cpu_for_comm():
for param in self.model.parameters():
# A container to hold async requests
requests = []
# We've to exchange self.params with requesting peers
to_send = param.data
receive_buffers = [
(peer, torch.zeros(param.data.shape))
for peer in peers
]
for peer, buffer in receive_buffers:
requests.append(dist.isend(
tensor=to_send,
dst=peer
))
requests.append(dist.irecv(
tensor=buffer,
src=peer
))
# Wait for all the requests to complete
for r in requests:
r.wait()
# Then compute the average
for _, buffer in receive_buffers:
param.data += buffer
param.data /= len(receive_buffers) + 1
return
class Cluster(object):
def __init__(self, master_addr='127.0.0.1', master_port='29500',
backend='tcp', num_workers=4,
seed=1234, gpu=True):
super(Cluster, self).__init__()
self.master_addr = master_addr
self.master_port = master_port
self.seed = seed
self.gpu = gpu
self.backend = backend
self.num_workers = num_workers
def _init_worker(self, rank, fn):
os.environ['MASTER_ADDR'] = self.master_addr
os.environ['MASTER_PORT'] = self.master_port
# Below is an attempt to address issue with number of threads:
# (https://github.com/pytorch/pytorch/issues/975)
os.environ['OMP_NUM_THREADS'] = '1'
torch.set_num_threads(1)
# break symmetry in random seed across workers
seed = self.seed + rank
torch.manual_seed(seed)
if self.gpu:
torch.manual_seed(seed)
dist.init_process_group(self.backend, rank=rank,
world_size=self.num_workers)
res = fn(rank)
return res
def run_processes(self, fn):
if self.gpu:
Pool = multiprocessing.get_context('spawn').Pool
else:
Pool = multiprocessing.Pool
partial_init = partial(self._init_worker, fn=fn)
pool = Pool(
self.num_workers,
maxtasksperchild=1,
)
results = pool.map(
partial_init,
range(self.num_workers),
chunksize=1,
)
return results
def _work(rank):
trainer = DistTrainer(rank)
trainer.train()
return
def main():
cluster = Cluster()
cluster.run_processes(_work)
return
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment