Created
June 22, 2020 17:58
-
-
Save Kshitij09/0b2e77890e905e4a5aebb51b211e3310 to your computer and use it in GitHub Desktop.
Utility functions for Pytorch Lightning (heavily borrowed from fastai)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning import Callback | |
from IPython.display import display, clear_output | |
import copy | |
import pandas as pd | |
import torch | |
from torch import nn | |
import matplotlib.pyplot as plt | |
import math | |
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
def subplots(nrows=1, ncols=1, figsize=None, imsize=3, add_vert=0, **kwargs): | |
if figsize is None: figsize=(ncols*imsize, nrows*imsize+add_vert) | |
fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs) | |
if nrows*ncols==1: ax = array([ax]) | |
return fig,ax | |
def get_grid(n, nrows=None, ncols=None, add_vert=0, figsize=None, double=False, title=None, return_fig=False, **kwargs): | |
"Return a grid of `n` axes, `rows` by `cols`" | |
nrows = nrows or int(math.sqrt(n)) | |
ncols = ncols or int(math.ceil(n/nrows)) | |
if double: ncols*=2 ; n*=2 | |
fig,axs = subplots(nrows, ncols, figsize=figsize, **kwargs) | |
axs = [ax if i<n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n] | |
if title is not None: fig.suptitle(title, weight='bold', size=14) | |
return (fig,axs) if return_fig else axs | |
def min_max_scale(x): return (x - x.min()) / (x.max() - x.min()) | |
def show_one(x,y,ax): | |
x = x.permute(1,2,0) | |
x = x.cpu().detach().numpy() | |
ax.imshow(min_max_scale(x)) | |
ax.set_title(train_ds.classes[y.item()]) | |
ax.axis('off') | |
return ax | |
def show_batch(xb,yb, max_n=10, nrows=None, ncols=None, figsize=None): | |
axs = get_grid(min(len(xb), max_n), nrows=nrows, ncols=ncols, figsize=figsize) | |
axs = [show_one(x,y,ax) for x,y,ax,_ in zip(xb,yb,axs,range(max_n))] | |
# return axs | |
class ConcatPool2d(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.avgpool = nn.AdaptiveAvgPool2d(1) | |
self.maxpool = nn.AdaptiveMaxPool2d(1) | |
def forward(self,x): | |
x1 = self.avgpool(x) | |
x2 = self.maxpool(x) | |
x = torch.cat([x1,x2],dim=1) | |
x = torch.flatten(x,1) | |
return x | |
def get_resolution(m,image_size=(32,32)): return m(torch.randn(1,3,*image_size)).shape[1:] | |
def _unwrap(x): | |
if isinstance(x,torch.Tensor): | |
return x.item() | |
return x | |
class PrintCallback(Callback): | |
def __init__(self,log_keys:list): | |
self.metrics = [] | |
self.log_keys = log_keys | |
def on_epoch_end(self,trainer,pl_module): | |
clear_output(wait=True) | |
metrics_dict = copy.deepcopy(trainer.callback_metrics) | |
del metrics_dict['loss'] | |
metrics_dict = {k:_unwrap(v) for k,v in metrics_dict.items()} | |
self.metrics.append(metrics_dict) | |
del metrics_dict | |
metrics_df = pd.DataFrame.from_records(self.metrics, | |
columns=self.log_keys) | |
display(metrics_df) | |
def broadcast_vec(ndim,dim,*t): | |
dims = [1] * ndim | |
dims[dim] = -1 | |
return [x.view(*dims) for x in t] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment