Skip to content

Instantly share code, notes, and snippets.

@agermanidis
Last active July 5, 2021 05:19
Show Gist options
  • Save agermanidis/275b23ad7a10ee89adccf021536bb97e to your computer and use it in GitHub Desktop.
Save agermanidis/275b23ad7a10ee89adccf021536bb97e to your computer and use it in GitHub Desktop.
pytorch classification model helpers
import time
import torch
def validate(val_loader, model, criterion, use_cuda=False, print_freq=10):
"""
Evaluate a classification model on the entire validation set.
Args:
val_loader: a DataLoader instance for the validation set
model: the model to evaluate
criterion: the loss criterion
use_cuda: run the model on the GPU
print_freq: log stats every N batches
Returns:
The model top-1 precision on the entire validation set.
"""
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
if use_cuda:
input = input.cuda(async=True)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input, volatile=True)
target_var = torch.autograd.Variable(target, volatile=True)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)\t'
'Prec@5 {top5.val:.3f}% ({top5.avg:.3f}%)'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def train(train_loader, model, criterion, optimizer, epoch, use_cuda=False, print_freq=10):
"""
Train a classification model on the entire training set.
Args:
train_loader: a DataLoader instance for the training set
model: the model to train
criterion: the loss criterion
optimizer: the training optimizer
epoch: the current epoch number
use_cuda: run the model on the GPU
print_freq: log stats every N batches
"""
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
if use_cuda:
input = input.cuda(async=True)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)\t'
'Prec@5 {top5.val:.3f}% ({top5.avg:.3f}%)'.format(
epoch, i, len(train_loader), batch_time=batch_time,
loss=losses, top1=top1, top5=top5))
def accuracy(output, target, topk=(1,)):
"""
Given predicted and ground truth labels,
calculate top-k accuracies.
Args:
output: predicted labels
target: ground truth labels
topk: a tuple with the top-n accuracies to calculate
Returns:
top-k accuracies
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment