Created
October 27, 2019 16:58
-
-
Save jeakwon/9d0063198a2d34943a950b1391381e3b to your computer and use it in GitHub Desktop.
pytorch tag class for metrics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
class AverageMeter: | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def accuracy(outputs, targets, topk=(1,)): | |
with torch.no_grad(): | |
maxk = max(topk) | |
batch_size = targets.size(0) | |
_, pred = outputs.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(targets.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
class Tag: | |
def __init__(self, topk=(1,), verbose=True): | |
self._topk = topk | |
self._verbose = verbose | |
def __call__(self, data_loader): | |
self._start = time.time() | |
self._batch_start = time.time() | |
self._iter_total = len(data_loader) | |
self._batch_size = data_loader.batch_size | |
self._iter = 0 | |
self._elapsed = AverageMeter() | |
self._losses = AverageMeter() | |
self._top = {} | |
for i in self._topk: | |
self._top[i]=AverageMeter() | |
return enumerate(data_loader, start=1) | |
def log(self, outputs, targets, loss): | |
self._iter += 1 | |
self._losses.update(loss.item(), self._batch_size) | |
accs = accuracy(outputs, targets, topk=self._topk) | |
for k, acc in zip(self._topk, accs): | |
self._top[k].update(acc.item(), self._batch_size) | |
self._elapsed.update(time.time() - self._batch_start) | |
self._batch_start = time.time() | |
Log = {f'top{k}': v.avg for k, v in self._top.items()} | |
Log['n_iter'] = self._iter | |
Log['t_batch'] = self._elapsed.val | |
Log['t_total'] = time.time()-self._start | |
Log['loss'] = self._losses.avg | |
if self._verbose: | |
self.show(Log) | |
return Log | |
def show(self, Log): | |
Log = Log.copy() | |
if self._iter_total: | |
Iter = '[{i}/{total}]'.format(i=Log.pop('n_iter'), total=self._iter_total) | |
else: | |
Iter = '[{}]'.format(Log.pop('n_iter')) | |
Batch = 't_batch: {:.3f}s'.format(Log.pop('t_batch') ) | |
Total = 't_total: {}'.format(time.strftime("%H:%M:%S", time.gmtime( Log.pop('t_total') ))) | |
Loss = 'loss: {:.3f}'.format(Log.pop('loss')) | |
Topk = [f'{k}: {v:.3f}' for k, v in Log.items()] | |
print(f'{Iter} '+' | '.join(Topk+[Loss, Batch, Total]), end='\r') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment