Skip to content

Instantly share code, notes, and snippets.

@Coderx7
Created June 25, 2018 06:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Coderx7/3090242a013464c5af8931a14250319e to your computer and use it in GitHub Desktop.
Save Coderx7/3090242a013464c5af8931a14250319e to your computer and use it in GitHub Desktop.
imagenet train pytorch script
# https://github.com/pytorch/vision/blob/master/torchvision/models/__init__.py
import argparse
import os
import shutil
import time
import os, sys, pdb, shutil, time, random, datetime
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from utils import convert_secs2time, time_string, time_file_str,AverageMeter
import torch.optim.lr_scheduler as lr_scheduler
# from models import print_log
import models
from tensorboardX import SummaryWriter
from utils import convert_model, measure_model
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', help='path to directory containing training and validation folders')
parser.add_argument('--train_dir_name', metavar='DIR', help='training set directory name')
parser.add_argument('--val_dir_name', metavar='DIR', help='validation set directory name')
parser.add_argument('--save_dir', type=str, default='./', help='Folder to save checkpoints and log.')
parser.add_argument('--arch', '-a', metavar='ARCH', default='simpnet_imgnet_5m_nodrp_safc_s1',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: simpnet_imgnet_5m_nodrp_safc_s1)')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 100)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
args = parser.parse_args()
args.prefix = time_file_str()
def main():
best_prec1 = 0
best_prec5 = 0
writer = SummaryWriter()
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
# used for file names, etc
time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
log = open(os.path.join(args.save_dir, '{}.{}_{}.log'.format(args.arch, args.prefix, time_stamp)), 'w')
# create model
print_log("=> creating model '{}'".format(args.arch), log)
model = models.__dict__[args.arch](1000)
print_log("=> Model : {}".format(model), log)
print_log("=> parameter : {}".format(args), log)
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# optimizer = torch.optim.Adadelta(model.parameters(), weight_decay=args.weight_decay,
# lr=0.1, rho=0.9)
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True)
IMAGE_SIZE=224
print_log(summary(model, input_size=(3, IMAGE_SIZE, IMAGE_SIZE)), log)
n_flops, n_params = measure_model(model, IMAGE_SIZE, IMAGE_SIZE)
print_log('FLOPs: %.2fM, Params: %.2fM' % (n_flops / 1e6, n_params / 1e6), log)
## epoch
milestones = [30, 60, 90, 130, 150]#[10, 20, 30, 40, 50, 60]#[15, 30, 60, 90, 110, 140]
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)
#scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print_log("=> loading checkpoint '{}'".format(args.resume), log)
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
if 'best_prec5' in checkpoint:
best_prec5 = checkpoint['best_prec5']
else:
best_prec5 = 0.00
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()
print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)
else:
print_log("=> no checkpoint found at '{}'".format(args.resume), log)
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, args.train_dir_name) #'train')
valdir = os.path.join(args.data, args.val_dir_name) #'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# meanstd = {
# mean = { 0.485, 0.456, 0.406 },
# std = { 0.229, 0.224, 0.225 },
# }
pca = {
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
'eigvec': torch.Tensor([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomSizedCrop(224),#224
# transforms.ColorJitter(
# brightness = 0.4,
# contrast = 0.4,
# saturation = 0.4,
# ),
#transforms.Lighting(0.1, pca['eigval'], pca['eigvec']),
#transforms.ColorNormalize(meanstd),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True, sampler=None)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Scale(256),
#t.ColorNormalize(meanstd),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
validate(val_loader, model, criterion)
return
filename = os.path.join(args.save_dir, 'checkpoint.{0}.{1}_{2}.pth.tar'.format(args.arch, args.prefix, time_stamp))
bestname = os.path.join(args.save_dir, 'best.{0}.{1}_{2}.pth.tar'.format(args.arch, args.prefix, time_stamp))
start_time = time.time()
epoch_time = AverageMeter()
for epoch in range(args.start_epoch, args.epochs):
current_learning_rate = float(scheduler.get_lr()[-1])
scheduler.step()
#adjust_learning_rate(optimizer, epoch)
need_hour, need_mins, need_secs = convert_secs2time(epoch_time.val * (args.epochs-epoch))
need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
#print_log(' [{:s}] :: {:3d}/{:3d} ----- [{:s}] {:s}'.format(args.arch, epoch, args.epochs, time_string(), need_time), log)
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
+ ' [Best : Accuracy(T1/T5)={:.2f}/{:.2f}, Error={:.2f}/{:.2f}]'.format(best_prec1, best_prec5, 100-best_prec1,100-best_prec5), log)
# train for one epoch
tr_prec1, tr_prec5, tr_loss = train(train_loader, model, criterion, optimizer, epoch, log)
# evaluate on validation set
prec1,prec5, val_loss = validate(val_loader, model, criterion, log)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
best_prec5 = max(prec5, best_prec5)
writer.add_scalar('Learning rate ', current_learning_rate, epoch)
writer.add_scalar('training/loss', tr_loss, epoch)
writer.add_scalar('training/Top1', tr_prec1, epoch)
writer.add_scalar('training/Top5', tr_prec5, epoch)
writer.add_scalar('validation/loss', val_loss, epoch)
writer.add_scalar('validation/Top1', prec1, epoch)
writer.add_scalar('validation/Top5', prec5, epoch)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'best_prec5': best_prec5,
'optimizer' : optimizer.state_dict(),
}, is_best, filename, bestname)
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
writer.close()
log.close()
# -- Lighting noise (AlexNet-style PCA-based noise)
# function M.Lighting(alphastd, eigval, eigvec)
# return function(input)
# if alphastd == 0 then
# return input
# end
# local alpha = torch.Tensor(3):normal(0, alphastd)
# local rgb = eigvec:clone()
# :cmul(alpha:view(1, 3):expand(3, 3))
# :cmul(eigval:view(1, 3):expand(3, 3))
# :sum(2)
# :squeeze()
# input = input:clone()
# for i=1,3 do
# input[i]:add(rgb[i])
# end
# return input
# end
# end
# class Lighting(object):
# """Lighting noise(AlexNet - style PCA - based noise)"""
# def __init__(self, alphastd, eigval, eigvec):
# self.alphastd = alphastd
# self.eigval = eigval
# self.eigvec = eigvec
# def __call__(self, img):
# if self.alphastd == 0:
# return img
# alpha = img.new().resize_(3).normal_(0, self.alphastd)
# rgb = self.eigvec.type_as(img).clone()\
# .mul(alpha.view(1, 3).expand(3, 3))\
# .mul(self.eigval.view(1, 3).expand(3, 3))\
# .sum(1).squeeze()
# return img.add(rgb.view(3, 1, 1).expand_as(img))
def train(train_loader, model, criterion, optimizer, epoch, log):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# print('i: ',i)
# compute output
output = model(input_var)
# print('output: ',output.shape)
# print('output(target_var): ',target_var.shape)
# print('target_var: ', target_var)
loss = criterion(output, target_var)
# print('loss: ',loss.shape)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
# print('input.size(): ', input.size(0))
# print('loss.data[0]: ', loss.data[0])
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print_log('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5), log)
return top1.avg, top5.avg, losses.avg
def validate(val_loader, model, criterion, log):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input, volatile=True)
target_var = torch.autograd.Variable(target, volatile=True)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print_log('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5), log)
print_log(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss@ {error:.3f}'.format(top1=top1, top5=top5, error=losses.avg), log)
return top1.avg, top5.avg, losses.avg
def save_checkpoint(state, is_best, filename, bestname):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, bestname)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def print_log(print_string, log):
print("{}".format(print_string))
log.write('{}\n'.format(print_string))
log.flush()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
def summary(model, input_size):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = '%s-%i' % (class_name, module_idx+1)
summary[m_key] = OrderedDict()
summary[m_key]['input_shape'] = list(input[0].size())
summary[m_key]['input_shape'][0] = -1
if isinstance(output, (list,tuple)):
summary[m_key]['output_shape'] = [[-1] + list(o.size())[1:] for o in output]
else:
summary[m_key]['output_shape'] = list(output.size())
summary[m_key]['output_shape'][0] = -1
params = 0
if hasattr(module, 'weight') and hasattr(module.weight, 'size'):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]['trainable'] = module.weight.requires_grad
if hasattr(module, 'bias') and hasattr(module.bias, 'size'):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]['nb_params'] = params
if (not isinstance(module, nn.Sequential) and
not isinstance(module, nn.ModuleList) and
not (module == model)):
hooks.append(module.register_forward_hook(hook))
if torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
# check if there are multiple inputs to the network
if isinstance(input_size[0], (list, tuple)):
x = [Variable(torch.rand(2,*in_size)).type(dtype) for in_size in input_size]
else:
x = Variable(torch.rand(2,*input_size)).type(dtype)
# print(type(x[0]))
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
# make a forward pass
# print(x.shape)
model(x)
# remove these hooks
for h in hooks:
h.remove()
print('----------------------------------------------------------------')
line_new = '{:>20} {:>25} {:>15}'.format('Layer (type)', 'Output Shape', 'Param #')
print(line_new)
print('================================================================')
total_params = 0
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
line_new = '{:>20} {:>25} {:>15}'.format(layer, str(summary[layer]['output_shape']), '{0:,}'.format(summary[layer]['nb_params']))
total_params += summary[layer]['nb_params']
if 'trainable' in summary[layer]:
if summary[layer]['trainable'] == True:
trainable_params += summary[layer]['nb_params']
print(line_new)
print('================================================================')
print('Total params: {0:,}'.format(total_params))
print('Trainable params: {0:,}'.format(trainable_params))
print('Non-trainable params: {0:,}'.format(total_params - trainable_params))
print('----------------------------------------------------------------')
# return summary
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment