Created
March 28, 2017 00:27
-
-
Save FuriouslyCurious/81742b8126f07f919522a588147e6086 to your computer and use it in GitHub Desktop.
Soumith ImageNet Example Experiments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import shutil | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim | |
import torch.utils.data | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
import torchvision.models as models | |
model_names = sorted(name for name in models.__dict__ | |
if name.islower() and not name.startswith("__") | |
and callable(models.__dict__[name])) | |
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') | |
parser.add_argument('data', metavar='DIR', | |
help='path to dataset') | |
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', | |
choices=model_names, | |
help='model architecture: ' + | |
' | '.join(model_names) + | |
' (default: resnet18)') | |
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |
help='number of data loading workers (default: 4)') | |
parser.add_argument('--epochs', default=90, type=int, metavar='N', | |
help='number of total epochs to run') | |
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', | |
help='manual epoch number (useful on restarts)') | |
parser.add_argument('-b', '--batch-size', default=256, type=int, | |
metavar='N', help='mini-batch size (default: 256)') | |
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, | |
metavar='LR', help='initial learning rate') | |
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', | |
help='momentum') | |
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, | |
metavar='W', help='weight decay (default: 1e-4)') | |
parser.add_argument('--print-freq', '-p', default=10, type=int, | |
metavar='N', help='print frequency (default: 10)') | |
parser.add_argument('--resume', default='', type=str, metavar='PATH', | |
help='path to latest checkpoint (default: none)') | |
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', | |
help='evaluate model on validation set') | |
parser.add_argument('--pretrained', dest='pretrained', action='store_true', | |
help='use pre-trained model') | |
best_prec1 = 0 | |
def main(): | |
global args, best_prec1 | |
args = parser.parse_args() | |
print("=> Parsing complete...") | |
# create model | |
if args.pretrained: | |
print("=> using pre-trained model '{}'".format(args.arch)) | |
model = models.__dict__[args.arch](pretrained=True) | |
else: | |
print("=> creating model '{}'".format(args.arch)) | |
model = models.__dict__[args.arch]() | |
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): | |
model.features = torch.nn.DataParallel(model.features) | |
model.cuda() | |
print("=> Not putting weights on GPU for VGG / AlexNet") | |
else: | |
model = torch.nn.DataParallel(model).cuda() | |
print("=> Using CUDA DataParallel") | |
# optionally resume from a checkpoint | |
if args.resume: | |
if os.path.isfile(args.resume): | |
print("=> loading checkpoint '{}'".format(args.resume)) | |
checkpoint = torch.load(args.resume) | |
args.start_epoch = checkpoint['epoch'] | |
best_prec1 = checkpoint['best_prec1'] | |
model.load_state_dict(checkpoint['state_dict']) | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.resume, checkpoint['epoch'])) | |
else: | |
print("=> no checkpoint found at '{}'".format(args.resume)) | |
cudnn.benchmark = True | |
# Data loading code | |
traindir = os.path.join(args.data, 'train') | |
valdir = os.path.join(args.data, 'val') | |
normalize = transforms.Normalize(mean=[0.6492440333064673, 0.6492440333064673, 0.6492440333064673], | |
std=[0.235452743560426, 0.235452743560426, 0.235452743560426]) | |
print("=> Starting training images loading...") | |
train_loader = torch.utils.data.DataLoader( | |
datasets.ImageFolder(traindir, transforms.Compose([ | |
#transforms.RandomSizedCrop(224), | |
transforms.Scale(512), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
])), | |
batch_size=args.batch_size, shuffle=True, | |
num_workers=args.workers, pin_memory=True) | |
print("=> Starting validation images loading...") | |
val_loader = torch.utils.data.DataLoader( | |
datasets.ImageFolder(valdir, transforms.Compose([ | |
transforms.Scale(512), | |
#transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize, | |
])), | |
batch_size=args.batch_size, shuffle=False, | |
num_workers=args.workers, pin_memory=True) | |
# define loss function (criterion) and optimizer | |
print("=> Loss criterion and optimizer setup") | |
criterion = nn.CrossEntropyLoss().cuda() | |
optimizer = torch.optim.SGD(model.parameters(), args.lr, | |
momentum=args.momentum, | |
weight_decay=args.weight_decay) | |
if args.evaluate: | |
print("=> Validating model...") | |
validate(val_loader, model, criterion) | |
return | |
print("=> Starting training...") | |
for epoch in range(args.start_epoch, args.epochs): | |
adjust_learning_rate(optimizer, epoch) | |
# train for one epoch | |
print("=> Training Epoch {}".format(epoch)) | |
train(train_loader, model, criterion, optimizer, epoch) | |
# evaluate on validation set | |
print("=> Validating Epoch {}".format(epoch)) | |
prec1 = validate(val_loader, model, criterion) | |
# remember best prec@1 and save checkpoint | |
is_best = prec1 > best_prec1 | |
best_prec1 = max(prec1, best_prec1) | |
save_checkpoint({ | |
'epoch': epoch + 1, | |
'arch': args.arch, | |
'state_dict': model.state_dict(), | |
'best_prec1': best_prec1, | |
}, is_best) | |
def train(train_loader, model, criterion, optimizer, epoch): | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to train mode | |
model.train() | |
end = time.time() | |
for i, (input, target) in enumerate(train_loader): | |
# measure data loading time | |
data_time.update(time.time() - end) | |
target = target.cuda(async=True) | |
input_var = torch.autograd.Variable(input) | |
target_var = torch.autograd.Variable(target) | |
# compute output | |
output = model(input_var) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
top5.update(prec5[0], input.size(0)) | |
# compute gradient and do SGD step | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % args.print_freq == 0: | |
print('Epoch: [{0}][{1}/{2}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
epoch, i, len(train_loader), batch_time=batch_time, | |
data_time=data_time, loss=losses, top1=top1, top5=top5)) | |
def validate(val_loader, model, criterion): | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to evaluate mode | |
model.eval() | |
end = time.time() | |
for i, (input, target) in enumerate(val_loader): | |
target = target.cuda(async=True) | |
input_var = torch.autograd.Variable(input, volatile=True) | |
target_var = torch.autograd.Variable(target, volatile=True) | |
# compute output | |
output = model(input_var) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
top5.update(prec5[0], input.size(0)) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % args.print_freq == 0: | |
print('Test: [{0}/{1}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
i, len(val_loader), batch_time=batch_time, loss=losses, | |
top1=top1, top5=top5)) | |
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' | |
.format(top1=top1, top5=top5)) | |
return top1.avg | |
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): | |
print("=> Saving model checkpoint...") | |
torch.save(state, filename) | |
if is_best: | |
shutil.copyfile(filename, 'model_best.pth.tar') | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def adjust_learning_rate(optimizer, epoch): | |
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" | |
lr = args.lr * (0.1 ** (epoch // 30)) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the precision@k for the specified values of k""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].view(-1).float().sum(0) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment