Skip to content

Instantly share code, notes, and snippets.

@Zsailer
Last active February 20, 2018 19:58
Show Gist options
  • Save Zsailer/70d47bedb0529be762f0 to your computer and use it in GitHub Desktop.
Save Zsailer/70d47bedb0529be762f0 to your computer and use it in GitHub Desktop.
Make matplotlib plots pretty and standard
import matplotlib
def prettify(ax, legend_loc=4):
""" A simple wrapper to make matplotlib figures prettier."""
# Change default colors to something softer.
colors = {
'b': '#0066CC',
'r': '#CC0000',
'm': '#660066',
'g': '#009933',
'c': '#009999',
'y': '#FFCC00',
'k': '#333333'
}
extra_limit_frac = 0.05
spine_widths = 1.35
line_widths = 1.5
errorbars = False
# Only prettify the first time.
if hasattr(ax, "prettify") is False:
ax.prettify = True
# Get current axis limits
xlimits = list(ax.get_xlim())
ylimits = list(ax.get_ylim())
xticks = list(ax.get_xticks())
yticks = list(ax.get_yticks())
# Extend the graph by 5 percent on all sides
xextra = extra_limit_frac*(xlimits[1] - xlimits[0])
yextra = extra_limit_frac*(ylimits[1] - ylimits[0])
# set ticks and tick labels
ax.set_xlim(xlimits[0] - xextra, xlimits[1] + xextra)
ax.set_ylim(ylimits[0] - yextra, ylimits[1] + yextra)
# Remove right and top spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# Set the bounds for visible axes
ax.spines['bottom'].set_bounds(xlimits[0], xlimits[1])
ax.spines['left'].set_bounds(ylimits[0], ylimits[1])
# Thicken the spines
ax.spines['bottom'].set_linewidth(spine_widths)
ax.spines['left'].set_linewidth(spine_widths)
# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
# Make ticks face outward and thicken them
ax.tick_params(direction='out', width=spine_widths)
if xticks[-1] > xlimits[1]:
xticks = xticks[:-1]
if yticks[-1] > ylimits[1]:
yticks = yticks[:-1]
ax.set_xticks(xticks)
ax.set_yticks(yticks)
## ----------------------------------------------------
## Styling the data
## ----------------------------------------------------
stuff = ax.get_children()
# If the first child is a collection, errorbars must be included
errorbars = [s for s in stuff if type(s) == matplotlib.collections.LineCollection]
lines = [s for s in stuff if type(s) == matplotlib.lines.Line2D]
# Change data-line colors and widths
line_color = {}
for i in range(len(lines)):
# Get line data
d = lines[i]
# Get color and add it to the line_Color dictionary for errobar reference
color = d.get_color()
line_color[int(i/3)]= color
# Set all line markers and edges to same color
d.set_color(colors[color])
d.set_markerfacecolor(colors[color])
d.set_markeredgecolor(colors[color])
d.set_markerfacecoloralt(colors[color])
d.set_linewidth(line_widths)
# color errorbars with the color from the lines
for i in range(len(errorbars)):
errorbars[i].set_color(colors[line_color[i]])
## ---------------------------
## Styling the legend
## ---------------------------
# If a legend exists, recreate it.
legend = ax.get_legend()
if legend is not None:
# If errorbars,
if len(errorbars) != 0:
# get handles
handles, labels = ax.get_legend_handles_labels()
# remove the errorbars
try:
handles = [h[0] for h in handles]
except TypeError:
handles = [h for h in handles]
# use them in the legend
ax.legend(handles, labels, numpoints=1, frameon=False, loc=legend_loc, fontsize="small")
else:
ax.legend(frameon=False, loc=legend_loc, fontsize="small")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment