Last active
July 5, 2021 05:19
-
-
Save agermanidis/275b23ad7a10ee89adccf021536bb97e to your computer and use it in GitHub Desktop.
pytorch classification model helpers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
def validate(val_loader, model, criterion, use_cuda=False, print_freq=10): | |
""" | |
Evaluate a classification model on the entire validation set. | |
Args: | |
val_loader: a DataLoader instance for the validation set | |
model: the model to evaluate | |
criterion: the loss criterion | |
use_cuda: run the model on the GPU | |
print_freq: log stats every N batches | |
Returns: | |
The model top-1 precision on the entire validation set. | |
""" | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to evaluate mode | |
model.eval() | |
end = time.time() | |
for i, (input, target) in enumerate(val_loader): | |
if use_cuda: | |
input = input.cuda(async=True) | |
target = target.cuda(async=True) | |
input_var = torch.autograd.Variable(input, volatile=True) | |
target_var = torch.autograd.Variable(target, volatile=True) | |
# compute output | |
output = model(input_var) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
top5.update(prec5[0], input.size(0)) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % print_freq == 0: | |
print('Test: [{0}/{1}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)\t' | |
'Prec@5 {top5.val:.3f}% ({top5.avg:.3f}%)'.format( | |
i, len(val_loader), batch_time=batch_time, loss=losses, | |
top1=top1, top5=top5)) | |
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' | |
.format(top1=top1, top5=top5)) | |
return top1.avg | |
def train(train_loader, model, criterion, optimizer, epoch, use_cuda=False, print_freq=10): | |
""" | |
Train a classification model on the entire training set. | |
Args: | |
train_loader: a DataLoader instance for the training set | |
model: the model to train | |
criterion: the loss criterion | |
optimizer: the training optimizer | |
epoch: the current epoch number | |
use_cuda: run the model on the GPU | |
print_freq: log stats every N batches | |
""" | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to train mode | |
model.train() | |
end = time.time() | |
for i, (input, target) in enumerate(train_loader): | |
if use_cuda: | |
input = input.cuda(async=True) | |
target = target.cuda(async=True) | |
input_var = torch.autograd.Variable(input) | |
target_var = torch.autograd.Variable(target) | |
# compute output | |
output = model(input_var) | |
loss = criterion(output, target_var) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
losses.update(loss.data[0], input.size(0)) | |
top1.update(prec1[0], input.size(0)) | |
top5.update(prec5[0], input.size(0)) | |
# compute gradient and do SGD step | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if i % print_freq == 0: | |
print('Epoch: [{0}][{1}/{2}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f}% ({top1.avg:.3f}%)\t' | |
'Prec@5 {top5.val:.3f}% ({top5.avg:.3f}%)'.format( | |
epoch, i, len(train_loader), batch_time=batch_time, | |
loss=losses, top1=top1, top5=top5)) | |
def accuracy(output, target, topk=(1,)): | |
""" | |
Given predicted and ground truth labels, | |
calculate top-k accuracies. | |
Args: | |
output: predicted labels | |
target: ground truth labels | |
topk: a tuple with the top-n accuracies to calculate | |
Returns: | |
top-k accuracies | |
""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].view(-1).float().sum(0) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
class AverageMeter(object): | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment