Skip to content

Instantly share code, notes, and snippets.

@jeakwon
Created October 27, 2019 16:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeakwon/9d0063198a2d34943a950b1391381e3b to your computer and use it in GitHub Desktop.
Save jeakwon/9d0063198a2d34943a950b1391381e3b to your computer and use it in GitHub Desktop.
pytorch tag class for metrics
import time
import torch
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(outputs, targets, topk=(1,)):
with torch.no_grad():
maxk = max(topk)
batch_size = targets.size(0)
_, pred = outputs.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(targets.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Tag:
def __init__(self, topk=(1,), verbose=True):
self._topk = topk
self._verbose = verbose
def __call__(self, data_loader):
self._start = time.time()
self._batch_start = time.time()
self._iter_total = len(data_loader)
self._batch_size = data_loader.batch_size
self._iter = 0
self._elapsed = AverageMeter()
self._losses = AverageMeter()
self._top = {}
for i in self._topk:
self._top[i]=AverageMeter()
return enumerate(data_loader, start=1)
def log(self, outputs, targets, loss):
self._iter += 1
self._losses.update(loss.item(), self._batch_size)
accs = accuracy(outputs, targets, topk=self._topk)
for k, acc in zip(self._topk, accs):
self._top[k].update(acc.item(), self._batch_size)
self._elapsed.update(time.time() - self._batch_start)
self._batch_start = time.time()
Log = {f'top{k}': v.avg for k, v in self._top.items()}
Log['n_iter'] = self._iter
Log['t_batch'] = self._elapsed.val
Log['t_total'] = time.time()-self._start
Log['loss'] = self._losses.avg
if self._verbose:
self.show(Log)
return Log
def show(self, Log):
Log = Log.copy()
if self._iter_total:
Iter = '[{i}/{total}]'.format(i=Log.pop('n_iter'), total=self._iter_total)
else:
Iter = '[{}]'.format(Log.pop('n_iter'))
Batch = 't_batch: {:.3f}s'.format(Log.pop('t_batch') )
Total = 't_total: {}'.format(time.strftime("%H:%M:%S", time.gmtime( Log.pop('t_total') )))
Loss = 'loss: {:.3f}'.format(Log.pop('loss'))
Topk = [f'{k}: {v:.3f}' for k, v in Log.items()]
print(f'{Iter} '+' | '.join(Topk+[Loss, Batch, Total]), end='\r')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment