Skip to content

Instantly share code, notes, and snippets.

@ushahid
Last active January 30, 2021 22:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ushahid/48efbdc04bc6e6726f8cec6df92bc904 to your computer and use it in GitHub Desktop.
Save ushahid/48efbdc04bc6e6726f8cec6df92bc904 to your computer and use it in GitHub Desktop.
import torch
from abc import abstractmethod
from callback.cb_base import CallbackBase
from itertools import chain
class ProcessResultsCallback(CallbackBase):
def __init__(self, modes=['train', 'val', 'test']):
super().__init__()
self.modes = modes[:]
self.ident = dict()
self.pred = dict()
self.gt = dict()
self.uninterrupted = True
@abstractmethod
def process_results(self, mode, ident, pred, gt, trainer, pl_module):
raise NotImplementedError
def _reset(self, mode):
self.ident[mode] = []
self.pred[mode] = []
self.gt[mode] = []
def _append(self, mode, ident, pred, gt):
self.ident[mode].append(ident)
self.pred[mode].append(pred)
self.gt[mode].append(gt)
def _concat(self, mode):
ident = list(chain.from_iterable(self.ident[mode]))
pred = torch.cat(self.pred[mode], axis=0).detach().cpu()
gt = torch.cat(self.gt[mode], axis=0).detach().cpu()
return ident, pred, gt
def on_keyboard_interrupt(self, trainer, pl_module):
self.uninterrupted = False
def on_train_epoch_start(self, trainer, pl_module):
if 'train' in self.modes:
self._reset('train')
def on_validation_epoch_start(self, trainer, pl_module):
if 'val' in self.modes:
self._reset('val')
def on_test_epoch_start(self, trainer, pl_module):
if 'test' in self.modes:
self._reset('test')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if 'train' in self.modes:
outputs = outputs[0][0]['extra']
self._append('train', outputs['id'], outputs['pred'], outputs['gt'])
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if 'val' in self.modes:
self._append('val', outputs['id'], outputs['pred'], outputs['gt'])
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if 'test' in self.modes:
self._append('test', outputs['id'], outputs['pred'], outputs['gt'])
def on_train_end(self, trainer, pl_module):
if self.uninterrupted:
if 'train' in self.modes:
ident, pred, gt = self._concat('train')
self._reset('train')
self.process_results('train', ident, pred, gt, trainer, pl_module)
if 'val' in self.modes:
ident, pred, gt = self._concat('val')
self._reset('val')
self.process_results('val', ident, pred, gt, trainer, pl_module)
def on_test_epoch_end(self, trainer, pl_module):
if 'test' in self.modes:
if 'val' in self.modes:
ident, pred, gt = self._concat('test')
self._reset('test')
self.process_results('test', ident, pred, gt, trainer, pl_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment