-
-
Save pleasantrabbit/f4522669c98e8a806744e6130ca09397 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import json | |
import signal | |
import logging | |
import time | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.data.distributed | |
import torchvision.models as models | |
import datetime | |
import numpy as np | |
import sys | |
import statistics | |
import threading | |
from time import sleep | |
from threading import Thread | |
from math import ceil | |
import random | |
from random import Random | |
from torch.multiprocessing import Process | |
from torch.autograd import Variable | |
from torchvision import datasets, transforms | |
import byteps.torch as bps | |
class Partition(object): | |
""" Dataset-like object, but only access a subset of it. """ | |
def __init__(self, data, index): | |
self.data = data | |
self.index = index | |
def __len__(self): | |
return len(self.index) | |
def __getitem__(self, index): | |
data_idx = self.index[index] | |
return self.data[data_idx] | |
class DataPartitioner(object): | |
""" Partitions a dataset into different chunks. """ | |
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): | |
self.data = data | |
self.partitions = [] | |
rng = Random() | |
rng.seed(seed) | |
data_len = len(data) | |
indexes = [x for x in range(0, data_len)] | |
# rng.shuffle(indexes) | |
for frac in sizes: | |
part_len = int(frac * data_len) | |
self.partitions.append(indexes[0:part_len]) | |
indexes = indexes[part_len:] | |
def use(self, partition): | |
return Partition(self.data, self.partitions[partition]) | |
class Net(nn.Module): | |
""" Network architecture. """ | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.conv2_drop = nn.Dropout2d() | |
self.fc1 = nn.Linear(320, 50) | |
self.fc2 = nn.Linear(50, 10) | |
def forward(self, x): | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) # randomness here | |
x = x.view(-1, 320) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
return F.log_softmax(x, dim=1) | |
def partition_dataset(args): | |
""" | |
Partitions the imported dataset by calling DataPartitioner() | |
:param args: command line input arguments | |
:return: training set and batch size | |
""" | |
dataset = datasets.MNIST( | |
'./data', | |
train=True, | |
download=True, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
) | |
bsz = 32.0 | |
partition_sizes = [0.25] * 4 | |
partition = DataPartitioner(dataset, sizes=partition_sizes) | |
partition = partition.use(0) # to minimize randomness, use the same partition for all workers | |
# retrieve a subset of the overall training set by current process rank | |
train_set = torch.utils.data.DataLoader( | |
partition, batch_size=int(bsz), shuffle=False, num_workers=0, pin_memory=True) | |
return train_set, bsz | |
def update_gradients(model): | |
""" | |
Send the model gradients to be aggregated using all-reduce | |
:param model: neural net used for training | |
:param args: command line input arguments | |
:return: None | |
""" | |
size = float(dist.get_world_size()) | |
for param in model.parameters(): | |
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) | |
param.grad.data /= size | |
def train(args): | |
train_set, bsz = partition_dataset(args) | |
model = Net() | |
if args.all_reduce: | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
else: # bps | |
from byteps.torch.parallel import DistributedDataParallel as DDP | |
model.cuda(device="cuda:0") | |
model = DDP( | |
model, device_ids=["cuda:0"] | |
) | |
model.train() | |
optimizer = optim.SGD(model.parameters(), lr=args.lr) | |
num_batches = ceil(len(train_set.dataset) / float(bsz)) | |
if False and args.bps: | |
# BytePS: broadcast parameters | |
bps.broadcast_parameters(model.state_dict(), root_rank=0) | |
bps.broadcast_optimizer_state(optimizer, root_rank=0) | |
for epoch in range(30): | |
epoch_loss = 0.0 | |
for batch_idx, (data, target) in enumerate(train_set): | |
data, target = data.cuda(device="cuda:0"), target.cuda(device="cuda:0") | |
######################################################################### | |
# Forward & backward pass | |
optimizer.zero_grad() | |
output = model(data) | |
criterion = nn.CrossEntropyLoss() | |
loss = criterion(output, target) | |
epoch_loss += loss.item() | |
loss.backward() | |
######################################################################### | |
# average the gradients if allreduce. In byteps it is handled by DDP | |
if False and args.all_reduce: update_gradients(model) | |
######################################################################### | |
optimizer.step() | |
if (batch_idx % args.log_interval == 0 and args.pindex == 0): | |
# only print info on worker 0 | |
print(f"Worker {args.pindex} job_id {args.job_id} Train Epoch: {epoch} [Iteration {batch_idx}/{len(train_set)}]") | |
epoch_loss = round(epoch_loss / num_batches, 5) | |
if args.bps: | |
epoch_loss = metric_average(args, epoch_loss, 'epoch_loss') | |
# the push_pull giving sum instead of average issue is fixed in some commits in late 2020 | |
# epoch_loss /= 4 # https://github.com/bytedance/byteps/issues/323 | |
epoch_loss = round(epoch_loss, 5) | |
print(f"Worker {args.pindex} job_id {args.job_id} epoch {epoch} loss {epoch_loss}") | |
def metric_average(args, val, name): | |
tensor = torch.tensor(val) | |
if args.cuda: | |
tensor = tensor.cuda() | |
avg_tensor = bps.push_pull(tensor, name=name) | |
return avg_tensor.item() | |
def init_processes(args): | |
# Seed everything -- see https://pytorch.org/docs/stable/notes/randomness.html | |
torch.manual_seed(args.seed) | |
torch.cuda.manual_seed(args.seed) | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
os.environ["PYTHONHASHSEED"] = str(args.seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
if args.all_reduce: | |
print("=====Initializing process group=====") | |
backend = 'gloo' if not args.cuda else 'nccl' | |
dist.init_process_group(backend=backend, init_method='env://') | |
print("=====Finished initializing process group=====") | |
else: # bps | |
bps.init() # initialize bps | |
# pin GPU to byteps local rank. | |
torch.cuda.set_device("cuda:0") | |
train(args) | |
def parse_args(): | |
""" | |
Parses the arguments and invokes the training function | |
:return: None | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--batch-size", type=int, default=32, help="input batch size") | |
parser.add_argument( | |
"-pindex", type=int, help="pindex of worker process in run_job.py" | |
) | |
parser.add_argument( | |
"-job_id", | |
type=str, | |
default="53706", | |
help="unique identifier of a training job, same as rdzv_id in etcd", | |
) | |
parser.add_argument( | |
"--lr", | |
type=float, | |
default=0.05, | |
metavar="LR", | |
help="learning rate (default: 0.05)", | |
) | |
parser.add_argument( | |
"--momentum", | |
type=float, | |
default=0.5, | |
metavar="M", | |
help="SGD momentum (default: 0.5)", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=1234, | |
metavar="S", | |
help="random seed (default: 1234)", | |
) | |
parser.add_argument( | |
"--no-cuda", action="store_true", default=False, help="disables CUDA training" | |
) | |
parser.add_argument( | |
"--log-interval", | |
type=int, | |
default=100, | |
metavar="N", | |
help="how many batches to wait before logging training status", | |
) | |
parser.add_argument( | |
"--bps", | |
action="store_true", | |
default=False, | |
help="use bps", | |
) | |
parser.add_argument( | |
"--all_reduce", | |
action="store_true", | |
default=False, | |
help="use all_reduce", | |
) | |
args = parser.parse_args() | |
if not (args.all_reduce or args.bps): | |
print("Specify a framework to use (bps/all_reduce)") | |
exit(1) | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
init_processes(args) | |
if __name__ == "__main__": | |
parse_args() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment