Skip to content

Instantly share code, notes, and snippets.

@jeakwon
Last active October 28, 2019 04:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeakwon/1f3eb0afdd2114a0ecf307957d5456e9 to your computer and use it in GitHub Desktop.
Save jeakwon/1f3eb0afdd2114a0ecf307957d5456e9 to your computer and use it in GitHub Desktop.
easy metrics for pytorch
import time
import torch
class Tag:
def __init__(self, topk=(1,), verbose=True):
self.__topk = topk
self.__verbose = verbose
def __call__(self, data_loader):
self.__time_start = time.time()
self.__tick_start = time.time()
self.__batch_size = data_loader.batch_size
self.__total_iter = len(data_loader)
self.__loss_sum = 0
for k in self.__topk:
setattr(self, f'__top{k}_sum', 0)
self.iter = 0
self.time = 0
self.tick = 0
self.loss = 0
for k in self.__topk:
setattr(self, f'top{k}', 0)
return data_loader
def log(self, outputs, targets, loss):
self.__outputs = outputs
self.__targets = targets
self.__loss = loss
metrics = {}
self.iter += 1
self.time = self.__update_time()
self.tick = self.__update_tick()
self.loss = self.__update_loss()
for k in self.__topk:
topk = self.__update_topk(k)
setattr(self, f'top{k}', topk)
metrics[f'top{k}'] = topk
metrics['iter']=self.iter
metrics['time']=self.time
metrics['tick']=self.tick
metrics['loss']=self.loss
if self.__verbose:
self.__display()
return metrics
@property
def msg(self):
Iter=f'{self.iter}/{self.__total_iter}'
Time=f'{time.strftime("%H:%M:%S", time.gmtime(self.time))}'
Tick=f'{self.tick:.3f}s/it'
Loss=f'Loss: {self.loss:.3f}'
Topk=' | '.join(['top{}: {:.3f}'.format(k, getattr(self, f'top{k}')) for k in self.__topk])
return f'[{Iter} {Tick} {Time}] {Loss} | {Topk}'
def __display(self):
print(self.msg, end='\r')
def __update_time(self):
ret = time.time() - self.__time_start
return ret
def __update_tick(self):
ret = time.time() - self.__tick_start
self.__tick_start = time.time()
return ret
def __update_loss(self):
self.__loss_sum += self.__loss.item()
ret = self.__loss_sum/self.iter
return ret
def __update_topk(self, k):
with torch.no_grad():
_, pred = self.__outputs.topk(k, 1, True, True)
pred = pred.t()
correct = pred.eq(self.__targets.view(1, -1).expand_as(pred))
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
topk = correct_k.mul_(100.0 / self.__batch_size).item()
setattr(self, f'__top{k}_sum', getattr(self, f'__top{k}_sum') + topk)
ret = getattr(self, f'__top{k}_sum')/self.iter
return ret
class Log:
def __init__(self, log_dir=None, display=True, ckpt_dir=None, ):
self.log_dir=log_dir
self.ckpt_dir=ckpt_dir
def __call__(self, epochs):
return range(1, epochs+1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment