Skip to content

Instantly share code, notes, and snippets.

@matthewcarbone
Created July 11, 2022 18:35
Show Gist options
  • Save matthewcarbone/f5201b1c44963ff9453b9cc1d5f768ac to your computer and use it in GitHub Desktop.
Save matthewcarbone/f5201b1c44963ff9453b9cc1d5f768ac to your computer and use it in GitHub Desktop.
Helper for making nice plots in matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits import axes_grid1
class MPLAdjutant:
def __init__(self):
self.default_DPI = 250
# self.default_labelsize = 14
# self.default_xtick_labelsize = 12
# self.default_ytick_labelsize = 12
self.width = 3.487
self.height = self.width / 1.618
def set_default_font(self, labelsize=12):
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.usetex'] = True
plt.rc('xtick', labelsize=labelsize)
plt.rc('ytick', labelsize=labelsize)
plt.rc('axes', labelsize=labelsize)
def set_defaults(self):
mpl.rcParams['figure.dpi'] = self.default_DPI
# mpl.rcParams['axes.labelsize'] = self.default_labelsize
# plt.rcParams['xtick.labelsize'] = self.default_xtick_labelsize
# plt.rcParams['ytick.labelsize'] = self.default_ytick_labelsize
self.set_default_font()
def set_size_one_column(self, fig, xwidth=1.0, xheight=1.0):
fig.set_size_inches(self.width * xwidth, self.height * xheight)
def set_size_square(self, fig, xwidth=1.0, xheight=1.0):
fig.set_size_inches(self.height * xwidth, self.height * xheight)
def set_size_inset(self, fig, xwidth=1.0, xheight=1.0):
fig.set_size_inches(
self.width * xwidth / 2.0, self.width * xheight / 2.0
)
@staticmethod
def add_colorbar(
im, aspect=10, pad_fraction=0.5, integral_ticks=None, **kwargs
):
"""Add a vertical color bar to an image plot."""
# https://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph
divider = axes_grid1.make_axes_locatable(im.axes)
width = axes_grid1.axes_size.AxesY(im.axes, aspect=1./aspect)
pad = axes_grid1.axes_size.Fraction(pad_fraction, width)
current_ax = plt.gca()
cax = divider.append_axes("right", size=width, pad=pad)
plt.sca(current_ax)
cbar = im.axes.figure.colorbar(im, cax=cax, **kwargs)
if integral_ticks is not None:
L = len(integral_ticks)
cbar.set_ticks([
cbar.vmin + (cbar.vmax - cbar.vmin) / L * ii
- (cbar.vmax - cbar.vmin) / L / 2.0 for ii in range(1, L + 1)
])
cbar.set_ticklabels(integral_ticks)
return cbar
@staticmethod
def _set_lims(ax, low, high, which, threshold):
"""Sets the axes limits.
Parameters
----------
which : {'x', 'y'}
threshold : float
The percentage margin.
"""
assert high > low
assert which in ['x', 'y']
domain = high - low
extend = threshold * domain
high += extend
low -= extend
if which == 'x':
ax.set_xlim(low, high)
else:
ax.set_ylim(low, high)
def set_xlim(self, ax, low, high, threshold=0.075):
MPLAdjutant._set_lims(ax, low, high, which='x', threshold=threshold)
def set_ylim(self, ax, low, high, threshold=0.075):
MPLAdjutant._set_lims(ax, low, high, which='y', threshold=threshold)
def set_xtick_spacing(self, ax, spacing):
ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(spacing))
def set_ytick_spacing(self, ax, spacing):
ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(spacing))
@staticmethod
def set_grids(
ax, minorticks=True, grid=False, bottom=True, left=True, right=True,
top=True
):
if minorticks:
ax.minorticks_on()
ax.tick_params(
which='both', direction='in', bottom=bottom, left=left,
top=top, right=right
)
if grid:
ax.grid(which='minor', alpha=0.2, linestyle=':')
ax.grid(which='major', alpha=0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment