Skip to content

Instantly share code, notes, and snippets.

@Kshitij09
Created June 22, 2020 17:58
Show Gist options
  • Save Kshitij09/0b2e77890e905e4a5aebb51b211e3310 to your computer and use it in GitHub Desktop.
Save Kshitij09/0b2e77890e905e4a5aebb51b211e3310 to your computer and use it in GitHub Desktop.
Utility functions for Pytorch Lightning (heavily borrowed from fastai)
from pytorch_lightning import Callback
from IPython.display import display, clear_output
import copy
import pandas as pd
import torch
from torch import nn
import matplotlib.pyplot as plt
import math
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
def subplots(nrows=1, ncols=1, figsize=None, imsize=3, add_vert=0, **kwargs):
if figsize is None: figsize=(ncols*imsize, nrows*imsize+add_vert)
fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
if nrows*ncols==1: ax = array([ax])
return fig,ax
def get_grid(n, nrows=None, ncols=None, add_vert=0, figsize=None, double=False, title=None, return_fig=False, **kwargs):
"Return a grid of `n` axes, `rows` by `cols`"
nrows = nrows or int(math.sqrt(n))
ncols = ncols or int(math.ceil(n/nrows))
if double: ncols*=2 ; n*=2
fig,axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
axs = [ax if i<n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
if title is not None: fig.suptitle(title, weight='bold', size=14)
return (fig,axs) if return_fig else axs
def min_max_scale(x): return (x - x.min()) / (x.max() - x.min())
def show_one(x,y,ax):
x = x.permute(1,2,0)
x = x.cpu().detach().numpy()
ax.imshow(min_max_scale(x))
ax.set_title(train_ds.classes[y.item()])
ax.axis('off')
return ax
def show_batch(xb,yb, max_n=10, nrows=None, ncols=None, figsize=None):
axs = get_grid(min(len(xb), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
axs = [show_one(x,y,ax) for x,y,ax,_ in zip(xb,yb,axs,range(max_n))]
# return axs
class ConcatPool2d(nn.Module):
def __init__(self):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.maxpool = nn.AdaptiveMaxPool2d(1)
def forward(self,x):
x1 = self.avgpool(x)
x2 = self.maxpool(x)
x = torch.cat([x1,x2],dim=1)
x = torch.flatten(x,1)
return x
def get_resolution(m,image_size=(32,32)): return m(torch.randn(1,3,*image_size)).shape[1:]
def _unwrap(x):
if isinstance(x,torch.Tensor):
return x.item()
return x
class PrintCallback(Callback):
def __init__(self,log_keys:list):
self.metrics = []
self.log_keys = log_keys
def on_epoch_end(self,trainer,pl_module):
clear_output(wait=True)
metrics_dict = copy.deepcopy(trainer.callback_metrics)
del metrics_dict['loss']
metrics_dict = {k:_unwrap(v) for k,v in metrics_dict.items()}
self.metrics.append(metrics_dict)
del metrics_dict
metrics_df = pd.DataFrame.from_records(self.metrics,
columns=self.log_keys)
display(metrics_df)
def broadcast_vec(ndim,dim,*t):
dims = [1] * ndim
dims[dim] = -1
return [x.view(*dims) for x in t]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment