Skip to content

Instantly share code, notes, and snippets.

@alkalait
Last active July 18, 2021 19:12
Show Gist options
  • Save alkalait/1497032fb601997efd9be4b90dddc63b to your computer and use it in GitHub Desktop.
Save alkalait/1497032fb601997efd9be4b90dddc63b to your computer and use it in GitHub Desktop.
show: pretty `imshow`s with commonly used arguments
# Author: Freddie Kalaitzis
# License: MIT
# Source: https://gist.github.com/alkalait/1497032fb601997efd9be4b90dddc63b
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import xarray as xr
sns.set_style('white')
from matplotlib.axes._subplots import Subplot
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch import Tensor
from typing import Optional, Union, List
Array = Union[np.ndarray, Tensor, xr.DataArray]
## TODO: colorbar issue with float single-channel images.
def show(
x : Union[Array, List[Array]],
ax : Optional[Union[Subplot, List[Subplot]]] = None,
rows : Optional[int] = None,
columns : Optional[int] = None,
order : str = 'R',
title : Optional[Union[str, List[str]]] = None,
vmin : Optional[float] = None,
vmax : Optional[float] = None,
normalize : bool = False,
fontsize : int = 15,
colorbar : bool = False,
cmap: Optional[str] = None,
axis : bool = True,
**figure_kwargs,
) -> None:
'''
Does a bunch of `plt.imshow`s with commonly used arguments.
'''
if order not in ['C', 'R']:
raise ValueError(f"Values expected for `order`: ['C'|'R']. Got '{order}'.")
def tonumpy(x):
if isinstance(x, xr.DataArray):
return x.data
elif isinstance(x, Tensor):
return x.cpu().detach().numpy()
elif isinstance(x, np.ndarray):
return x
elif isinstance(x, list) and len(x) == 0:
return np.array([0])
else:
raise NotImplementedError(f'Not implemented for type {type(x).__name__}.')
if not isinstance(x, list):
x = [x]
G = []
dates = []
for g in x: # List of 2D / 3D / 4D / 5D arrays
if isinstance(g, (list, type(None))):
g = np.array([])
if hasattr(g, 'time'):
time = g.time
if time.ndim == 0:
time = time.expand_dims(dim='time')
time = time.to_index().astype(str)
else:
time = [None] * max(1, len(g))
g = tonumpy(g)
if g.ndim <= 2:
G += [g]
if g.ndim == 3:
G += [g.transpose(1, 2, 0)]
elif g.ndim >= 4:
if g.ndim == 5:
B, T, C, H, W = g.shape
g = g.reshape(B * T, C, H, W)
time = [None] * len(g)
G += [img.transpose(1, 2, 0) for img in g]
dates.extend(time)
n = len(G)
if rows is None and columns is None:
if order == 'R':
rows, columns = 1, (n or 1)
elif order == 'C':
rows, columns = (n or 1), 1
elif rows is None:
rows = int(np.ceil(n / columns))
elif columns is None:
columns = int(np.ceil(n / rows))
s = figure_kwargs.get('figsize', 3)
figure_kwargs['figsize'] = (columns * s, rows * s)
if ax is None:
_, ax = plt.subplots(rows, columns, tight_layout=True, squeeze=False, **figure_kwargs)
if order == 'C':
ax = ax.T
ax = ax.flatten()
else:
pass
if isinstance(title, (xr.DataArray, np.ndarray, Tensor)):
title = [f'{x:.4f}' for x in title.data]
elif isinstance(title, list):
title = [str(x) for x in title]
elif title is None:
title = [None] * n if n else None
for g, ax_, t, date in zip(G, ax, title, dates):
g = np.nan_to_num(g)
unique_vals = np.unique(g)
if len(unique_vals) == 0:
ax_.imshow([[1]], cmap='gray', vmin=0, vmax=1);
ax_.axis(False)
continue
cmap = cmap or plt.get_cmap().name
if g.dtype.type is np.bool_:
cmap = plt.cm.get_cmap('binary_r', 2)
im = ax_.imshow(g, interpolation='nearest', cmap=cmap, vmin=0, vmax=1);
else:
_vmax = (float(g.max()) or 1.0) if vmax is None else vmax
_vmin = float(g.min()) if vmin is None else vmin
if len(unique_vals) > 1 and normalize:
g = (g - _vmin) / (_vmax - _vmin)
_vmin, _vmax = 0, 1
else:
_vmin, _vmax = g.min(), g.max()
im = ax_.imshow(g, interpolation='nearest', cmap=cmap, vmin=_vmin, vmax=_vmax);
ax_.axis(axis)
t = t or date
ax_.set_title(t, fontsize=fontsize);
if colorbar and (g.shape[-1] == 1 or g.ndim == 2):
divider = make_axes_locatable(ax_)
cax = divider.append_axes('right', size='5%', pad=0.05)
im.set_cmap(cmap)
if len(unique_vals) <= 20:
cbar = plt.colorbar(im, cax=cax, orientation='vertical', ticks=np.unique(g))
ticklabels = [f"{x:.0f}" if x % 1 == 0 else f"{x:.3f}" for x in unique_vals]
cbar.ax.set_yticklabels(ticklabels)
else:
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
# for ax_ in ax:
# ax_.axis(axis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment