Skip to content

Instantly share code, notes, and snippets.

@JotaRata
Last active July 5, 2023 20:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JotaRata/87c50a8c6bb9517316feb29a8a241aeb to your computer and use it in GitHub Desktop.
Save JotaRata/87c50a8c6bb9517316feb29a8a241aeb to your computer and use it in GitHub Desktop.
Matplotlib subplots shortcut
class get_axis():
"""
Context manager to generate subplots quickly
Parameters:
- rows, cols (int): Rows and columns to draw.
- xlabel, ylabel (str / list): Labels for the horizontal and vertical axis, If xlabel/ylabel are lists then each corresponding subplot will have its own labels.
- title (str / list): Title of the plot, If title is a list then each subplot will have its own title.
- figzise (tuple): Size of the figure in inches (same as figure.figsize).
- figargs (dict): Additionl arguments passed to plt.subplots
Usage:
This class should be used along the with statement
Example 1:
with get_axis(1, 1, title='Plot 1', figsize= (6, 6)) as ax:
ax.scatter(data[:, 0], data[:, 1], label= 'Data 1')
ax.scatter(data2[:, 0], data2[:, 1], label= 'Data 2')
ax.legend()
Example 2:
with get_axis(1, 2, title= ['Cosine function', 'Squared Cosine function'], xlabel= 'Time', sharey= True) as axs:
t = np.linspace(0, 6.28, 20) # variable t cannot be referenced outside this context
axs[0].plot(t, np.cos(t))
axs[1].plot(t, np.cos(t) ** 2)
"""
def __init__(self, rows=1, cols=1, xlabel= None, ylabel=None, title= None, axsize= (10, 6), **figargs):
self.r= rows; self.c= cols
self.xlabel= np.pad(xlabel, (0, rows*cols - np.size(xlabel)), mode='constant', constant_values=('', '')) if type(xlabel) is tuple or type(xlabel) is list else xlabel
self.ylabel= np.pad(ylabel, (0, rows*cols - np.size(ylabel)), mode='constant', constant_values=('', '')) if type(ylabel) is tuple or type(ylabel) is list else ylabel
self.title = np.pad(title, (0, rows*cols - np.size(title)), mode='constant', constant_values=('', '')) if type(title) is tuple or type(title) is list else title
self.args= figargs
if 'figsize' not in figargs.keys():
self.size = (axsize[0] * cols, axsize[1] * rows)
else:
self.size = figargs['figsize']
del self.args['figsize']
def __enter__(self):
_, ax = plt.subplots(self.r, self.c, figsize= self.size, **self.args)
if isinstance(ax, plt.Axes):
if type(self.xlabel) is str:
ax.set_xlabel(self.xlabel)
if type(self.ylabel) is str:
ax.set_ylabel(self.ylabel)
ax.set_title(self.title)
elif isinstance(ax, np.ndarray):
ax= ax.flatten()
_i = range(np.size(ax))
if self.xlabel is not None:
if type(self.xlabel) is str:
for i in _i: ax[i].set_xlabel(self.xlabel)
elif is_list(ax):
for i in _i: ax[i].set_xlabel(self.xlabel[i])
if self.ylabel is not None:
if type(self.ylabel) is str:
for i in _i: ax[i].set_ylabel(self.ylabel)
elif is_list(ax):
for i in _i: ax[i].set_ylabel(self.ylabel[i])
if self.title is not None:
if type(self.title) is str:
for i in _i: ax[i].set_title(self.title)
elif is_list(ax):
for i in _i: ax[i].set_title(self.title[i])
return ax
def __exit__(self, *args): pass
@JotaRata
Copy link
Author

JotaRata commented Nov 24, 2022

Context manager to generate subplots quickly

Parameters:

  • rows, cols (int): Rows and columns to draw.
  • xlabel, ylabel (str / list): Labels for the horizontal and vertical axis, If xlabel/ylabel are lists then each corresponding subplot will have its own labels.
  • title (str / list): Title of the plot, If title is a list then each subplot will have its own title.
  • figzise (tuple): Size of the figure in inches (same as figure.figsize).
  • figargs (dict): Additionl arguments passed to plt.subplots

Usage:

This class should be used along the with statement

Example 1:

with get_axis(1, 1, title='Plot 1', figsize= (6, 6)) as ax:
  ax.scatter(data[:, 0], data[:, 1], label= 'Data 1')
  ax.scatter(data2[:, 0], data2[:, 1], label= 'Data 2')
  ax.legend()

Example 2:

with get_axis(1, 2, title= ['Cosine function', 'Squared Cosine function'], xlabel= 'Time', sharey= True) as axs:
  t = np.linspace(0, 6.28, 20)    # variable t cannot be referenced outside this context
  axs[0].plot(t, np.cos(t))
  axs[1].plot(t, np.cos(t) ** 2)

You can also use the returned object to store temporal variables that live inside the current scope.

with get_axis(1, 1, title='Plot 1', figsize= (6, 6)) as ax:
  ax.time = np.linspace(0, 2 * np.pi, 100)
  ax.values = np.sin(ax.time)

  ax.scatter(ax.time, ax.values, label= 'Data')
  ax.legend()

imagen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment