Skip to content

Instantly share code, notes, and snippets.

@cmaspi
Created February 29, 2024 10:51
Show Gist options
  • Save cmaspi/71bd9bdefec986bd7c0e2893e8b7049a to your computer and use it in GitHub Desktop.
Save cmaspi/71bd9bdefec986bd7c0e2893e8b7049a to your computer and use it in GitHub Desktop.
My Smart plotter to reduce some redundance in code (uses matplotlib)
import matplotlib.pyplot as plt
from typing import List, Tuple
import math
import numpy as np
class SmartPlot:
@staticmethod
def plot_single_image(image,
title,
colorbar=False):
plt.imshow(image)
plt.title(title)
if colorbar:
plt.colorbar()
plt.show()
@staticmethod
def get_dimensions(N):
sqrt = round(math.sqrt(N))
if sqrt * sqrt == N:
return (sqrt, sqrt)
for m in range(5, 0, -1):
if N % m == 0:
return (N//m, m)
@classmethod
def plot_multi_images(cls,
images: List,
titles: List,
figsize: Tuple[int, int],
axis: List | str = 'on',
cmaps: List | str = 'virdis',
suptitle: str | None = None,
dimensions: Tuple[int, int] | None = None,
colorbar: bool | List[bool] = False,
return_fig_axis: bool = False
):
num_images = len(images)
if dimensions is None:
dimensions = cls.get_dimensions(num_images)
if type(colorbar) is bool:
colorbar = [colorbar] * num_images
if type(axis) is str:
axis = [axis] * num_images
if type(cmaps) is str:
cmaps = [cmaps] * num_images
fig, ax = plt.subplots(*dimensions, figsize=figsize)
ax = np.array(ax).flatten()
for i, (img, title, cbar) in enumerate(zip(images, titles, colorbar)):
t = ax[i].imshow(img, cmap=cmaps[i])
if title:
ax[i].set_title(title)
if cbar:
plt.colorbar(t)
ax[i].axis(axis[i])
if suptitle:
fig.suptitle(suptitle)
if return_fig_axis:
return fig, ax
else:
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment