Skip to content

Instantly share code, notes, and snippets.

@asross
Last active September 7, 2018 14:44
Show Gist options
  • Save asross/9d1fdf286ac0f1d2af6a5b17d83dc5d9 to your computer and use it in GitHub Desktop.
Save asross/9d1fdf286ac0f1d2af6a5b17d83dc5d9 to your computer and use it in GitHub Desktop.
A helper to plot grids of graphs in matplotlib.pyplot
"""
Examples:
with figure_grid(5, 3) as grid:
grid.next()
# plot something
grid.next()
# plot something
# ...etc
with figure_grid(10, 4) as grid:
for i, axis in enumerate(grid.each_subplot()):
# plot something
"""
import matplotlib.pyplot as plt
class figure_grid():
def next_subplot(self, **kwargs):
if self.subplots:
self.after_each()
self.subplots += 1
return self.fig.add_subplot(self.rows, self.cols, self.subplots, **kwargs)
def each_subplot(self):
for _ in range(self.rows * self.cols):
yield self.next_subplot()
def title(self, title, fontsize=18, y=None, va='top', **kwargs):
if y is None:
y = 1 + 0.175/(self.rh*self.rows)
self.fig.suptitle(title, y=y, va=va, fontsize=fontsize, **kwargs)
def __init__(self, rows, cols, rowheight=3, rowwidth=16, after_each=lambda: None, title=None):
self.rows = rows
self.cols = cols
self.rh = rowheight
self.rw = rowwidth
self.fig = plt.figure(figsize=(rowwidth, rowheight*self.rows))
self.subplots = 0
if after_each == 'legend':
after_each = lambda: plt.legend(loc='best')
self.after_each = after_each
if title is not None:
self.title(title)
def __enter__(self):
return self
def __exit__(self, _type, _value, _traceback):
self.after_each()
plt.tight_layout()
plt.show()
next = next_subplot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment