Skip to content

Instantly share code, notes, and snippets.

@ilkarman
Created October 8, 2020 11:27
Show Gist options
  • Save ilkarman/1c057ca1843ffc1b07779c9a3b94848f to your computer and use it in GitHub Desktop.
Save ilkarman/1c057ca1843ffc1b07779c9a3b94848f to your computer and use it in GitHub Desktop.
import argparse
import os
import shutil
import time
import socket
import multiprocessing
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.cuda.amp import autocast, GradScaler
# Nvidia's own version of default_collate from pytorch; instead of calling transforms.ToTensor()
def fast_collate(batch, memory_format):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
def parse():
model_names = sorted(
name
for name in models.__dict__
if name.islower()
and not name.startswith("__")
and callable(models.__dict__[name])
)
parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
parser.add_argument("data", metavar="DIR", default="", help="path to dataset")
parser.add_argument(
"--arch",
"-a",
metavar="ARCH",
default="resnet50",
choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)",
)
parser.add_argument(
"-j",
"--workers",
default=4,
type=int,
metavar="N",
help="number of data loading workers per process/GPU (default: 4)",
)
parser.add_argument(
"--epochs",
default=5,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=0,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
default=256,
type=int,
metavar="N",
help="mini-batch size per process/GPU (default: 256)",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.1,
type=float,
metavar="LR",
help="Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float("
"args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="momentum"
)
parser.add_argument(
"--weight-decay",
"--wd",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
)
parser.add_argument(
"--print-freq",
"-p",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
action="store_true",
help="use pre-trained model",
)
parser.add_argument(
"--prof", default=-1, type=int, help="Only run 10 iterations for profiling."
)
parser.add_argument("--deterministic", action="store_true")
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--channels-last", type=bool, default=True)
parser.add_argument("--synthetic", action="store_true", help="Run on fake-data")
parser.add_argument(
"--validate", action="store_true", help="Run validation during training-loop"
)
args = parser.parse_args()
return args
def main():
global best_prec1, args
args = parse()
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
print("\nNCCL VERSION: {}\n".format(torch.cuda.nccl.version()))
print("\nCPU Count: {}\n".format(multiprocessing.cpu_count()))
cudnn.benchmark = True
best_prec1 = 0
ngpus_per_node = torch.cuda.device_count()
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
args.distributed = False
if "WORLD_SIZE" in os.environ:
args.distributed = int(os.environ["WORLD_SIZE"]) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
print("Local rank {}".format(args.local_rank))
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
print("Use GPU: {} for training".format(args.gpu))
torch.distributed.init_process_group(
backend="nccl",
init_method="env://"
)
args.world_size = torch.distributed.get_world_size()
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size * args.world_size) / 256.0
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
if args.distributed:
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
model=model.to(memory_format=memory_format)
print("Batch per GPU: {}".format(args.batch_size))
print("Total Batch on Node: {}".format(args.batch_size * ngpus_per_node))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
model = model.to(memory_format=memory_format)
model = torch.nn.parallel.DistributedDataParallel(model)
else:
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
model = model.to(memory_format=memory_format)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
crop_size = 224
val_size = 256
traindir = os.path.join(args.data, "train")
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose(
[
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip()
]
),
)
print("Train dataset size: {}".format(len(train_dataset)))
if args.validate:
valdir = os.path.join(args.data, "val")
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose(
[
transforms.Resize(val_size),
transforms.CenterCrop(crop_size)
]
),
)
train_sampler = None
val_sampler = None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
if args.validate:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
collate_fn = lambda b: fast_collate(b, memory_format)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.workers,
pin_memory=True,
sampler=train_sampler,
collate_fn=collate_fn,
)
if args.validate:
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
sampler=val_sampler,
collate_fn=collate_fn,
)
if args.evaluate:
validate(val_loader, model, criterion)
return
print("Starting training")
stime = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
stime_epoch = time.time()
train(train_loader, model, criterion, optimizer, epoch)
# Validate after full-training
if args.validate:
# evaluate on validation set
prec1 = validate(val_loader, model, criterion)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": args.arch,
"state_dict": model.state_dict(),
"best_prec1": best_prec1,
"optimizer": optimizer.state_dict(),
},
is_best,
)
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next()
scaler = GradScaler()
i = 0
while input is not None:
i += 1
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
with autocast():
output = model(input)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if i % args.print_freq == 0:
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
# to_python_float incurs a host<->device sync
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
torch.cuda.synchronize()
batch_time.update((time.time() - end)/args.print_freq)
end = time.time()
if args.local_rank == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader),
args.world_size*args.batch_size/batch_time.val,
args.world_size*args.batch_size/batch_time.avg,
batch_time=batch_time,
loss=losses, top1=top1, top5=top5))
input, target = prefetcher.next()
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
prefetcher = data_prefetcher(val_loader)
input, target = prefetcher.next()
i = 0
while input is not None:
i += 1
# compute output
with torch.no_grad():
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# TODO: Change timings to mirror train().
if args.local_rank == 0 and i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {2:.3f} ({3:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
input, target = prefetcher.next()
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr*(0.1**factor)
"""Warmup"""
if epoch < 5:
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
# if(args.local_rank == 0):
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= args.world_size
return rt
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment