Skip to content

Instantly share code, notes, and snippets.

@avijit9
Created July 7, 2017 06:52
Show Gist options
  • Save avijit9/1c7eebf124a02a555f7626a0fbcd04a5 to your computer and use it in GitHub Desktop.
Save avijit9/1c7eebf124a02a555f7626a0fbcd04a5 to your computer and use it in GitHub Desktop.
Pytorch example on Fintetuning
from __future__ import print_function
import argparse
import os
import time
import pdb
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torchvision.models as models
from utils import AverageMeter, save_checkpoint, adjust_learning_rate, accuracy
print("Successfully imported all..")
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
print (model_names)
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--exp', metavar='EXP', help='experiment name')
parser.add_argument('--data', metavar='DIR', default = '/tmp/avijit_dataset/', help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',choices=model_names,help='model architecture: ' +' | '.join(model_names) +' (default: resnet18)')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=9)
parser.add_argument('--batchSize', type=int, default=128, help="input batch size")
parser.add_argument('--imageSize', type=int, default=224,help='image Size')
parser.add_argument('--niter', type=int, default=99, help="no of epochs to train the model")
parser.add_argument('--lr', type=float, default=0.01, help="leraning rate, default=0.0001")
parser.add_argument('--beta1', type=float, default=0.9, help="beta1 for adam")
parser.add_argument('--ngpu' , type=int, default=1, help='number of GPUs to use')
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--pretrained', default=True, dest='pretrained', action='store_true',help='use pre-trained model')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
opt = parser.parse_args()
print(opt)
try:
os.makedirs(opt.outf)
except OSError:
pass
opt.manualSeed = random.randint(1, 10000) # fix seed
best_prec1 = 0
def main():
global args, best_prec1
if opt.pretrained:
print ("=> using pre-trained model '{}'".format(opt.arch))
model = models.__dict__[opt.arch](pretrained=True)
# print(model)
model.fc = nn.Linear(512, 11)
# pdb.set_trace()
else:
print ("=> creating model '{}'".format(opt.arch))
model = models.__dict__[opt.arch]()
if opt.arch.startswith('alexnet') or opt.arch.startswith('vgg'):
#:w
# model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
#model = torch.nn.DataParallel(model).cuda()
model = model.cuda()
#:w: print(model.__dict__)
if opt.resume:
if os.path.isfile(opt.resume):
print("=> Loading checkpont '{}'".format(opt.resume))
checkpoint = torch.load(opt.resume)
opt.start_epoc = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})".format(opt.evaluate,checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(opt.resume))
cudnn.benchmark = True
traindir = os.path.join(opt.data, 'train')
valdir = os.path.join(opt.data,'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
dset.ImageFolder(traindir,transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])), batch_size = opt.batchSize, shuffle = True, num_workers = opt.workers, pin_memory = True)
#pdb.set_trace()
val_loader = torch.utils.data.DataLoader(
dset.ImageFolder(valdir,transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])), batch_size = opt.batchSize, shuffle = False, num_workers = opt.workers, pin_memory = True)
# pdb.set_trace()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD([
{'params': model.conv1.parameters()},
{'params': model.bn1.parameters()},
{'params': model.relu.parameters()},
{'params': model.maxpool.parameters()},
{'params': model.layer1.parameters()},
{'params': model.layer2.parameters()},
{'params': model.layer3.parameters()},
{'params': model.layer4.parameters()},
{'params': model.avgpool.parameters()},
{'params': model.fc.parameters(), 'lr': opt.lr}
], lr=opt.lr*0.1, momentum=0.9)
#optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum = opt.momentum, weight_decay = opt.weight_decay)
#pdb.set_trace()
for epoch in range(opt.start_epoch, opt.niter):
#adjust_learning_rate(opt.lr, optimizer, epoch)
train(train_loader, model, criterion, optimizer, epoch)
prec1 = validation(val_loader, model, criterion)
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({'epoch': epoch+1, 'opt': opt.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1}, is_best, opt.exp)
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# Train mode
model.train()
# compute gradient and do SGD step
# compute gradient and do SGD step
end = time.time()
for i, (input, target) in enumerate(train_loader):
data_time.update(time.time()-end)
input = input.cuda()
target = target.cuda(async=True)
input_var = Variable(input)
target_var = Variable(target)
#compute output
output = model(input_var)
loss = criterion(output, target_var)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Measure accuracy and record loss
# Measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
#print(input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
def validation(test_loader, model, criterion):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# Evaluation mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(test_loader):
data_time.update(time.time()-end)
input = input.cuda()
target = target.cuda(async=True)
input_var = Variable(input)
target_var = Variable(target)
#compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % opt.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(test_loader), batch_time=batch_time,
loss=losses, top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))
return top1.avg
if __name__=='__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment